反向代理源站实现使用域名分组

This commit is contained in:
刘祥超
2021-09-20 11:54:45 +08:00
parent f1af151080
commit 7a1bd29f6f
4 changed files with 49 additions and 11 deletions

View File

@@ -87,7 +87,7 @@ func (this *OriginDAO) FindOriginName(tx *dbs.Tx, id int64) (string, error) {
} }
// CreateOrigin 创建源站 // CreateOrigin 创建源站
func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, name string, addrJSON string, description string, weight int32, isOn bool, connTimeout *shared.TimeDuration, readTimeout *shared.TimeDuration, idleTimeout *shared.TimeDuration, maxConns int32, maxIdleConns int32) (originId int64, err error) { func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, name string, addrJSON string, description string, weight int32, isOn bool, connTimeout *shared.TimeDuration, readTimeout *shared.TimeDuration, idleTimeout *shared.TimeDuration, maxConns int32, maxIdleConns int32, domains []string) (originId int64, err error) {
op := NewOriginOperator() op := NewOriginOperator()
op.AdminId = adminId op.AdminId = adminId
op.UserId = userId op.UserId = userId
@@ -132,6 +132,17 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, nam
weight = 0 weight = 0
} }
op.Weight = weight op.Weight = weight
if len(domains) > 0 {
domainsJSON, err := json.Marshal(domains)
if err != nil {
return 0, err
}
op.Domains = domainsJSON
} else {
op.Domains = "[]"
}
op.State = OriginStateEnabled op.State = OriginStateEnabled
err = this.Save(tx, op) err = this.Save(tx, op)
if err != nil { if err != nil {
@@ -141,7 +152,7 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, nam
} }
// UpdateOrigin 修改源站 // UpdateOrigin 修改源站
func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, addrJSON string, description string, weight int32, isOn bool, connTimeout *shared.TimeDuration, readTimeout *shared.TimeDuration, idleTimeout *shared.TimeDuration, maxConns int32, maxIdleConns int32) error { func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, addrJSON string, description string, weight int32, isOn bool, connTimeout *shared.TimeDuration, readTimeout *shared.TimeDuration, idleTimeout *shared.TimeDuration, maxConns int32, maxIdleConns int32, domains []string) error {
if originId <= 0 { if originId <= 0 {
return errors.New("invalid originId") return errors.New("invalid originId")
} }
@@ -189,6 +200,17 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, add
op.IsOn = isOn op.IsOn = isOn
op.Version = dbs.SQL("version+1") op.Version = dbs.SQL("version+1")
if len(domains) > 0 {
domainsJSON, err := json.Marshal(domains)
if err != nil {
return err
}
op.Domains = domainsJSON
} else {
op.Domains = "[]"
}
err := this.Save(tx, op) err := this.Save(tx, op)
if err != nil { if err != nil {
return err return err
@@ -229,6 +251,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap
MaxIdleConns: int(origin.MaxIdleConns), MaxIdleConns: int(origin.MaxIdleConns),
RequestURI: origin.HttpRequestURI, RequestURI: origin.HttpRequestURI,
RequestHost: origin.Host, RequestHost: origin.Host,
Domains: origin.DecodeDomains(),
} }
if IsNotNull(origin.Addr) { if IsNotNull(origin.Addr) {

View File

@@ -1,6 +1,6 @@
package models package models
// 源站 // Origin 源站
type Origin struct { type Origin struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
AdminId uint32 `field:"adminId"` // 管理员ID AdminId uint32 `field:"adminId"` // 管理员ID
@@ -26,6 +26,7 @@ type Origin struct {
Cert string `field:"cert"` // 证书设置 Cert string `field:"cert"` // 证书设置
Ftp string `field:"ftp"` // FTP相关设置 Ftp string `field:"ftp"` // FTP相关设置
CreatedAt uint64 `field:"createdAt"` // 创建时间 CreatedAt uint64 `field:"createdAt"` // 创建时间
Domains string `field:"domains"` // 所属域名
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
} }
@@ -54,6 +55,7 @@ type OriginOperator struct {
Cert interface{} // 证书设置 Cert interface{} // 证书设置
Ftp interface{} // FTP相关设置 Ftp interface{} // FTP相关设置
CreatedAt interface{} // 创建时间 CreatedAt interface{} // 创建时间
Domains interface{} // 所属域名
State interface{} // 状态 State interface{} // 状态
} }

View File

@@ -3,10 +3,11 @@ package models
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
) )
// 解析地址 // DecodeAddr 解析地址
func (this *Origin) DecodeAddr() (*serverconfigs.NetworkAddressConfig, error) { func (this *Origin) DecodeAddr() (*serverconfigs.NetworkAddressConfig, error) {
if len(this.Addr) == 0 || this.Addr == "null" { if len(this.Addr) == 0 || this.Addr == "null" {
return nil, errors.New("addr is empty") return nil, errors.New("addr is empty")
@@ -15,3 +16,14 @@ func (this *Origin) DecodeAddr() (*serverconfigs.NetworkAddressConfig, error) {
err := json.Unmarshal([]byte(this.Addr), addr) err := json.Unmarshal([]byte(this.Addr), addr)
return addr, err return addr, err
} }
func (this *Origin) DecodeDomains() []string {
var result = []string{}
if len(this.Domains) > 0 {
err := json.Unmarshal([]byte(this.Domains), &result)
if err != nil {
remotelogs.Error("Origin.DecodeDomains", err.Error())
}
}
return result
}

View File

@@ -10,12 +10,12 @@ import (
"github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/maps"
) )
// 源站相关管理 // OriginService 源站相关管理
type OriginService struct { type OriginService struct {
BaseService BaseService
} }
// 创建源站 // CreateOrigin 创建源站
func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOriginRequest) (*pb.CreateOriginResponse, error) { func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOriginRequest) (*pb.CreateOriginResponse, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil { if err != nil {
@@ -58,7 +58,7 @@ func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOrigi
} }
} }
originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns) originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, req.Domains)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -66,7 +66,7 @@ func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOrigi
return &pb.CreateOriginResponse{OriginId: originId}, nil return &pb.CreateOriginResponse{OriginId: originId}, nil
} }
// 修改源站 // UpdateOrigin 修改源站
func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOriginRequest) (*pb.RPCSuccess, error) { func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOriginRequest) (*pb.RPCSuccess, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil { if err != nil {
@@ -112,7 +112,7 @@ func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOrigi
} }
} }
err = models.SharedOriginDAO.UpdateOrigin(tx, req.OriginId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns) err = models.SharedOriginDAO.UpdateOrigin(tx, req.OriginId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, req.Domains)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -120,7 +120,7 @@ func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOrigi
return this.Success() return this.Success()
} }
// 查找单个源站信息 // FindEnabledOrigin 查找单个源站信息
func (this *OriginService) FindEnabledOrigin(ctx context.Context, req *pb.FindEnabledOriginRequest) (*pb.FindEnabledOriginResponse, error) { func (this *OriginService) FindEnabledOrigin(ctx context.Context, req *pb.FindEnabledOriginRequest) (*pb.FindEnabledOriginResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil { if err != nil {
@@ -157,11 +157,12 @@ func (this *OriginService) FindEnabledOrigin(ctx context.Context, req *pb.FindEn
PortRange: addr.PortRange, PortRange: addr.PortRange,
}, },
Description: origin.Description, Description: origin.Description,
Domains: origin.DecodeDomains(),
} }
return &pb.FindEnabledOriginResponse{Origin: result}, nil return &pb.FindEnabledOriginResponse{Origin: result}, nil
} }
// 查找源站配置 // FindEnabledOriginConfig 查找源站配置
func (this *OriginService) FindEnabledOriginConfig(ctx context.Context, req *pb.FindEnabledOriginConfigRequest) (*pb.FindEnabledOriginConfigResponse, error) { func (this *OriginService) FindEnabledOriginConfig(ctx context.Context, req *pb.FindEnabledOriginConfigRequest) (*pb.FindEnabledOriginConfigResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil { if err != nil {