diff --git a/internal/db/models/log_dao.go b/internal/db/models/log_dao.go index 731f3c24..5c101962 100644 --- a/internal/db/models/log_dao.go +++ b/internal/db/models/log_dao.go @@ -36,8 +36,21 @@ func init() { // 创建管理员日志 func (this *LogDAO) CreateLog(adminType string, adminId int64, level string, description string, action string, ip string) error { op := NewLogOperator() + op.Level = level + op.Description = description + op.Action = action + op.Ip = ip op.Type = adminType - op.AdminId, op.Level, op.Description, op.Action, op.Ip = adminId, level, description, action, ip + + switch adminType { + case "admin": + op.AdminId = adminId + case "user": + op.UserId = adminId + case "provider": + op.ProviderId = adminId + } + op.Day = timeutil.Format("Ymd") op.Type = LogTypeAdmin err := this.Save(op) diff --git a/internal/db/models/user_dao.go b/internal/db/models/user_dao.go index 0d7a0802..5807edbd 100644 --- a/internal/db/models/user_dao.go +++ b/internal/db/models/user_dao.go @@ -179,3 +179,17 @@ func (this *UserDAO) ListEnabledUserIds(offset, size int64) ([]int64, error) { } return result, nil } + +// 检查用户名、密码 +func (this *UserDAO) CheckUserPassword(username string, encryptedPassword string) (int64, error) { + if len(username) == 0 || len(encryptedPassword) == 0 { + return 0, nil + } + return this.Query(). + Attr("username", username). + Attr("password", encryptedPassword). + Attr("state", UserStateEnabled). + Attr("isOn", true). + ResultPk(). + FindInt64Col(0) +} diff --git a/internal/db/models/user_node_dao.go b/internal/db/models/user_node_dao.go new file mode 100644 index 00000000..708eb90e --- /dev/null +++ b/internal/db/models/user_node_dao.go @@ -0,0 +1,234 @@ +package models + +import ( + "encoding/json" + "github.com/TeaOSLab/EdgeAPI/internal/errors" + _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/maps" + "github.com/iwind/TeaGo/rands" + "github.com/iwind/TeaGo/types" + "strconv" +) + +const ( + UserNodeStateEnabled = 1 // 已启用 + UserNodeStateDisabled = 0 // 已禁用 +) + +type UserNodeDAO dbs.DAO + +func NewUserNodeDAO() *UserNodeDAO { + return dbs.NewDAO(&UserNodeDAO{ + DAOObject: dbs.DAOObject{ + DB: Tea.Env, + Table: "edgeUserNodes", + Model: new(UserNode), + PkName: "id", + }, + }).(*UserNodeDAO) +} + +var SharedUserNodeDAO *UserNodeDAO + +func init() { + dbs.OnReady(func() { + SharedUserNodeDAO = NewUserNodeDAO() + }) +} + +// 启用条目 +func (this *UserNodeDAO) EnableUserNode(id uint32) error { + _, err := this.Query(). + Pk(id). + Set("state", UserNodeStateEnabled). + Update() + return err +} + +// 禁用条目 +func (this *UserNodeDAO) DisableUserNode(id int64) error { + _, err := this.Query(). + Pk(id). + Set("state", UserNodeStateDisabled). + Update() + return err +} + +// 查找启用中的条目 +func (this *UserNodeDAO) FindEnabledUserNode(id int64) (*UserNode, error) { + result, err := this.Query(). + Pk(id). + Attr("state", UserNodeStateEnabled). + Find() + if result == nil { + return nil, err + } + return result.(*UserNode), err +} + +// 根据主键查找名称 +func (this *UserNodeDAO) FindUserNodeName(id int64) (string, error) { + return this.Query(). + Pk(id). + Result("name"). + FindStringCol("") +} + +// 列出所有可用用户节点 +func (this *UserNodeDAO) FindAllEnabledUserNodes() (result []*UserNode, err error) { + _, err = this.Query(). + State(UserNodeStateEnabled). + Desc("order"). + AscPk(). + Slice(&result). + FindAll() + return +} + +// 计算用户节点数量 +func (this *UserNodeDAO) CountAllEnabledUserNodes() (int64, error) { + return this.Query(). + State(UserNodeStateEnabled). + Count() +} + +// 列出单页的用户节点 +func (this *UserNodeDAO) ListEnabledUserNodes(offset int64, size int64) (result []*UserNode, err error) { + _, err = this.Query(). + State(UserNodeStateEnabled). + Offset(offset). + Limit(size). + Desc("order"). + DescPk(). + Slice(&result). + FindAll() + return +} + +// 根据主机名和端口获取ID +func (this *UserNodeDAO) FindEnabledUserNodeIdWithAddr(protocol string, host string, port int) (int64, error) { + addr := maps.Map{ + "protocol": protocol, + "host": host, + "portRange": strconv.Itoa(port), + } + addrJSON, err := json.Marshal(addr) + if err != nil { + return 0, err + } + + one, err := this.Query(). + State(UserNodeStateEnabled). + Where("JSON_CONTAINS(accessAddrs, :addr)"). + Param("addr", string(addrJSON)). + ResultPk(). + Find() + if err != nil { + return 0, err + } + if one == nil { + return 0, nil + } + return int64(one.(*UserNode).Id), nil +} + +// 创建用户节点 +func (this *UserNodeDAO) CreateUserNode(name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) { + uniqueId, err := this.genUniqueId() + if err != nil { + return 0, err + } + secret := rands.String(32) + err = NewApiTokenDAO().CreateAPIToken(uniqueId, secret, NodeRoleUser) + if err != nil { + return + } + + op := NewUserNodeOperator() + op.IsOn = isOn + op.UniqueId = uniqueId + op.Secret = secret + op.Name = name + op.Description = description + + if len(httpJSON) > 0 { + op.Http = httpJSON + } + if len(httpsJSON) > 0 { + op.Https = httpsJSON + } + if len(accessAddrsJSON) > 0 { + op.AccessAddrs = accessAddrsJSON + } + + op.State = NodeStateEnabled + err = this.Save(op) + if err != nil { + return + } + + return types.Int64(op.Id), nil +} + +// 修改用户节点 +func (this *UserNodeDAO) UpdateUserNode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) error { + if nodeId <= 0 { + return errors.New("invalid nodeId") + } + + op := NewUserNodeOperator() + op.Id = nodeId + op.Name = name + op.Description = description + op.IsOn = isOn + + if len(httpJSON) > 0 { + op.Http = httpJSON + } else { + op.Http = "null" + } + if len(httpsJSON) > 0 { + op.Https = httpsJSON + } else { + op.Https = "null" + } + if len(accessAddrsJSON) > 0 { + op.AccessAddrs = accessAddrsJSON + } else { + op.AccessAddrs = "null" + } + + err := this.Save(op) + return err +} + +// 根据唯一ID获取节点信息 +func (this *UserNodeDAO) FindEnabledUserNodeWithUniqueId(uniqueId string) (*UserNode, error) { + result, err := this.Query(). + Attr("uniqueId", uniqueId). + Attr("state", UserNodeStateEnabled). + Find() + if result == nil { + return nil, err + } + return result.(*UserNode), err +} + +// 生成唯一ID +func (this *UserNodeDAO) genUniqueId() (string, error) { + for { + uniqueId := rands.HexString(32) + ok, err := this.Query(). + Attr("uniqueId", uniqueId). + Exist() + if err != nil { + return "", err + } + if ok { + continue + } + return uniqueId, nil + } +} diff --git a/internal/db/models/user_node_dao_test.go b/internal/db/models/user_node_dao_test.go new file mode 100644 index 00000000..97c24b56 --- /dev/null +++ b/internal/db/models/user_node_dao_test.go @@ -0,0 +1,5 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" +) diff --git a/internal/db/models/user_node_model.go b/internal/db/models/user_node_model.go new file mode 100644 index 00000000..1b42e27d --- /dev/null +++ b/internal/db/models/user_node_model.go @@ -0,0 +1,42 @@ +package models + +// API节点 +type UserNode struct { + Id uint32 `field:"id"` // ID + IsOn uint8 `field:"isOn"` // 是否启用 + UniqueId string `field:"uniqueId"` // 唯一ID + Secret string `field:"secret"` // 密钥 + Name string `field:"name"` // 名称 + Description string `field:"description"` // 描述 + Http string `field:"http"` // 监听的HTTP配置 + Https string `field:"https"` // 监听的HTTPS配置 + AccessAddrs string `field:"accessAddrs"` // 外部访问地址 + Order uint32 `field:"order"` // 排序 + State uint8 `field:"state"` // 状态 + CreatedAt uint64 `field:"createdAt"` // 创建时间 + AdminId uint32 `field:"adminId"` // 管理员ID + Weight uint32 `field:"weight"` // 权重 + Status string `field:"status"` // 运行状态 +} + +type UserNodeOperator struct { + Id interface{} // ID + IsOn interface{} // 是否启用 + UniqueId interface{} // 唯一ID + Secret interface{} // 密钥 + Name interface{} // 名称 + Description interface{} // 描述 + Http interface{} // 监听的HTTP配置 + Https interface{} // 监听的HTTPS配置 + AccessAddrs interface{} // 外部访问地址 + Order interface{} // 排序 + State interface{} // 状态 + CreatedAt interface{} // 创建时间 + AdminId interface{} // 管理员ID + Weight interface{} // 权重 + Status interface{} // 运行状态 +} + +func NewUserNodeOperator() *UserNodeOperator { + return &UserNodeOperator{} +} diff --git a/internal/db/models/user_node_model_ext.go b/internal/db/models/user_node_model_ext.go new file mode 100644 index 00000000..81ecb2aa --- /dev/null +++ b/internal/db/models/user_node_model_ext.go @@ -0,0 +1,95 @@ +package models + +import ( + "encoding/json" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" +) + +// 解析HTTP配置 +func (this *UserNode) DecodeHTTP() (*serverconfigs.HTTPProtocolConfig, error) { + if !IsNotNull(this.Http) { + return nil, nil + } + config := &serverconfigs.HTTPProtocolConfig{} + err := json.Unmarshal([]byte(this.Http), config) + if err != nil { + return nil, err + } + + err = config.Init() + if err != nil { + return nil, err + } + + return config, nil +} + +// 解析HTTPS配置 +func (this *UserNode) DecodeHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) { + if !IsNotNull(this.Https) { + return nil, nil + } + config := &serverconfigs.HTTPSProtocolConfig{} + err := json.Unmarshal([]byte(this.Https), config) + if err != nil { + return nil, err + } + + err = config.Init() + if err != nil { + return nil, err + } + + if config.SSLPolicyRef != nil { + policyId := config.SSLPolicyRef.SSLPolicyId + if policyId > 0 { + sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(policyId) + if err != nil { + return nil, err + } + if sslPolicy != nil { + config.SSLPolicy = sslPolicy + } + } + } + + err = config.Init() + if err != nil { + return nil, err + } + + return config, nil +} + +// 解析访问地址 +func (this *UserNode) DecodeAccessAddrs() ([]*serverconfigs.NetworkAddressConfig, error) { + if !IsNotNull(this.AccessAddrs) { + return nil, nil + } + + addrConfigs := []*serverconfigs.NetworkAddressConfig{} + err := json.Unmarshal([]byte(this.AccessAddrs), &addrConfigs) + if err != nil { + return nil, err + } + for _, addrConfig := range addrConfigs { + err = addrConfig.Init() + if err != nil { + return nil, err + } + } + return addrConfigs, nil +} + +// 解析访问地址,并返回字符串形式 +func (this *UserNode) DecodeAccessAddrStrings() ([]string, error) { + addrs, err := this.DecodeAccessAddrs() + if err != nil { + return nil, err + } + result := []string{} + for _, addr := range addrs { + result = append(result, addr.FullAddresses()...) + } + return result, nil +} diff --git a/internal/nodes/api_node.go b/internal/nodes/api_node.go index 071b534d..ac95447f 100644 --- a/internal/nodes/api_node.go +++ b/internal/nodes/api_node.go @@ -205,6 +205,7 @@ func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) err pb.RegisterUserServiceServer(rpcServer, &services.UserService{}) pb.RegisterServerDailyStatServiceServer(rpcServer, &services.ServerDailyStatService{}) pb.RegisterUserBillServiceServer(rpcServer, &services.UserBillService{}) + pb.RegisterUserNodeServiceServer(rpcServer, &services.UserNodeService{}) err := rpcServer.Serve(listener) if err != nil { return errors.New("[API_NODE]start rpc failed: " + err.Error()) diff --git a/internal/rpc/services/service_admin.go b/internal/rpc/services/service_admin.go index 97ea7ee2..92608aaf 100644 --- a/internal/rpc/services/service_admin.go +++ b/internal/rpc/services/service_admin.go @@ -47,7 +47,7 @@ func (this *AdminService) LoginAdmin(ctx context.Context, req *pb.LoginAdminRequ } return &pb.LoginAdminResponse{ - AdminId: int64(adminId), + AdminId: adminId, IsOk: true, }, nil } diff --git a/internal/rpc/services/service_base.go b/internal/rpc/services/service_base.go index 69e8496e..5265b94f 100644 --- a/internal/rpc/services/service_base.go +++ b/internal/rpc/services/service_base.go @@ -57,12 +57,18 @@ func (this *BaseService) ValidateAdminAndUser(ctx context.Context, reqUserId int return } -// 校验节点 +// 校验边缘节点 func (this *BaseService) ValidateNode(ctx context.Context) (nodeId int64, err error) { _, nodeId, err = rpcutils.ValidateRequest(ctx, rpcutils.UserTypeNode) return } +// 校验用户节点 +func (this *BaseService) ValidateUser(ctx context.Context) (userId int64, err error) { + _, userId, err = rpcutils.ValidateRequest(ctx, rpcutils.UserTypeUser) + return +} + // 返回成功 func (this *BaseService) Success() (*pb.RPCSuccess, error) { return &pb.RPCSuccess{}, nil diff --git a/internal/rpc/services/service_ssl_policy.go b/internal/rpc/services/service_ssl_policy.go index c84709e5..d55b7197 100644 --- a/internal/rpc/services/service_ssl_policy.go +++ b/internal/rpc/services/service_ssl_policy.go @@ -47,7 +47,8 @@ func (this *SSLPolicyService) UpdateSSLPolicy(ctx context.Context, req *pb.Updat // 查找Policy func (this *SSLPolicyService) FindEnabledSSLPolicyConfig(ctx context.Context, req *pb.FindEnabledSSLPolicyConfigRequest) (*pb.FindEnabledSSLPolicyConfigResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + // 这里不使用validateAdminAndUser(),是因为我们允许用户ID为0的时候也可以调用 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_sys_setting.go b/internal/rpc/services/service_sys_setting.go index fcdd6e67..fd691b1a 100644 --- a/internal/rpc/services/service_sys_setting.go +++ b/internal/rpc/services/service_sys_setting.go @@ -14,7 +14,7 @@ type SysSettingService struct { // 更改配置 func (this *SysSettingService) UpdateSysSetting(ctx context.Context, req *pb.UpdateSysSettingRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser) if err != nil { return nil, err } @@ -30,7 +30,7 @@ func (this *SysSettingService) UpdateSysSetting(ctx context.Context, req *pb.Upd // 读取配置 func (this *SysSettingService) ReadSysSetting(ctx context.Context, req *pb.ReadSysSettingRequest) (*pb.ReadSysSettingResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_user.go b/internal/rpc/services/service_user.go index dd899b50..487ac08d 100644 --- a/internal/rpc/services/service_user.go +++ b/internal/rpc/services/service_user.go @@ -3,6 +3,8 @@ package services import ( "context" "github.com/TeaOSLab/EdgeAPI/internal/db/models" + rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" ) @@ -137,3 +139,38 @@ func (this *UserService) CheckUsername(ctx context.Context, req *pb.CheckUsernam } return &pb.CheckUsernameResponse{Exists: b}, nil } + +// 登录 +func (this *UserService) LoginUser(ctx context.Context, req *pb.LoginUserRequest) (*pb.LoginUserResponse, error) { + _, _, err := rpcutils.ValidateRequest(ctx) + if err != nil { + return nil, err + } + + if len(req.Username) == 0 || len(req.Password) == 0 { + return &pb.LoginUserResponse{ + UserId: 0, + IsOk: false, + Message: "请输入正确的用户名密码", + }, nil + } + + userId, err := models.SharedUserDAO.CheckUserPassword(req.Username, req.Password) + if err != nil { + utils.PrintError(err) + return nil, err + } + + if userId <= 0 { + return &pb.LoginUserResponse{ + UserId: 0, + IsOk: false, + Message: "请输入正确的用户名密码", + }, nil + } + + return &pb.LoginUserResponse{ + UserId: userId, + IsOk: true, + }, nil +} diff --git a/internal/rpc/services/service_user_node.go b/internal/rpc/services/service_user_node.go new file mode 100644 index 00000000..fa6b4bcd --- /dev/null +++ b/internal/rpc/services/service_user_node.go @@ -0,0 +1,227 @@ +package services + +import ( + "context" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeAPI/internal/errors" + rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "google.golang.org/grpc/metadata" +) + +type UserNodeService struct { + BaseService +} + +// 创建用户节点 +func (this *UserNodeService) CreateUserNode(ctx context.Context, req *pb.CreateUserNodeRequest) (*pb.CreateUserNodeResponse, error) { + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + nodeId, err := models.SharedUserNodeDAO.CreateUserNode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn) + if err != nil { + return nil, err + } + + return &pb.CreateUserNodeResponse{NodeId: nodeId}, nil +} + +// 修改用户节点 +func (this *UserNodeService) UpdateUserNode(ctx context.Context, req *pb.UpdateUserNodeRequest) (*pb.RPCSuccess, error) { + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedUserNodeDAO.UpdateUserNode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn) + if err != nil { + return nil, err + } + + return this.Success() +} + +// 删除用户节点 +func (this *UserNodeService) DeleteUserNode(ctx context.Context, req *pb.DeleteUserNodeRequest) (*pb.RPCSuccess, error) { + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedUserNodeDAO.DisableUserNode(req.NodeId) + if err != nil { + return nil, err + } + + return this.Success() +} + +// 列出所有可用用户节点 +func (this *UserNodeService) FindAllEnabledUserNodes(ctx context.Context, req *pb.FindAllEnabledUserNodesRequest) (*pb.FindAllEnabledUserNodesResponse, error) { + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + nodes, err := models.SharedUserNodeDAO.FindAllEnabledUserNodes() + if err != nil { + return nil, err + } + + result := []*pb.UserNode{} + for _, node := range nodes { + accessAddrs, err := node.DecodeAccessAddrStrings() + if err != nil { + return nil, err + } + + result = append(result, &pb.UserNode{ + Id: int64(node.Id), + IsOn: node.IsOn == 1, + UniqueId: node.UniqueId, + Secret: node.Secret, + Name: node.Name, + Description: node.Description, + HttpJSON: []byte(node.Http), + HttpsJSON: []byte(node.Https), + AccessAddrsJSON: []byte(node.AccessAddrs), + AccessAddrs: accessAddrs, + }) + } + + return &pb.FindAllEnabledUserNodesResponse{Nodes: result}, nil +} + +// 计算用户节点数量 +func (this *UserNodeService) CountAllEnabledUserNodes(ctx context.Context, req *pb.CountAllEnabledUserNodesRequest) (*pb.RPCCountResponse, error) { + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + count, err := models.SharedUserNodeDAO.CountAllEnabledUserNodes() + if err != nil { + return nil, err + } + + return this.SuccessCount(count) +} + +// 列出单页的用户节点 +func (this *UserNodeService) ListEnabledUserNodes(ctx context.Context, req *pb.ListEnabledUserNodesRequest) (*pb.ListEnabledUserNodesResponse, error) { + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + nodes, err := models.SharedUserNodeDAO.ListEnabledUserNodes(req.Offset, req.Size) + if err != nil { + return nil, err + } + + result := []*pb.UserNode{} + for _, node := range nodes { + accessAddrs, err := node.DecodeAccessAddrStrings() + if err != nil { + return nil, err + } + + result = append(result, &pb.UserNode{ + Id: int64(node.Id), + IsOn: node.IsOn == 1, + UniqueId: node.UniqueId, + Secret: node.Secret, + Name: node.Name, + Description: node.Description, + HttpJSON: []byte(node.Http), + HttpsJSON: []byte(node.Https), + AccessAddrsJSON: []byte(node.AccessAddrs), + AccessAddrs: accessAddrs, + }) + } + + return &pb.ListEnabledUserNodesResponse{Nodes: result}, nil +} + +// 根据ID查找节点 +func (this *UserNodeService) FindEnabledUserNode(ctx context.Context, req *pb.FindEnabledUserNodeRequest) (*pb.FindEnabledUserNodeResponse, error) { + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + node, err := models.SharedUserNodeDAO.FindEnabledUserNode(req.NodeId) + if err != nil { + return nil, err + } + + if node == nil { + return &pb.FindEnabledUserNodeResponse{Node: nil}, nil + } + + accessAddrs, err := node.DecodeAccessAddrStrings() + if err != nil { + return nil, err + } + + result := &pb.UserNode{ + Id: int64(node.Id), + IsOn: node.IsOn == 1, + UniqueId: node.UniqueId, + Secret: node.Secret, + Name: node.Name, + Description: node.Description, + HttpJSON: []byte(node.Http), + HttpsJSON: []byte(node.Https), + AccessAddrsJSON: []byte(node.AccessAddrs), + AccessAddrs: accessAddrs, + } + return &pb.FindEnabledUserNodeResponse{Node: result}, nil +} + +// 获取当前用户节点的版本 +func (this *UserNodeService) FindCurrentUserNode(ctx context.Context, req *pb.FindCurrentUserNodeRequest) (*pb.FindCurrentUserNodeResponse, error) { + _, err := this.ValidateUser(ctx) + if err != nil { + return nil, err + } + + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, errors.New("context: need 'nodeId'") + } + nodeIds := md.Get("nodeid") + if len(nodeIds) == 0 { + return nil, errors.New("invalid 'nodeId'") + } + nodeId := nodeIds[0] + node, err := models.SharedUserNodeDAO.FindEnabledUserNodeWithUniqueId(nodeId) + if err != nil { + return nil, err + } + + if node == nil { + return &pb.FindCurrentUserNodeResponse{Node: nil}, nil + } + + accessAddrs, err := node.DecodeAccessAddrStrings() + if err != nil { + return nil, err + } + + result := &pb.UserNode{ + Id: int64(node.Id), + IsOn: node.IsOn == 1, + UniqueId: node.UniqueId, + Secret: node.Secret, + Name: node.Name, + Description: node.Description, + HttpJSON: []byte(node.Http), + HttpsJSON: []byte(node.Https), + AccessAddrsJSON: []byte(node.AccessAddrs), + AccessAddrs: accessAddrs, + } + return &pb.FindCurrentUserNodeResponse{Node: result}, nil +} diff --git a/internal/rpc/utils/utils.go b/internal/rpc/utils/utils.go index f0ef85ba..be2b58ce 100644 --- a/internal/rpc/utils/utils.go +++ b/internal/rpc/utils/utils.go @@ -115,6 +115,7 @@ func ValidateRequest(ctx context.Context, userTypes ...UserType) (userType UserT return UserTypeCluster, 0, errors.New("context: not found cluster with id '" + nodeId + "'") } nodeUserId = clusterId + case UserTypeUser: } if nodeUserId > 0 { diff --git a/internal/setup/sql_data.go b/internal/setup/sql_data.go index b2d7b60c..58a7aa85 100644 --- a/internal/setup/sql_data.go +++ b/internal/setup/sql_data.go @@ -4,6 +4,7 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/acme" "github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/types" stringutil "github.com/iwind/TeaGo/utils/string" ) @@ -20,6 +21,9 @@ var upgradeFuncs = []*upgradeVersion{ { "0.0.5", upgradeV0_0_5, }, + { + "0.0.6", upgradeV0_0_6, + }, } // 升级SQL数据 @@ -94,7 +98,7 @@ func upgradeV0_0_3(db *dbs.DB) error { return nil } -// v0.0.4 +// v0.0.5 func upgradeV0_0_5(db *dbs.DB) error { // 升级edgeACMETasks _, err := db.Exec("UPDATE edgeACMETasks SET authType=? WHERE authType IS NULL OR LENGTH(authType)=0", acme.AuthTypeDNS) @@ -104,3 +108,31 @@ func upgradeV0_0_5(db *dbs.DB) error { return nil } + +// v0.0.6 +func upgradeV0_0_6(db *dbs.DB) error { + stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeAPITokens WHERE role='user'") + if err != nil { + return err + } + defer func() { + _ = stmt.Close() + }() + col, err := stmt.FindCol(0) + if err != nil { + return err + } + count := types.Int(col) + if count > 0 { + return nil + } + + nodeId := rands.HexString(32) + secret := rands.String(32) + _, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role) VALUES (?, ?, ?)", nodeId, secret, "user") + if err != nil { + return err + } + + return nil +} diff --git a/internal/setup/sql_executor.go b/internal/setup/sql_executor.go index 507722ca..34a1d06c 100644 --- a/internal/setup/sql_executor.go +++ b/internal/setup/sql_executor.go @@ -77,6 +77,12 @@ func (this *SQLExecutor) checkData(db *dbs.DB) error { return err } + // 检查用户平台节点 + err = this.checkUserNode(db) + if err != nil { + return err + } + // 检查集群配置 err = this.checkCluster(db) if err != nil { @@ -143,6 +149,34 @@ func (this *SQLExecutor) checkAdminNode(db *dbs.DB) error { return nil } +// 检查用户平台节点 +func (this *SQLExecutor) checkUserNode(db *dbs.DB) error { + stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeAPITokens WHERE role='user'") + if err != nil { + return err + } + defer func() { + _ = stmt.Close() + }() + col, err := stmt.FindCol(0) + if err != nil { + return err + } + count := types.Int(col) + if count > 0 { + return nil + } + + nodeId := rands.HexString(32) + secret := rands.String(32) + _, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role) VALUES (?, ?, ?)", nodeId, secret, "user") + if err != nil { + return err + } + + return nil +} + // 检查集群配置 func (this *SQLExecutor) checkCluster(db *dbs.DB) error { /// 检查是否有集群数字