diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index 3be5ab1a..5fe41e24 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -1232,16 +1232,42 @@ func (this *ServerDAO) FindEnabledServerIdWithReverseProxyId(tx *dbs.Tx, reverse } // 检查端口是否被使用 -func (this *ServerDAO) CheckPortIsUsing(tx *dbs.Tx, clusterId int64, port int) (bool, error) { +func (this *ServerDAO) CheckPortIsUsing(tx *dbs.Tx, clusterId int64, port int, excludeServerId int64, excludeProtocol string) (bool, error) { listen := maps.Map{ "portRange": strconv.Itoa(port), } - - return this.Query(tx). + query := this.Query(tx). Attr("clusterId", clusterId). - State(ServerStateEnabled). - Where("(JSON_CONTAINS(http, :listen) OR JSON_CONTAINS(https, :listen) OR JSON_CONTAINS(tcp, :listen) OR JSON_CONTAINS(tls, :listen))"). - Param(":listen", string(listen.AsJSON())). + State(ServerStateEnabled) + protocols := []string{"http", "https", "tcp", "tls", "udp"} + where := "" + if excludeServerId <= 0 { + conds := []string{} + for _, p := range protocols { + conds = append(conds, "JSON_CONTAINS("+p+", :listen, '$.listen')") + } + where = strings.Join(conds, " OR ") + } else { + conds := []string{} + for _, p := range protocols { + conds = append(conds, "JSON_CONTAINS("+p+", :listen, '$.listen')") + } + where1 := "(id!=:serverId AND (" + strings.Join(conds, " OR ") + "))" + + conds = []string{} + for _, p := range protocols { + if p == excludeProtocol { + continue + } + conds = append(conds, "JSON_CONTAINS("+p+", :listen, '$.listen')") + } + where2 := "(id=:serverId AND (" + strings.Join(conds, " OR ") + "))" + where = where1 + " OR " + where2 + query.Param("serverId", excludeServerId) + } + return query. + Where("("+where+")"). + Param("listen", string(listen.AsJSON())). Exist() } diff --git a/internal/db/models/server_dao_test.go b/internal/db/models/server_dao_test.go index d378cedb..d92af2e8 100644 --- a/internal/db/models/server_dao_test.go +++ b/internal/db/models/server_dao_test.go @@ -86,4 +86,23 @@ func TestServerDAO_FindAllEnabledServerIdsWithSSLPolicyIds(t *testing.T) { t.Fatal(err) } t.Log("serverIds:", serverIds) -} \ No newline at end of file +} + +func TestServerDAO_CheckPortIsUsing(t *testing.T) { + dbs.NotifyReady() + var tx *dbs.Tx + //{ + // isUsing, err := SharedServerDAO.CheckPortIsUsing(tx, 18, 1234, 0, "") + // if err != nil { + // t.Fatal(err) + // } + // t.Log("isUsing:", isUsing) + //} + { + isUsing, err := SharedServerDAO.CheckPortIsUsing(tx, 18, 1234, 44, "tcp") + if err != nil { + t.Fatal(err) + } + t.Log("isUsing:", isUsing) + } +} diff --git a/internal/db/models/user_features.go b/internal/db/models/user_features.go index bd33ac4d..62b22ab7 100644 --- a/internal/db/models/user_features.go +++ b/internal/db/models/user_features.go @@ -25,6 +25,11 @@ var ( Code: "server.tcp", Description: "用户可以添加TCP/TLS负载均衡服务", }, + { + Name: "自定义负载均衡端口", + Code: "server.tcp.port", + Description: "用户可以自定义TCP端口", + }, { Name: "开启WAF", Code: "server.waf", diff --git a/internal/rpc/services/service_node_cluster.go b/internal/rpc/services/service_node_cluster.go index fc42cd72..86e8dea8 100644 --- a/internal/rpc/services/service_node_cluster.go +++ b/internal/rpc/services/service_node_cluster.go @@ -839,7 +839,7 @@ func (this *NodeClusterService) FindFreePortInNodeCluster(ctx context.Context, r continue } - isUsing, err := models.SharedServerDAO.CheckPortIsUsing(tx, req.NodeClusterId, port) + isUsing, err := models.SharedServerDAO.CheckPortIsUsing(tx, req.NodeClusterId, port, 0, "") if err != nil { return nil, err } @@ -850,3 +850,18 @@ func (this *NodeClusterService) FindFreePortInNodeCluster(ctx context.Context, r return nil, errors.New("can not find random port") } + +// 检查端口是否已经被使用 +func (this *NodeClusterService) CheckPortIsUsingInNodeCluster(ctx context.Context, req *pb.CheckPortIsUsingInNodeClusterRequest) (*pb.CheckPortIsUsingInNodeClusterResponse, error) { + _, _, err := this.ValidateAdminAndUser(ctx, 0, 0) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + isUsing, err := models.SharedServerDAO.CheckPortIsUsing(tx, req.NodeClusterId, int(req.Port), req.ExcludeServerId, req.ExcludeProtocol) + if err != nil { + return nil, err + } + return &pb.CheckPortIsUsingInNodeClusterResponse{IsUsing: isUsing}, nil +} diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index fb41085a..8c1e30fd 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -210,13 +210,16 @@ func (this *ServerService) UpdateServerHTTPS(ctx context.Context, req *pb.Update // 修改TCP服务 func (this *ServerService) UpdateServerTCP(ctx context.Context, req *pb.UpdateServerTCPRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } - if req.ServerId <= 0 { - return nil, errors.New("invalid serverId") + if userId > 0 { + err = models.SharedServerDAO.CheckUserServer(nil, userId, req.ServerId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -233,13 +236,16 @@ func (this *ServerService) UpdateServerTCP(ctx context.Context, req *pb.UpdateSe // 修改TLS服务 func (this *ServerService) UpdateServerTLS(ctx context.Context, req *pb.UpdateServerTLSRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } - if req.ServerId <= 0 { - return nil, errors.New("invalid serverId") + if userId > 0 { + err = models.SharedServerDAO.CheckUserServer(nil, userId, req.ServerId) + if err != nil { + return nil, err + } } tx := this.NullTx()