mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 16:00:25 +08:00 
			
		
		
		
	阶段性提交
This commit is contained in:
		
							
								
								
									
										155
									
								
								internal/configs/serverconfigs/cache_policy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										155
									
								
								internal/configs/serverconfigs/cache_policy.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,155 @@
 | 
			
		||||
package serverconfigs
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/configutils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
 | 
			
		||||
	"github.com/iwind/TeaGo/Tea"
 | 
			
		||||
	"github.com/iwind/TeaGo/files"
 | 
			
		||||
	"github.com/iwind/TeaGo/lists"
 | 
			
		||||
	"github.com/iwind/TeaGo/logs"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var DefaultSkippedResponseCacheControlValues = []string{"private", "no-cache", "no-store"}
 | 
			
		||||
 | 
			
		||||
// 缓存策略配置
 | 
			
		||||
type CachePolicy struct {
 | 
			
		||||
	Id   int    `yaml:"id" json:"id"`
 | 
			
		||||
	IsOn bool   `yaml:"isOn" json:"isOn"` // 是否开启 TODO
 | 
			
		||||
	Name string `yaml:"name" json:"name"` // 名称
 | 
			
		||||
 | 
			
		||||
	Key      string              `yaml:"key" json:"key"`           // 每个缓存的Key规则,里面可以有变量
 | 
			
		||||
	Capacity shared.SizeCapacity `yaml:"capacity" json:"capacity"` // 最大内容容量
 | 
			
		||||
	Life     shared.TimeDuration `yaml:"life" json:"life"`         // 时间
 | 
			
		||||
	Status   []int               `yaml:"status" json:"status"`     // 缓存的状态码列表
 | 
			
		||||
	MaxSize  shared.SizeCapacity `yaml:"maxSize" json:"maxSize"`   // 能够请求的最大尺寸
 | 
			
		||||
 | 
			
		||||
	SkipResponseCacheControlValues []string `yaml:"skipCacheControlValues" json:"skipCacheControlValues"`     // 可以跳过的响应的Cache-Control值
 | 
			
		||||
	SkipResponseSetCookie          bool     `yaml:"skipSetCookie" json:"skipSetCookie"`                       // 是否跳过响应的Set-Cookie Header
 | 
			
		||||
	EnableRequestCachePragma       bool     `yaml:"enableRequestCachePragma" json:"enableRequestCachePragma"` // 是否支持客户端的Pragma: no-cache
 | 
			
		||||
 | 
			
		||||
	Cond []*shared.RequestCond `yaml:"cond" json:"cond"`
 | 
			
		||||
 | 
			
		||||
	life     time.Duration
 | 
			
		||||
	maxSize  int64
 | 
			
		||||
	capacity int64
 | 
			
		||||
 | 
			
		||||
	uppercaseSkipCacheControlValues []string
 | 
			
		||||
 | 
			
		||||
	Type    string                 `yaml:"type" json:"type"`       // 类型
 | 
			
		||||
	Options map[string]interface{} `yaml:"options" json:"options"` // 选项
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取新对象
 | 
			
		||||
func NewCachePolicy() *CachePolicy {
 | 
			
		||||
	return &CachePolicy{
 | 
			
		||||
		SkipResponseCacheControlValues: DefaultSkippedResponseCacheControlValues,
 | 
			
		||||
		SkipResponseSetCookie:          true,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 从文件中读取缓存策略
 | 
			
		||||
func NewCachePolicyFromFile(file string) *CachePolicy {
 | 
			
		||||
	if len(file) == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	reader, err := files.NewReader(Tea.ConfigFile(file))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Error(err)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = reader.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	p := NewCachePolicy()
 | 
			
		||||
	err = reader.ReadYAML(p)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Error(err)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return p
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 校验
 | 
			
		||||
func (this *CachePolicy) Validate() error {
 | 
			
		||||
	var err error
 | 
			
		||||
	this.maxSize = this.MaxSize.Bytes()
 | 
			
		||||
	this.life = this.Life.Duration()
 | 
			
		||||
	this.capacity = this.Capacity.Bytes()
 | 
			
		||||
 | 
			
		||||
	this.uppercaseSkipCacheControlValues = []string{}
 | 
			
		||||
	for _, value := range this.SkipResponseCacheControlValues {
 | 
			
		||||
		this.uppercaseSkipCacheControlValues = append(this.uppercaseSkipCacheControlValues, strings.ToUpper(value))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// cond
 | 
			
		||||
	if len(this.Cond) > 0 {
 | 
			
		||||
		for _, cond := range this.Cond {
 | 
			
		||||
			err := cond.Validate()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 最大数据尺寸
 | 
			
		||||
func (this *CachePolicy) MaxDataSize() int64 {
 | 
			
		||||
	return this.maxSize
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 容量
 | 
			
		||||
func (this *CachePolicy) CapacitySize() int64 {
 | 
			
		||||
	return this.capacity
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 生命周期
 | 
			
		||||
func (this *CachePolicy) LifeDuration() time.Duration {
 | 
			
		||||
	return this.life
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 保存
 | 
			
		||||
func (this *CachePolicy) Save() error {
 | 
			
		||||
	shared.Locker.Lock()
 | 
			
		||||
	defer shared.Locker.Unlock()
 | 
			
		||||
 | 
			
		||||
	filename := "cache.policy." + strconv.Itoa(this.Id) + ".conf"
 | 
			
		||||
	writer, err := files.NewWriter(Tea.ConfigFile(filename))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = writer.Close()
 | 
			
		||||
	}()
 | 
			
		||||
	_, err = writer.WriteYAML(this)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 删除
 | 
			
		||||
func (this *CachePolicy) Delete() error {
 | 
			
		||||
	filename := "cache.policy." + strconv.Itoa(this.Id) + ".conf"
 | 
			
		||||
	return files.NewFile(Tea.ConfigFile(filename)).Delete()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 是否包含某个Cache-Control值
 | 
			
		||||
func (this *CachePolicy) ContainsCacheControl(value string) bool {
 | 
			
		||||
	return lists.ContainsString(this.uppercaseSkipCacheControlValues, strings.ToUpper(value))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 检查是否匹配关键词
 | 
			
		||||
func (this *CachePolicy) MatchKeyword(keyword string) (matched bool, name string, tags []string) {
 | 
			
		||||
	if configutils.MatchKeyword(this.Name, keyword) || configutils.MatchKeyword(this.Type, keyword) {
 | 
			
		||||
		matched = true
 | 
			
		||||
		name = this.Name
 | 
			
		||||
		if len(this.Type) > 0 {
 | 
			
		||||
			tags = []string{"类型:" + this.Type}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@@ -11,15 +11,15 @@ var globalConfigFile = "global.yaml"
 | 
			
		||||
// 全局设置
 | 
			
		||||
type GlobalConfig struct {
 | 
			
		||||
	HTTPAll struct {
 | 
			
		||||
		MatchDomainStrictly bool `yaml:"matchDomainStrictly"`
 | 
			
		||||
	} `yaml:"httpAll"`
 | 
			
		||||
	HTTP   struct{} `yaml:"http"`
 | 
			
		||||
	HTTPS  struct{} `yaml:"https"`
 | 
			
		||||
	TCPAll struct{} `yaml:"tcpAll"`
 | 
			
		||||
	TCP    struct{} `yaml:"tcp"`
 | 
			
		||||
	TLS    struct{} `yaml:"tls"`
 | 
			
		||||
	Unix   struct{} `yaml:"unix"`
 | 
			
		||||
	UDP    struct{} `yaml:"udp"`
 | 
			
		||||
		MatchDomainStrictly bool `yaml:"matchDomainStrictly" json:"matchDomainStrictly"`
 | 
			
		||||
	} `yaml:"httpAll" json:"httpAll"`
 | 
			
		||||
	HTTP   struct{} `yaml:"http" json:"http"`
 | 
			
		||||
	HTTPS  struct{} `yaml:"https" json:"https"`
 | 
			
		||||
	TCPAll struct{} `yaml:"tcpAll" json:"tcpAll"`
 | 
			
		||||
	TCP    struct{} `yaml:"tcp" json:"tcp"`
 | 
			
		||||
	TLS    struct{} `yaml:"tls" json:"tls"`
 | 
			
		||||
	Unix   struct{} `yaml:"unix" json:"unix"`
 | 
			
		||||
	UDP    struct{} `yaml:"udp" json:"udp"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SharedGlobalConfig() *GlobalConfig {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,181 @@
 | 
			
		||||
package serverconfigs
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/sslconfigs"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 源站服务配置
 | 
			
		||||
type OriginServerConfig struct {
 | 
			
		||||
	Id          string                `yaml:"id" json:"id"`                   // ID
 | 
			
		||||
	IsOn        bool                  `yaml:"isOn" json:"isOn"`               // 是否启用
 | 
			
		||||
	HeaderList *shared.HeaderList `yaml:"headers" json:"headers"`
 | 
			
		||||
 | 
			
		||||
	Id          int64                 `yaml:"id" json:"id"`                   // ID
 | 
			
		||||
	IsOn        bool                  `yaml:"isOn" json:"isOn"`               // 是否启用 TODO
 | 
			
		||||
	Version     int                   `yaml:"version" json:"version"`         // 版本
 | 
			
		||||
	Name        string                `yaml:"name" json:"name"`               // 名称 TODO
 | 
			
		||||
	Addr        *NetworkAddressConfig `yaml:"addr" json:"addr"`               // 地址
 | 
			
		||||
	Description string                `yaml:"description" json:"description"` // 描述 TODO
 | 
			
		||||
	Code        string                `yaml:"code" json:"code"`               // 代号 TODO
 | 
			
		||||
	Scheme      string                `yaml:"scheme" json:"scheme"`           // 协议 TODO
 | 
			
		||||
 | 
			
		||||
	Weight       uint   `yaml:"weight" json:"weight"`             // 权重 TODO
 | 
			
		||||
	IsBackup     bool   `yaml:"backup" json:"isBackup"`           // 是否为备份 TODO
 | 
			
		||||
	FailTimeout  string `yaml:"failTimeout" json:"failTimeout"`   // 连接失败超时 TODO
 | 
			
		||||
	ReadTimeout  string `yaml:"readTimeout" json:"readTimeout"`   // 读取超时时间 TODO
 | 
			
		||||
	IdleTimeout  string `yaml:"idleTimeout" json:"idleTimeout"`   // 空闲连接超时时间 TODO
 | 
			
		||||
	MaxFails     int32  `yaml:"maxFails" json:"maxFails"`         // 最多失败次数 TODO
 | 
			
		||||
	CurrentFails int32  `yaml:"currentFails" json:"currentFails"` // 当前已失败次数 TODO
 | 
			
		||||
	MaxConns     int32  `yaml:"maxConns" json:"maxConns"`         // 最大并发连接数 TODO
 | 
			
		||||
	CurrentConns int32  `yaml:"currentConns" json:"currentConns"` // 当前连接数 TODO
 | 
			
		||||
	IdleConns    int32  `yaml:"idleConns" json:"idleConns"`       // 最大空闲连接数 TODO
 | 
			
		||||
 | 
			
		||||
	IsDown   bool      `yaml:"down" json:"isDown"`                           // 是否下线 TODO
 | 
			
		||||
	DownTime time.Time `yaml:"downTime,omitempty" json:"downTime,omitempty"` // 下线时间 TODO
 | 
			
		||||
 | 
			
		||||
	RequestURI      string                 `yaml:"requestURI" json:"requestURI"`           // 转发后的请求URI TODO
 | 
			
		||||
	ResponseHeaders []*shared.HeaderConfig `yaml:"responseHeaders" json:"responseHeaders"` // 响应Header TODO
 | 
			
		||||
	Host            string                 `yaml:"host" json:"host"`                       // 自定义主机名 TODO
 | 
			
		||||
 | 
			
		||||
	// 健康检查URL,目前支持:
 | 
			
		||||
	// - http|https 返回2xx-3xx认为成功
 | 
			
		||||
	HealthCheck struct {
 | 
			
		||||
		IsOn        bool                `yaml:"isOn" json:"isOn"`               // 是否开启 TODO
 | 
			
		||||
		URL         string              `yaml:"url" json:"url"`                 // TODO
 | 
			
		||||
		Interval    int                 `yaml:"interval" json:"interval"`       // TODO
 | 
			
		||||
		StatusCodes []int               `yaml:"statusCodes" json:"statusCodes"` // TODO
 | 
			
		||||
		Timeout     shared.TimeDuration `yaml:"timeout" json:"timeout"`         // 超时时间 TODO
 | 
			
		||||
	} `yaml:"healthCheck" json:"healthCheck"`
 | 
			
		||||
 | 
			
		||||
	Cert *sslconfigs.SSLCertConfig `yaml:"cert" json:"cert"` // 请求源服务器用的证书
 | 
			
		||||
 | 
			
		||||
	// ftp
 | 
			
		||||
	FTP *OriginServerFTPConfig `yaml:"ftp" json:"ftp"`
 | 
			
		||||
 | 
			
		||||
	failTimeoutDuration time.Duration
 | 
			
		||||
	readTimeoutDuration time.Duration
 | 
			
		||||
	idleTimeoutDuration time.Duration
 | 
			
		||||
 | 
			
		||||
	hasRequestURI bool
 | 
			
		||||
	requestPath   string
 | 
			
		||||
	requestArgs   string
 | 
			
		||||
 | 
			
		||||
	hasRequestHeaders  bool
 | 
			
		||||
	hasResponseHeaders bool
 | 
			
		||||
 | 
			
		||||
	hasHost bool
 | 
			
		||||
 | 
			
		||||
	uniqueKey string
 | 
			
		||||
 | 
			
		||||
	hasAddrVariables bool // 地址中是否含有变量
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 校验
 | 
			
		||||
func (this *OriginServerConfig) Init() error {
 | 
			
		||||
	// 证书
 | 
			
		||||
	if this.Cert != nil {
 | 
			
		||||
		err := this.Cert.Init()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// unique key
 | 
			
		||||
	this.uniqueKey = strconv.FormatInt(this.Id, 10) + "@" + fmt.Sprintf("%d", this.Version)
 | 
			
		||||
 | 
			
		||||
	// failTimeout
 | 
			
		||||
	if len(this.FailTimeout) > 0 {
 | 
			
		||||
		this.failTimeoutDuration, _ = time.ParseDuration(this.FailTimeout)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// readTimeout
 | 
			
		||||
	if len(this.ReadTimeout) > 0 {
 | 
			
		||||
		this.readTimeoutDuration, _ = time.ParseDuration(this.ReadTimeout)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// idleTimeout
 | 
			
		||||
	if len(this.IdleTimeout) > 0 {
 | 
			
		||||
		this.idleTimeoutDuration, _ = time.ParseDuration(this.IdleTimeout)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Headers
 | 
			
		||||
	if this.HeaderList != nil {
 | 
			
		||||
		err := this.HeaderList.Init()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// request uri
 | 
			
		||||
	if len(this.RequestURI) == 0 || this.RequestURI == "${requestURI}" {
 | 
			
		||||
		this.hasRequestURI = false
 | 
			
		||||
	} else {
 | 
			
		||||
		this.hasRequestURI = true
 | 
			
		||||
 | 
			
		||||
		if strings.Contains(this.RequestURI, "?") {
 | 
			
		||||
			pieces := strings.SplitN(this.RequestURI, "?", -1)
 | 
			
		||||
			this.requestPath = pieces[0]
 | 
			
		||||
			this.requestArgs = pieces[1]
 | 
			
		||||
		} else {
 | 
			
		||||
			this.requestPath = this.RequestURI
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO init health check
 | 
			
		||||
 | 
			
		||||
	// headers
 | 
			
		||||
	if this.HeaderList != nil {
 | 
			
		||||
		this.hasRequestHeaders = len(this.HeaderList.RequestHeaders) > 0
 | 
			
		||||
	}
 | 
			
		||||
	this.hasResponseHeaders = len(this.ResponseHeaders) > 0
 | 
			
		||||
 | 
			
		||||
	// host
 | 
			
		||||
	this.hasHost = len(this.Host) > 0
 | 
			
		||||
 | 
			
		||||
	// variables
 | 
			
		||||
	// TODO 在host和port中支持变量
 | 
			
		||||
	this.hasAddrVariables = false
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 候选对象代号
 | 
			
		||||
func (this *OriginServerConfig) CandidateCodes() []string {
 | 
			
		||||
	codes := []string{strconv.FormatInt(this.Id, 10)}
 | 
			
		||||
	if len(this.Code) > 0 {
 | 
			
		||||
		codes = append(codes, this.Code)
 | 
			
		||||
	}
 | 
			
		||||
	return codes
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 候选对象权重
 | 
			
		||||
func (this *OriginServerConfig) CandidateWeight() uint {
 | 
			
		||||
	return this.Weight
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 连接源站
 | 
			
		||||
func (this *OriginServerConfig) Connect() (net.Conn, error) {
 | 
			
		||||
	switch this.Scheme {
 | 
			
		||||
	case "", ProtocolTCP:
 | 
			
		||||
		// TODO 支持TCP4/TCP6
 | 
			
		||||
		// TODO 支持指定特定网卡
 | 
			
		||||
		// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
 | 
			
		||||
		return net.DialTimeout("tcp", this.Addr.Host+":"+this.Addr.PortRange, this.failTimeoutDuration)
 | 
			
		||||
	case ProtocolTLS:
 | 
			
		||||
		// TODO 支持TCP4/TCP6
 | 
			
		||||
		// TODO 支持指定特定网卡
 | 
			
		||||
		// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
 | 
			
		||||
		// TODO 支持使用证书
 | 
			
		||||
		return tls.Dial("tcp", this.Addr.Host+":"+this.Addr.PortRange, &tls.Config{})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO 支持从Unix、Pipe、HTTP、HTTPS中读取数据
 | 
			
		||||
 | 
			
		||||
	return nil, errors.New("invalid scheme '" + this.Scheme + "'")
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										8
									
								
								internal/configs/serverconfigs/origin_server_ftp.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								internal/configs/serverconfigs/origin_server_ftp.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,8 @@
 | 
			
		||||
package serverconfigs
 | 
			
		||||
 | 
			
		||||
// FTP源站配置
 | 
			
		||||
type OriginServerFTPConfig struct {
 | 
			
		||||
	Username string `yaml:"username" json:"username"` // 用户名
 | 
			
		||||
	Password string `yaml:"password" json:"password"` // 密码
 | 
			
		||||
	Dir      string `yaml:"dir" json:"dir"`           // 目录
 | 
			
		||||
}
 | 
			
		||||
@@ -5,9 +5,10 @@ import "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/sslconfigs"
 | 
			
		||||
type TLSProtocolConfig struct {
 | 
			
		||||
	BaseProtocol `yaml:",inline"`
 | 
			
		||||
 | 
			
		||||
	SSL *sslconfigs.SSLConfig `yaml:"ssl"`
 | 
			
		||||
	SSL *sslconfigs.SSLConfig `yaml:"ssl" json:"ssl"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 初始化
 | 
			
		||||
func (this *TLSProtocolConfig) Init() error {
 | 
			
		||||
	err := this.InitBase()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,105 @@
 | 
			
		||||
package serverconfigs
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/scheduling"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 反向代理设置
 | 
			
		||||
type ReverseProxyConfig struct {
 | 
			
		||||
	IsOn    bool                  `yaml:"isOn" json:"isOn"`       // 是否启用
 | 
			
		||||
	Origins []*OriginServerConfig `yaml:"origins" json:"origins"` // 源站列表
 | 
			
		||||
	IsOn       bool                  `yaml:"isOn" json:"isOn"`             // 是否启用 TODO
 | 
			
		||||
	Origins    []*OriginServerConfig `yaml:"origins" json:"origins"`       // 源站列表
 | 
			
		||||
	Scheduling *SchedulingConfig     `yaml:"scheduling" json:"scheduling"` // 调度算法选项
 | 
			
		||||
 | 
			
		||||
	hasOrigins         bool
 | 
			
		||||
	schedulingIsBackup bool
 | 
			
		||||
	schedulingObject   scheduling.SchedulingInterface
 | 
			
		||||
	schedulingLocker   sync.Mutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 初始化
 | 
			
		||||
func (this *ReverseProxyConfig) Init() error {
 | 
			
		||||
	this.hasOrigins = len(this.Origins) > 0
 | 
			
		||||
 | 
			
		||||
	for _, origin := range this.Origins {
 | 
			
		||||
		err := origin.Init()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// scheduling
 | 
			
		||||
	this.SetupScheduling(false)
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 取得下一个可用的后端服务
 | 
			
		||||
func (this *ReverseProxyConfig) NextOrigin(call *shared.RequestCall) *OriginServerConfig {
 | 
			
		||||
	this.schedulingLocker.Lock()
 | 
			
		||||
	defer this.schedulingLocker.Unlock()
 | 
			
		||||
 | 
			
		||||
	if this.schedulingObject == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if this.Scheduling != nil && call != nil && call.Options != nil {
 | 
			
		||||
		for k, v := range this.Scheduling.Options {
 | 
			
		||||
			call.Options[k] = v
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	candidate := this.schedulingObject.Next(call)
 | 
			
		||||
	if candidate == nil {
 | 
			
		||||
		// 启用备用服务器
 | 
			
		||||
		if !this.schedulingIsBackup {
 | 
			
		||||
			this.SetupScheduling(true)
 | 
			
		||||
 | 
			
		||||
			candidate = this.schedulingObject.Next(call)
 | 
			
		||||
			if candidate == nil {
 | 
			
		||||
				return nil
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if candidate == nil {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return candidate.(*OriginServerConfig)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 设置调度算法
 | 
			
		||||
func (this *ReverseProxyConfig) SetupScheduling(isBackup bool) {
 | 
			
		||||
	if !isBackup {
 | 
			
		||||
		this.schedulingLocker.Lock()
 | 
			
		||||
		defer this.schedulingLocker.Unlock()
 | 
			
		||||
	}
 | 
			
		||||
	this.schedulingIsBackup = isBackup
 | 
			
		||||
 | 
			
		||||
	if this.Scheduling == nil {
 | 
			
		||||
		this.schedulingObject = &scheduling.RandomScheduling{}
 | 
			
		||||
	} else {
 | 
			
		||||
		typeCode := this.Scheduling.Code
 | 
			
		||||
		s := scheduling.FindSchedulingType(typeCode)
 | 
			
		||||
		if s == nil {
 | 
			
		||||
			this.Scheduling = nil
 | 
			
		||||
			this.schedulingObject = &scheduling.RandomScheduling{}
 | 
			
		||||
		} else {
 | 
			
		||||
			this.schedulingObject = s["instance"].(scheduling.SchedulingInterface)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, origin := range this.Origins {
 | 
			
		||||
		if origin.IsOn && !origin.IsDown {
 | 
			
		||||
			if isBackup && origin.IsBackup {
 | 
			
		||||
				this.schedulingObject.Add(origin)
 | 
			
		||||
			} else if !isBackup && !origin.IsBackup {
 | 
			
		||||
				this.schedulingObject.Add(origin)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.schedulingObject.Start()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										10
									
								
								internal/configs/serverconfigs/scheduling/candidate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								internal/configs/serverconfigs/scheduling/candidate.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
			
		||||
package scheduling
 | 
			
		||||
 | 
			
		||||
// 候选对象接口
 | 
			
		||||
type CandidateInterface interface {
 | 
			
		||||
	// 权重
 | 
			
		||||
	CandidateWeight() uint
 | 
			
		||||
 | 
			
		||||
	// 代号
 | 
			
		||||
	CandidateCodes() []string
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										39
									
								
								internal/configs/serverconfigs/scheduling/scheduling.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								internal/configs/serverconfigs/scheduling/scheduling.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,39 @@
 | 
			
		||||
package scheduling
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
 | 
			
		||||
	"github.com/iwind/TeaGo/maps"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 调度算法接口
 | 
			
		||||
type SchedulingInterface interface {
 | 
			
		||||
	// 是否有候选对象
 | 
			
		||||
	HasCandidates() bool
 | 
			
		||||
 | 
			
		||||
	// 添加候选对象
 | 
			
		||||
	Add(candidate ...CandidateInterface)
 | 
			
		||||
 | 
			
		||||
	// 启动
 | 
			
		||||
	Start()
 | 
			
		||||
 | 
			
		||||
	// 查找下一个候选对象
 | 
			
		||||
	Next(call *shared.RequestCall) CandidateInterface
 | 
			
		||||
 | 
			
		||||
	// 获取简要信息
 | 
			
		||||
	Summary() maps.Map
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 调度算法基础类
 | 
			
		||||
type Scheduling struct {
 | 
			
		||||
	Candidates []CandidateInterface
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 判断是否有候选对象
 | 
			
		||||
func (this *Scheduling) HasCandidates() bool {
 | 
			
		||||
	return len(this.Candidates) > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 添加候选对象
 | 
			
		||||
func (this *Scheduling) Add(candidate ...CandidateInterface) {
 | 
			
		||||
	this.Candidates = append(this.Candidates, candidate...)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										45
									
								
								internal/configs/serverconfigs/scheduling/scheduling_hash.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								internal/configs/serverconfigs/scheduling/scheduling_hash.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,45 @@
 | 
			
		||||
package scheduling
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
 | 
			
		||||
	"github.com/iwind/TeaGo/maps"
 | 
			
		||||
	"hash/crc32"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Hash调度算法
 | 
			
		||||
type HashScheduling struct {
 | 
			
		||||
	Scheduling
 | 
			
		||||
 | 
			
		||||
	count uint32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 启动
 | 
			
		||||
func (this *HashScheduling) Start() {
 | 
			
		||||
	this.count = uint32(len(this.Candidates))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取下一个候选对象
 | 
			
		||||
func (this *HashScheduling) Next(call *shared.RequestCall) CandidateInterface {
 | 
			
		||||
	if this.count == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	key := call.Options.GetString("key")
 | 
			
		||||
 | 
			
		||||
	if call.Formatter != nil {
 | 
			
		||||
		key = call.Formatter(key)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sum := crc32.ChecksumIEEE([]byte(key))
 | 
			
		||||
	return this.Candidates[sum%this.count]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取简要信息
 | 
			
		||||
func (this *HashScheduling) Summary() maps.Map {
 | 
			
		||||
	return maps.Map{
 | 
			
		||||
		"code":        "hash",
 | 
			
		||||
		"name":        "Hash算法",
 | 
			
		||||
		"description": "根据自定义的键值的Hash值分配后端服务器",
 | 
			
		||||
		"networks":    []string{"http"},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,45 @@
 | 
			
		||||
package scheduling
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestHashScheduling_Next(t *testing.T) {
 | 
			
		||||
	s := &HashScheduling{}
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "a",
 | 
			
		||||
		Weight: 10,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "b",
 | 
			
		||||
		Weight: 10,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "c",
 | 
			
		||||
		Weight: 10,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "d",
 | 
			
		||||
		Weight: 30,
 | 
			
		||||
	})
 | 
			
		||||
	s.Start()
 | 
			
		||||
 | 
			
		||||
	hits := map[string]uint{}
 | 
			
		||||
	for _, c := range s.Candidates {
 | 
			
		||||
		hits[c.(*TestCandidate).Name] = 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rand.Seed(time.Now().UnixNano())
 | 
			
		||||
	for i := 0; i < 1000000; i ++ {
 | 
			
		||||
		call := shared.NewRequestCall()
 | 
			
		||||
		call.Options["key"] = "192.168.1." + fmt.Sprintf("%d", rand.Int())
 | 
			
		||||
 | 
			
		||||
		c := s.Next(call)
 | 
			
		||||
		hits[c.(*TestCandidate).Name] ++
 | 
			
		||||
	}
 | 
			
		||||
	t.Log(hits)
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,78 @@
 | 
			
		||||
package scheduling
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
 | 
			
		||||
	"github.com/iwind/TeaGo/maps"
 | 
			
		||||
	"math"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 随机调度算法
 | 
			
		||||
type RandomScheduling struct {
 | 
			
		||||
	Scheduling
 | 
			
		||||
 | 
			
		||||
	array []CandidateInterface
 | 
			
		||||
	count uint // 实际总的服务器数
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 启动
 | 
			
		||||
func (this *RandomScheduling) Start() {
 | 
			
		||||
	sumWeight := uint(0)
 | 
			
		||||
	for _, c := range this.Candidates {
 | 
			
		||||
		weight := c.CandidateWeight()
 | 
			
		||||
		if weight == 0 {
 | 
			
		||||
			weight = 1
 | 
			
		||||
		} else if weight > 10000 {
 | 
			
		||||
			weight = 10000
 | 
			
		||||
		}
 | 
			
		||||
		sumWeight += weight
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if sumWeight == 0 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, c := range this.Candidates {
 | 
			
		||||
		weight := c.CandidateWeight()
 | 
			
		||||
		if weight == 0 {
 | 
			
		||||
			weight = 1
 | 
			
		||||
		} else if weight > 10000 {
 | 
			
		||||
			weight = 10000
 | 
			
		||||
		}
 | 
			
		||||
		count := uint(0)
 | 
			
		||||
		if sumWeight <= 1000 {
 | 
			
		||||
			count = weight
 | 
			
		||||
		} else {
 | 
			
		||||
			count = uint(math.Round(float64(weight*10000) / float64(sumWeight))) // 1% 产生 100个数据,最多支持10000个服务器
 | 
			
		||||
		}
 | 
			
		||||
		for i := uint(0); i < count; i++ {
 | 
			
		||||
			this.array = append(this.array, c)
 | 
			
		||||
		}
 | 
			
		||||
		this.count += count
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rand.Seed(time.Now().UnixNano())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取下一个候选对象
 | 
			
		||||
func (this *RandomScheduling) Next(call *shared.RequestCall) CandidateInterface {
 | 
			
		||||
	if this.count == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	if this.count == 1 {
 | 
			
		||||
		return this.array[0]
 | 
			
		||||
	}
 | 
			
		||||
	index := rand.Int() % int(this.count)
 | 
			
		||||
	return this.array[index]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取简要信息
 | 
			
		||||
func (this *RandomScheduling) Summary() maps.Map {
 | 
			
		||||
	return maps.Map{
 | 
			
		||||
		"code":        "random",
 | 
			
		||||
		"name":        "Random随机算法",
 | 
			
		||||
		"description": "根据权重设置随机分配后端服务器",
 | 
			
		||||
		"networks":    []string{"http", "tcp"},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,79 @@
 | 
			
		||||
package scheduling
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"sync"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TestCandidate struct {
 | 
			
		||||
	Name   string
 | 
			
		||||
	Weight uint
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *TestCandidate) CandidateWeight() uint {
 | 
			
		||||
	return this.Weight
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *TestCandidate) CandidateCodes() []string {
 | 
			
		||||
	return []string{this.Name}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRandomScheduling_Next(t *testing.T) {
 | 
			
		||||
	s := &RandomScheduling{}
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "a",
 | 
			
		||||
		Weight: 10,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "b",
 | 
			
		||||
		Weight: 10,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "c",
 | 
			
		||||
		Weight: 10,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "d",
 | 
			
		||||
		Weight: 30,
 | 
			
		||||
	})
 | 
			
		||||
	s.Start()
 | 
			
		||||
 | 
			
		||||
	/**for _, c := range s.array {
 | 
			
		||||
		t.Log(c.(*TestCandidate).Name, ":", c.CandidateWeight())
 | 
			
		||||
	}**/
 | 
			
		||||
 | 
			
		||||
	hits := map[string]uint{}
 | 
			
		||||
	for _, c := range s.array {
 | 
			
		||||
		hits[c.(*TestCandidate).Name] = 0
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Log("count:", s.count, "array length:", len(s.array))
 | 
			
		||||
 | 
			
		||||
	var locker sync.Mutex
 | 
			
		||||
	var wg = sync.WaitGroup{}
 | 
			
		||||
	wg.Add(100 * 10000)
 | 
			
		||||
	for i := 0; i < 100*10000; i ++ {
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
 | 
			
		||||
			c := s.Next(nil)
 | 
			
		||||
 | 
			
		||||
			locker.Lock()
 | 
			
		||||
			defer locker.Unlock()
 | 
			
		||||
			hits[c.(*TestCandidate).Name] ++
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
 | 
			
		||||
	t.Log(hits)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRandomScheduling_NextZero(t *testing.T) {
 | 
			
		||||
	s := &RandomScheduling{}
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "a",
 | 
			
		||||
		Weight: 0,
 | 
			
		||||
	})
 | 
			
		||||
	s.Start()
 | 
			
		||||
	t.Log(s.Next(nil))
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,80 @@
 | 
			
		||||
package scheduling
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
 | 
			
		||||
	"github.com/iwind/TeaGo/lists"
 | 
			
		||||
	"github.com/iwind/TeaGo/maps"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 轮询调度算法
 | 
			
		||||
type RoundRobinScheduling struct {
 | 
			
		||||
	Scheduling
 | 
			
		||||
 | 
			
		||||
	rawWeights     []uint
 | 
			
		||||
	currentWeights []uint
 | 
			
		||||
	count          uint
 | 
			
		||||
	index          uint
 | 
			
		||||
 | 
			
		||||
	locker sync.Mutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 启动
 | 
			
		||||
func (this *RoundRobinScheduling) Start() {
 | 
			
		||||
	lists.Sort(this.Candidates, func(i int, j int) bool {
 | 
			
		||||
		c1 := this.Candidates[i]
 | 
			
		||||
		c2 := this.Candidates[j]
 | 
			
		||||
		return c1.CandidateWeight() > c2.CandidateWeight()
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	for _, c := range this.Candidates {
 | 
			
		||||
		weight := c.CandidateWeight()
 | 
			
		||||
		if weight == 0 {
 | 
			
		||||
			weight = 1
 | 
			
		||||
		} else if weight > 10000 {
 | 
			
		||||
			weight = 10000
 | 
			
		||||
		}
 | 
			
		||||
		this.rawWeights = append(this.rawWeights, weight)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.currentWeights = append([]uint{}, this.rawWeights...)
 | 
			
		||||
	this.count = uint(len(this.Candidates))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取下一个候选对象
 | 
			
		||||
func (this *RoundRobinScheduling) Next(call *shared.RequestCall) CandidateInterface {
 | 
			
		||||
	if this.count == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	this.locker.Lock()
 | 
			
		||||
	defer this.locker.Unlock()
 | 
			
		||||
 | 
			
		||||
	if this.index > this.count-1 {
 | 
			
		||||
		this.index = 0
 | 
			
		||||
	}
 | 
			
		||||
	weight := this.currentWeights[this.index]
 | 
			
		||||
 | 
			
		||||
	// 已经一轮了,则重置状态
 | 
			
		||||
	if weight == 0 {
 | 
			
		||||
		if this.currentWeights[0] == 0 {
 | 
			
		||||
			this.currentWeights = append([]uint{}, this.rawWeights...)
 | 
			
		||||
		}
 | 
			
		||||
		this.index = 0
 | 
			
		||||
		weight = this.currentWeights[this.index]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c := this.Candidates[this.index]
 | 
			
		||||
	this.currentWeights[this.index] --
 | 
			
		||||
	this.index++
 | 
			
		||||
	return c
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取简要信息
 | 
			
		||||
func (this *RoundRobinScheduling) Summary() maps.Map {
 | 
			
		||||
	return maps.Map{
 | 
			
		||||
		"code":        "roundRobin",
 | 
			
		||||
		"name":        "RoundRobin轮询算法",
 | 
			
		||||
		"description": "根据权重,依次分配后端服务器",
 | 
			
		||||
		"networks":    []string{"http", "tcp"},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,101 @@
 | 
			
		||||
package scheduling
 | 
			
		||||
 | 
			
		||||
import "testing"
 | 
			
		||||
 | 
			
		||||
func TestRoundRobinScheduling_Next(t *testing.T) {
 | 
			
		||||
	s := &RoundRobinScheduling{}
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "a",
 | 
			
		||||
		Weight: 5,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "b",
 | 
			
		||||
		Weight: 10,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "c",
 | 
			
		||||
		Weight: 20,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "d",
 | 
			
		||||
		Weight: 30,
 | 
			
		||||
	})
 | 
			
		||||
	s.Start()
 | 
			
		||||
 | 
			
		||||
	for _, c := range s.Candidates {
 | 
			
		||||
		t.Log(c.(*TestCandidate).Name, c.CandidateWeight())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Log(s.currentWeights)
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < 100; i ++ {
 | 
			
		||||
		t.Log("===", "round", i, "===")
 | 
			
		||||
		t.Log(s.Next(nil))
 | 
			
		||||
		t.Log(s.currentWeights)
 | 
			
		||||
		t.Log(s.rawWeights)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRoundRobinScheduling_Two(t *testing.T) {
 | 
			
		||||
	s := &RoundRobinScheduling{}
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "a",
 | 
			
		||||
		Weight: 10,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "b",
 | 
			
		||||
		Weight: 10,
 | 
			
		||||
	})
 | 
			
		||||
	s.Start()
 | 
			
		||||
 | 
			
		||||
	for _, c := range s.Candidates {
 | 
			
		||||
		t.Log(c.(*TestCandidate).Name, c.CandidateWeight())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Log(s.currentWeights)
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < 100; i ++ {
 | 
			
		||||
		t.Log("===", "round", i, "===")
 | 
			
		||||
		t.Log(s.Next(nil))
 | 
			
		||||
		t.Log(s.currentWeights)
 | 
			
		||||
		t.Log(s.rawWeights)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRoundRobinScheduling_NextPerformance(t *testing.T) {
 | 
			
		||||
	s := &RoundRobinScheduling{}
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "a",
 | 
			
		||||
		Weight: 1,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "b",
 | 
			
		||||
		Weight: 2,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "c",
 | 
			
		||||
		Weight: 3,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "d",
 | 
			
		||||
		Weight: 6,
 | 
			
		||||
	})
 | 
			
		||||
	s.Start()
 | 
			
		||||
 | 
			
		||||
	for _, c := range s.Candidates {
 | 
			
		||||
		t.Log(c.(*TestCandidate).Name, c.CandidateWeight())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Log(s.currentWeights)
 | 
			
		||||
 | 
			
		||||
	hits := map[string]uint{}
 | 
			
		||||
	for _, c := range s.Candidates {
 | 
			
		||||
		hits[c.(*TestCandidate).Name] = 0
 | 
			
		||||
	}
 | 
			
		||||
	for i := 0; i < 100*10000; i ++ {
 | 
			
		||||
		c := s.Next(nil)
 | 
			
		||||
		hits[c.(*TestCandidate).Name] ++
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	t.Log(hits)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										106
									
								
								internal/configs/serverconfigs/scheduling/scheduling_sticky.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								internal/configs/serverconfigs/scheduling/scheduling_sticky.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,106 @@
 | 
			
		||||
package scheduling
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
 | 
			
		||||
	"github.com/iwind/TeaGo/maps"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Sticky调度算法
 | 
			
		||||
type StickyScheduling struct {
 | 
			
		||||
	Scheduling
 | 
			
		||||
 | 
			
		||||
	count   uint32
 | 
			
		||||
	mapping map[string]CandidateInterface // code => candidate
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 启动
 | 
			
		||||
func (this *StickyScheduling) Start() {
 | 
			
		||||
	this.mapping = map[string]CandidateInterface{}
 | 
			
		||||
	for _, c := range this.Candidates {
 | 
			
		||||
		for _, code := range c.CandidateCodes() {
 | 
			
		||||
			this.mapping[code] = c
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.count = uint32(len(this.Candidates))
 | 
			
		||||
	rand.Seed(time.Now().UnixNano())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取下一个候选对象
 | 
			
		||||
func (this *StickyScheduling) Next(call *shared.RequestCall) CandidateInterface {
 | 
			
		||||
	if this.count == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	typeCode := call.Options.GetString("type")
 | 
			
		||||
	param := call.Options.GetString("param")
 | 
			
		||||
 | 
			
		||||
	if call.Request == nil {
 | 
			
		||||
		return this.Candidates[uint32(rand.Int())%this.count]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	code := ""
 | 
			
		||||
	if typeCode == "cookie" {
 | 
			
		||||
		cookie, err := call.Request.Cookie(param)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			code = cookie.Value
 | 
			
		||||
		}
 | 
			
		||||
	} else if typeCode == "header" {
 | 
			
		||||
		code = call.Request.Header.Get(param)
 | 
			
		||||
	} else if typeCode == "argument" {
 | 
			
		||||
		code = call.Request.URL.Query().Get(param)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	matched := false
 | 
			
		||||
	var c CandidateInterface = nil
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if !matched && c != nil {
 | 
			
		||||
			codes := c.CandidateCodes()
 | 
			
		||||
			if len(codes) == 0 {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if typeCode == "cookie" {
 | 
			
		||||
				call.AddResponseCall(func(resp http.ResponseWriter) {
 | 
			
		||||
					http.SetCookie(resp, &http.Cookie{
 | 
			
		||||
						Name:    param,
 | 
			
		||||
						Value:   codes[0],
 | 
			
		||||
						Path:    "/",
 | 
			
		||||
						Expires: time.Now().AddDate(0, 1, 0),
 | 
			
		||||
					})
 | 
			
		||||
				})
 | 
			
		||||
			} else {
 | 
			
		||||
				call.AddResponseCall(func(resp http.ResponseWriter) {
 | 
			
		||||
					resp.Header().Set(param, codes[0])
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if len(code) == 0 {
 | 
			
		||||
		c = this.Candidates[uint32(rand.Int())%this.count]
 | 
			
		||||
		return c
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	found := false
 | 
			
		||||
	c, found = this.mapping[code]
 | 
			
		||||
	if !found {
 | 
			
		||||
		c = this.Candidates[uint32(rand.Int())%this.count]
 | 
			
		||||
		return c
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	matched = true
 | 
			
		||||
	return c
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取简要信息
 | 
			
		||||
func (this *StickyScheduling) Summary() maps.Map {
 | 
			
		||||
	return maps.Map{
 | 
			
		||||
		"code":        "sticky",
 | 
			
		||||
		"name":        "Sticky算法",
 | 
			
		||||
		"description": "利用Cookie、URL参数或者HTTP Header来指定后端服务器",
 | 
			
		||||
		"networks":    []string{"http"},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -0,0 +1,128 @@
 | 
			
		||||
package scheduling
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
 | 
			
		||||
	"github.com/iwind/TeaGo/maps"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestStickyScheduling_NextArgument(t *testing.T) {
 | 
			
		||||
	s := &StickyScheduling{}
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "a",
 | 
			
		||||
		Weight: 1,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "b",
 | 
			
		||||
		Weight: 2,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "c",
 | 
			
		||||
		Weight: 3,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "d",
 | 
			
		||||
		Weight: 6,
 | 
			
		||||
	})
 | 
			
		||||
	s.Start()
 | 
			
		||||
 | 
			
		||||
	t.Log(s.mapping)
 | 
			
		||||
 | 
			
		||||
	req, err := http.NewRequest("GET", "http://www.example.com/?backend=c", nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	options := maps.Map{
 | 
			
		||||
		"type":  "argument",
 | 
			
		||||
		"param": "backend",
 | 
			
		||||
	}
 | 
			
		||||
	call := shared.NewRequestCall()
 | 
			
		||||
	call.Request = req
 | 
			
		||||
	call.Options = options
 | 
			
		||||
	t.Log(s.Next(call))
 | 
			
		||||
	t.Log(options)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStickyScheduling_NextCookie(t *testing.T) {
 | 
			
		||||
	s := &StickyScheduling{}
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "a",
 | 
			
		||||
		Weight: 1,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "b",
 | 
			
		||||
		Weight: 2,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "c",
 | 
			
		||||
		Weight: 3,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "d",
 | 
			
		||||
		Weight: 6,
 | 
			
		||||
	})
 | 
			
		||||
	s.Start()
 | 
			
		||||
 | 
			
		||||
	t.Log(s.mapping)
 | 
			
		||||
 | 
			
		||||
	req, err := http.NewRequest("GET", "http://www.example.com/?backend=c", nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req.AddCookie(&http.Cookie{
 | 
			
		||||
		Name:  "backend",
 | 
			
		||||
		Value: "c",
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	options := maps.Map{
 | 
			
		||||
		"type":  "cookie",
 | 
			
		||||
		"param": "backend",
 | 
			
		||||
	}
 | 
			
		||||
	call := shared.NewRequestCall()
 | 
			
		||||
	call.Request = req
 | 
			
		||||
	call.Options = options
 | 
			
		||||
	t.Log(s.Next(call))
 | 
			
		||||
	t.Log(options)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStickyScheduling_NextHeader(t *testing.T) {
 | 
			
		||||
	s := &StickyScheduling{}
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "a",
 | 
			
		||||
		Weight: 1,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "b",
 | 
			
		||||
		Weight: 2,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "c",
 | 
			
		||||
		Weight: 3,
 | 
			
		||||
	})
 | 
			
		||||
	s.Add(&TestCandidate{
 | 
			
		||||
		Name:   "d",
 | 
			
		||||
		Weight: 6,
 | 
			
		||||
	})
 | 
			
		||||
	s.Start()
 | 
			
		||||
 | 
			
		||||
	t.Log(s.mapping)
 | 
			
		||||
 | 
			
		||||
	req, err := http.NewRequest("GET", "http://www.example.com/?backend=c", nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("backend", "c")
 | 
			
		||||
 | 
			
		||||
	options := maps.Map{
 | 
			
		||||
		"type":  "header",
 | 
			
		||||
		"param": "backend",
 | 
			
		||||
	}
 | 
			
		||||
	call := shared.NewRequestCall()
 | 
			
		||||
	call.Request = req
 | 
			
		||||
	call.Options = options
 | 
			
		||||
	t.Log(s.Next(call))
 | 
			
		||||
	t.Log(options)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										28
									
								
								internal/configs/serverconfigs/scheduling/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								internal/configs/serverconfigs/scheduling/utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,28 @@
 | 
			
		||||
package scheduling
 | 
			
		||||
 | 
			
		||||
import "github.com/iwind/TeaGo/maps"
 | 
			
		||||
 | 
			
		||||
// 所有请求类型
 | 
			
		||||
func AllSchedulingTypes() []maps.Map {
 | 
			
		||||
	types := []maps.Map{}
 | 
			
		||||
	for _, s := range []SchedulingInterface{
 | 
			
		||||
		new(RandomScheduling),
 | 
			
		||||
		new(RoundRobinScheduling),
 | 
			
		||||
		new(HashScheduling),
 | 
			
		||||
		new(StickyScheduling),
 | 
			
		||||
	} {
 | 
			
		||||
		summary := s.Summary()
 | 
			
		||||
		summary["instance"] = s
 | 
			
		||||
		types = append(types, summary)
 | 
			
		||||
	}
 | 
			
		||||
	return types
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FindSchedulingType(code string) maps.Map {
 | 
			
		||||
	for _, summary := range AllSchedulingTypes() {
 | 
			
		||||
		if summary["code"] == code {
 | 
			
		||||
			return summary
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										14
									
								
								internal/configs/serverconfigs/scheduling_config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								internal/configs/serverconfigs/scheduling_config.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,14 @@
 | 
			
		||||
package serverconfigs
 | 
			
		||||
 | 
			
		||||
import "github.com/iwind/TeaGo/maps"
 | 
			
		||||
 | 
			
		||||
// 调度算法配置
 | 
			
		||||
type SchedulingConfig struct {
 | 
			
		||||
	Code    string   `yaml:"code" json:"code"`       // 类型
 | 
			
		||||
	Options maps.Map `yaml:"options" json:"options"` // 选项
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取新对象
 | 
			
		||||
func NewSchedulingConfig() *SchedulingConfig {
 | 
			
		||||
	return &SchedulingConfig{}
 | 
			
		||||
}
 | 
			
		||||
@@ -76,6 +76,13 @@ func (this *ServerConfig) Init() error {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if this.ReverseProxy != nil {
 | 
			
		||||
		err := this.ReverseProxy.Init()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										68
									
								
								internal/configs/serverconfigs/shared/header.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								internal/configs/serverconfigs/shared/header.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,68 @@
 | 
			
		||||
package shared
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/iwind/TeaGo/utils/string"
 | 
			
		||||
	"regexp"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var regexpNamedVariable = regexp.MustCompile("\\${[\\w.-]+}")
 | 
			
		||||
 | 
			
		||||
// 头部信息定义
 | 
			
		||||
type HeaderConfig struct {
 | 
			
		||||
	IsOn   bool   `yaml:"isOn" json:"isOn"`     // 是否开启
 | 
			
		||||
	Id     string `yaml:"id" json:"id"`         // ID
 | 
			
		||||
	Name   string `yaml:"name" json:"name"`     // Name
 | 
			
		||||
	Value  string `yaml:"value" json:"value"`   // Value
 | 
			
		||||
	Always bool   `yaml:"always" json:"always"` // 是否忽略状态码
 | 
			
		||||
	Status []int  `yaml:"status" json:"status"` // 支持的状态码
 | 
			
		||||
 | 
			
		||||
	statusMap    map[int]bool
 | 
			
		||||
	hasVariables bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取新Header对象
 | 
			
		||||
func NewHeaderConfig() *HeaderConfig {
 | 
			
		||||
	return &HeaderConfig{
 | 
			
		||||
		IsOn: true,
 | 
			
		||||
		Id:   stringutil.Rand(16),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 校验
 | 
			
		||||
func (this *HeaderConfig) Validate() error {
 | 
			
		||||
	this.statusMap = map[int]bool{}
 | 
			
		||||
	this.hasVariables = regexpNamedVariable.MatchString(this.Value)
 | 
			
		||||
 | 
			
		||||
	if this.Status == nil {
 | 
			
		||||
		this.Status = []int{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, status := range this.Status {
 | 
			
		||||
		this.statusMap[status] = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 判断是否匹配状态码
 | 
			
		||||
func (this *HeaderConfig) Match(statusCode int) bool {
 | 
			
		||||
	if !this.IsOn {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if this.Always {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if this.statusMap != nil {
 | 
			
		||||
		_, found := this.statusMap[statusCode]
 | 
			
		||||
		return found
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 是否有变量
 | 
			
		||||
func (this *HeaderConfig) HasVariables() bool {
 | 
			
		||||
	return this.hasVariables
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										245
									
								
								internal/configs/serverconfigs/shared/header_list.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										245
									
								
								internal/configs/serverconfigs/shared/header_list.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,245 @@
 | 
			
		||||
package shared
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/iwind/TeaGo/lists"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// HeaderList相关操作接口
 | 
			
		||||
type HeaderListInterface interface {
 | 
			
		||||
	// 校验
 | 
			
		||||
	ValidateHeaders() error
 | 
			
		||||
 | 
			
		||||
	// 取得所有的IgnoreHeader
 | 
			
		||||
	AllIgnoreResponseHeaders() []string
 | 
			
		||||
 | 
			
		||||
	// 添加IgnoreHeader
 | 
			
		||||
	AddIgnoreResponseHeader(name string)
 | 
			
		||||
 | 
			
		||||
	// 判断是否包含IgnoreHeader
 | 
			
		||||
	ContainsIgnoreResponseHeader(name string) bool
 | 
			
		||||
 | 
			
		||||
	// 移除IgnoreHeader
 | 
			
		||||
	RemoveIgnoreResponseHeader(name string)
 | 
			
		||||
 | 
			
		||||
	// 修改IgnoreHeader
 | 
			
		||||
	UpdateIgnoreResponseHeader(oldName string, newName string)
 | 
			
		||||
 | 
			
		||||
	// 取得所有的Header
 | 
			
		||||
	AllResponseHeaders() []*HeaderConfig
 | 
			
		||||
 | 
			
		||||
	// 添加Header
 | 
			
		||||
	AddResponseHeader(header *HeaderConfig)
 | 
			
		||||
 | 
			
		||||
	// 判断是否包含Header
 | 
			
		||||
	ContainsResponseHeader(name string) bool
 | 
			
		||||
 | 
			
		||||
	// 查找Header
 | 
			
		||||
	FindResponseHeader(headerId string) *HeaderConfig
 | 
			
		||||
 | 
			
		||||
	// 移除Header
 | 
			
		||||
	RemoveResponseHeader(headerId string)
 | 
			
		||||
 | 
			
		||||
	// 取得所有的请求Header
 | 
			
		||||
	AllRequestHeaders() []*HeaderConfig
 | 
			
		||||
 | 
			
		||||
	// 添加请求Header
 | 
			
		||||
	AddRequestHeader(header *HeaderConfig)
 | 
			
		||||
 | 
			
		||||
	// 查找请求Header
 | 
			
		||||
	FindRequestHeader(headerId string) *HeaderConfig
 | 
			
		||||
 | 
			
		||||
	// 移除请求Header
 | 
			
		||||
	RemoveRequestHeader(headerId string)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HeaderList定义
 | 
			
		||||
type HeaderList struct {
 | 
			
		||||
	// 添加的响应Headers
 | 
			
		||||
	Headers []*HeaderConfig `yaml:"headers" json:"headers"`
 | 
			
		||||
 | 
			
		||||
	// 忽略的响应Headers
 | 
			
		||||
	IgnoreHeaders []string `yaml:"ignoreHeaders" json:"ignoreHeaders"`
 | 
			
		||||
 | 
			
		||||
	// 自定义请求Headers
 | 
			
		||||
	RequestHeaders []*HeaderConfig `yaml:"requestHeaders" json:"requestHeaders"`
 | 
			
		||||
 | 
			
		||||
	hasResponseHeaders bool
 | 
			
		||||
	hasRequestHeaders  bool
 | 
			
		||||
 | 
			
		||||
	hasIgnoreHeaders       bool
 | 
			
		||||
	uppercaseIgnoreHeaders []string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 校验
 | 
			
		||||
func (this *HeaderList) Init() error {
 | 
			
		||||
	this.hasResponseHeaders = len(this.Headers) > 0
 | 
			
		||||
	this.hasRequestHeaders = len(this.RequestHeaders) > 0
 | 
			
		||||
 | 
			
		||||
	for _, h := range this.Headers {
 | 
			
		||||
		err := h.Validate()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, h := range this.RequestHeaders {
 | 
			
		||||
		err := h.Validate()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.hasIgnoreHeaders = len(this.IgnoreHeaders) > 0
 | 
			
		||||
	this.uppercaseIgnoreHeaders = []string{}
 | 
			
		||||
	for _, headerKey := range this.IgnoreHeaders {
 | 
			
		||||
		this.uppercaseIgnoreHeaders = append(this.uppercaseIgnoreHeaders, strings.ToUpper(headerKey))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 是否有Headers
 | 
			
		||||
func (this *HeaderList) HasResponseHeaders() bool {
 | 
			
		||||
	return this.hasResponseHeaders
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 取得所有的IgnoreHeader
 | 
			
		||||
func (this *HeaderList) AllIgnoreResponseHeaders() []string {
 | 
			
		||||
	if this.IgnoreHeaders == nil {
 | 
			
		||||
		return []string{}
 | 
			
		||||
	}
 | 
			
		||||
	return this.IgnoreHeaders
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 添加IgnoreHeader
 | 
			
		||||
func (this *HeaderList) AddIgnoreResponseHeader(name string) {
 | 
			
		||||
	if !lists.ContainsString(this.IgnoreHeaders, name) {
 | 
			
		||||
		this.IgnoreHeaders = append(this.IgnoreHeaders, name)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 判断是否包含IgnoreHeader
 | 
			
		||||
func (this *HeaderList) ContainsIgnoreResponseHeader(name string) bool {
 | 
			
		||||
	if len(this.IgnoreHeaders) == 0 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return lists.ContainsString(this.IgnoreHeaders, name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 修改IgnoreHeader
 | 
			
		||||
func (this *HeaderList) UpdateIgnoreResponseHeader(oldName string, newName string) {
 | 
			
		||||
	result := []string{}
 | 
			
		||||
	for _, h := range this.IgnoreHeaders {
 | 
			
		||||
		if h == oldName {
 | 
			
		||||
			result = append(result, newName)
 | 
			
		||||
		} else {
 | 
			
		||||
			result = append(result, h)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	this.IgnoreHeaders = result
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 移除IgnoreHeader
 | 
			
		||||
func (this *HeaderList) RemoveIgnoreResponseHeader(name string) {
 | 
			
		||||
	result := []string{}
 | 
			
		||||
	for _, n := range this.IgnoreHeaders {
 | 
			
		||||
		if n == name {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		result = append(result, n)
 | 
			
		||||
	}
 | 
			
		||||
	this.IgnoreHeaders = result
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 取得所有的Header
 | 
			
		||||
func (this *HeaderList) AllResponseHeaders() []*HeaderConfig {
 | 
			
		||||
	if this.Headers == nil {
 | 
			
		||||
		return []*HeaderConfig{}
 | 
			
		||||
	}
 | 
			
		||||
	return this.Headers
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 添加Header
 | 
			
		||||
func (this *HeaderList) AddResponseHeader(header *HeaderConfig) {
 | 
			
		||||
	this.Headers = append(this.Headers, header)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 判断是否包含Header
 | 
			
		||||
func (this *HeaderList) ContainsResponseHeader(name string) bool {
 | 
			
		||||
	for _, h := range this.Headers {
 | 
			
		||||
		if h.Name == name {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 查找Header
 | 
			
		||||
func (this *HeaderList) FindResponseHeader(headerId string) *HeaderConfig {
 | 
			
		||||
	for _, h := range this.Headers {
 | 
			
		||||
		if h.Id == headerId {
 | 
			
		||||
			return h
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 移除Header
 | 
			
		||||
func (this *HeaderList) RemoveResponseHeader(headerId string) {
 | 
			
		||||
	result := []*HeaderConfig{}
 | 
			
		||||
	for _, h := range this.Headers {
 | 
			
		||||
		if h.Id == headerId {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		result = append(result, h)
 | 
			
		||||
	}
 | 
			
		||||
	this.Headers = result
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 添加请求Header
 | 
			
		||||
func (this *HeaderList) AddRequestHeader(header *HeaderConfig) {
 | 
			
		||||
	this.RequestHeaders = append(this.RequestHeaders, header)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 判断是否有请求Header
 | 
			
		||||
func (this *HeaderList) HasRequestHeaders() bool {
 | 
			
		||||
	return this.hasRequestHeaders
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 取得所有的请求Header
 | 
			
		||||
func (this *HeaderList) AllRequestHeaders() []*HeaderConfig {
 | 
			
		||||
	return this.RequestHeaders
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 查找请求Header
 | 
			
		||||
func (this *HeaderList) FindRequestHeader(headerId string) *HeaderConfig {
 | 
			
		||||
	for _, h := range this.RequestHeaders {
 | 
			
		||||
		if h.Id == headerId {
 | 
			
		||||
			return h
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 移除请求Header
 | 
			
		||||
func (this *HeaderList) RemoveRequestHeader(headerId string) {
 | 
			
		||||
	result := []*HeaderConfig{}
 | 
			
		||||
	for _, h := range this.RequestHeaders {
 | 
			
		||||
		if h.Id == headerId {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		result = append(result, h)
 | 
			
		||||
	}
 | 
			
		||||
	this.RequestHeaders = result
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 判断是否有Ignore Headers
 | 
			
		||||
func (this *HeaderList) HasIgnoreHeaders() bool {
 | 
			
		||||
	return this.hasIgnoreHeaders
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 查找大写的Ignore Headers
 | 
			
		||||
func (this *HeaderList) UppercaseIgnoreHeaders() []string {
 | 
			
		||||
	return this.uppercaseIgnoreHeaders
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										23
									
								
								internal/configs/serverconfigs/shared/header_list_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								internal/configs/serverconfigs/shared/header_list_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,23 @@
 | 
			
		||||
package shared
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestHeaderList_FormatHeaders(t *testing.T) {
 | 
			
		||||
	list := &HeaderList{}
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < 5; i++ {
 | 
			
		||||
		list.AddRequestHeader(&HeaderConfig{
 | 
			
		||||
			IsOn:  true,
 | 
			
		||||
			Name:  "A" + fmt.Sprintf("%d", i),
 | 
			
		||||
			Value: "ABCDEFGHIJ${name}KLM${hello}NEFGHIJILKKKk",
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := list.Init()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										28
									
								
								internal/configs/serverconfigs/shared/header_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								internal/configs/serverconfigs/shared/header_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,28 @@
 | 
			
		||||
package shared
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/iwind/TeaGo/assert"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestHeaderConfig_Match(t *testing.T) {
 | 
			
		||||
	a := assert.NewAssertion(t)
 | 
			
		||||
	h := NewHeaderConfig()
 | 
			
		||||
	err := h.Validate()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	a.IsFalse(h.Match(200))
 | 
			
		||||
	a.IsFalse(h.Match(400))
 | 
			
		||||
 | 
			
		||||
	h.Status = []int{200, 201, 400}
 | 
			
		||||
	err = h.Validate()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	a.IsTrue(h.Match(400))
 | 
			
		||||
	a.IsFalse(h.Match(500))
 | 
			
		||||
 | 
			
		||||
	h.Always = true
 | 
			
		||||
	a.IsTrue(h.Match(500))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										13
									
								
								internal/configs/serverconfigs/shared/regexp.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								internal/configs/serverconfigs/shared/regexp.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,13 @@
 | 
			
		||||
package shared
 | 
			
		||||
 | 
			
		||||
import "regexp"
 | 
			
		||||
 | 
			
		||||
// 常用的正则表达式
 | 
			
		||||
var (
 | 
			
		||||
	RegexpDigitNumber    = regexp.MustCompile(`^\d+$`)                    // 正整数
 | 
			
		||||
	RegexpFloatNumber    = regexp.MustCompile(`^\d+(\.\d+)?$`)            // 正浮点数,不支持e
 | 
			
		||||
	RegexpAllDigitNumber = regexp.MustCompile(`^[+-]?\d+$`)               // 整数,支持正负数
 | 
			
		||||
	RegexpAllFloatNumber = regexp.MustCompile(`^[+-]?\d+(\.\d+)?$`)       // 浮点数,支持正负数,不支持e
 | 
			
		||||
	RegexpExternalURL    = regexp.MustCompile("(?i)^(http|https|ftp)://") // URL
 | 
			
		||||
	RegexpNamedVariable  = regexp.MustCompile("\\${[\\w.-]+}")            // 命名变量
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										17
									
								
								internal/configs/serverconfigs/shared/regexp_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								internal/configs/serverconfigs/shared/regexp_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
			
		||||
package shared
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/iwind/TeaGo/assert"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestRegexp(t *testing.T) {
 | 
			
		||||
	a := assert.NewAssertion(t)
 | 
			
		||||
 | 
			
		||||
	a.IsTrue(RegexpFloatNumber.MatchString("123"))
 | 
			
		||||
	a.IsTrue(RegexpFloatNumber.MatchString("123.456"))
 | 
			
		||||
	a.IsFalse(RegexpFloatNumber.MatchString(".456"))
 | 
			
		||||
	a.IsFalse(RegexpFloatNumber.MatchString("abc"))
 | 
			
		||||
	a.IsFalse(RegexpFloatNumber.MatchString("123."))
 | 
			
		||||
	a.IsFalse(RegexpFloatNumber.MatchString("123.456e7"))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										41
									
								
								internal/configs/serverconfigs/shared/request_call.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								internal/configs/serverconfigs/shared/request_call.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,41 @@
 | 
			
		||||
package shared
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/iwind/TeaGo/maps"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 请求调用
 | 
			
		||||
type RequestCall struct {
 | 
			
		||||
	Formatter         func(source string) string
 | 
			
		||||
	Request           *http.Request
 | 
			
		||||
	ResponseCallbacks []func(resp http.ResponseWriter)
 | 
			
		||||
	Options           maps.Map
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取新对象
 | 
			
		||||
func NewRequestCall() *RequestCall {
 | 
			
		||||
	return &RequestCall{
 | 
			
		||||
		Options: maps.Map{},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 重置
 | 
			
		||||
func (this *RequestCall) Reset() {
 | 
			
		||||
	this.Formatter = nil
 | 
			
		||||
	this.Request = nil
 | 
			
		||||
	this.ResponseCallbacks = nil
 | 
			
		||||
	this.Options = maps.Map{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 添加响应回调
 | 
			
		||||
func (this *RequestCall) AddResponseCall(callback func(resp http.ResponseWriter)) {
 | 
			
		||||
	this.ResponseCallbacks = append(this.ResponseCallbacks, callback)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行响应回调
 | 
			
		||||
func (this *RequestCall) CallResponseCallbacks(resp http.ResponseWriter) {
 | 
			
		||||
	for _, callback := range this.ResponseCallbacks {
 | 
			
		||||
		callback(resp)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										372
									
								
								internal/configs/serverconfigs/shared/request_cond.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										372
									
								
								internal/configs/serverconfigs/shared/request_cond.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,372 @@
 | 
			
		||||
package shared
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/iwind/TeaGo/Tea"
 | 
			
		||||
	"github.com/iwind/TeaGo/lists"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"github.com/iwind/TeaGo/utils/string"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 重写条件定义
 | 
			
		||||
type RequestCond struct {
 | 
			
		||||
	Id string `yaml:"id" json:"id"` // ID
 | 
			
		||||
 | 
			
		||||
	// 要测试的字符串
 | 
			
		||||
	// 其中可以使用跟请求相关的参数,比如:
 | 
			
		||||
	// ${arg.name}, ${requestPath}
 | 
			
		||||
	Param string `yaml:"param" json:"param"`
 | 
			
		||||
 | 
			
		||||
	// 运算符
 | 
			
		||||
	Operator RequestCondOperator `yaml:"operator" json:"operator"`
 | 
			
		||||
 | 
			
		||||
	// 对比
 | 
			
		||||
	Value string `yaml:"value" json:"value"`
 | 
			
		||||
 | 
			
		||||
	isInt   bool
 | 
			
		||||
	isFloat bool
 | 
			
		||||
	isIP    bool
 | 
			
		||||
 | 
			
		||||
	regValue   *regexp.Regexp
 | 
			
		||||
	floatValue float64
 | 
			
		||||
	ipValue    net.IP
 | 
			
		||||
	arrayValue []string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 取得新对象
 | 
			
		||||
func NewRequestCond() *RequestCond {
 | 
			
		||||
	return &RequestCond{
 | 
			
		||||
		Id: stringutil.Rand(16),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 校验配置
 | 
			
		||||
func (this *RequestCond) Validate() error {
 | 
			
		||||
	this.isInt = RegexpDigitNumber.MatchString(this.Value)
 | 
			
		||||
	this.isFloat = RegexpFloatNumber.MatchString(this.Value)
 | 
			
		||||
 | 
			
		||||
	if lists.ContainsString([]string{
 | 
			
		||||
		RequestCondOperatorRegexp,
 | 
			
		||||
		RequestCondOperatorNotRegexp,
 | 
			
		||||
	}, this.Operator) {
 | 
			
		||||
		reg, err := regexp.Compile(this.Value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		this.regValue = reg
 | 
			
		||||
	} else if lists.ContainsString([]string{
 | 
			
		||||
		RequestCondOperatorEqFloat,
 | 
			
		||||
		RequestCondOperatorGtFloat,
 | 
			
		||||
		RequestCondOperatorGteFloat,
 | 
			
		||||
		RequestCondOperatorLtFloat,
 | 
			
		||||
		RequestCondOperatorLteFloat,
 | 
			
		||||
	}, this.Operator) {
 | 
			
		||||
		this.floatValue = types.Float64(this.Value)
 | 
			
		||||
	} else if lists.ContainsString([]string{
 | 
			
		||||
		RequestCondOperatorEqIP,
 | 
			
		||||
		RequestCondOperatorGtIP,
 | 
			
		||||
		RequestCondOperatorGteIP,
 | 
			
		||||
		RequestCondOperatorLtIP,
 | 
			
		||||
		RequestCondOperatorLteIP,
 | 
			
		||||
	}, this.Operator) {
 | 
			
		||||
		this.ipValue = net.ParseIP(this.Value)
 | 
			
		||||
		this.isIP = this.ipValue != nil
 | 
			
		||||
 | 
			
		||||
		if !this.isIP {
 | 
			
		||||
			return errors.New("value should be a valid ip")
 | 
			
		||||
		}
 | 
			
		||||
	} else if lists.ContainsString([]string{
 | 
			
		||||
		RequestCondOperatorIPRange,
 | 
			
		||||
	}, this.Operator) {
 | 
			
		||||
		if strings.Contains(this.Value, ",") {
 | 
			
		||||
			ipList := strings.SplitN(this.Value, ",", 2)
 | 
			
		||||
			ipString1 := strings.TrimSpace(ipList[0])
 | 
			
		||||
			ipString2 := strings.TrimSpace(ipList[1])
 | 
			
		||||
 | 
			
		||||
			if len(ipString1) > 0 {
 | 
			
		||||
				ip1 := net.ParseIP(ipString1)
 | 
			
		||||
				if ip1 == nil {
 | 
			
		||||
					return errors.New("start ip is invalid")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if len(ipString2) > 0 {
 | 
			
		||||
				ip2 := net.ParseIP(ipString2)
 | 
			
		||||
				if ip2 == nil {
 | 
			
		||||
					return errors.New("end ip is invalid")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		} else if strings.Contains(this.Value, "/") {
 | 
			
		||||
			_, _, err := net.ParseCIDR(this.Value)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			return errors.New("invalid ip range")
 | 
			
		||||
		}
 | 
			
		||||
	} else if lists.ContainsString([]string{
 | 
			
		||||
		RequestCondOperatorIn,
 | 
			
		||||
		RequestCondOperatorNotIn,
 | 
			
		||||
		RequestCondOperatorFileExt,
 | 
			
		||||
	}, this.Operator) {
 | 
			
		||||
		stringsValue := []string{}
 | 
			
		||||
		err := json.Unmarshal([]byte(this.Value), &stringsValue)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		this.arrayValue = stringsValue
 | 
			
		||||
	} else if lists.ContainsString([]string{
 | 
			
		||||
		RequestCondOperatorFileMimeType,
 | 
			
		||||
	}, this.Operator) {
 | 
			
		||||
		stringsValue := []string{}
 | 
			
		||||
		err := json.Unmarshal([]byte(this.Value), &stringsValue)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		for k, v := range stringsValue {
 | 
			
		||||
			if strings.Contains(v, "*") {
 | 
			
		||||
				v = regexp.QuoteMeta(v)
 | 
			
		||||
				v = strings.Replace(v, `\*`, ".*", -1)
 | 
			
		||||
				stringsValue[k] = v
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		this.arrayValue = stringsValue
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 将此条件应用于请求,检查是否匹配
 | 
			
		||||
func (this *RequestCond) Match(formatter func(source string) string) bool {
 | 
			
		||||
	paramValue := formatter(this.Param)
 | 
			
		||||
	switch this.Operator {
 | 
			
		||||
	case RequestCondOperatorRegexp:
 | 
			
		||||
		if this.regValue == nil {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		return this.regValue.MatchString(paramValue)
 | 
			
		||||
	case RequestCondOperatorNotRegexp:
 | 
			
		||||
		if this.regValue == nil {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		return !this.regValue.MatchString(paramValue)
 | 
			
		||||
	case RequestCondOperatorEqInt:
 | 
			
		||||
		return this.isInt && paramValue == this.Value
 | 
			
		||||
	case RequestCondOperatorEqFloat:
 | 
			
		||||
		return this.isFloat && types.Float64(paramValue) == this.floatValue
 | 
			
		||||
	case RequestCondOperatorGtFloat:
 | 
			
		||||
		return this.isFloat && types.Float64(paramValue) > this.floatValue
 | 
			
		||||
	case RequestCondOperatorGteFloat:
 | 
			
		||||
		return this.isFloat && types.Float64(paramValue) >= this.floatValue
 | 
			
		||||
	case RequestCondOperatorLtFloat:
 | 
			
		||||
		return this.isFloat && types.Float64(paramValue) < this.floatValue
 | 
			
		||||
	case RequestCondOperatorLteFloat:
 | 
			
		||||
		return this.isFloat && types.Float64(paramValue) <= this.floatValue
 | 
			
		||||
	case RequestCondOperatorMod:
 | 
			
		||||
		pieces := strings.SplitN(this.Value, ",", 2)
 | 
			
		||||
		if len(pieces) == 1 {
 | 
			
		||||
			rem := types.Int64(pieces[0])
 | 
			
		||||
			return types.Int64(paramValue)%10 == rem
 | 
			
		||||
		}
 | 
			
		||||
		div := types.Int64(pieces[0])
 | 
			
		||||
		if div == 0 {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		rem := types.Int64(pieces[1])
 | 
			
		||||
		return types.Int64(paramValue)%div == rem
 | 
			
		||||
	case RequestCondOperatorMod10:
 | 
			
		||||
		return types.Int64(paramValue)%10 == types.Int64(this.Value)
 | 
			
		||||
	case RequestCondOperatorMod100:
 | 
			
		||||
		return types.Int64(paramValue)%100 == types.Int64(this.Value)
 | 
			
		||||
	case RequestCondOperatorEqString:
 | 
			
		||||
		return paramValue == this.Value
 | 
			
		||||
	case RequestCondOperatorNeqString:
 | 
			
		||||
		return paramValue != this.Value
 | 
			
		||||
	case RequestCondOperatorHasPrefix:
 | 
			
		||||
		return strings.HasPrefix(paramValue, this.Value)
 | 
			
		||||
	case RequestCondOperatorHasSuffix:
 | 
			
		||||
		return strings.HasSuffix(paramValue, this.Value)
 | 
			
		||||
	case RequestCondOperatorContainsString:
 | 
			
		||||
		return strings.Contains(paramValue, this.Value)
 | 
			
		||||
	case RequestCondOperatorNotContainsString:
 | 
			
		||||
		return !strings.Contains(paramValue, this.Value)
 | 
			
		||||
	case RequestCondOperatorEqIP:
 | 
			
		||||
		ip := net.ParseIP(paramValue)
 | 
			
		||||
		if ip == nil {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		return this.isIP && bytes.Compare(this.ipValue, ip) == 0
 | 
			
		||||
	case RequestCondOperatorGtIP:
 | 
			
		||||
		ip := net.ParseIP(paramValue)
 | 
			
		||||
		if ip == nil {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		return this.isIP && bytes.Compare(ip, this.ipValue) > 0
 | 
			
		||||
	case RequestCondOperatorGteIP:
 | 
			
		||||
		ip := net.ParseIP(paramValue)
 | 
			
		||||
		if ip == nil {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		return this.isIP && bytes.Compare(ip, this.ipValue) >= 0
 | 
			
		||||
	case RequestCondOperatorLtIP:
 | 
			
		||||
		ip := net.ParseIP(paramValue)
 | 
			
		||||
		if ip == nil {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		return this.isIP && bytes.Compare(ip, this.ipValue) < 0
 | 
			
		||||
	case RequestCondOperatorLteIP:
 | 
			
		||||
		ip := net.ParseIP(paramValue)
 | 
			
		||||
		if ip == nil {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		return this.isIP && bytes.Compare(ip, this.ipValue) <= 0
 | 
			
		||||
	case RequestCondOperatorIPRange:
 | 
			
		||||
		ip := net.ParseIP(paramValue)
 | 
			
		||||
		if ip == nil {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 检查IP范围格式
 | 
			
		||||
		if strings.Contains(this.Value, ",") {
 | 
			
		||||
			ipList := strings.SplitN(this.Value, ",", 2)
 | 
			
		||||
			ipString1 := strings.TrimSpace(ipList[0])
 | 
			
		||||
			ipString2 := strings.TrimSpace(ipList[1])
 | 
			
		||||
 | 
			
		||||
			if len(ipString1) > 0 {
 | 
			
		||||
				ip1 := net.ParseIP(ipString1)
 | 
			
		||||
				if ip1 == nil {
 | 
			
		||||
					return false
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if bytes.Compare(ip, ip1) < 0 {
 | 
			
		||||
					return false
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if len(ipString2) > 0 {
 | 
			
		||||
				ip2 := net.ParseIP(ipString2)
 | 
			
		||||
				if ip2 == nil {
 | 
			
		||||
					return false
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if bytes.Compare(ip, ip2) > 0 {
 | 
			
		||||
					return false
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return true
 | 
			
		||||
		} else if strings.Contains(this.Value, "/") {
 | 
			
		||||
			_, ipNet, err := net.ParseCIDR(this.Value)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return false
 | 
			
		||||
			}
 | 
			
		||||
			return ipNet.Contains(ip)
 | 
			
		||||
		} else {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	case RequestCondOperatorIn:
 | 
			
		||||
		return lists.ContainsString(this.arrayValue, paramValue)
 | 
			
		||||
	case RequestCondOperatorNotIn:
 | 
			
		||||
		return !lists.ContainsString(this.arrayValue, paramValue)
 | 
			
		||||
	case RequestCondOperatorFileExt:
 | 
			
		||||
		ext := filepath.Ext(paramValue)
 | 
			
		||||
		if len(ext) > 0 {
 | 
			
		||||
			ext = ext[1:] // remove dot
 | 
			
		||||
		}
 | 
			
		||||
		return lists.ContainsString(this.arrayValue, strings.ToLower(ext))
 | 
			
		||||
	case RequestCondOperatorFileMimeType:
 | 
			
		||||
		index := strings.Index(paramValue, ";")
 | 
			
		||||
		if index >= 0 {
 | 
			
		||||
			paramValue = strings.TrimSpace(paramValue[:index])
 | 
			
		||||
		}
 | 
			
		||||
		if len(this.arrayValue) == 0 {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		for _, v := range this.arrayValue {
 | 
			
		||||
			if strings.Contains(v, "*") {
 | 
			
		||||
				reg, err := stringutil.RegexpCompile("^" + v + "$")
 | 
			
		||||
				if err == nil && reg.MatchString(paramValue) {
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
			} else if paramValue == v {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case RequestCondOperatorVersionRange:
 | 
			
		||||
		if strings.Contains(this.Value, ",") {
 | 
			
		||||
			versions := strings.SplitN(this.Value, ",", 2)
 | 
			
		||||
			version1 := strings.TrimSpace(versions[0])
 | 
			
		||||
			version2 := strings.TrimSpace(versions[1])
 | 
			
		||||
			if len(version1) > 0 && stringutil.VersionCompare(paramValue, version1) < 0 {
 | 
			
		||||
				return false
 | 
			
		||||
			}
 | 
			
		||||
			if len(version2) > 0 && stringutil.VersionCompare(paramValue, version2) > 0 {
 | 
			
		||||
				return false
 | 
			
		||||
			}
 | 
			
		||||
			return true
 | 
			
		||||
		} else {
 | 
			
		||||
			return stringutil.VersionCompare(paramValue, this.Value) >= 0
 | 
			
		||||
		}
 | 
			
		||||
	case RequestCondOperatorIPMod:
 | 
			
		||||
		pieces := strings.SplitN(this.Value, ",", 2)
 | 
			
		||||
		if len(pieces) == 1 {
 | 
			
		||||
			rem := types.Int64(pieces[0])
 | 
			
		||||
			return this.ipToInt64(net.ParseIP(paramValue))%10 == rem
 | 
			
		||||
		}
 | 
			
		||||
		div := types.Int64(pieces[0])
 | 
			
		||||
		if div == 0 {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		rem := types.Int64(pieces[1])
 | 
			
		||||
		return this.ipToInt64(net.ParseIP(paramValue))%div == rem
 | 
			
		||||
	case RequestCondOperatorIPMod10:
 | 
			
		||||
		return this.ipToInt64(net.ParseIP(paramValue))%10 == types.Int64(this.Value)
 | 
			
		||||
	case RequestCondOperatorIPMod100:
 | 
			
		||||
		return this.ipToInt64(net.ParseIP(paramValue))%100 == types.Int64(this.Value)
 | 
			
		||||
	case RequestCondOperatorFileExist:
 | 
			
		||||
		index := strings.Index(paramValue, "?")
 | 
			
		||||
		if index > -1 {
 | 
			
		||||
			paramValue = paramValue[:index]
 | 
			
		||||
		}
 | 
			
		||||
		if len(paramValue) == 0 {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		if !filepath.IsAbs(paramValue) {
 | 
			
		||||
			paramValue = Tea.Root + Tea.DS + paramValue
 | 
			
		||||
		}
 | 
			
		||||
		stat, err := os.Stat(paramValue)
 | 
			
		||||
		return err == nil && !stat.IsDir()
 | 
			
		||||
	case RequestCondOperatorFileNotExist:
 | 
			
		||||
		index := strings.Index(paramValue, "?")
 | 
			
		||||
		if index > -1 {
 | 
			
		||||
			paramValue = paramValue[:index]
 | 
			
		||||
		}
 | 
			
		||||
		if len(paramValue) == 0 {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
		if !filepath.IsAbs(paramValue) {
 | 
			
		||||
			paramValue = Tea.Root + Tea.DS + paramValue
 | 
			
		||||
		}
 | 
			
		||||
		stat, err := os.Stat(paramValue)
 | 
			
		||||
		return err != nil || stat.IsDir()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *RequestCond) ipToInt64(ip net.IP) int64 {
 | 
			
		||||
	if len(ip) == 0 {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	if len(ip) == 16 {
 | 
			
		||||
		return int64(binary.BigEndian.Uint32(ip[12:16]))
 | 
			
		||||
	}
 | 
			
		||||
	return int64(binary.BigEndian.Uint32(ip))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										1000
									
								
								internal/configs/serverconfigs/shared/request_cond_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1000
									
								
								internal/configs/serverconfigs/shared/request_cond_test.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										226
									
								
								internal/configs/serverconfigs/shared/request_operators.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										226
									
								
								internal/configs/serverconfigs/shared/request_operators.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,226 @@
 | 
			
		||||
package shared
 | 
			
		||||
 | 
			
		||||
import "github.com/iwind/TeaGo/maps"
 | 
			
		||||
 | 
			
		||||
// 运算符定义
 | 
			
		||||
type RequestCondOperator = string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	// 正则
 | 
			
		||||
	RequestCondOperatorRegexp    RequestCondOperator = "regexp"
 | 
			
		||||
	RequestCondOperatorNotRegexp RequestCondOperator = "not regexp"
 | 
			
		||||
 | 
			
		||||
	// 数字相关
 | 
			
		||||
	RequestCondOperatorEqInt    RequestCondOperator = "eq int"   // 整数等于
 | 
			
		||||
	RequestCondOperatorEqFloat  RequestCondOperator = "eq float" // 浮点数等于
 | 
			
		||||
	RequestCondOperatorGtFloat  RequestCondOperator = "gt"
 | 
			
		||||
	RequestCondOperatorGteFloat RequestCondOperator = "gte"
 | 
			
		||||
	RequestCondOperatorLtFloat  RequestCondOperator = "lt"
 | 
			
		||||
	RequestCondOperatorLteFloat RequestCondOperator = "lte"
 | 
			
		||||
 | 
			
		||||
	// 取模
 | 
			
		||||
	RequestCondOperatorMod10  RequestCondOperator = "mod 10"
 | 
			
		||||
	RequestCondOperatorMod100 RequestCondOperator = "mod 100"
 | 
			
		||||
	RequestCondOperatorMod    RequestCondOperator = "mod"
 | 
			
		||||
 | 
			
		||||
	// 字符串相关
 | 
			
		||||
	RequestCondOperatorEqString          RequestCondOperator = "eq"
 | 
			
		||||
	RequestCondOperatorNeqString         RequestCondOperator = "not"
 | 
			
		||||
	RequestCondOperatorHasPrefix         RequestCondOperator = "prefix"
 | 
			
		||||
	RequestCondOperatorHasSuffix         RequestCondOperator = "suffix"
 | 
			
		||||
	RequestCondOperatorContainsString    RequestCondOperator = "contains"
 | 
			
		||||
	RequestCondOperatorNotContainsString RequestCondOperator = "not contains"
 | 
			
		||||
	RequestCondOperatorIn                RequestCondOperator = "in"
 | 
			
		||||
	RequestCondOperatorNotIn             RequestCondOperator = "not in"
 | 
			
		||||
	RequestCondOperatorFileExt           RequestCondOperator = "file ext"
 | 
			
		||||
	RequestCondOperatorFileMimeType      RequestCondOperator = "mime type"
 | 
			
		||||
	RequestCondOperatorVersionRange      RequestCondOperator = "version range"
 | 
			
		||||
 | 
			
		||||
	// IP相关
 | 
			
		||||
	RequestCondOperatorEqIP     RequestCondOperator = "eq ip"
 | 
			
		||||
	RequestCondOperatorGtIP     RequestCondOperator = "gt ip"
 | 
			
		||||
	RequestCondOperatorGteIP    RequestCondOperator = "gte ip"
 | 
			
		||||
	RequestCondOperatorLtIP     RequestCondOperator = "lt ip"
 | 
			
		||||
	RequestCondOperatorLteIP    RequestCondOperator = "lte ip"
 | 
			
		||||
	RequestCondOperatorIPRange  RequestCondOperator = "ip range"
 | 
			
		||||
	RequestCondOperatorIPMod10  RequestCondOperator = "ip mod 10"
 | 
			
		||||
	RequestCondOperatorIPMod100 RequestCondOperator = "ip mod 100"
 | 
			
		||||
	RequestCondOperatorIPMod    RequestCondOperator = "ip mod"
 | 
			
		||||
 | 
			
		||||
	// 文件相关
 | 
			
		||||
	RequestCondOperatorFileExist    RequestCondOperator = "file exist"
 | 
			
		||||
	RequestCondOperatorFileNotExist RequestCondOperator = "file not exist"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 所有的运算符
 | 
			
		||||
func AllRequestOperators() []maps.Map {
 | 
			
		||||
	return []maps.Map{
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "正则表达式匹配",
 | 
			
		||||
			"op":          RequestCondOperatorRegexp,
 | 
			
		||||
			"description": "判断是否正则表达式匹配",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "正则表达式不匹配",
 | 
			
		||||
			"op":          RequestCondOperatorNotRegexp,
 | 
			
		||||
			"description": "判断是否正则表达式不匹配",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "字符串等于",
 | 
			
		||||
			"op":          RequestCondOperatorEqString,
 | 
			
		||||
			"description": "使用字符串对比参数值是否相等于某个值",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "字符串前缀",
 | 
			
		||||
			"op":          RequestCondOperatorHasPrefix,
 | 
			
		||||
			"description": "参数值包含某个前缀",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "字符串后缀",
 | 
			
		||||
			"op":          RequestCondOperatorHasSuffix,
 | 
			
		||||
			"description": "参数值包含某个后缀",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "字符串包含",
 | 
			
		||||
			"op":          RequestCondOperatorContainsString,
 | 
			
		||||
			"description": "参数值包含另外一个字符串",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "字符串不包含",
 | 
			
		||||
			"op":          RequestCondOperatorNotContainsString,
 | 
			
		||||
			"description": "参数值不包含另外一个字符串",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "字符串不等于",
 | 
			
		||||
			"op":          RequestCondOperatorNeqString,
 | 
			
		||||
			"description": "使用字符串对比参数值是否不相等于某个值",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "在列表中",
 | 
			
		||||
			"op":          RequestCondOperatorIn,
 | 
			
		||||
			"description": "判断参数值在某个列表中",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "不在列表中",
 | 
			
		||||
			"op":          RequestCondOperatorNotIn,
 | 
			
		||||
			"description": "判断参数值不在某个列表中",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "扩展名",
 | 
			
		||||
			"op":          RequestCondOperatorFileExt,
 | 
			
		||||
			"description": "判断小写的扩展名(不带点)在某个列表中",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "MimeType",
 | 
			
		||||
			"op":          RequestCondOperatorFileMimeType,
 | 
			
		||||
			"description": "判断MimeType在某个列表中,支持类似于image/*的语法",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "版本号范围",
 | 
			
		||||
			"op":          RequestCondOperatorVersionRange,
 | 
			
		||||
			"description": "判断版本号在某个范围内,格式为version1,version2",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "整数等于",
 | 
			
		||||
			"op":          RequestCondOperatorEqInt,
 | 
			
		||||
			"description": "将参数转换为整数数字后进行对比",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "浮点数等于",
 | 
			
		||||
			"op":          RequestCondOperatorEqFloat,
 | 
			
		||||
			"description": "将参数转换为可以有小数的浮点数字进行对比",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "数字大于",
 | 
			
		||||
			"op":          RequestCondOperatorGtFloat,
 | 
			
		||||
			"description": "将参数转换为数字进行对比",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "数字大于等于",
 | 
			
		||||
			"op":          RequestCondOperatorGteFloat,
 | 
			
		||||
			"description": "将参数转换为数字进行对比",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "数字小于",
 | 
			
		||||
			"op":          RequestCondOperatorLtFloat,
 | 
			
		||||
			"description": "将参数转换为数字进行对比",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "数字小于等于",
 | 
			
		||||
			"op":          RequestCondOperatorLteFloat,
 | 
			
		||||
			"description": "将参数转换为数字进行对比",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "整数取模10",
 | 
			
		||||
			"op":          RequestCondOperatorMod10,
 | 
			
		||||
			"description": "对整数参数值取模,除数为10,对比值为余数",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "整数取模100",
 | 
			
		||||
			"op":          RequestCondOperatorMod100,
 | 
			
		||||
			"description": "对整数参数值取模,除数为100,对比值为余数",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "整数取模",
 | 
			
		||||
			"op":          RequestCondOperatorMod,
 | 
			
		||||
			"description": "对整数参数值取模,对比值格式为:除数,余数,比如10,1",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "IP等于",
 | 
			
		||||
			"op":          RequestCondOperatorEqIP,
 | 
			
		||||
			"description": "将参数转换为IP进行对比",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "IP大于",
 | 
			
		||||
			"op":          RequestCondOperatorGtIP,
 | 
			
		||||
			"description": "将参数转换为IP进行对比",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "IP大于等于",
 | 
			
		||||
			"op":          RequestCondOperatorGteIP,
 | 
			
		||||
			"description": "将参数转换为IP进行对比",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "IP小于",
 | 
			
		||||
			"op":          RequestCondOperatorLtIP,
 | 
			
		||||
			"description": "将参数转换为IP进行对比",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "IP小于等于",
 | 
			
		||||
			"op":          RequestCondOperatorLteIP,
 | 
			
		||||
			"description": "将参数转换为IP进行对比",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "IP范围",
 | 
			
		||||
			"op":          RequestCondOperatorIPRange,
 | 
			
		||||
			"description": "IP在某个范围之内,范围格式可以是英文逗号分隔的ip1,ip2,或者CIDR格式的ip/bits",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "IP取模10",
 | 
			
		||||
			"op":          RequestCondOperatorIPMod10,
 | 
			
		||||
			"description": "对IP参数值取模,除数为10,对比值为余数",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "IP取模100",
 | 
			
		||||
			"op":          RequestCondOperatorIPMod100,
 | 
			
		||||
			"description": "对IP参数值取模,除数为100,对比值为余数",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "IP取模",
 | 
			
		||||
			"op":          RequestCondOperatorIPMod,
 | 
			
		||||
			"description": "对IP参数值取模,对比值格式为:除数,余数,比如10,1",
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "文件存在",
 | 
			
		||||
			"op":          RequestCondOperatorFileExist,
 | 
			
		||||
			"description": "判断参数值解析后的文件是否存在",
 | 
			
		||||
		},
 | 
			
		||||
 | 
			
		||||
		{
 | 
			
		||||
			"name":        "文件不存在",
 | 
			
		||||
			"op":          RequestCondOperatorFileNotExist,
 | 
			
		||||
			"description": "判断参数值解析后的文件是否不存在",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										30
									
								
								internal/configs/serverconfigs/shared/size_capacity.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								internal/configs/serverconfigs/shared/size_capacity.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,30 @@
 | 
			
		||||
package shared
 | 
			
		||||
 | 
			
		||||
type SizeCapacityUnit = string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	SizeCapacityUnitByte SizeCapacityUnit = "byte"
 | 
			
		||||
	SizeCapacityUnitKB   SizeCapacityUnit = "kb"
 | 
			
		||||
	SizeCapacityUnitMB   SizeCapacityUnit = "mb"
 | 
			
		||||
	SizeCapacityUnitGB   SizeCapacityUnit = "gb"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type SizeCapacity struct {
 | 
			
		||||
	Count int64            `json:"count" yaml:"count"`
 | 
			
		||||
	Unit  SizeCapacityUnit `json:"unit" yaml:"unit"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *SizeCapacity) Bytes() int64 {
 | 
			
		||||
	switch this.Unit {
 | 
			
		||||
	case SizeCapacityUnitByte:
 | 
			
		||||
		return this.Count
 | 
			
		||||
	case SizeCapacityUnitKB:
 | 
			
		||||
		return this.Count * 1024
 | 
			
		||||
	case SizeCapacityUnitMB:
 | 
			
		||||
		return this.Count * 1024 * 1024
 | 
			
		||||
	case SizeCapacityUnitGB:
 | 
			
		||||
		return this.Count * 1024 * 1024 * 1024
 | 
			
		||||
	default:
 | 
			
		||||
		return this.Count
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										36
									
								
								internal/configs/serverconfigs/shared/time_duration.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								internal/configs/serverconfigs/shared/time_duration.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,36 @@
 | 
			
		||||
package shared
 | 
			
		||||
 | 
			
		||||
import "time"
 | 
			
		||||
 | 
			
		||||
type TimeDurationUnit = string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	TimeDurationUnitMS     TimeDurationUnit = "ms"
 | 
			
		||||
	TimeDurationUnitSecond TimeDurationUnit = "second"
 | 
			
		||||
	TimeDurationUnitMinute TimeDurationUnit = "minute"
 | 
			
		||||
	TimeDurationUnitHour   TimeDurationUnit = "hour"
 | 
			
		||||
	TimeDurationUnitDay    TimeDurationUnit = "day"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 时间间隔
 | 
			
		||||
type TimeDuration struct {
 | 
			
		||||
	Count int64            `yaml:"count" json:"count"` // 数量
 | 
			
		||||
	Unit  TimeDurationUnit `yaml:"unit" json:"unit"`   // 单位
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *TimeDuration) Duration() time.Duration {
 | 
			
		||||
	switch this.Unit {
 | 
			
		||||
	case TimeDurationUnitMS:
 | 
			
		||||
		return time.Duration(this.Count) * time.Millisecond
 | 
			
		||||
	case TimeDurationUnitSecond:
 | 
			
		||||
		return time.Duration(this.Count) * time.Second
 | 
			
		||||
	case TimeDurationUnitMinute:
 | 
			
		||||
		return time.Duration(this.Count) * time.Minute
 | 
			
		||||
	case TimeDurationUnitHour:
 | 
			
		||||
		return time.Duration(this.Count) * time.Hour
 | 
			
		||||
	case TimeDurationUnitDay:
 | 
			
		||||
		return time.Duration(this.Count) * 24 * time.Hour
 | 
			
		||||
	default:
 | 
			
		||||
		return time.Duration(this.Count) * time.Second
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -97,7 +97,7 @@ func (this *SSLConfig) Init() error {
 | 
			
		||||
			if cert == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if !cert.On {
 | 
			
		||||
			if !cert.IsOn {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data, err := ioutil.ReadFile(cert.FullCertPath())
 | 
			
		||||
 
 | 
			
		||||
@@ -19,7 +19,7 @@ import (
 | 
			
		||||
// SSL证书
 | 
			
		||||
type SSLCertConfig struct {
 | 
			
		||||
	Id          string `yaml:"id" json:"id"`
 | 
			
		||||
	On          bool   `yaml:"on" json:"on"`
 | 
			
		||||
	IsOn        bool   `yaml:"isOn" json:"isOn"`
 | 
			
		||||
	Description string `yaml:"description" json:"description"` // 说明
 | 
			
		||||
	CertFile    string `yaml:"certFile" json:"certFile"`
 | 
			
		||||
	KeyFile     string `yaml:"keyFile" json:"keyFile"`
 | 
			
		||||
@@ -39,7 +39,7 @@ type SSLCertConfig struct {
 | 
			
		||||
// 获取新的SSL证书
 | 
			
		||||
func NewSSLCertConfig(certFile string, keyFile string) *SSLCertConfig {
 | 
			
		||||
	return &SSLCertConfig{
 | 
			
		||||
		On:       true,
 | 
			
		||||
		IsOn:     true,
 | 
			
		||||
		Id:       stringutil.Rand(16),
 | 
			
		||||
		CertFile: certFile,
 | 
			
		||||
		KeyFile:  keyFile,
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,7 @@ import (
 | 
			
		||||
// HSTS设置
 | 
			
		||||
// 参考: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Strict-Transport-Security
 | 
			
		||||
type HSTSConfig struct {
 | 
			
		||||
	On                bool     `yaml:"on" json:"on"`
 | 
			
		||||
	IsOn              bool     `yaml:"isOn" json:"isOn"`
 | 
			
		||||
	MaxAge            int      `yaml:"maxAge" json:"maxAge"` // 单位秒
 | 
			
		||||
	IncludeSubDomains bool     `yaml:"includeSubDomains" json:"includeSubDomains"`
 | 
			
		||||
	Preload           bool     `yaml:"preload" json:"preload"`
 | 
			
		||||
 
 | 
			
		||||
@@ -59,8 +59,10 @@ func (this *HTTPListener) Serve() error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *HTTPListener) Close() error {
 | 
			
		||||
	// TODO
 | 
			
		||||
	return nil
 | 
			
		||||
	if this.httpServer != nil {
 | 
			
		||||
		_ = this.httpServer.Close()
 | 
			
		||||
	}
 | 
			
		||||
	return this.Listener.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *HTTPListener) handleHTTP(writer http.ResponseWriter, req *http.Request) {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,9 @@
 | 
			
		||||
package nodes
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs"
 | 
			
		||||
	"github.com/iwind/TeaGo/logs"
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -13,11 +15,99 @@ type TCPListener struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *TCPListener) Serve() error {
 | 
			
		||||
	// TODO
 | 
			
		||||
	for {
 | 
			
		||||
		conn, err := this.Listener.Accept()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		err = this.handleConn(conn)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logs.Println("[TCP_LISTENER]" + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *TCPListener) handleConn(conn net.Conn) error {
 | 
			
		||||
	firstServer := this.Group.FirstServer()
 | 
			
		||||
	if firstServer == nil {
 | 
			
		||||
		return errors.New("no server available")
 | 
			
		||||
	}
 | 
			
		||||
	if firstServer.ReverseProxy == nil {
 | 
			
		||||
		return errors.New("no ReverseProxy configured for the server")
 | 
			
		||||
	}
 | 
			
		||||
	originConn, err := this.connectOrigin(firstServer.ReverseProxy)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var closer = func() {
 | 
			
		||||
		_ = conn.Close()
 | 
			
		||||
		_ = originConn.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		originBuffer := make([]byte, 32*1024) // TODO 需要可以设置,并可以使用Pool
 | 
			
		||||
		for {
 | 
			
		||||
			n, err := originConn.Read(originBuffer)
 | 
			
		||||
			if n > 0 {
 | 
			
		||||
				_, err = conn.Write(originBuffer[:n])
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					closer()
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				closer()
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	clientBuffer := make([]byte, 32*1024) // TODO 需要可以设置,并可以使用Pool
 | 
			
		||||
	for {
 | 
			
		||||
		n, err := conn.Read(clientBuffer)
 | 
			
		||||
		if n > 0 {
 | 
			
		||||
			_, err = originConn.Write(clientBuffer[:n])
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 关闭连接
 | 
			
		||||
	closer()
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *TCPListener) Close() error {
 | 
			
		||||
	// TODO
 | 
			
		||||
	return nil
 | 
			
		||||
	return this.Listener.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig) (conn net.Conn, err error) {
 | 
			
		||||
	if reverseProxy == nil {
 | 
			
		||||
		return nil, errors.New("no reverse proxy config")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	retries := 3
 | 
			
		||||
	for i := 0; i < retries; i++ {
 | 
			
		||||
		origin := reverseProxy.NextOrigin(nil)
 | 
			
		||||
		if origin == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		conn, err = origin.Connect()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logs.Println("[TCP_LISTENER]unable to connect origin: " + origin.Addr.Host + ":" + origin.Addr.PortRange + ": " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		} else {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	err = errors.New("no origin can be used")
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user