diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index ba73f4af..b7150ab7 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -220,6 +220,12 @@ func (this *ServerDAO) CreateServer(tx *dbs.Tx, serverId = types.Int64(op.Id) + // 更新端口 + err = this.NotifyServerPortsUpdate(tx, serverId) + if err != nil { + return serverId, err + } + // 通知配置更改 err = this.NotifyUpdate(tx, serverId) if err != nil { @@ -323,6 +329,12 @@ func (this *ServerDAO) UpdateServerHTTP(tx *dbs.Tx, serverId int64, config []byt return err } + // 更新端口 + err = this.NotifyServerPortsUpdate(tx, serverId) + if err != nil { + return err + } + return this.NotifyUpdate(tx, serverId) } @@ -342,6 +354,12 @@ func (this *ServerDAO) UpdateServerHTTPS(tx *dbs.Tx, serverId int64, httpsJSON [ return err } + // 更新端口 + err = this.NotifyServerPortsUpdate(tx, serverId) + if err != nil { + return err + } + return this.NotifyUpdate(tx, serverId) } @@ -361,6 +379,12 @@ func (this *ServerDAO) UpdateServerTCP(tx *dbs.Tx, serverId int64, config []byte return err } + // 更新端口 + err = this.NotifyServerPortsUpdate(tx, serverId) + if err != nil { + return err + } + return this.NotifyUpdate(tx, serverId) } @@ -380,6 +404,12 @@ func (this *ServerDAO) UpdateServerTLS(tx *dbs.Tx, serverId int64, config []byte return err } + // 更新端口 + err = this.NotifyServerPortsUpdate(tx, serverId) + if err != nil { + return err + } + return this.NotifyUpdate(tx, serverId) } @@ -418,6 +448,12 @@ func (this *ServerDAO) UpdateServerUDP(tx *dbs.Tx, serverId int64, config []byte return err } + // 更新端口 + err = this.NotifyServerPortsUpdate(tx, serverId) + if err != nil { + return err + } + return this.NotifyUpdate(tx, serverId) } @@ -1333,43 +1369,19 @@ func (this *ServerDAO) FindEnabledServerIdWithReverseProxyId(tx *dbs.Tx, reverse FindInt64Col(0) } -// CheckPortIsUsing 检查端口是否被使用 -func (this *ServerDAO) CheckPortIsUsing(tx *dbs.Tx, clusterId int64, port int, excludeServerId int64, excludeProtocol string) (bool, error) { - listen := maps.Map{ - "portRange": strconv.Itoa(port), - } +// CheckTCPPortIsUsing 检查TCP端口是否被使用 +func (this *ServerDAO) CheckTCPPortIsUsing(tx *dbs.Tx, clusterId int64, port int, excludeServerId int64, excludeProtocol string) (bool, error) { query := this.Query(tx). Attr("clusterId", clusterId). - State(ServerStateEnabled) - protocols := []string{"http", "https", "tcp", "tls", "udp"} - where := "" + State(ServerStateEnabled). + Param("port", types.String(port)) if excludeServerId <= 0 { - conds := []string{} - for _, p := range protocols { - conds = append(conds, "JSON_CONTAINS("+p+", :listen, '$.listen')") - } - where = strings.Join(conds, " OR ") + query.Where("JSON_CONTAINS(tcpPorts, :port)") } else { - conds := []string{} - for _, p := range protocols { - conds = append(conds, "JSON_CONTAINS("+p+", :listen, '$.listen')") - } - where1 := "(id!=:serverId AND (" + strings.Join(conds, " OR ") + "))" - - conds = []string{} - for _, p := range protocols { - if p == excludeProtocol { - continue - } - conds = append(conds, "JSON_CONTAINS("+p+", :listen, '$.listen')") - } - where2 := "(id=:serverId AND (" + strings.Join(conds, " OR ") + "))" - where = where1 + " OR " + where2 + query.Where("(id!=:serverId AND JSON_CONTAINS(tcpPorts, :port))") query.Param("serverId", excludeServerId) } return query. - Where("("+where+")"). - Param("listen", string(listen.AsJSON())). Exist() } @@ -1516,7 +1528,11 @@ func (this *ServerDAO) FindFirstHTTPOrHTTPSPortWithClusterId(tx *dbs.Tx, cluster return 0, err } if len(ports) > 0 { - return types.Int(ports[0]), nil + var port = ports[0] + if strings.Contains(port, "-") { // IP范围 + return types.Int(port[:strings.Index(port, "-")]), nil + } + return types.Int(port), nil } } @@ -1527,14 +1543,124 @@ func (this *ServerDAO) FindFirstHTTPOrHTTPSPortWithClusterId(tx *dbs.Tx, cluster if err != nil { return 0, err } - if len(ports) > 0 { - return types.Int(ports[0]), nil + var port = ports[0] + + if strings.Contains(port, "-") { // IP范围 + return types.Int(port[:strings.Index(port, "-")]), nil } + return types.Int(port), nil } return 0, nil } +// NotifyServerPortsUpdate 通知服务端口变化 +func (this *ServerDAO) NotifyServerPortsUpdate(tx *dbs.Tx, serverId int64) error { + one, err := this.Query(tx). + Pk(serverId). + Result("tcp", "tls", "udp", "http", "https"). + Find() + if err != nil { + return err + } + if one == nil { + return nil + } + var server = one.(*Server) + + // HTTP + var tcpListens = []*serverconfigs.NetworkAddressConfig{} + var udpListens = []*serverconfigs.NetworkAddressConfig{} + if len(server.Http) > 0 && server.Http != "null" { + httpConfig := &serverconfigs.HTTPProtocolConfig{} + err := json.Unmarshal([]byte(server.Http), httpConfig) + if err != nil { + return err + } + tcpListens = append(tcpListens, httpConfig.Listen...) + } + + // HTTPS + if len(server.Https) > 0 && server.Https != "null" { + httpsConfig := &serverconfigs.HTTPSProtocolConfig{} + err := json.Unmarshal([]byte(server.Https), httpsConfig) + if err != nil { + return err + } + tcpListens = append(tcpListens, httpsConfig.Listen...) + } + + // TCP + if len(server.Tcp) > 0 && server.Tcp != "null" { + tcpConfig := &serverconfigs.TCPProtocolConfig{} + err := json.Unmarshal([]byte(server.Tcp), tcpConfig) + if err != nil { + return err + } + tcpListens = append(tcpListens, tcpConfig.Listen...) + } + + // TLS + if len(server.Tls) > 0 && server.Tls != "null" { + tlsConfig := &serverconfigs.TLSProtocolConfig{} + err := json.Unmarshal([]byte(server.Tls), tlsConfig) + if err != nil { + return err + } + tcpListens = append(tcpListens, tlsConfig.Listen...) + } + + // UDP + if len(server.Udp) > 0 && server.Udp != "null" { + udpConfig := &serverconfigs.UDPProtocolConfig{} + err := json.Unmarshal([]byte(server.Udp), udpConfig) + if err != nil { + return err + } + udpListens = append(udpListens, udpConfig.Listen...) + } + + var tcpPorts = []int{} + for _, listen := range tcpListens { + _ = listen.Init() + if listen.MinPort > 0 && listen.MaxPort > 0 && listen.MinPort <= listen.MaxPort { + for i := listen.MinPort; i <= listen.MaxPort; i++ { + if !lists.ContainsInt(tcpPorts, i) { + tcpPorts = append(tcpPorts, i) + } + } + } + } + + tcpPortsJSON, err := json.Marshal(tcpPorts) + if err != nil { + return err + } + + var udpPorts = []int{} + for _, listen := range udpListens { + _ = listen.Init() + if listen.MinPort > 0 && listen.MaxPort > 0 && listen.MinPort <= listen.MaxPort { + for i := listen.MinPort; i <= listen.MaxPort; i++ { + if !lists.ContainsInt(udpPorts, i) { + udpPorts = append(udpPorts, i) + } + } + } + } + + udpPortsJSON, err := json.Marshal(udpPorts) + if err != nil { + return err + } + + return this.Query(tx). + Pk(serverId). + Set("tcpPorts", string(tcpPortsJSON)). + Set("udpPorts", string(udpPortsJSON)). + UpdateQuickly() +} + // NotifyUpdate 同步集群 func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error { // 创建任务 diff --git a/internal/db/models/server_dao_test.go b/internal/db/models/server_dao_test.go index 232847b4..aa76edfa 100644 --- a/internal/db/models/server_dao_test.go +++ b/internal/db/models/server_dao_test.go @@ -96,7 +96,7 @@ func TestServerDAO_CheckPortIsUsing(t *testing.T) { // t.Log("isUsing:", isUsing) //} { - isUsing, err := SharedServerDAO.CheckPortIsUsing(tx, 18, 1234, 44, "tcp") + isUsing, err := SharedServerDAO.CheckTCPPortIsUsing(tx, 18, 3306, 0, "tcp") if err != nil { t.Fatal(err) } diff --git a/internal/db/models/server_model.go b/internal/db/models/server_model.go index 326cd8cd..e07b7a00 100644 --- a/internal/db/models/server_model.go +++ b/internal/db/models/server_model.go @@ -1,6 +1,6 @@ package models -// 服务 +// Server 服务 type Server struct { Id uint32 `field:"id"` // ID IsOn uint8 `field:"isOn"` // 是否启用 @@ -31,6 +31,8 @@ type Server struct { CreatedAt uint64 `field:"createdAt"` // 创建时间 State uint8 `field:"state"` // 状态 DnsName string `field:"dnsName"` // DNS名称 + TcpPorts string `field:"tcpPorts"` // 所包含TCP端口 + UdpPorts string `field:"udpPorts"` // 所包含UDP端口 } type ServerOperator struct { @@ -63,6 +65,8 @@ type ServerOperator struct { CreatedAt interface{} // 创建时间 State interface{} // 状态 DnsName interface{} // DNS名称 + TcpPorts interface{} // 所包含TCP端口 + UdpPorts interface{} // 所包含UDP端口 } func NewServerOperator() *ServerOperator { diff --git a/internal/rpc/services/service_node_cluster.go b/internal/rpc/services/service_node_cluster.go index 205ebce9..32b969c9 100644 --- a/internal/rpc/services/service_node_cluster.go +++ b/internal/rpc/services/service_node_cluster.go @@ -903,7 +903,7 @@ func (this *NodeClusterService) FindFreePortInNodeCluster(ctx context.Context, r continue } - isUsing, err := models.SharedServerDAO.CheckPortIsUsing(tx, req.NodeClusterId, port, 0, "") + isUsing, err := models.SharedServerDAO.CheckTCPPortIsUsing(tx, req.NodeClusterId, port, 0, "") if err != nil { return nil, err } @@ -923,7 +923,7 @@ func (this *NodeClusterService) CheckPortIsUsingInNodeCluster(ctx context.Contex } var tx = this.NullTx() - isUsing, err := models.SharedServerDAO.CheckPortIsUsing(tx, req.NodeClusterId, int(req.Port), req.ExcludeServerId, req.ExcludeProtocol) + isUsing, err := models.SharedServerDAO.CheckTCPPortIsUsing(tx, req.NodeClusterId, int(req.Port), req.ExcludeServerId, req.ExcludeProtocol) if err != nil { return nil, err } diff --git a/internal/setup/sql_upgrade.go b/internal/setup/sql_upgrade.go index e5633001..53e2b2f0 100644 --- a/internal/setup/sql_upgrade.go +++ b/internal/setup/sql_upgrade.go @@ -493,5 +493,21 @@ func upgradeV0_3_2(db *dbs.DB) error { } } + // 更新服务端口 + var serverDAO = models.NewServerDAO() + ones, err := serverDAO.Query(nil). + ResultPk(). + FindAll() + if err != nil { + return err + } + for _, one := range ones { + var serverId = int64(one.(*models.Server).Id) + err = serverDAO.NotifyServerPortsUpdate(nil, serverId) + if err != nil { + return err + } + } + return nil }