TCP、TLS、UDP支持端口范围

This commit is contained in:
GoEdgeLab
2021-10-10 16:29:50 +08:00
parent 7d3fd6acc9
commit cc7fc9cb3e
5 changed files with 183 additions and 37 deletions

View File

@@ -220,6 +220,12 @@ func (this *ServerDAO) CreateServer(tx *dbs.Tx,
serverId = types.Int64(op.Id)
// 更新端口
err = this.NotifyServerPortsUpdate(tx, serverId)
if err != nil {
return serverId, err
}
// 通知配置更改
err = this.NotifyUpdate(tx, serverId)
if err != nil {
@@ -323,6 +329,12 @@ func (this *ServerDAO) UpdateServerHTTP(tx *dbs.Tx, serverId int64, config []byt
return err
}
// 更新端口
err = this.NotifyServerPortsUpdate(tx, serverId)
if err != nil {
return err
}
return this.NotifyUpdate(tx, serverId)
}
@@ -342,6 +354,12 @@ func (this *ServerDAO) UpdateServerHTTPS(tx *dbs.Tx, serverId int64, httpsJSON [
return err
}
// 更新端口
err = this.NotifyServerPortsUpdate(tx, serverId)
if err != nil {
return err
}
return this.NotifyUpdate(tx, serverId)
}
@@ -361,6 +379,12 @@ func (this *ServerDAO) UpdateServerTCP(tx *dbs.Tx, serverId int64, config []byte
return err
}
// 更新端口
err = this.NotifyServerPortsUpdate(tx, serverId)
if err != nil {
return err
}
return this.NotifyUpdate(tx, serverId)
}
@@ -380,6 +404,12 @@ func (this *ServerDAO) UpdateServerTLS(tx *dbs.Tx, serverId int64, config []byte
return err
}
// 更新端口
err = this.NotifyServerPortsUpdate(tx, serverId)
if err != nil {
return err
}
return this.NotifyUpdate(tx, serverId)
}
@@ -418,6 +448,12 @@ func (this *ServerDAO) UpdateServerUDP(tx *dbs.Tx, serverId int64, config []byte
return err
}
// 更新端口
err = this.NotifyServerPortsUpdate(tx, serverId)
if err != nil {
return err
}
return this.NotifyUpdate(tx, serverId)
}
@@ -1333,43 +1369,19 @@ func (this *ServerDAO) FindEnabledServerIdWithReverseProxyId(tx *dbs.Tx, reverse
FindInt64Col(0)
}
// CheckPortIsUsing 检查端口是否被使用
func (this *ServerDAO) CheckPortIsUsing(tx *dbs.Tx, clusterId int64, port int, excludeServerId int64, excludeProtocol string) (bool, error) {
listen := maps.Map{
"portRange": strconv.Itoa(port),
}
// CheckTCPPortIsUsing 检查TCP端口是否被使用
func (this *ServerDAO) CheckTCPPortIsUsing(tx *dbs.Tx, clusterId int64, port int, excludeServerId int64, excludeProtocol string) (bool, error) {
query := this.Query(tx).
Attr("clusterId", clusterId).
State(ServerStateEnabled)
protocols := []string{"http", "https", "tcp", "tls", "udp"}
where := ""
State(ServerStateEnabled).
Param("port", types.String(port))
if excludeServerId <= 0 {
conds := []string{}
for _, p := range protocols {
conds = append(conds, "JSON_CONTAINS("+p+", :listen, '$.listen')")
}
where = strings.Join(conds, " OR ")
query.Where("JSON_CONTAINS(tcpPorts, :port)")
} else {
conds := []string{}
for _, p := range protocols {
conds = append(conds, "JSON_CONTAINS("+p+", :listen, '$.listen')")
}
where1 := "(id!=:serverId AND (" + strings.Join(conds, " OR ") + "))"
conds = []string{}
for _, p := range protocols {
if p == excludeProtocol {
continue
}
conds = append(conds, "JSON_CONTAINS("+p+", :listen, '$.listen')")
}
where2 := "(id=:serverId AND (" + strings.Join(conds, " OR ") + "))"
where = where1 + " OR " + where2
query.Where("(id!=:serverId AND JSON_CONTAINS(tcpPorts, :port))")
query.Param("serverId", excludeServerId)
}
return query.
Where("("+where+")").
Param("listen", string(listen.AsJSON())).
Exist()
}
@@ -1516,7 +1528,11 @@ func (this *ServerDAO) FindFirstHTTPOrHTTPSPortWithClusterId(tx *dbs.Tx, cluster
return 0, err
}
if len(ports) > 0 {
return types.Int(ports[0]), nil
var port = ports[0]
if strings.Contains(port, "-") { // IP范围
return types.Int(port[:strings.Index(port, "-")]), nil
}
return types.Int(port), nil
}
}
@@ -1527,14 +1543,124 @@ func (this *ServerDAO) FindFirstHTTPOrHTTPSPortWithClusterId(tx *dbs.Tx, cluster
if err != nil {
return 0, err
}
if len(ports) > 0 {
return types.Int(ports[0]), nil
var port = ports[0]
if strings.Contains(port, "-") { // IP范围
return types.Int(port[:strings.Index(port, "-")]), nil
}
return types.Int(port), nil
}
return 0, nil
}
// NotifyServerPortsUpdate 通知服务端口变化
func (this *ServerDAO) NotifyServerPortsUpdate(tx *dbs.Tx, serverId int64) error {
one, err := this.Query(tx).
Pk(serverId).
Result("tcp", "tls", "udp", "http", "https").
Find()
if err != nil {
return err
}
if one == nil {
return nil
}
var server = one.(*Server)
// HTTP
var tcpListens = []*serverconfigs.NetworkAddressConfig{}
var udpListens = []*serverconfigs.NetworkAddressConfig{}
if len(server.Http) > 0 && server.Http != "null" {
httpConfig := &serverconfigs.HTTPProtocolConfig{}
err := json.Unmarshal([]byte(server.Http), httpConfig)
if err != nil {
return err
}
tcpListens = append(tcpListens, httpConfig.Listen...)
}
// HTTPS
if len(server.Https) > 0 && server.Https != "null" {
httpsConfig := &serverconfigs.HTTPSProtocolConfig{}
err := json.Unmarshal([]byte(server.Https), httpsConfig)
if err != nil {
return err
}
tcpListens = append(tcpListens, httpsConfig.Listen...)
}
// TCP
if len(server.Tcp) > 0 && server.Tcp != "null" {
tcpConfig := &serverconfigs.TCPProtocolConfig{}
err := json.Unmarshal([]byte(server.Tcp), tcpConfig)
if err != nil {
return err
}
tcpListens = append(tcpListens, tcpConfig.Listen...)
}
// TLS
if len(server.Tls) > 0 && server.Tls != "null" {
tlsConfig := &serverconfigs.TLSProtocolConfig{}
err := json.Unmarshal([]byte(server.Tls), tlsConfig)
if err != nil {
return err
}
tcpListens = append(tcpListens, tlsConfig.Listen...)
}
// UDP
if len(server.Udp) > 0 && server.Udp != "null" {
udpConfig := &serverconfigs.UDPProtocolConfig{}
err := json.Unmarshal([]byte(server.Udp), udpConfig)
if err != nil {
return err
}
udpListens = append(udpListens, udpConfig.Listen...)
}
var tcpPorts = []int{}
for _, listen := range tcpListens {
_ = listen.Init()
if listen.MinPort > 0 && listen.MaxPort > 0 && listen.MinPort <= listen.MaxPort {
for i := listen.MinPort; i <= listen.MaxPort; i++ {
if !lists.ContainsInt(tcpPorts, i) {
tcpPorts = append(tcpPorts, i)
}
}
}
}
tcpPortsJSON, err := json.Marshal(tcpPorts)
if err != nil {
return err
}
var udpPorts = []int{}
for _, listen := range udpListens {
_ = listen.Init()
if listen.MinPort > 0 && listen.MaxPort > 0 && listen.MinPort <= listen.MaxPort {
for i := listen.MinPort; i <= listen.MaxPort; i++ {
if !lists.ContainsInt(udpPorts, i) {
udpPorts = append(udpPorts, i)
}
}
}
}
udpPortsJSON, err := json.Marshal(udpPorts)
if err != nil {
return err
}
return this.Query(tx).
Pk(serverId).
Set("tcpPorts", string(tcpPortsJSON)).
Set("udpPorts", string(udpPortsJSON)).
UpdateQuickly()
}
// NotifyUpdate 同步集群
func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error {
// 创建任务

View File

@@ -96,7 +96,7 @@ func TestServerDAO_CheckPortIsUsing(t *testing.T) {
// t.Log("isUsing:", isUsing)
//}
{
isUsing, err := SharedServerDAO.CheckPortIsUsing(tx, 18, 1234, 44, "tcp")
isUsing, err := SharedServerDAO.CheckTCPPortIsUsing(tx, 18, 3306, 0, "tcp")
if err != nil {
t.Fatal(err)
}

View File

@@ -1,6 +1,6 @@
package models
// 服务
// Server 服务
type Server struct {
Id uint32 `field:"id"` // ID
IsOn uint8 `field:"isOn"` // 是否启用
@@ -31,6 +31,8 @@ type Server struct {
CreatedAt uint64 `field:"createdAt"` // 创建时间
State uint8 `field:"state"` // 状态
DnsName string `field:"dnsName"` // DNS名称
TcpPorts string `field:"tcpPorts"` // 所包含TCP端口
UdpPorts string `field:"udpPorts"` // 所包含UDP端口
}
type ServerOperator struct {
@@ -63,6 +65,8 @@ type ServerOperator struct {
CreatedAt interface{} // 创建时间
State interface{} // 状态
DnsName interface{} // DNS名称
TcpPorts interface{} // 所包含TCP端口
UdpPorts interface{} // 所包含UDP端口
}
func NewServerOperator() *ServerOperator {

View File

@@ -903,7 +903,7 @@ func (this *NodeClusterService) FindFreePortInNodeCluster(ctx context.Context, r
continue
}
isUsing, err := models.SharedServerDAO.CheckPortIsUsing(tx, req.NodeClusterId, port, 0, "")
isUsing, err := models.SharedServerDAO.CheckTCPPortIsUsing(tx, req.NodeClusterId, port, 0, "")
if err != nil {
return nil, err
}
@@ -923,7 +923,7 @@ func (this *NodeClusterService) CheckPortIsUsingInNodeCluster(ctx context.Contex
}
var tx = this.NullTx()
isUsing, err := models.SharedServerDAO.CheckPortIsUsing(tx, req.NodeClusterId, int(req.Port), req.ExcludeServerId, req.ExcludeProtocol)
isUsing, err := models.SharedServerDAO.CheckTCPPortIsUsing(tx, req.NodeClusterId, int(req.Port), req.ExcludeServerId, req.ExcludeProtocol)
if err != nil {
return nil, err
}

View File

@@ -493,5 +493,21 @@ func upgradeV0_3_2(db *dbs.DB) error {
}
}
// 更新服务端口
var serverDAO = models.NewServerDAO()
ones, err := serverDAO.Query(nil).
ResultPk().
FindAll()
if err != nil {
return err
}
for _, one := range ones {
var serverId = int64(one.(*models.Server).Id)
err = serverDAO.NotifyServerPortsUpdate(nil, serverId)
if err != nil {
return err
}
}
return nil
}