diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index 4aa417bb..483871e9 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -70,10 +70,10 @@ func (this *ServerDAO) DisableServer(tx *dbs.Tx, id int64) (err error) { return } -// 查找启用中的条目 -func (this *ServerDAO) FindEnabledServer(tx *dbs.Tx, id int64) (*Server, error) { +// 查找启用中的服务 +func (this *ServerDAO) FindEnabledServer(tx *dbs.Tx, serverId int64) (*Server, error) { result, err := this.Query(tx). - Pk(id). + Pk(serverId). Attr("state", ServerStateEnabled). Find() if result == nil { @@ -82,6 +82,19 @@ func (this *ServerDAO) FindEnabledServer(tx *dbs.Tx, id int64) (*Server, error) return result.(*Server), err } +// 查找服务基本信息 +func (this *ServerDAO) FindEnabledServerBasic(tx *dbs.Tx, serverId int64) (*Server, error) { + result, err := this.Query(tx). + Pk(serverId). + State(ServerStateEnabled). + Result("id", "name", "description", "isOn", "type", "clusterId"). + Find() + if result == nil { + return nil, err + } + return result.(*Server), err +} + // 查找服务类型 func (this *ServerDAO) FindEnabledServerType(tx *dbs.Tx, serverId int64) (string, error) { return this.Query(tx). @@ -232,6 +245,28 @@ func (this *ServerDAO) UpdateServerBasic(tx *dbs.Tx, serverId int64, name string return this.createEvent() } +// 设置用户相关的基本信息 +func (this *ServerDAO) UpdateUserServerBasic(tx *dbs.Tx, serverId int64, name string) error { + if serverId <= 0 { + return errors.New("serverId should not be smaller than 0") + } + op := NewServerOperator() + op.Id = serverId + op.Name = name + + err := this.Save(tx, op) + if err != nil { + return err + } + + _, err = this.RenewServerConfig(tx, serverId, false) + if err != nil { + return err + } + + return this.createEvent() +} + // 修复服务是否启用 func (this *ServerDAO) UpdateServerIsOn(tx *dbs.Tx, serverId int64, isOn bool) error { _, err := this.Query(tx). @@ -558,7 +593,7 @@ func (this *ServerDAO) UpdateServerReverseProxy(tx *dbs.Tx, serverId int64, conf } // 计算所有可用服务数量 -func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag configutils.BoolState) (int64, error) { +func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag configutils.BoolState, protocolFamily string) (int64, error) { query := this.Query(tx). State(ServerStateEnabled) if groupId > 0 { @@ -578,11 +613,16 @@ func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, ke if auditingFlag == configutils.BoolStateYes { query.Attr("isAuditing", true) } + if protocolFamily == "http" { + query.Where("(http IS NOT NULL OR https IS NOT NULL)") + } else if protocolFamily == "tcp" { + query.Where("(tcp IS NOT NULL OR tls IS NOT NULL)") + } return query.Count() } // 列出单页的服务 -func (this *ServerDAO) ListEnabledServersMatch(tx *dbs.Tx, offset int64, size int64, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag int32) (result []*Server, err error) { +func (this *ServerDAO) ListEnabledServersMatch(tx *dbs.Tx, offset int64, size int64, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag int32, protocolFamily string) (result []*Server, err error) { query := this.Query(tx). State(ServerStateEnabled). Offset(offset). @@ -607,6 +647,11 @@ func (this *ServerDAO) ListEnabledServersMatch(tx *dbs.Tx, offset int64, size in if auditingFlag == 1 { query.Attr("isAuditing", true) } + if protocolFamily == "http" { + query.Where("(http IS NOT NULL OR https IS NOT NULL)") + } else if protocolFamily == "tcp" { + query.Where("(tcp IS NOT NULL OR tls IS NOT NULL)") + } _, err = query.FindAll() return @@ -1089,6 +1134,20 @@ func (this *ServerDAO) FindEnabledServerIdWithWebId(tx *dbs.Tx, webId int64) (se FindInt64Col(0) } +// 检查端口是否被使用 +func (this *ServerDAO) CheckPortIsUsing(tx *dbs.Tx, clusterId int64, port int) (bool, error) { + listen := maps.Map{ + "portRange": strconv.Itoa(port), + } + + return this.Query(tx). + Attr("clusterId", clusterId). + State(ServerStateEnabled). + Where("(JSON_CONTAINS(http, :listen) OR JSON_CONTAINS(https, :listen) OR JSON_CONTAINS(tcp, :listen) OR JSON_CONTAINS(tls, :listen))"). + Param(":listen", string(listen.AsJSON())). + Exist() +} + // 生成DNS Name func (this *ServerDAO) genDNSName(tx *dbs.Tx) (string, error) { for { diff --git a/internal/db/models/sys_locker_dao.go b/internal/db/models/sys_locker_dao.go index caaed9f9..65a3101e 100644 --- a/internal/db/models/sys_locker_dao.go +++ b/internal/db/models/sys_locker_dao.go @@ -30,7 +30,7 @@ func init() { } // 开锁 -func (this *SysLockerDAO) Lock(tx *dbs.Tx, key string, timeout int64) (bool, error) { +func (this *SysLockerDAO) Lock(tx *dbs.Tx, key string, timeout int64) (ok bool, err error) { maxErrors := 5 for { one, err := this.Query(tx). diff --git a/internal/db/models/user_features.go b/internal/db/models/user_features.go index 779684dc..3bbfe6c5 100644 --- a/internal/db/models/user_features.go +++ b/internal/db/models/user_features.go @@ -15,11 +15,21 @@ var ( Code: "server.accessLog.forward", Description: "用户可以配置访问日志转发到自定义的API", }, + { + Name: "负载均衡", + Code: "server.tcp", + Description: "用户可以添加TCP/TLS负载均衡服务", + }, { Name: "开启WAF", Code: "server.waf", Description: "用户可以开启WAF功能并可以设置黑白名单等", }, + { + Name: "费用账单", + Code: "finance", + Description: "开启费用账单相关功能", + }, } ) diff --git a/internal/nodes/api_node.go b/internal/nodes/api_node.go index c785d332..45c91c49 100644 --- a/internal/nodes/api_node.go +++ b/internal/nodes/api_node.go @@ -223,6 +223,7 @@ func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) err pb.RegisterUserNodeServiceServer(rpcServer, &services.UserNodeService{}) pb.RegisterLoginServiceServer(rpcServer, &services.LoginService{}) pb.RegisterUserAccessKeyServiceServer(rpcServer, &services.UserAccessKeyService{}) + pb.RegisterSysLockerServiceServer(rpcServer, &services.SysLockerService{}) err := rpcServer.Serve(listener) if err != nil { return errors.New("[API_NODE]start rpc failed: " + err.Error()) diff --git a/internal/rpc/services/service_node_cluster.go b/internal/rpc/services/service_node_cluster.go index 24f50c85..58c270bf 100644 --- a/internal/rpc/services/service_node_cluster.go +++ b/internal/rpc/services/service_node_cluster.go @@ -10,7 +10,9 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/tasks" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/maps" + "github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/types" "strconv" ) @@ -845,3 +847,61 @@ func (this *NodeClusterService) FindNodeClusterSystemService(ctx context.Context } return &pb.FindNodeClusterSystemServiceResponse{ParamsJSON: paramsJSON}, nil } + +// 获取集群中可以使用的端口 +func (this *NodeClusterService) FindFreePortInNodeCluster(ctx context.Context, req *pb.FindFreePortInNodeClusterRequest) (*pb.FindFreePortInNodeClusterResponse, error) { + _, _, err := this.ValidateAdminAndUser(ctx, 0, 0) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + globalConfig, err := models.SharedSysSettingDAO.ReadGlobalConfig(tx) + if err != nil { + return nil, err + } + + // 检查端口 + portMin := globalConfig.TCPAll.PortRangeMin + portMax := globalConfig.TCPAll.PortRangeMax + denyPorts := globalConfig.TCPAll.DenyPorts + + if portMin == 0 && portMax == 0 { + portMin = 10_000 + portMax = 40_000 + } + if portMin < 1024 { + portMin = 10_000 + } + if portMin > 65534 { + portMin = 65534 + } + if portMax < 1024 { + portMax = 30_000 + } + if portMax > 65534 { + portMax = 65534 + } + + if portMin > portMax { + portMax, portMin = portMin, portMax + } + + // 最多尝试N次 + for i := 0; i < 60; i++ { + port := rands.Int(portMin, portMax) + if len(denyPorts) > 0 && lists.ContainsInt(denyPorts, port) { + continue + } + + isUsing, err := models.SharedServerDAO.CheckPortIsUsing(tx, req.NodeClusterId, port) + if err != nil { + return nil, err + } + if !isUsing { + return &pb.FindFreePortInNodeClusterResponse{Port: int32(port)}, nil + } + } + + return nil, errors.New("can not find random port") +} diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index c5bd1768..f4119469 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -43,6 +43,21 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe } } } + + // TLS + if len(req.TlsJSON) > 0 { + tlsConfig := &serverconfigs.TLSProtocolConfig{} + err = json.Unmarshal(req.TlsJSON, tlsConfig) + if err != nil { + return nil, err + } + if tlsConfig.SSLPolicyRef != nil && tlsConfig.SSLPolicyRef.SSLPolicyId > 0 { + err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, tlsConfig.SSLPolicyRef.SSLPolicyId, userId) + if err != nil { + return nil, err + } + } + } } // 是否需要审核 @@ -50,14 +65,17 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe serverNamesJSON := req.ServerNamesJON auditingServerNamesJSON := []byte("[]") if userId > 0 { - globalConfig, err := models.SharedSysSettingDAO.ReadGlobalConfig(tx) - if err != nil { - return nil, err - } - if globalConfig != nil && globalConfig.HTTPAll.DomainAuditingIsOn { - isAuditing = true - serverNamesJSON = []byte("[]") - auditingServerNamesJSON = req.ServerNamesJON + // 如果域名不为空的时候需要审核 + if len(serverNamesJSON) > 0 && string(serverNamesJSON) != "[]" { + globalConfig, err := models.SharedSysSettingDAO.ReadGlobalConfig(tx) + if err != nil { + return nil, err + } + if globalConfig != nil && globalConfig.HTTPAll.DomainAuditingIsOn { + isAuditing = true + serverNamesJSON = []byte("[]") + auditingServerNamesJSON = req.ServerNamesJON + } } } @@ -487,7 +505,7 @@ func (this *ServerService) CountAllEnabledServersMatch(ctx context.Context, req tx := this.NullTx() - count, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, req.GroupId, req.Keyword, req.UserId, req.ClusterId, types.Int8(req.AuditingFlag)) + count, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, req.GroupId, req.Keyword, req.UserId, req.ClusterId, types.Int8(req.AuditingFlag), req.ProtocolFamily) if err != nil { return nil, err } @@ -505,7 +523,7 @@ func (this *ServerService) ListEnabledServersMatch(ctx context.Context, req *pb. tx := this.NullTx() - servers, err := models.SharedServerDAO.ListEnabledServersMatch(tx, req.Offset, req.Size, req.GroupId, req.Keyword, req.UserId, req.ClusterId, req.AuditingFlag) + servers, err := models.SharedServerDAO.ListEnabledServersMatch(tx, req.Offset, req.Size, req.GroupId, req.Keyword, req.UserId, req.ClusterId, req.AuditingFlag, req.ProtocolFamily) if err != nil { return nil, err } @@ -1188,3 +1206,60 @@ func (this *ServerService) FindAllEnabledServerNamesWithUserId(ctx context.Conte } return &pb.FindAllEnabledServerNamesWithUserIdResponse{ServerNames: serverNames}, nil } + +// 查找服务基本信息 +func (this *ServerService) FindEnabledUserServerBasic(ctx context.Context, req *pb.FindEnabledUserServerBasicRequest) (*pb.FindEnabledUserServerBasicResponse, error) { + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + + if userId > 0 { + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + if err != nil { + return nil, err + } + } + + server, err := models.SharedServerDAO.FindEnabledServerBasic(tx, req.ServerId) + if err != nil { + return nil, err + } + if server == nil { + return &pb.FindEnabledUserServerBasicResponse{Server: nil}, nil + } + + return &pb.FindEnabledUserServerBasicResponse{Server: &pb.Server{ + Id: int64(server.Id), + Name: server.Name, + Description: server.Description, + IsOn: server.IsOn == 1, + Type: server.Type, + }}, nil +} + +// 修改用户服务基本信息 +func (this *ServerService) UpdateEnabledUserServerBasic(ctx context.Context, req *pb.UpdateEnabledUserServerBasicRequest) (*pb.RPCSuccess, error) { + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + + if userId > 0 { + err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + if err != nil { + return nil, err + } + } + + err = models.SharedServerDAO.UpdateUserServerBasic(tx, req.ServerId, req.Name) + if err != nil { + return nil, err + } + + return this.Success() +} diff --git a/internal/rpc/services/service_sys_locker.go b/internal/rpc/services/service_sys_locker.go new file mode 100644 index 00000000..23f6c8ad --- /dev/null +++ b/internal/rpc/services/service_sys_locker.go @@ -0,0 +1,58 @@ +package services + +import ( + "context" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" +) + +// 互斥锁管理 +type SysLockerService struct { + BaseService +} + +// 获得锁 +func (this *SysLockerService) SysLockerLock(ctx context.Context, req *pb.SysLockerLockRequest) (*pb.SysLockerLockResponse, error) { + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) + if err != nil { + return nil, err + } + + key := req.Key + if userId > 0 { + key = "@user" // 这里不加入用户ID,防止多个用户间冲突 + } + + timeout := req.TimeoutSeconds + if timeout <= 0 { + timeout = 60 + } else if timeout > 86400 { // 最多不能超过1天 + timeout = 86400 + } + + var tx = this.NullTx() + ok, err := models.SharedSysLockerDAO.Lock(tx, key, timeout) + if err != nil { + return nil, err + } + return &pb.SysLockerLockResponse{Ok: ok}, nil +} + +// 释放锁 +func (this *SysLockerService) SysLockerUnlock(ctx context.Context, req *pb.SysLockerUnlockRequest) (*pb.RPCSuccess, error) { + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) + if err != nil { + return nil, err + } + + key := req.Key + if userId > 0 { + key = "@user" + } + var tx = this.NullTx() + err = models.SharedSysLockerDAO.Unlock(tx, key) + if err != nil { + return nil, err + } + return this.Success() +} diff --git a/internal/rpc/services/service_user.go b/internal/rpc/services/service_user.go index 18d40b53..94cc92ea 100644 --- a/internal/rpc/services/service_user.go +++ b/internal/rpc/services/service_user.go @@ -306,7 +306,7 @@ func (this *UserService) ComposeUserDashboard(ctx context.Context, req *pb.Compo tx := this.NullTx() // 网站数量 - countServers, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, 0, "", req.UserId, 0, configutils.BoolStateAll) + countServers, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, 0, "", req.UserId, 0, configutils.BoolStateAll, "") if err != nil { return nil, err }