diff --git a/internal/db/models/origin_dao.go b/internal/db/models/origin_dao.go index 675ad870..198df2c8 100644 --- a/internal/db/models/origin_dao.go +++ b/internal/db/models/origin_dao.go @@ -87,7 +87,7 @@ func (this *OriginDAO) FindOriginName(tx *dbs.Tx, id int64) (string, error) { } // CreateOrigin 创建源站 -func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, name string, addrJSON string, description string, weight int32, isOn bool, connTimeout *shared.TimeDuration, readTimeout *shared.TimeDuration, idleTimeout *shared.TimeDuration, maxConns int32, maxIdleConns int32) (originId int64, err error) { +func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, name string, addrJSON string, description string, weight int32, isOn bool, connTimeout *shared.TimeDuration, readTimeout *shared.TimeDuration, idleTimeout *shared.TimeDuration, maxConns int32, maxIdleConns int32, domains []string) (originId int64, err error) { op := NewOriginOperator() op.AdminId = adminId op.UserId = userId @@ -132,6 +132,17 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, nam weight = 0 } op.Weight = weight + + if len(domains) > 0 { + domainsJSON, err := json.Marshal(domains) + if err != nil { + return 0, err + } + op.Domains = domainsJSON + } else { + op.Domains = "[]" + } + op.State = OriginStateEnabled err = this.Save(tx, op) if err != nil { @@ -141,7 +152,7 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, nam } // UpdateOrigin 修改源站 -func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, addrJSON string, description string, weight int32, isOn bool, connTimeout *shared.TimeDuration, readTimeout *shared.TimeDuration, idleTimeout *shared.TimeDuration, maxConns int32, maxIdleConns int32) error { +func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, addrJSON string, description string, weight int32, isOn bool, connTimeout *shared.TimeDuration, readTimeout *shared.TimeDuration, idleTimeout *shared.TimeDuration, maxConns int32, maxIdleConns int32, domains []string) error { if originId <= 0 { return errors.New("invalid originId") } @@ -189,6 +200,17 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, add op.IsOn = isOn op.Version = dbs.SQL("version+1") + + if len(domains) > 0 { + domainsJSON, err := json.Marshal(domains) + if err != nil { + return err + } + op.Domains = domainsJSON + } else { + op.Domains = "[]" + } + err := this.Save(tx, op) if err != nil { return err @@ -229,6 +251,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap MaxIdleConns: int(origin.MaxIdleConns), RequestURI: origin.HttpRequestURI, RequestHost: origin.Host, + Domains: origin.DecodeDomains(), } if IsNotNull(origin.Addr) { diff --git a/internal/db/models/origin_model.go b/internal/db/models/origin_model.go index f79fb8a3..55d9f1a2 100644 --- a/internal/db/models/origin_model.go +++ b/internal/db/models/origin_model.go @@ -1,6 +1,6 @@ package models -// 源站 +// Origin 源站 type Origin struct { Id uint32 `field:"id"` // ID AdminId uint32 `field:"adminId"` // 管理员ID @@ -26,6 +26,7 @@ type Origin struct { Cert string `field:"cert"` // 证书设置 Ftp string `field:"ftp"` // FTP相关设置 CreatedAt uint64 `field:"createdAt"` // 创建时间 + Domains string `field:"domains"` // 所属域名 State uint8 `field:"state"` // 状态 } @@ -54,6 +55,7 @@ type OriginOperator struct { Cert interface{} // 证书设置 Ftp interface{} // FTP相关设置 CreatedAt interface{} // 创建时间 + Domains interface{} // 所属域名 State interface{} // 状态 } diff --git a/internal/db/models/origin_model_ext.go b/internal/db/models/origin_model_ext.go index f2809a88..b47302fa 100644 --- a/internal/db/models/origin_model_ext.go +++ b/internal/db/models/origin_model_ext.go @@ -3,10 +3,11 @@ package models import ( "encoding/json" "errors" + "github.com/TeaOSLab/EdgeAPI/internal/remotelogs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" ) -// 解析地址 +// DecodeAddr 解析地址 func (this *Origin) DecodeAddr() (*serverconfigs.NetworkAddressConfig, error) { if len(this.Addr) == 0 || this.Addr == "null" { return nil, errors.New("addr is empty") @@ -15,3 +16,14 @@ func (this *Origin) DecodeAddr() (*serverconfigs.NetworkAddressConfig, error) { err := json.Unmarshal([]byte(this.Addr), addr) return addr, err } + +func (this *Origin) DecodeDomains() []string { + var result = []string{} + if len(this.Domains) > 0 { + err := json.Unmarshal([]byte(this.Domains), &result) + if err != nil { + remotelogs.Error("Origin.DecodeDomains", err.Error()) + } + } + return result +} diff --git a/internal/rpc/services/service_origin.go b/internal/rpc/services/service_origin.go index 0eb15c2d..2255bff0 100644 --- a/internal/rpc/services/service_origin.go +++ b/internal/rpc/services/service_origin.go @@ -10,12 +10,12 @@ import ( "github.com/iwind/TeaGo/maps" ) -// 源站相关管理 +// OriginService 源站相关管理 type OriginService struct { BaseService } -// 创建源站 +// CreateOrigin 创建源站 func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOriginRequest) (*pb.CreateOriginResponse, error) { adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { @@ -58,7 +58,7 @@ func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOrigi } } - originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns) + originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, req.Domains) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOrigi return &pb.CreateOriginResponse{OriginId: originId}, nil } -// 修改源站 +// UpdateOrigin 修改源站 func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOriginRequest) (*pb.RPCSuccess, error) { _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { @@ -112,7 +112,7 @@ func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOrigi } } - err = models.SharedOriginDAO.UpdateOrigin(tx, req.OriginId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns) + err = models.SharedOriginDAO.UpdateOrigin(tx, req.OriginId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, req.Domains) if err != nil { return nil, err } @@ -120,7 +120,7 @@ func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOrigi return this.Success() } -// 查找单个源站信息 +// FindEnabledOrigin 查找单个源站信息 func (this *OriginService) FindEnabledOrigin(ctx context.Context, req *pb.FindEnabledOriginRequest) (*pb.FindEnabledOriginResponse, error) { _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { @@ -157,11 +157,12 @@ func (this *OriginService) FindEnabledOrigin(ctx context.Context, req *pb.FindEn PortRange: addr.PortRange, }, Description: origin.Description, + Domains: origin.DecodeDomains(), } return &pb.FindEnabledOriginResponse{Origin: result}, nil } -// 查找源站配置 +// FindEnabledOriginConfig 查找源站配置 func (this *OriginService) FindEnabledOriginConfig(ctx context.Context, req *pb.FindEnabledOriginConfigRequest) (*pb.FindEnabledOriginConfigResponse, error) { _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil {