From 555662ae2a739e2b5b386a3ce762142b093e0211 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Fri, 12 Jan 2024 11:50:10 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=BF=AB=E6=8D=B7=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E5=92=8C=E5=88=A0=E9=99=A4=E7=BD=91=E7=AB=99=E6=BA=90?= =?UTF-8?q?=E7=AB=99API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/db/models/origin_dao.go | 11 + internal/db/models/reverse_proxy_dao.go | 6 +- internal/db/models/reverse_proxy_model_ext.go | 30 +++ internal/db/models/server_dao.go | 45 +++- internal/rpc/services/service_server.go | 218 +++++++++++++++++- 5 files changed, 298 insertions(+), 12 deletions(-) diff --git a/internal/db/models/origin_dao.go b/internal/db/models/origin_dao.go index de82632b..b9786022 100644 --- a/internal/db/models/origin_dao.go +++ b/internal/db/models/origin_dao.go @@ -538,6 +538,17 @@ func (this *OriginDAO) CheckUserOrigin(tx *dbs.Tx, userId int64, originId int64) return SharedReverseProxyDAO.CheckUserReverseProxy(tx, userId, reverseProxyId) } +// ExistsOrigin 检查源站是否存在 +func (this *OriginDAO) ExistsOrigin(tx *dbs.Tx, originId int64) (bool, error) { + if originId <= 0 { + return false, nil + } + return this.Query(tx). + Pk(originId). + State(OriginStateEnabled). + Exist() +} + // NotifyUpdate 通知更新 func (this *OriginDAO) NotifyUpdate(tx *dbs.Tx, originId int64) error { reverseProxyId, err := SharedReverseProxyDAO.FindReverseProxyContainsOriginId(tx, originId) diff --git a/internal/db/models/reverse_proxy_dao.go b/internal/db/models/reverse_proxy_dao.go index 0e0752f5..bdd22a17 100644 --- a/internal/db/models/reverse_proxy_dao.go +++ b/internal/db/models/reverse_proxy_dao.go @@ -376,14 +376,14 @@ func (this *ReverseProxyDAO) UpdateReverseProxyScheduling(tx *dbs.Tx, reversePro } // UpdateReverseProxyPrimaryOrigins 修改主要源站 -func (this *ReverseProxyDAO) UpdateReverseProxyPrimaryOrigins(tx *dbs.Tx, reverseProxyId int64, origins []byte) error { +func (this *ReverseProxyDAO) UpdateReverseProxyPrimaryOrigins(tx *dbs.Tx, reverseProxyId int64, originRefs []byte) error { if reverseProxyId <= 0 { return errors.New("invalid reverseProxyId") } var op = NewReverseProxyOperator() op.Id = reverseProxyId - if len(origins) > 0 { - op.PrimaryOrigins = origins + if len(originRefs) > 0 { + op.PrimaryOrigins = originRefs } else { op.PrimaryOrigins = "[]" } diff --git a/internal/db/models/reverse_proxy_model_ext.go b/internal/db/models/reverse_proxy_model_ext.go index 2640e7f9..648a8664 100644 --- a/internal/db/models/reverse_proxy_model_ext.go +++ b/internal/db/models/reverse_proxy_model_ext.go @@ -1 +1,31 @@ package models + +import ( + "encoding/json" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/iwind/TeaGo/logs" +) + +// DecodePrimaryOrigins 解析主要源站 +func (this *ReverseProxy) DecodePrimaryOrigins() []*serverconfigs.OriginRef { + var refs = []*serverconfigs.OriginRef{} + if IsNotNull(this.PrimaryOrigins) { + err := json.Unmarshal(this.PrimaryOrigins, &refs) + if err != nil { + logs.Error(err) + } + } + return refs +} + +// DecodeBackupOrigins 解析备用源站 +func (this *ReverseProxy) DecodeBackupOrigins() []*serverconfigs.OriginRef { + var refs = []*serverconfigs.OriginRef{} + if IsNotNull(this.BackupOrigins) { + err := json.Unmarshal(this.BackupOrigins, &refs) + if err != nil { + logs.Error(err) + } + } + return refs +} diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index 519cca20..9c725d4a 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -757,14 +757,14 @@ func (this *ServerDAO) UpdateServerAuditing(tx *dbs.Tx, serverId int64, result * return this.NotifyDNSUpdate(tx, serverId) } -// UpdateServerReverseProxy 修改反向代理配置 -func (this *ServerDAO) UpdateServerReverseProxy(tx *dbs.Tx, serverId int64, config []byte) error { +// UpdateServerReverseProxyRef 修改反向代理配置 +func (this *ServerDAO) UpdateServerReverseProxyRef(tx *dbs.Tx, serverId int64, reverseProxyRefJSON []byte) error { if serverId <= 0 { return errors.New("serverId should not be smaller than 0") } var op = NewServerOperator() op.Id = serverId - op.ReverseProxy = JSONBytes(config) + op.ReverseProxy = JSONBytes(reverseProxyRefJSON) err := this.Save(tx, op) if err != nil { return err @@ -773,6 +773,28 @@ func (this *ServerDAO) UpdateServerReverseProxy(tx *dbs.Tx, serverId int64, conf return this.NotifyUpdate(tx, serverId) } +// CreateServerReverseProxyRef 创建反向代理配置 +func (this *ServerDAO) CreateServerReverseProxyRef(tx *dbs.Tx, userId int64, serverId int64) (reverseProxyId int64, err error) { + reverseProxyId, err = SharedReverseProxyDAO.CreateReverseProxy(tx, 0, userId, nil, []byte("[]"), []byte("[]")) + if err != nil { + return 0, err + } + var reverseProxyRef = &serverconfigs.ReverseProxyRef{ + IsPrior: false, + IsOn: true, + ReverseProxyId: reverseProxyId, + } + reverseProxyRefJSON, err := json.Marshal(reverseProxyRef) + if err != nil { + return 0, err + } + err = this.UpdateServerReverseProxyRef(tx, serverId, reverseProxyRefJSON) + if err != nil { + return 0, err + } + return reverseProxyId, nil +} + // CountAllEnabledServers 计算所有可用服务数量 func (this *ServerDAO) CountAllEnabledServers(tx *dbs.Tx) (int64, error) { return this.Query(tx). @@ -1362,8 +1384,8 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server, ignoreCer return config, nil } -// FindReverseProxyRef 根据条件获取反向代理配置 -func (this *ServerDAO) FindReverseProxyRef(tx *dbs.Tx, serverId int64) (*serverconfigs.ReverseProxyRef, error) { +// FindServerReverseProxyRef 根据条件获取反向代理配置 +func (this *ServerDAO) FindServerReverseProxyRef(tx *dbs.Tx, serverId int64) (*serverconfigs.ReverseProxyRef, error) { reverseProxy, err := this.Query(tx). Pk(serverId). Result("reverseProxy"). @@ -1374,7 +1396,7 @@ func (this *ServerDAO) FindReverseProxyRef(tx *dbs.Tx, serverId int64) (*serverc if len(reverseProxy) == 0 || reverseProxy == "null" { return nil, nil } - config := &serverconfigs.ReverseProxyRef{} + var config = &serverconfigs.ReverseProxyRef{} err = json.Unmarshal([]byte(reverseProxy), config) return config, err } @@ -2998,6 +3020,17 @@ func (this *ServerDAO) CheckServerPlanQuota(tx *dbs.Tx, serverId int64, countSer return nil } +// ExistsServer 检查网站是否存在 +func (this *ServerDAO) ExistsServer(tx *dbs.Tx, serverId int64) (bool, error) { + if serverId <= 0 { + return false, nil + } + return this.Query(tx). + Pk(serverId). + State(ServerStateEnabled). + Exist() +} + // NotifyUpdate 同步服务所在的集群 func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error { if serverId <= 0 { diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index 51a4cc92..15ca1c4d 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -695,6 +695,218 @@ func (this *ServerService) CreateBasicTCPServer(ctx context.Context, req *pb.Cre return &pb.CreateBasicTCPServerResponse{ServerId: serverId}, nil } +// AddServerOrigin 为网站添加源站 +func (this *ServerService) AddServerOrigin(ctx context.Context, req *pb.AddServerOriginRequest) (*pb.RPCSuccess, error) { + _, userId, err := this.ValidateAdminAndUser(ctx, true) + if err != nil { + return nil, err + } + + if req.ServerId <= 0 { + return nil, errors.New("require 'serverId'") + } + if req.OriginId <= 0 { + return nil, errors.New("require 'originId'") + } + + var tx = this.NullTx() + + // check user + if userId > 0 { + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) + if err != nil { + return nil, err + } + + err = models.SharedOriginDAO.CheckUserOrigin(tx, userId, req.OriginId) + if err != nil { + return nil, err + } + } else { + // check server + existsServer, err := models.SharedServerDAO.ExistsServer(tx, req.ServerId) + if err != nil { + return nil, err + } + if !existsServer { + return nil, errors.New("server '" + types.String(req.ServerId) + "' not found") + } + + // check origin + existsOrigin, err := models.SharedOriginDAO.ExistsOrigin(tx, req.OriginId) + if err != nil { + return nil, err + } + if !existsOrigin { + return nil, errors.New("origin '" + types.String(req.OriginId) + "' not found") + } + } + + reverseProxyRef, err := models.SharedServerDAO.FindServerReverseProxyRef(tx, req.ServerId) + if err != nil { + return nil, err + } + if reverseProxyRef == nil || reverseProxyRef.ReverseProxyId <= 0 { + reverseProxyId, err := models.SharedServerDAO.CreateServerReverseProxyRef(tx, userId, req.ServerId) + if err != nil { + return nil, err + } + reverseProxyRef = &serverconfigs.ReverseProxyRef{ + IsPrior: false, + IsOn: true, + ReverseProxyId: reverseProxyId, + } + } + + reverseProxy, err := models.SharedReverseProxyDAO.FindEnabledReverseProxy(tx, reverseProxyRef.ReverseProxyId) + if err != nil { + return nil, err + } + if reverseProxy == nil { + return nil, errors.New("can not found reverse proxy") + } + + if req.IsPrimary { + var refs = reverseProxy.DecodePrimaryOrigins() + refs = append(refs, &serverconfigs.OriginRef{ + IsOn: true, + OriginId: req.OriginId, + }) + refsJSON, err := json.Marshal(refs) + if err != nil { + return nil, err + } + err = models.SharedReverseProxyDAO.UpdateReverseProxyPrimaryOrigins(tx, int64(reverseProxy.Id), refsJSON) + if err != nil { + return nil, err + } + } else { + var refs = reverseProxy.DecodeBackupOrigins() + refs = append(refs, &serverconfigs.OriginRef{ + IsOn: true, + OriginId: req.OriginId, + }) + refsJSON, err := json.Marshal(refs) + if err != nil { + return nil, err + } + err = models.SharedReverseProxyDAO.UpdateReverseProxyBackupOrigins(tx, int64(reverseProxy.Id), refsJSON) + if err != nil { + return nil, err + } + } + + return this.Success() +} + +// DeleteServerOrigin 从网站中删除某个源站 +func (this *ServerService) DeleteServerOrigin(ctx context.Context, req *pb.DeleteServerOriginRequest) (*pb.RPCSuccess, error) { + _, userId, err := this.ValidateAdminAndUser(ctx, true) + if err != nil { + return nil, err + } + + if req.ServerId <= 0 { + return nil, errors.New("require 'serverId'") + } + if req.OriginId <= 0 { + return nil, errors.New("require 'originId'") + } + + var tx = this.NullTx() + + // check user + if userId > 0 { + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) + if err != nil { + return nil, err + } + + err = models.SharedOriginDAO.CheckUserOrigin(tx, userId, req.OriginId) + if err != nil { + return nil, err + } + } else { + // check server + existsServer, err := models.SharedServerDAO.ExistsServer(tx, req.ServerId) + if err != nil { + return nil, err + } + if !existsServer { + return nil, errors.New("server '" + types.String(req.ServerId) + "' not found") + } + + // check origin + existsOrigin, err := models.SharedOriginDAO.ExistsOrigin(tx, req.OriginId) + if err != nil { + return nil, err + } + if !existsOrigin { + return nil, errors.New("origin '" + types.String(req.OriginId) + "' not found") + } + } + + reverseProxyRef, err := models.SharedServerDAO.FindServerReverseProxyRef(tx, req.ServerId) + if err != nil { + return nil, err + } + if reverseProxyRef == nil || reverseProxyRef.ReverseProxyId <= 0 { + return this.Success() + } + + reverseProxy, err := models.SharedReverseProxyDAO.FindEnabledReverseProxy(tx, reverseProxyRef.ReverseProxyId) + if err != nil { + return nil, err + } + if reverseProxy == nil { + return this.Success() + } + + var primaryOrigins = reverseProxy.DecodePrimaryOrigins() + var newPrimaryOrigins = []*serverconfigs.OriginRef{} + var found = false + for _, origin := range primaryOrigins { + if origin.OriginId == req.OriginId { + found = true + continue + } + newPrimaryOrigins = append(newPrimaryOrigins, origin) + } + if found { + newPrimaryOriginsJSON, err := json.Marshal(newPrimaryOrigins) + if err != nil { + return nil, err + } + err = models.SharedReverseProxyDAO.UpdateReverseProxyPrimaryOrigins(tx, int64(reverseProxy.Id), newPrimaryOriginsJSON) + if err != nil { + return nil, err + } + } + + var backupOrigins = reverseProxy.DecodeBackupOrigins() + var newBackupOrigins = []*serverconfigs.OriginRef{} + found = false + for _, origin := range backupOrigins { + if origin.OriginId == req.OriginId { + found = true + continue + } + newBackupOrigins = append(newBackupOrigins, origin) + } + if found { + newBackupOriginsJSON, err := json.Marshal(newBackupOrigins) + if err != nil { + return nil, err + } + err = models.SharedReverseProxyDAO.UpdateReverseProxyBackupOrigins(tx, int64(reverseProxy.Id), newBackupOriginsJSON) + if err != nil { + return nil, err + } + } + + return this.Success() +} + // UpdateServerBasic 修改服务基本信息 func (this *ServerService) UpdateServerBasic(ctx context.Context, req *pb.UpdateServerBasicRequest) (*pb.RPCSuccess, error) { // 校验请求 @@ -1001,7 +1213,7 @@ func (this *ServerService) UpdateServerReverseProxy(ctx context.Context, req *pb } // 修改配置 - err = models.SharedServerDAO.UpdateServerReverseProxy(tx, req.ServerId, req.ReverseProxyJSON) + err = models.SharedServerDAO.UpdateServerReverseProxyRef(tx, req.ServerId, req.ReverseProxyJSON) if err != nil { return nil, err } @@ -1705,7 +1917,7 @@ func (this *ServerService) FindAndInitServerReverseProxyConfig(ctx context.Conte var tx = this.NullTx() - reverseProxyRef, err := models.SharedServerDAO.FindReverseProxyRef(tx, req.ServerId) + reverseProxyRef, err := models.SharedServerDAO.FindServerReverseProxyRef(tx, req.ServerId) if err != nil { return nil, err } @@ -1724,7 +1936,7 @@ func (this *ServerService) FindAndInitServerReverseProxyConfig(ctx context.Conte if err != nil { return nil, err } - err = models.SharedServerDAO.UpdateServerReverseProxy(tx, req.ServerId, refJSON) + err = models.SharedServerDAO.UpdateServerReverseProxyRef(tx, req.ServerId, refJSON) if err != nil { return nil, err }