diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index add479c4..bf5f608f 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -652,7 +652,7 @@ func (this *ServerDAO) CountAllEnabledServers(tx *dbs.Tx) (int64, error) { } // CountAllEnabledServersMatch 计算所有可用服务数量 -func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag configutils.BoolState, protocolFamily string) (int64, error) { +func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag configutils.BoolState, protocolFamilies []string) (int64, error) { query := this.Query(tx). State(ServerStateEnabled) if groupId > 0 { @@ -678,16 +678,27 @@ func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, ke if auditingFlag == configutils.BoolStateYes { query.Attr("isAuditing", true) } - if protocolFamily == "http" { - query.Where("(http IS NOT NULL OR https IS NOT NULL)") - } else if protocolFamily == "tcp" { - query.Where("(tcp IS NOT NULL OR tls IS NOT NULL)") + + var protocolConds = []string{} + for _, family := range protocolFamilies { + switch family { + case "http": + protocolConds = append(protocolConds, "(http IS NOT NULL OR https IS NOT NULL)") + case "tcp": + protocolConds = append(protocolConds, "(tcp IS NOT NULL OR tls IS NOT NULL)") + case "udp": + protocolConds = append(protocolConds, "(udp IS NOT NULL)") + } } + if len(protocolConds) > 0 { + query.Where("(" + strings.Join(protocolConds, " OR ") + ")") + } + return query.Count() } // ListEnabledServersMatch 列出单页的服务 -func (this *ServerDAO) ListEnabledServersMatch(tx *dbs.Tx, offset int64, size int64, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag int32, protocolFamily string) (result []*Server, err error) { +func (this *ServerDAO) ListEnabledServersMatch(tx *dbs.Tx, offset int64, size int64, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag int32, protocolFamilies []string) (result []*Server, err error) { query := this.Query(tx). State(ServerStateEnabled). Offset(offset). @@ -718,10 +729,19 @@ func (this *ServerDAO) ListEnabledServersMatch(tx *dbs.Tx, offset int64, size in if auditingFlag == 1 { query.Attr("isAuditing", true) } - if protocolFamily == "http" { - query.Where("(http IS NOT NULL OR https IS NOT NULL)") - } else if protocolFamily == "tcp" { - query.Where("(tcp IS NOT NULL OR tls IS NOT NULL)") + var protocolConds = []string{} + for _, family := range protocolFamilies { + switch family { + case "http": + protocolConds = append(protocolConds, "(http IS NOT NULL OR https IS NOT NULL)") + case "tcp": + protocolConds = append(protocolConds, "(tcp IS NOT NULL OR tls IS NOT NULL)") + case "udp": + protocolConds = append(protocolConds, "(udp IS NOT NULL)") + } + } + if len(protocolConds) > 0 { + query.Where("(" + strings.Join(protocolConds, " OR ") + ")") } _, err = query.FindAll() @@ -1536,16 +1556,67 @@ func (this *ServerDAO) FindEnabledServerIdWithReverseProxyId(tx *dbs.Tx, reverse FindInt64Col(0) } -// CheckTCPPortIsUsing 检查TCP端口是否被使用 -func (this *ServerDAO) CheckTCPPortIsUsing(tx *dbs.Tx, clusterId int64, port int, excludeServerId int64, excludeProtocol string) (bool, error) { +// CheckPortIsUsing 检查端口是否被使用 +// protocolFamily支持tcp和udp +func (this *ServerDAO) CheckPortIsUsing(tx *dbs.Tx, clusterId int64, protocolFamily string, port int, excludeServerId int64, excludeProtocol string) (bool, error) { + // 检查是否在别的协议中 + if excludeServerId > 0 { + one, err := this.Query(tx). + Pk(excludeServerId). + Result("tcp", "tls", "udp", "http", "https"). + Find() + if err != nil { + return false, err + } + if one != nil { + var server = one.(*Server) + for _, protocol := range []string{"http", "https", "tcp", "tls", "udp"} { + if protocol == excludeProtocol { + continue + } + switch protocol { + case "http": + if protocolFamily == "tcp" && lists.ContainsInt(server.DecodeHTTPPorts(), port) { + return true, nil + } + case "https": + if protocolFamily == "tcp" && lists.ContainsInt(server.DecodeHTTPSPorts(), port) { + return true, nil + } + case "tcp": + if protocolFamily == "tcp" && lists.ContainsInt(server.DecodeTCPPorts(), port) { + return true, nil + } + case "tls": + if protocolFamily == "tcp" && lists.ContainsInt(server.DecodeTLSPorts(), port) { + return true, nil + } + case "udp": + // 不需要判断 + } + } + } + } + + // 其他服务中 query := this.Query(tx). Attr("clusterId", clusterId). State(ServerStateEnabled). Param("port", types.String(port)) if excludeServerId <= 0 { - query.Where("JSON_CONTAINS(tcpPorts, :port)") + switch protocolFamily { + case "tcp", "http", "": + query.Where("JSON_CONTAINS(tcpPorts, :port)") + case "udp": + query.Where("JSON_CONTAINS(udpPorts, :port)") + } } else { - query.Where("(id!=:serverId AND JSON_CONTAINS(tcpPorts, :port))") + switch protocolFamily { + case "tcp", "http", "": + query.Where("(id!=:serverId AND JSON_CONTAINS(tcpPorts, :port))") + case "udp": + query.Where("(id!=:serverId AND JSON_CONTAINS(udpPorts, :port))") + } query.Param("serverId", excludeServerId) } return query. diff --git a/internal/db/models/server_model_ext.go b/internal/db/models/server_model_ext.go index 911b493f..95c4309d 100644 --- a/internal/db/models/server_model_ext.go +++ b/internal/db/models/server_model_ext.go @@ -3,6 +3,7 @@ package models import ( "encoding/json" "github.com/TeaOSLab/EdgeAPI/internal/remotelogs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" ) // DecodeGroupIds 解析服务所属分组ID @@ -19,3 +20,108 @@ func (this *Server) DecodeGroupIds() []int64 { } return result } + +// DecodeHTTPPorts 获取HTTP所有端口 +func (this *Server) DecodeHTTPPorts() (ports []int) { + if len(this.Http) > 0 && this.Http != "null" { + config := &serverconfigs.HTTPProtocolConfig{} + err := json.Unmarshal([]byte(this.Http), config) + if err != nil { + return nil + } + err = config.Init() + if err != nil { + return nil + } + for _, listen := range config.Listen { + for i := listen.MinPort; i <= listen.MaxPort; i++ { + ports = append(ports, i) + } + } + } + return +} + +// DecodeHTTPSPorts 获取HTTPS所有端口 +func (this *Server) DecodeHTTPSPorts() (ports []int) { + if len(this.Https) > 0 && this.Https != "null" { + config := &serverconfigs.HTTPSProtocolConfig{} + err := json.Unmarshal([]byte(this.Https), config) + if err != nil { + return nil + } + err = config.Init() + if err != nil { + return nil + } + for _, listen := range config.Listen { + for i := listen.MinPort; i <= listen.MaxPort; i++ { + ports = append(ports, i) + } + } + } + return +} + +// DecodeTCPPorts 获取TCP所有端口 +func (this *Server) DecodeTCPPorts() (ports []int) { + if len(this.Tcp) > 0 && this.Tcp != "null" { + config := &serverconfigs.TCPProtocolConfig{} + err := json.Unmarshal([]byte(this.Tcp), config) + if err != nil { + return nil + } + err = config.Init() + if err != nil { + return nil + } + for _, listen := range config.Listen { + for i := listen.MinPort; i <= listen.MaxPort; i++ { + ports = append(ports, i) + } + } + } + return +} + +// DecodeTLSPorts 获取TLS所有端口 +func (this *Server) DecodeTLSPorts() (ports []int) { + if len(this.Tls) > 0 && this.Tls != "null" { + config := &serverconfigs.TLSProtocolConfig{} + err := json.Unmarshal([]byte(this.Tls), config) + if err != nil { + return nil + } + err = config.Init() + if err != nil { + return nil + } + for _, listen := range config.Listen { + for i := listen.MinPort; i <= listen.MaxPort; i++ { + ports = append(ports, i) + } + } + } + return +} + +// DecodeUDPPorts 获取UDP所有端口 +func (this *Server) DecodeUDPPorts() (ports []int) { + if len(this.Udp) > 0 && this.Udp != "null" { + config := &serverconfigs.UDPProtocolConfig{} + err := json.Unmarshal([]byte(this.Udp), config) + if err != nil { + return nil + } + err = config.Init() + if err != nil { + return nil + } + for _, listen := range config.Listen { + for i := listen.MinPort; i <= listen.MaxPort; i++ { + ports = append(ports, i) + } + } + } + return +} diff --git a/internal/db/models/stats/server_region_country_monthly_stat_dao.go b/internal/db/models/stats/server_region_country_monthly_stat_dao.go index cdee8e6f..72e90aae 100644 --- a/internal/db/models/stats/server_region_country_monthly_stat_dao.go +++ b/internal/db/models/stats/server_region_country_monthly_stat_dao.go @@ -29,7 +29,7 @@ func init() { }) } -// 增加数量 +// IncreaseMonthlyCount 增加数量 func (this *ServerRegionCountryMonthlyStatDAO) IncreaseMonthlyCount(tx *dbs.Tx, serverId int64, countryId int64, month string, count int64) error { if len(month) != 6 { return errors.New("invalid month '" + month + "'") @@ -50,7 +50,7 @@ func (this *ServerRegionCountryMonthlyStatDAO) IncreaseMonthlyCount(tx *dbs.Tx, return nil } -// 查找单页数据 +// ListStats 查找单页数据 func (this *ServerRegionCountryMonthlyStatDAO) ListStats(tx *dbs.Tx, serverId int64, month string, offset int64, size int64) (result []*ServerRegionCountryMonthlyStat, err error) { query := this.Query(tx). Attr("serverId", serverId). diff --git a/internal/db/models/user_features.go b/internal/db/models/user_features.go index 62b22ab7..88a1cad8 100644 --- a/internal/db/models/user_features.go +++ b/internal/db/models/user_features.go @@ -21,15 +21,25 @@ var ( Description: "用户可以配置访问日志转发到自定义的API", }, { - Name: "负载均衡", + Name: "TCP负载均衡", Code: "server.tcp", Description: "用户可以添加TCP/TLS负载均衡服务", }, { - Name: "自定义负载均衡端口", + Name: "自定义TCP负载均衡端口", Code: "server.tcp.port", Description: "用户可以自定义TCP端口", }, + { + Name: "UDP负载均衡", + Code: "server.udp", + Description: "用户可以添加UDP负载均衡服务", + }, + { + Name: "自定义UDP负载均衡端口", + Code: "server.udp.port", + Description: "用户可以自定义UDP端口", + }, { Name: "开启WAF", Code: "server.waf", @@ -43,7 +53,7 @@ var ( } ) -// 用户功能 +// UserFeature 用户功能 type UserFeature struct { Name string `json:"name"` Code string `json:"code"` @@ -54,12 +64,12 @@ func (this *UserFeature) ToPB() *pb.UserFeature { return &pb.UserFeature{Name: this.Name, Code: this.Code, Description: this.Description} } -// 所有功能列表 +// FindAllUserFeatures 所有功能列表 func FindAllUserFeatures() []*UserFeature { return allUserFeatures } -// 查询单个功能 +// FindUserFeature 查询单个功能 func FindUserFeature(code string) *UserFeature { for _, feature := range allUserFeatures { if feature.Code == code { diff --git a/internal/rpc/services/service_node_cluster.go b/internal/rpc/services/service_node_cluster.go index f789ef54..91116e15 100644 --- a/internal/rpc/services/service_node_cluster.go +++ b/internal/rpc/services/service_node_cluster.go @@ -905,7 +905,7 @@ func (this *NodeClusterService) FindFreePortInNodeCluster(ctx context.Context, r continue } - isUsing, err := models.SharedServerDAO.CheckTCPPortIsUsing(tx, req.NodeClusterId, port, 0, "") + isUsing, err := models.SharedServerDAO.CheckPortIsUsing(tx, req.NodeClusterId, req.ProtocolFamily, port, 0, "") if err != nil { return nil, err } @@ -925,7 +925,7 @@ func (this *NodeClusterService) CheckPortIsUsingInNodeCluster(ctx context.Contex } var tx = this.NullTx() - isUsing, err := models.SharedServerDAO.CheckTCPPortIsUsing(tx, req.NodeClusterId, int(req.Port), req.ExcludeServerId, req.ExcludeProtocol) + isUsing, err := models.SharedServerDAO.CheckPortIsUsing(tx, req.NodeClusterId, req.ProtocolFamily, int(req.Port), req.ExcludeServerId, req.ExcludeProtocol) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index de1327f3..80f5fc20 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -330,11 +330,18 @@ func (this *ServerService) UpdateServerUnix(ctx context.Context, req *pb.UpdateS // UpdateServerUDP 修改UDP服务 func (this *ServerService) UpdateServerUDP(ctx context.Context, req *pb.UpdateServerUDPRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, err := this.ValidateAdmin(ctx, 0) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } + if userId > 0 { + err = models.SharedServerDAO.CheckUserServer(nil, userId, req.ServerId) + if err != nil { + return nil, err + } + } + if req.ServerId <= 0 { return nil, errors.New("invalid serverId") } @@ -567,7 +574,7 @@ func (this *ServerService) CountAllEnabledServersMatch(ctx context.Context, req tx := this.NullTx() - count, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, req.ServerGroupId, req.Keyword, req.UserId, req.NodeClusterId, types.Int8(req.AuditingFlag), req.ProtocolFamily) + count, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, req.ServerGroupId, req.Keyword, req.UserId, req.NodeClusterId, types.Int8(req.AuditingFlag), utils.SplitStrings(req.ProtocolFamily, ",")) if err != nil { return nil, err } @@ -585,7 +592,7 @@ func (this *ServerService) ListEnabledServersMatch(ctx context.Context, req *pb. tx := this.NullTx() - servers, err := models.SharedServerDAO.ListEnabledServersMatch(tx, req.Offset, req.Size, req.ServerGroupId, req.Keyword, req.UserId, req.NodeClusterId, req.AuditingFlag, req.ProtocolFamily) + servers, err := models.SharedServerDAO.ListEnabledServersMatch(tx, req.Offset, req.Size, req.ServerGroupId, req.Keyword, req.UserId, req.NodeClusterId, req.AuditingFlag, utils.SplitStrings(req.ProtocolFamily, ",")) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_user.go b/internal/rpc/services/service_user.go index 397954d8..e9de0972 100644 --- a/internal/rpc/services/service_user.go +++ b/internal/rpc/services/service_user.go @@ -310,7 +310,7 @@ func (this *UserService) ComposeUserDashboard(ctx context.Context, req *pb.Compo tx := this.NullTx() // 网站数量 - countServers, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, 0, "", req.UserId, 0, configutils.BoolStateAll, "") + countServers, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, 0, "", req.UserId, 0, configutils.BoolStateAll, []string{}) if err != nil { return nil, err } diff --git a/internal/utils/strings.go b/internal/utils/strings.go new file mode 100644 index 00000000..754f268d --- /dev/null +++ b/internal/utils/strings.go @@ -0,0 +1,21 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package utils + +import "strings" + +// SplitStrings 分隔字符串 +// 忽略其中为空的片段 +func SplitStrings(s string, glue string) []string { + var result = []string{} + + if len(s) > 0 { + for _, p := range strings.Split(s, glue) { + p = strings.TrimSpace(p) + if len(p) > 0 { + result = append(result, p) + } + } + } + return result +} diff --git a/internal/utils/strings_test.go b/internal/utils/strings_test.go new file mode 100644 index 00000000..6ef1ed7f --- /dev/null +++ b/internal/utils/strings_test.go @@ -0,0 +1,10 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package utils + +import "testing" + +func TestSplitStrings(t *testing.T) { + t.Log(SplitStrings("a, b, c", ",")) + t.Log(SplitStrings("a, b, c, ", ",")) +}