mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-02 22:10:25 +08:00
阶段性提交
This commit is contained in:
155
internal/configs/serverconfigs/cache_policy.go
Normal file
155
internal/configs/serverconfigs/cache_policy.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package serverconfigs
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/configutils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/files"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var DefaultSkippedResponseCacheControlValues = []string{"private", "no-cache", "no-store"}
|
||||
|
||||
// 缓存策略配置
|
||||
type CachePolicy struct {
|
||||
Id int `yaml:"id" json:"id"`
|
||||
IsOn bool `yaml:"isOn" json:"isOn"` // 是否开启 TODO
|
||||
Name string `yaml:"name" json:"name"` // 名称
|
||||
|
||||
Key string `yaml:"key" json:"key"` // 每个缓存的Key规则,里面可以有变量
|
||||
Capacity shared.SizeCapacity `yaml:"capacity" json:"capacity"` // 最大内容容量
|
||||
Life shared.TimeDuration `yaml:"life" json:"life"` // 时间
|
||||
Status []int `yaml:"status" json:"status"` // 缓存的状态码列表
|
||||
MaxSize shared.SizeCapacity `yaml:"maxSize" json:"maxSize"` // 能够请求的最大尺寸
|
||||
|
||||
SkipResponseCacheControlValues []string `yaml:"skipCacheControlValues" json:"skipCacheControlValues"` // 可以跳过的响应的Cache-Control值
|
||||
SkipResponseSetCookie bool `yaml:"skipSetCookie" json:"skipSetCookie"` // 是否跳过响应的Set-Cookie Header
|
||||
EnableRequestCachePragma bool `yaml:"enableRequestCachePragma" json:"enableRequestCachePragma"` // 是否支持客户端的Pragma: no-cache
|
||||
|
||||
Cond []*shared.RequestCond `yaml:"cond" json:"cond"`
|
||||
|
||||
life time.Duration
|
||||
maxSize int64
|
||||
capacity int64
|
||||
|
||||
uppercaseSkipCacheControlValues []string
|
||||
|
||||
Type string `yaml:"type" json:"type"` // 类型
|
||||
Options map[string]interface{} `yaml:"options" json:"options"` // 选项
|
||||
}
|
||||
|
||||
// 获取新对象
|
||||
func NewCachePolicy() *CachePolicy {
|
||||
return &CachePolicy{
|
||||
SkipResponseCacheControlValues: DefaultSkippedResponseCacheControlValues,
|
||||
SkipResponseSetCookie: true,
|
||||
}
|
||||
}
|
||||
|
||||
// 从文件中读取缓存策略
|
||||
func NewCachePolicyFromFile(file string) *CachePolicy {
|
||||
if len(file) == 0 {
|
||||
return nil
|
||||
}
|
||||
reader, err := files.NewReader(Tea.ConfigFile(file))
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
_ = reader.Close()
|
||||
}()
|
||||
|
||||
p := NewCachePolicy()
|
||||
err = reader.ReadYAML(p)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// 校验
|
||||
func (this *CachePolicy) Validate() error {
|
||||
var err error
|
||||
this.maxSize = this.MaxSize.Bytes()
|
||||
this.life = this.Life.Duration()
|
||||
this.capacity = this.Capacity.Bytes()
|
||||
|
||||
this.uppercaseSkipCacheControlValues = []string{}
|
||||
for _, value := range this.SkipResponseCacheControlValues {
|
||||
this.uppercaseSkipCacheControlValues = append(this.uppercaseSkipCacheControlValues, strings.ToUpper(value))
|
||||
}
|
||||
|
||||
// cond
|
||||
if len(this.Cond) > 0 {
|
||||
for _, cond := range this.Cond {
|
||||
err := cond.Validate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// 最大数据尺寸
|
||||
func (this *CachePolicy) MaxDataSize() int64 {
|
||||
return this.maxSize
|
||||
}
|
||||
|
||||
// 容量
|
||||
func (this *CachePolicy) CapacitySize() int64 {
|
||||
return this.capacity
|
||||
}
|
||||
|
||||
// 生命周期
|
||||
func (this *CachePolicy) LifeDuration() time.Duration {
|
||||
return this.life
|
||||
}
|
||||
|
||||
// 保存
|
||||
func (this *CachePolicy) Save() error {
|
||||
shared.Locker.Lock()
|
||||
defer shared.Locker.Unlock()
|
||||
|
||||
filename := "cache.policy." + strconv.Itoa(this.Id) + ".conf"
|
||||
writer, err := files.NewWriter(Tea.ConfigFile(filename))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = writer.Close()
|
||||
}()
|
||||
_, err = writer.WriteYAML(this)
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除
|
||||
func (this *CachePolicy) Delete() error {
|
||||
filename := "cache.policy." + strconv.Itoa(this.Id) + ".conf"
|
||||
return files.NewFile(Tea.ConfigFile(filename)).Delete()
|
||||
}
|
||||
|
||||
// 是否包含某个Cache-Control值
|
||||
func (this *CachePolicy) ContainsCacheControl(value string) bool {
|
||||
return lists.ContainsString(this.uppercaseSkipCacheControlValues, strings.ToUpper(value))
|
||||
}
|
||||
|
||||
// 检查是否匹配关键词
|
||||
func (this *CachePolicy) MatchKeyword(keyword string) (matched bool, name string, tags []string) {
|
||||
if configutils.MatchKeyword(this.Name, keyword) || configutils.MatchKeyword(this.Type, keyword) {
|
||||
matched = true
|
||||
name = this.Name
|
||||
if len(this.Type) > 0 {
|
||||
tags = []string{"类型:" + this.Type}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -11,15 +11,15 @@ var globalConfigFile = "global.yaml"
|
||||
// 全局设置
|
||||
type GlobalConfig struct {
|
||||
HTTPAll struct {
|
||||
MatchDomainStrictly bool `yaml:"matchDomainStrictly"`
|
||||
} `yaml:"httpAll"`
|
||||
HTTP struct{} `yaml:"http"`
|
||||
HTTPS struct{} `yaml:"https"`
|
||||
TCPAll struct{} `yaml:"tcpAll"`
|
||||
TCP struct{} `yaml:"tcp"`
|
||||
TLS struct{} `yaml:"tls"`
|
||||
Unix struct{} `yaml:"unix"`
|
||||
UDP struct{} `yaml:"udp"`
|
||||
MatchDomainStrictly bool `yaml:"matchDomainStrictly" json:"matchDomainStrictly"`
|
||||
} `yaml:"httpAll" json:"httpAll"`
|
||||
HTTP struct{} `yaml:"http" json:"http"`
|
||||
HTTPS struct{} `yaml:"https" json:"https"`
|
||||
TCPAll struct{} `yaml:"tcpAll" json:"tcpAll"`
|
||||
TCP struct{} `yaml:"tcp" json:"tcp"`
|
||||
TLS struct{} `yaml:"tls" json:"tls"`
|
||||
Unix struct{} `yaml:"unix" json:"unix"`
|
||||
UDP struct{} `yaml:"udp" json:"udp"`
|
||||
}
|
||||
|
||||
func SharedGlobalConfig() *GlobalConfig {
|
||||
|
||||
@@ -1,10 +1,181 @@
|
||||
package serverconfigs
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/sslconfigs"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 源站服务配置
|
||||
type OriginServerConfig struct {
|
||||
Id string `yaml:"id" json:"id"` // ID
|
||||
IsOn bool `yaml:"isOn" json:"isOn"` // 是否启用
|
||||
HeaderList *shared.HeaderList `yaml:"headers" json:"headers"`
|
||||
|
||||
Id int64 `yaml:"id" json:"id"` // ID
|
||||
IsOn bool `yaml:"isOn" json:"isOn"` // 是否启用 TODO
|
||||
Version int `yaml:"version" json:"version"` // 版本
|
||||
Name string `yaml:"name" json:"name"` // 名称 TODO
|
||||
Addr *NetworkAddressConfig `yaml:"addr" json:"addr"` // 地址
|
||||
Description string `yaml:"description" json:"description"` // 描述 TODO
|
||||
Code string `yaml:"code" json:"code"` // 代号 TODO
|
||||
Scheme string `yaml:"scheme" json:"scheme"` // 协议 TODO
|
||||
|
||||
Weight uint `yaml:"weight" json:"weight"` // 权重 TODO
|
||||
IsBackup bool `yaml:"backup" json:"isBackup"` // 是否为备份 TODO
|
||||
FailTimeout string `yaml:"failTimeout" json:"failTimeout"` // 连接失败超时 TODO
|
||||
ReadTimeout string `yaml:"readTimeout" json:"readTimeout"` // 读取超时时间 TODO
|
||||
IdleTimeout string `yaml:"idleTimeout" json:"idleTimeout"` // 空闲连接超时时间 TODO
|
||||
MaxFails int32 `yaml:"maxFails" json:"maxFails"` // 最多失败次数 TODO
|
||||
CurrentFails int32 `yaml:"currentFails" json:"currentFails"` // 当前已失败次数 TODO
|
||||
MaxConns int32 `yaml:"maxConns" json:"maxConns"` // 最大并发连接数 TODO
|
||||
CurrentConns int32 `yaml:"currentConns" json:"currentConns"` // 当前连接数 TODO
|
||||
IdleConns int32 `yaml:"idleConns" json:"idleConns"` // 最大空闲连接数 TODO
|
||||
|
||||
IsDown bool `yaml:"down" json:"isDown"` // 是否下线 TODO
|
||||
DownTime time.Time `yaml:"downTime,omitempty" json:"downTime,omitempty"` // 下线时间 TODO
|
||||
|
||||
RequestURI string `yaml:"requestURI" json:"requestURI"` // 转发后的请求URI TODO
|
||||
ResponseHeaders []*shared.HeaderConfig `yaml:"responseHeaders" json:"responseHeaders"` // 响应Header TODO
|
||||
Host string `yaml:"host" json:"host"` // 自定义主机名 TODO
|
||||
|
||||
// 健康检查URL,目前支持:
|
||||
// - http|https 返回2xx-3xx认为成功
|
||||
HealthCheck struct {
|
||||
IsOn bool `yaml:"isOn" json:"isOn"` // 是否开启 TODO
|
||||
URL string `yaml:"url" json:"url"` // TODO
|
||||
Interval int `yaml:"interval" json:"interval"` // TODO
|
||||
StatusCodes []int `yaml:"statusCodes" json:"statusCodes"` // TODO
|
||||
Timeout shared.TimeDuration `yaml:"timeout" json:"timeout"` // 超时时间 TODO
|
||||
} `yaml:"healthCheck" json:"healthCheck"`
|
||||
|
||||
Cert *sslconfigs.SSLCertConfig `yaml:"cert" json:"cert"` // 请求源服务器用的证书
|
||||
|
||||
// ftp
|
||||
FTP *OriginServerFTPConfig `yaml:"ftp" json:"ftp"`
|
||||
|
||||
failTimeoutDuration time.Duration
|
||||
readTimeoutDuration time.Duration
|
||||
idleTimeoutDuration time.Duration
|
||||
|
||||
hasRequestURI bool
|
||||
requestPath string
|
||||
requestArgs string
|
||||
|
||||
hasRequestHeaders bool
|
||||
hasResponseHeaders bool
|
||||
|
||||
hasHost bool
|
||||
|
||||
uniqueKey string
|
||||
|
||||
hasAddrVariables bool // 地址中是否含有变量
|
||||
}
|
||||
|
||||
// 校验
|
||||
func (this *OriginServerConfig) Init() error {
|
||||
// 证书
|
||||
if this.Cert != nil {
|
||||
err := this.Cert.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// unique key
|
||||
this.uniqueKey = strconv.FormatInt(this.Id, 10) + "@" + fmt.Sprintf("%d", this.Version)
|
||||
|
||||
// failTimeout
|
||||
if len(this.FailTimeout) > 0 {
|
||||
this.failTimeoutDuration, _ = time.ParseDuration(this.FailTimeout)
|
||||
}
|
||||
|
||||
// readTimeout
|
||||
if len(this.ReadTimeout) > 0 {
|
||||
this.readTimeoutDuration, _ = time.ParseDuration(this.ReadTimeout)
|
||||
}
|
||||
|
||||
// idleTimeout
|
||||
if len(this.IdleTimeout) > 0 {
|
||||
this.idleTimeoutDuration, _ = time.ParseDuration(this.IdleTimeout)
|
||||
}
|
||||
|
||||
// Headers
|
||||
if this.HeaderList != nil {
|
||||
err := this.HeaderList.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// request uri
|
||||
if len(this.RequestURI) == 0 || this.RequestURI == "${requestURI}" {
|
||||
this.hasRequestURI = false
|
||||
} else {
|
||||
this.hasRequestURI = true
|
||||
|
||||
if strings.Contains(this.RequestURI, "?") {
|
||||
pieces := strings.SplitN(this.RequestURI, "?", -1)
|
||||
this.requestPath = pieces[0]
|
||||
this.requestArgs = pieces[1]
|
||||
} else {
|
||||
this.requestPath = this.RequestURI
|
||||
}
|
||||
}
|
||||
|
||||
// TODO init health check
|
||||
|
||||
// headers
|
||||
if this.HeaderList != nil {
|
||||
this.hasRequestHeaders = len(this.HeaderList.RequestHeaders) > 0
|
||||
}
|
||||
this.hasResponseHeaders = len(this.ResponseHeaders) > 0
|
||||
|
||||
// host
|
||||
this.hasHost = len(this.Host) > 0
|
||||
|
||||
// variables
|
||||
// TODO 在host和port中支持变量
|
||||
this.hasAddrVariables = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 候选对象代号
|
||||
func (this *OriginServerConfig) CandidateCodes() []string {
|
||||
codes := []string{strconv.FormatInt(this.Id, 10)}
|
||||
if len(this.Code) > 0 {
|
||||
codes = append(codes, this.Code)
|
||||
}
|
||||
return codes
|
||||
}
|
||||
|
||||
// 候选对象权重
|
||||
func (this *OriginServerConfig) CandidateWeight() uint {
|
||||
return this.Weight
|
||||
}
|
||||
|
||||
// 连接源站
|
||||
func (this *OriginServerConfig) Connect() (net.Conn, error) {
|
||||
switch this.Scheme {
|
||||
case "", ProtocolTCP:
|
||||
// TODO 支持TCP4/TCP6
|
||||
// TODO 支持指定特定网卡
|
||||
// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
|
||||
return net.DialTimeout("tcp", this.Addr.Host+":"+this.Addr.PortRange, this.failTimeoutDuration)
|
||||
case ProtocolTLS:
|
||||
// TODO 支持TCP4/TCP6
|
||||
// TODO 支持指定特定网卡
|
||||
// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
|
||||
// TODO 支持使用证书
|
||||
return tls.Dial("tcp", this.Addr.Host+":"+this.Addr.PortRange, &tls.Config{})
|
||||
}
|
||||
|
||||
// TODO 支持从Unix、Pipe、HTTP、HTTPS中读取数据
|
||||
|
||||
return nil, errors.New("invalid scheme '" + this.Scheme + "'")
|
||||
}
|
||||
|
||||
8
internal/configs/serverconfigs/origin_server_ftp.go
Normal file
8
internal/configs/serverconfigs/origin_server_ftp.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package serverconfigs
|
||||
|
||||
// FTP源站配置
|
||||
type OriginServerFTPConfig struct {
|
||||
Username string `yaml:"username" json:"username"` // 用户名
|
||||
Password string `yaml:"password" json:"password"` // 密码
|
||||
Dir string `yaml:"dir" json:"dir"` // 目录
|
||||
}
|
||||
@@ -5,9 +5,10 @@ import "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/sslconfigs"
|
||||
type TLSProtocolConfig struct {
|
||||
BaseProtocol `yaml:",inline"`
|
||||
|
||||
SSL *sslconfigs.SSLConfig `yaml:"ssl"`
|
||||
SSL *sslconfigs.SSLConfig `yaml:"ssl" json:"ssl"`
|
||||
}
|
||||
|
||||
// 初始化
|
||||
func (this *TLSProtocolConfig) Init() error {
|
||||
err := this.InitBase()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,105 @@
|
||||
package serverconfigs
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/scheduling"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// 反向代理设置
|
||||
type ReverseProxyConfig struct {
|
||||
IsOn bool `yaml:"isOn" json:"isOn"` // 是否启用
|
||||
Origins []*OriginServerConfig `yaml:"origins" json:"origins"` // 源站列表
|
||||
IsOn bool `yaml:"isOn" json:"isOn"` // 是否启用 TODO
|
||||
Origins []*OriginServerConfig `yaml:"origins" json:"origins"` // 源站列表
|
||||
Scheduling *SchedulingConfig `yaml:"scheduling" json:"scheduling"` // 调度算法选项
|
||||
|
||||
hasOrigins bool
|
||||
schedulingIsBackup bool
|
||||
schedulingObject scheduling.SchedulingInterface
|
||||
schedulingLocker sync.Mutex
|
||||
}
|
||||
|
||||
// 初始化
|
||||
func (this *ReverseProxyConfig) Init() error {
|
||||
this.hasOrigins = len(this.Origins) > 0
|
||||
|
||||
for _, origin := range this.Origins {
|
||||
err := origin.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// scheduling
|
||||
this.SetupScheduling(false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 取得下一个可用的后端服务
|
||||
func (this *ReverseProxyConfig) NextOrigin(call *shared.RequestCall) *OriginServerConfig {
|
||||
this.schedulingLocker.Lock()
|
||||
defer this.schedulingLocker.Unlock()
|
||||
|
||||
if this.schedulingObject == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if this.Scheduling != nil && call != nil && call.Options != nil {
|
||||
for k, v := range this.Scheduling.Options {
|
||||
call.Options[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
candidate := this.schedulingObject.Next(call)
|
||||
if candidate == nil {
|
||||
// 启用备用服务器
|
||||
if !this.schedulingIsBackup {
|
||||
this.SetupScheduling(true)
|
||||
|
||||
candidate = this.schedulingObject.Next(call)
|
||||
if candidate == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if candidate == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return candidate.(*OriginServerConfig)
|
||||
}
|
||||
|
||||
// 设置调度算法
|
||||
func (this *ReverseProxyConfig) SetupScheduling(isBackup bool) {
|
||||
if !isBackup {
|
||||
this.schedulingLocker.Lock()
|
||||
defer this.schedulingLocker.Unlock()
|
||||
}
|
||||
this.schedulingIsBackup = isBackup
|
||||
|
||||
if this.Scheduling == nil {
|
||||
this.schedulingObject = &scheduling.RandomScheduling{}
|
||||
} else {
|
||||
typeCode := this.Scheduling.Code
|
||||
s := scheduling.FindSchedulingType(typeCode)
|
||||
if s == nil {
|
||||
this.Scheduling = nil
|
||||
this.schedulingObject = &scheduling.RandomScheduling{}
|
||||
} else {
|
||||
this.schedulingObject = s["instance"].(scheduling.SchedulingInterface)
|
||||
}
|
||||
}
|
||||
|
||||
for _, origin := range this.Origins {
|
||||
if origin.IsOn && !origin.IsDown {
|
||||
if isBackup && origin.IsBackup {
|
||||
this.schedulingObject.Add(origin)
|
||||
} else if !isBackup && !origin.IsBackup {
|
||||
this.schedulingObject.Add(origin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.schedulingObject.Start()
|
||||
}
|
||||
|
||||
10
internal/configs/serverconfigs/scheduling/candidate.go
Normal file
10
internal/configs/serverconfigs/scheduling/candidate.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package scheduling
|
||||
|
||||
// 候选对象接口
|
||||
type CandidateInterface interface {
|
||||
// 权重
|
||||
CandidateWeight() uint
|
||||
|
||||
// 代号
|
||||
CandidateCodes() []string
|
||||
}
|
||||
39
internal/configs/serverconfigs/scheduling/scheduling.go
Normal file
39
internal/configs/serverconfigs/scheduling/scheduling.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
)
|
||||
|
||||
// 调度算法接口
|
||||
type SchedulingInterface interface {
|
||||
// 是否有候选对象
|
||||
HasCandidates() bool
|
||||
|
||||
// 添加候选对象
|
||||
Add(candidate ...CandidateInterface)
|
||||
|
||||
// 启动
|
||||
Start()
|
||||
|
||||
// 查找下一个候选对象
|
||||
Next(call *shared.RequestCall) CandidateInterface
|
||||
|
||||
// 获取简要信息
|
||||
Summary() maps.Map
|
||||
}
|
||||
|
||||
// 调度算法基础类
|
||||
type Scheduling struct {
|
||||
Candidates []CandidateInterface
|
||||
}
|
||||
|
||||
// 判断是否有候选对象
|
||||
func (this *Scheduling) HasCandidates() bool {
|
||||
return len(this.Candidates) > 0
|
||||
}
|
||||
|
||||
// 添加候选对象
|
||||
func (this *Scheduling) Add(candidate ...CandidateInterface) {
|
||||
this.Candidates = append(this.Candidates, candidate...)
|
||||
}
|
||||
45
internal/configs/serverconfigs/scheduling/scheduling_hash.go
Normal file
45
internal/configs/serverconfigs/scheduling/scheduling_hash.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"hash/crc32"
|
||||
)
|
||||
|
||||
// Hash调度算法
|
||||
type HashScheduling struct {
|
||||
Scheduling
|
||||
|
||||
count uint32
|
||||
}
|
||||
|
||||
// 启动
|
||||
func (this *HashScheduling) Start() {
|
||||
this.count = uint32(len(this.Candidates))
|
||||
}
|
||||
|
||||
// 获取下一个候选对象
|
||||
func (this *HashScheduling) Next(call *shared.RequestCall) CandidateInterface {
|
||||
if this.count == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := call.Options.GetString("key")
|
||||
|
||||
if call.Formatter != nil {
|
||||
key = call.Formatter(key)
|
||||
}
|
||||
|
||||
sum := crc32.ChecksumIEEE([]byte(key))
|
||||
return this.Candidates[sum%this.count]
|
||||
}
|
||||
|
||||
// 获取简要信息
|
||||
func (this *HashScheduling) Summary() maps.Map {
|
||||
return maps.Map{
|
||||
"code": "hash",
|
||||
"name": "Hash算法",
|
||||
"description": "根据自定义的键值的Hash值分配后端服务器",
|
||||
"networks": []string{"http"},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHashScheduling_Next(t *testing.T) {
|
||||
s := &HashScheduling{}
|
||||
s.Add(&TestCandidate{
|
||||
Name: "a",
|
||||
Weight: 10,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "b",
|
||||
Weight: 10,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "c",
|
||||
Weight: 10,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "d",
|
||||
Weight: 30,
|
||||
})
|
||||
s.Start()
|
||||
|
||||
hits := map[string]uint{}
|
||||
for _, c := range s.Candidates {
|
||||
hits[c.(*TestCandidate).Name] = 0
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
for i := 0; i < 1000000; i ++ {
|
||||
call := shared.NewRequestCall()
|
||||
call.Options["key"] = "192.168.1." + fmt.Sprintf("%d", rand.Int())
|
||||
|
||||
c := s.Next(call)
|
||||
hits[c.(*TestCandidate).Name] ++
|
||||
}
|
||||
t.Log(hits)
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"math"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 随机调度算法
|
||||
type RandomScheduling struct {
|
||||
Scheduling
|
||||
|
||||
array []CandidateInterface
|
||||
count uint // 实际总的服务器数
|
||||
}
|
||||
|
||||
// 启动
|
||||
func (this *RandomScheduling) Start() {
|
||||
sumWeight := uint(0)
|
||||
for _, c := range this.Candidates {
|
||||
weight := c.CandidateWeight()
|
||||
if weight == 0 {
|
||||
weight = 1
|
||||
} else if weight > 10000 {
|
||||
weight = 10000
|
||||
}
|
||||
sumWeight += weight
|
||||
}
|
||||
|
||||
if sumWeight == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, c := range this.Candidates {
|
||||
weight := c.CandidateWeight()
|
||||
if weight == 0 {
|
||||
weight = 1
|
||||
} else if weight > 10000 {
|
||||
weight = 10000
|
||||
}
|
||||
count := uint(0)
|
||||
if sumWeight <= 1000 {
|
||||
count = weight
|
||||
} else {
|
||||
count = uint(math.Round(float64(weight*10000) / float64(sumWeight))) // 1% 产生 100个数据,最多支持10000个服务器
|
||||
}
|
||||
for i := uint(0); i < count; i++ {
|
||||
this.array = append(this.array, c)
|
||||
}
|
||||
this.count += count
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// 获取下一个候选对象
|
||||
func (this *RandomScheduling) Next(call *shared.RequestCall) CandidateInterface {
|
||||
if this.count == 0 {
|
||||
return nil
|
||||
}
|
||||
if this.count == 1 {
|
||||
return this.array[0]
|
||||
}
|
||||
index := rand.Int() % int(this.count)
|
||||
return this.array[index]
|
||||
}
|
||||
|
||||
// 获取简要信息
|
||||
func (this *RandomScheduling) Summary() maps.Map {
|
||||
return maps.Map{
|
||||
"code": "random",
|
||||
"name": "Random随机算法",
|
||||
"description": "根据权重设置随机分配后端服务器",
|
||||
"networks": []string{"http", "tcp"},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type TestCandidate struct {
|
||||
Name string
|
||||
Weight uint
|
||||
}
|
||||
|
||||
func (this *TestCandidate) CandidateWeight() uint {
|
||||
return this.Weight
|
||||
}
|
||||
|
||||
func (this *TestCandidate) CandidateCodes() []string {
|
||||
return []string{this.Name}
|
||||
}
|
||||
|
||||
func TestRandomScheduling_Next(t *testing.T) {
|
||||
s := &RandomScheduling{}
|
||||
s.Add(&TestCandidate{
|
||||
Name: "a",
|
||||
Weight: 10,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "b",
|
||||
Weight: 10,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "c",
|
||||
Weight: 10,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "d",
|
||||
Weight: 30,
|
||||
})
|
||||
s.Start()
|
||||
|
||||
/**for _, c := range s.array {
|
||||
t.Log(c.(*TestCandidate).Name, ":", c.CandidateWeight())
|
||||
}**/
|
||||
|
||||
hits := map[string]uint{}
|
||||
for _, c := range s.array {
|
||||
hits[c.(*TestCandidate).Name] = 0
|
||||
}
|
||||
|
||||
t.Log("count:", s.count, "array length:", len(s.array))
|
||||
|
||||
var locker sync.Mutex
|
||||
var wg = sync.WaitGroup{}
|
||||
wg.Add(100 * 10000)
|
||||
for i := 0; i < 100*10000; i ++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
c := s.Next(nil)
|
||||
|
||||
locker.Lock()
|
||||
defer locker.Unlock()
|
||||
hits[c.(*TestCandidate).Name] ++
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
t.Log(hits)
|
||||
}
|
||||
|
||||
func TestRandomScheduling_NextZero(t *testing.T) {
|
||||
s := &RandomScheduling{}
|
||||
s.Add(&TestCandidate{
|
||||
Name: "a",
|
||||
Weight: 0,
|
||||
})
|
||||
s.Start()
|
||||
t.Log(s.Next(nil))
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// 轮询调度算法
|
||||
type RoundRobinScheduling struct {
|
||||
Scheduling
|
||||
|
||||
rawWeights []uint
|
||||
currentWeights []uint
|
||||
count uint
|
||||
index uint
|
||||
|
||||
locker sync.Mutex
|
||||
}
|
||||
|
||||
// 启动
|
||||
func (this *RoundRobinScheduling) Start() {
|
||||
lists.Sort(this.Candidates, func(i int, j int) bool {
|
||||
c1 := this.Candidates[i]
|
||||
c2 := this.Candidates[j]
|
||||
return c1.CandidateWeight() > c2.CandidateWeight()
|
||||
})
|
||||
|
||||
for _, c := range this.Candidates {
|
||||
weight := c.CandidateWeight()
|
||||
if weight == 0 {
|
||||
weight = 1
|
||||
} else if weight > 10000 {
|
||||
weight = 10000
|
||||
}
|
||||
this.rawWeights = append(this.rawWeights, weight)
|
||||
}
|
||||
|
||||
this.currentWeights = append([]uint{}, this.rawWeights...)
|
||||
this.count = uint(len(this.Candidates))
|
||||
}
|
||||
|
||||
// 获取下一个候选对象
|
||||
func (this *RoundRobinScheduling) Next(call *shared.RequestCall) CandidateInterface {
|
||||
if this.count == 0 {
|
||||
return nil
|
||||
}
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
if this.index > this.count-1 {
|
||||
this.index = 0
|
||||
}
|
||||
weight := this.currentWeights[this.index]
|
||||
|
||||
// 已经一轮了,则重置状态
|
||||
if weight == 0 {
|
||||
if this.currentWeights[0] == 0 {
|
||||
this.currentWeights = append([]uint{}, this.rawWeights...)
|
||||
}
|
||||
this.index = 0
|
||||
weight = this.currentWeights[this.index]
|
||||
}
|
||||
|
||||
c := this.Candidates[this.index]
|
||||
this.currentWeights[this.index] --
|
||||
this.index++
|
||||
return c
|
||||
}
|
||||
|
||||
// 获取简要信息
|
||||
func (this *RoundRobinScheduling) Summary() maps.Map {
|
||||
return maps.Map{
|
||||
"code": "roundRobin",
|
||||
"name": "RoundRobin轮询算法",
|
||||
"description": "根据权重,依次分配后端服务器",
|
||||
"networks": []string{"http", "tcp"},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package scheduling
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRoundRobinScheduling_Next(t *testing.T) {
|
||||
s := &RoundRobinScheduling{}
|
||||
s.Add(&TestCandidate{
|
||||
Name: "a",
|
||||
Weight: 5,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "b",
|
||||
Weight: 10,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "c",
|
||||
Weight: 20,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "d",
|
||||
Weight: 30,
|
||||
})
|
||||
s.Start()
|
||||
|
||||
for _, c := range s.Candidates {
|
||||
t.Log(c.(*TestCandidate).Name, c.CandidateWeight())
|
||||
}
|
||||
|
||||
t.Log(s.currentWeights)
|
||||
|
||||
for i := 0; i < 100; i ++ {
|
||||
t.Log("===", "round", i, "===")
|
||||
t.Log(s.Next(nil))
|
||||
t.Log(s.currentWeights)
|
||||
t.Log(s.rawWeights)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinScheduling_Two(t *testing.T) {
|
||||
s := &RoundRobinScheduling{}
|
||||
s.Add(&TestCandidate{
|
||||
Name: "a",
|
||||
Weight: 10,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "b",
|
||||
Weight: 10,
|
||||
})
|
||||
s.Start()
|
||||
|
||||
for _, c := range s.Candidates {
|
||||
t.Log(c.(*TestCandidate).Name, c.CandidateWeight())
|
||||
}
|
||||
|
||||
t.Log(s.currentWeights)
|
||||
|
||||
for i := 0; i < 100; i ++ {
|
||||
t.Log("===", "round", i, "===")
|
||||
t.Log(s.Next(nil))
|
||||
t.Log(s.currentWeights)
|
||||
t.Log(s.rawWeights)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinScheduling_NextPerformance(t *testing.T) {
|
||||
s := &RoundRobinScheduling{}
|
||||
s.Add(&TestCandidate{
|
||||
Name: "a",
|
||||
Weight: 1,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "b",
|
||||
Weight: 2,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "c",
|
||||
Weight: 3,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "d",
|
||||
Weight: 6,
|
||||
})
|
||||
s.Start()
|
||||
|
||||
for _, c := range s.Candidates {
|
||||
t.Log(c.(*TestCandidate).Name, c.CandidateWeight())
|
||||
}
|
||||
|
||||
t.Log(s.currentWeights)
|
||||
|
||||
hits := map[string]uint{}
|
||||
for _, c := range s.Candidates {
|
||||
hits[c.(*TestCandidate).Name] = 0
|
||||
}
|
||||
for i := 0; i < 100*10000; i ++ {
|
||||
c := s.Next(nil)
|
||||
hits[c.(*TestCandidate).Name] ++
|
||||
}
|
||||
|
||||
t.Log(hits)
|
||||
}
|
||||
106
internal/configs/serverconfigs/scheduling/scheduling_sticky.go
Normal file
106
internal/configs/serverconfigs/scheduling/scheduling_sticky.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Sticky调度算法
|
||||
type StickyScheduling struct {
|
||||
Scheduling
|
||||
|
||||
count uint32
|
||||
mapping map[string]CandidateInterface // code => candidate
|
||||
}
|
||||
|
||||
// 启动
|
||||
func (this *StickyScheduling) Start() {
|
||||
this.mapping = map[string]CandidateInterface{}
|
||||
for _, c := range this.Candidates {
|
||||
for _, code := range c.CandidateCodes() {
|
||||
this.mapping[code] = c
|
||||
}
|
||||
}
|
||||
|
||||
this.count = uint32(len(this.Candidates))
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// 获取下一个候选对象
|
||||
func (this *StickyScheduling) Next(call *shared.RequestCall) CandidateInterface {
|
||||
if this.count == 0 {
|
||||
return nil
|
||||
}
|
||||
typeCode := call.Options.GetString("type")
|
||||
param := call.Options.GetString("param")
|
||||
|
||||
if call.Request == nil {
|
||||
return this.Candidates[uint32(rand.Int())%this.count]
|
||||
}
|
||||
|
||||
code := ""
|
||||
if typeCode == "cookie" {
|
||||
cookie, err := call.Request.Cookie(param)
|
||||
if err == nil {
|
||||
code = cookie.Value
|
||||
}
|
||||
} else if typeCode == "header" {
|
||||
code = call.Request.Header.Get(param)
|
||||
} else if typeCode == "argument" {
|
||||
code = call.Request.URL.Query().Get(param)
|
||||
}
|
||||
|
||||
matched := false
|
||||
var c CandidateInterface = nil
|
||||
|
||||
defer func() {
|
||||
if !matched && c != nil {
|
||||
codes := c.CandidateCodes()
|
||||
if len(codes) == 0 {
|
||||
return
|
||||
}
|
||||
if typeCode == "cookie" {
|
||||
call.AddResponseCall(func(resp http.ResponseWriter) {
|
||||
http.SetCookie(resp, &http.Cookie{
|
||||
Name: param,
|
||||
Value: codes[0],
|
||||
Path: "/",
|
||||
Expires: time.Now().AddDate(0, 1, 0),
|
||||
})
|
||||
})
|
||||
} else {
|
||||
call.AddResponseCall(func(resp http.ResponseWriter) {
|
||||
resp.Header().Set(param, codes[0])
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if len(code) == 0 {
|
||||
c = this.Candidates[uint32(rand.Int())%this.count]
|
||||
return c
|
||||
}
|
||||
|
||||
found := false
|
||||
c, found = this.mapping[code]
|
||||
if !found {
|
||||
c = this.Candidates[uint32(rand.Int())%this.count]
|
||||
return c
|
||||
}
|
||||
|
||||
matched = true
|
||||
return c
|
||||
}
|
||||
|
||||
// 获取简要信息
|
||||
func (this *StickyScheduling) Summary() maps.Map {
|
||||
return maps.Map{
|
||||
"code": "sticky",
|
||||
"name": "Sticky算法",
|
||||
"description": "利用Cookie、URL参数或者HTTP Header来指定后端服务器",
|
||||
"networks": []string{"http"},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStickyScheduling_NextArgument(t *testing.T) {
|
||||
s := &StickyScheduling{}
|
||||
s.Add(&TestCandidate{
|
||||
Name: "a",
|
||||
Weight: 1,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "b",
|
||||
Weight: 2,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "c",
|
||||
Weight: 3,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "d",
|
||||
Weight: 6,
|
||||
})
|
||||
s.Start()
|
||||
|
||||
t.Log(s.mapping)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://www.example.com/?backend=c", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
options := maps.Map{
|
||||
"type": "argument",
|
||||
"param": "backend",
|
||||
}
|
||||
call := shared.NewRequestCall()
|
||||
call.Request = req
|
||||
call.Options = options
|
||||
t.Log(s.Next(call))
|
||||
t.Log(options)
|
||||
}
|
||||
|
||||
func TestStickyScheduling_NextCookie(t *testing.T) {
|
||||
s := &StickyScheduling{}
|
||||
s.Add(&TestCandidate{
|
||||
Name: "a",
|
||||
Weight: 1,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "b",
|
||||
Weight: 2,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "c",
|
||||
Weight: 3,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "d",
|
||||
Weight: 6,
|
||||
})
|
||||
s.Start()
|
||||
|
||||
t.Log(s.mapping)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://www.example.com/?backend=c", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "backend",
|
||||
Value: "c",
|
||||
})
|
||||
|
||||
options := maps.Map{
|
||||
"type": "cookie",
|
||||
"param": "backend",
|
||||
}
|
||||
call := shared.NewRequestCall()
|
||||
call.Request = req
|
||||
call.Options = options
|
||||
t.Log(s.Next(call))
|
||||
t.Log(options)
|
||||
}
|
||||
|
||||
func TestStickyScheduling_NextHeader(t *testing.T) {
|
||||
s := &StickyScheduling{}
|
||||
s.Add(&TestCandidate{
|
||||
Name: "a",
|
||||
Weight: 1,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "b",
|
||||
Weight: 2,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "c",
|
||||
Weight: 3,
|
||||
})
|
||||
s.Add(&TestCandidate{
|
||||
Name: "d",
|
||||
Weight: 6,
|
||||
})
|
||||
s.Start()
|
||||
|
||||
t.Log(s.mapping)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://www.example.com/?backend=c", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("backend", "c")
|
||||
|
||||
options := maps.Map{
|
||||
"type": "header",
|
||||
"param": "backend",
|
||||
}
|
||||
call := shared.NewRequestCall()
|
||||
call.Request = req
|
||||
call.Options = options
|
||||
t.Log(s.Next(call))
|
||||
t.Log(options)
|
||||
}
|
||||
28
internal/configs/serverconfigs/scheduling/utils.go
Normal file
28
internal/configs/serverconfigs/scheduling/utils.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package scheduling
|
||||
|
||||
import "github.com/iwind/TeaGo/maps"
|
||||
|
||||
// 所有请求类型
|
||||
func AllSchedulingTypes() []maps.Map {
|
||||
types := []maps.Map{}
|
||||
for _, s := range []SchedulingInterface{
|
||||
new(RandomScheduling),
|
||||
new(RoundRobinScheduling),
|
||||
new(HashScheduling),
|
||||
new(StickyScheduling),
|
||||
} {
|
||||
summary := s.Summary()
|
||||
summary["instance"] = s
|
||||
types = append(types, summary)
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
||||
func FindSchedulingType(code string) maps.Map {
|
||||
for _, summary := range AllSchedulingTypes() {
|
||||
if summary["code"] == code {
|
||||
return summary
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
14
internal/configs/serverconfigs/scheduling_config.go
Normal file
14
internal/configs/serverconfigs/scheduling_config.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package serverconfigs
|
||||
|
||||
import "github.com/iwind/TeaGo/maps"
|
||||
|
||||
// 调度算法配置
|
||||
type SchedulingConfig struct {
|
||||
Code string `yaml:"code" json:"code"` // 类型
|
||||
Options maps.Map `yaml:"options" json:"options"` // 选项
|
||||
}
|
||||
|
||||
// 获取新对象
|
||||
func NewSchedulingConfig() *SchedulingConfig {
|
||||
return &SchedulingConfig{}
|
||||
}
|
||||
@@ -76,6 +76,13 @@ func (this *ServerConfig) Init() error {
|
||||
}
|
||||
}
|
||||
|
||||
if this.ReverseProxy != nil {
|
||||
err := this.ReverseProxy.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
68
internal/configs/serverconfigs/shared/header.go
Normal file
68
internal/configs/serverconfigs/shared/header.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/utils/string"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var regexpNamedVariable = regexp.MustCompile("\\${[\\w.-]+}")
|
||||
|
||||
// 头部信息定义
|
||||
type HeaderConfig struct {
|
||||
IsOn bool `yaml:"isOn" json:"isOn"` // 是否开启
|
||||
Id string `yaml:"id" json:"id"` // ID
|
||||
Name string `yaml:"name" json:"name"` // Name
|
||||
Value string `yaml:"value" json:"value"` // Value
|
||||
Always bool `yaml:"always" json:"always"` // 是否忽略状态码
|
||||
Status []int `yaml:"status" json:"status"` // 支持的状态码
|
||||
|
||||
statusMap map[int]bool
|
||||
hasVariables bool
|
||||
}
|
||||
|
||||
// 获取新Header对象
|
||||
func NewHeaderConfig() *HeaderConfig {
|
||||
return &HeaderConfig{
|
||||
IsOn: true,
|
||||
Id: stringutil.Rand(16),
|
||||
}
|
||||
}
|
||||
|
||||
// 校验
|
||||
func (this *HeaderConfig) Validate() error {
|
||||
this.statusMap = map[int]bool{}
|
||||
this.hasVariables = regexpNamedVariable.MatchString(this.Value)
|
||||
|
||||
if this.Status == nil {
|
||||
this.Status = []int{}
|
||||
}
|
||||
|
||||
for _, status := range this.Status {
|
||||
this.statusMap[status] = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 判断是否匹配状态码
|
||||
func (this *HeaderConfig) Match(statusCode int) bool {
|
||||
if !this.IsOn {
|
||||
return false
|
||||
}
|
||||
|
||||
if this.Always {
|
||||
return true
|
||||
}
|
||||
|
||||
if this.statusMap != nil {
|
||||
_, found := this.statusMap[statusCode]
|
||||
return found
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// 是否有变量
|
||||
func (this *HeaderConfig) HasVariables() bool {
|
||||
return this.hasVariables
|
||||
}
|
||||
245
internal/configs/serverconfigs/shared/header_list.go
Normal file
245
internal/configs/serverconfigs/shared/header_list.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// HeaderList相关操作接口
|
||||
type HeaderListInterface interface {
|
||||
// 校验
|
||||
ValidateHeaders() error
|
||||
|
||||
// 取得所有的IgnoreHeader
|
||||
AllIgnoreResponseHeaders() []string
|
||||
|
||||
// 添加IgnoreHeader
|
||||
AddIgnoreResponseHeader(name string)
|
||||
|
||||
// 判断是否包含IgnoreHeader
|
||||
ContainsIgnoreResponseHeader(name string) bool
|
||||
|
||||
// 移除IgnoreHeader
|
||||
RemoveIgnoreResponseHeader(name string)
|
||||
|
||||
// 修改IgnoreHeader
|
||||
UpdateIgnoreResponseHeader(oldName string, newName string)
|
||||
|
||||
// 取得所有的Header
|
||||
AllResponseHeaders() []*HeaderConfig
|
||||
|
||||
// 添加Header
|
||||
AddResponseHeader(header *HeaderConfig)
|
||||
|
||||
// 判断是否包含Header
|
||||
ContainsResponseHeader(name string) bool
|
||||
|
||||
// 查找Header
|
||||
FindResponseHeader(headerId string) *HeaderConfig
|
||||
|
||||
// 移除Header
|
||||
RemoveResponseHeader(headerId string)
|
||||
|
||||
// 取得所有的请求Header
|
||||
AllRequestHeaders() []*HeaderConfig
|
||||
|
||||
// 添加请求Header
|
||||
AddRequestHeader(header *HeaderConfig)
|
||||
|
||||
// 查找请求Header
|
||||
FindRequestHeader(headerId string) *HeaderConfig
|
||||
|
||||
// 移除请求Header
|
||||
RemoveRequestHeader(headerId string)
|
||||
}
|
||||
|
||||
// HeaderList定义
|
||||
type HeaderList struct {
|
||||
// 添加的响应Headers
|
||||
Headers []*HeaderConfig `yaml:"headers" json:"headers"`
|
||||
|
||||
// 忽略的响应Headers
|
||||
IgnoreHeaders []string `yaml:"ignoreHeaders" json:"ignoreHeaders"`
|
||||
|
||||
// 自定义请求Headers
|
||||
RequestHeaders []*HeaderConfig `yaml:"requestHeaders" json:"requestHeaders"`
|
||||
|
||||
hasResponseHeaders bool
|
||||
hasRequestHeaders bool
|
||||
|
||||
hasIgnoreHeaders bool
|
||||
uppercaseIgnoreHeaders []string
|
||||
}
|
||||
|
||||
// 校验
|
||||
func (this *HeaderList) Init() error {
|
||||
this.hasResponseHeaders = len(this.Headers) > 0
|
||||
this.hasRequestHeaders = len(this.RequestHeaders) > 0
|
||||
|
||||
for _, h := range this.Headers {
|
||||
err := h.Validate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, h := range this.RequestHeaders {
|
||||
err := h.Validate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
this.hasIgnoreHeaders = len(this.IgnoreHeaders) > 0
|
||||
this.uppercaseIgnoreHeaders = []string{}
|
||||
for _, headerKey := range this.IgnoreHeaders {
|
||||
this.uppercaseIgnoreHeaders = append(this.uppercaseIgnoreHeaders, strings.ToUpper(headerKey))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 是否有Headers
|
||||
func (this *HeaderList) HasResponseHeaders() bool {
|
||||
return this.hasResponseHeaders
|
||||
}
|
||||
|
||||
// 取得所有的IgnoreHeader
|
||||
func (this *HeaderList) AllIgnoreResponseHeaders() []string {
|
||||
if this.IgnoreHeaders == nil {
|
||||
return []string{}
|
||||
}
|
||||
return this.IgnoreHeaders
|
||||
}
|
||||
|
||||
// 添加IgnoreHeader
|
||||
func (this *HeaderList) AddIgnoreResponseHeader(name string) {
|
||||
if !lists.ContainsString(this.IgnoreHeaders, name) {
|
||||
this.IgnoreHeaders = append(this.IgnoreHeaders, name)
|
||||
}
|
||||
}
|
||||
|
||||
// 判断是否包含IgnoreHeader
|
||||
func (this *HeaderList) ContainsIgnoreResponseHeader(name string) bool {
|
||||
if len(this.IgnoreHeaders) == 0 {
|
||||
return false
|
||||
}
|
||||
return lists.ContainsString(this.IgnoreHeaders, name)
|
||||
}
|
||||
|
||||
// 修改IgnoreHeader
|
||||
func (this *HeaderList) UpdateIgnoreResponseHeader(oldName string, newName string) {
|
||||
result := []string{}
|
||||
for _, h := range this.IgnoreHeaders {
|
||||
if h == oldName {
|
||||
result = append(result, newName)
|
||||
} else {
|
||||
result = append(result, h)
|
||||
}
|
||||
}
|
||||
this.IgnoreHeaders = result
|
||||
}
|
||||
|
||||
// 移除IgnoreHeader
|
||||
func (this *HeaderList) RemoveIgnoreResponseHeader(name string) {
|
||||
result := []string{}
|
||||
for _, n := range this.IgnoreHeaders {
|
||||
if n == name {
|
||||
continue
|
||||
}
|
||||
result = append(result, n)
|
||||
}
|
||||
this.IgnoreHeaders = result
|
||||
}
|
||||
|
||||
// 取得所有的Header
|
||||
func (this *HeaderList) AllResponseHeaders() []*HeaderConfig {
|
||||
if this.Headers == nil {
|
||||
return []*HeaderConfig{}
|
||||
}
|
||||
return this.Headers
|
||||
}
|
||||
|
||||
// 添加Header
|
||||
func (this *HeaderList) AddResponseHeader(header *HeaderConfig) {
|
||||
this.Headers = append(this.Headers, header)
|
||||
}
|
||||
|
||||
// 判断是否包含Header
|
||||
func (this *HeaderList) ContainsResponseHeader(name string) bool {
|
||||
for _, h := range this.Headers {
|
||||
if h.Name == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 查找Header
|
||||
func (this *HeaderList) FindResponseHeader(headerId string) *HeaderConfig {
|
||||
for _, h := range this.Headers {
|
||||
if h.Id == headerId {
|
||||
return h
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 移除Header
|
||||
func (this *HeaderList) RemoveResponseHeader(headerId string) {
|
||||
result := []*HeaderConfig{}
|
||||
for _, h := range this.Headers {
|
||||
if h.Id == headerId {
|
||||
continue
|
||||
}
|
||||
result = append(result, h)
|
||||
}
|
||||
this.Headers = result
|
||||
}
|
||||
|
||||
// 添加请求Header
|
||||
func (this *HeaderList) AddRequestHeader(header *HeaderConfig) {
|
||||
this.RequestHeaders = append(this.RequestHeaders, header)
|
||||
}
|
||||
|
||||
// 判断是否有请求Header
|
||||
func (this *HeaderList) HasRequestHeaders() bool {
|
||||
return this.hasRequestHeaders
|
||||
}
|
||||
|
||||
// 取得所有的请求Header
|
||||
func (this *HeaderList) AllRequestHeaders() []*HeaderConfig {
|
||||
return this.RequestHeaders
|
||||
}
|
||||
|
||||
// 查找请求Header
|
||||
func (this *HeaderList) FindRequestHeader(headerId string) *HeaderConfig {
|
||||
for _, h := range this.RequestHeaders {
|
||||
if h.Id == headerId {
|
||||
return h
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 移除请求Header
|
||||
func (this *HeaderList) RemoveRequestHeader(headerId string) {
|
||||
result := []*HeaderConfig{}
|
||||
for _, h := range this.RequestHeaders {
|
||||
if h.Id == headerId {
|
||||
continue
|
||||
}
|
||||
result = append(result, h)
|
||||
}
|
||||
this.RequestHeaders = result
|
||||
}
|
||||
|
||||
// 判断是否有Ignore Headers
|
||||
func (this *HeaderList) HasIgnoreHeaders() bool {
|
||||
return this.hasIgnoreHeaders
|
||||
}
|
||||
|
||||
// 查找大写的Ignore Headers
|
||||
func (this *HeaderList) UppercaseIgnoreHeaders() []string {
|
||||
return this.uppercaseIgnoreHeaders
|
||||
}
|
||||
23
internal/configs/serverconfigs/shared/header_list_test.go
Normal file
23
internal/configs/serverconfigs/shared/header_list_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHeaderList_FormatHeaders(t *testing.T) {
|
||||
list := &HeaderList{}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
list.AddRequestHeader(&HeaderConfig{
|
||||
IsOn: true,
|
||||
Name: "A" + fmt.Sprintf("%d", i),
|
||||
Value: "ABCDEFGHIJ${name}KLM${hello}NEFGHIJILKKKk",
|
||||
})
|
||||
}
|
||||
|
||||
err := list.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
28
internal/configs/serverconfigs/shared/header_test.go
Normal file
28
internal/configs/serverconfigs/shared/header_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHeaderConfig_Match(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
h := NewHeaderConfig()
|
||||
err := h.Validate()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a.IsFalse(h.Match(200))
|
||||
a.IsFalse(h.Match(400))
|
||||
|
||||
h.Status = []int{200, 201, 400}
|
||||
err = h.Validate()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a.IsTrue(h.Match(400))
|
||||
a.IsFalse(h.Match(500))
|
||||
|
||||
h.Always = true
|
||||
a.IsTrue(h.Match(500))
|
||||
}
|
||||
13
internal/configs/serverconfigs/shared/regexp.go
Normal file
13
internal/configs/serverconfigs/shared/regexp.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package shared
|
||||
|
||||
import "regexp"
|
||||
|
||||
// 常用的正则表达式
|
||||
var (
|
||||
RegexpDigitNumber = regexp.MustCompile(`^\d+$`) // 正整数
|
||||
RegexpFloatNumber = regexp.MustCompile(`^\d+(\.\d+)?$`) // 正浮点数,不支持e
|
||||
RegexpAllDigitNumber = regexp.MustCompile(`^[+-]?\d+$`) // 整数,支持正负数
|
||||
RegexpAllFloatNumber = regexp.MustCompile(`^[+-]?\d+(\.\d+)?$`) // 浮点数,支持正负数,不支持e
|
||||
RegexpExternalURL = regexp.MustCompile("(?i)^(http|https|ftp)://") // URL
|
||||
RegexpNamedVariable = regexp.MustCompile("\\${[\\w.-]+}") // 命名变量
|
||||
)
|
||||
17
internal/configs/serverconfigs/shared/regexp_test.go
Normal file
17
internal/configs/serverconfigs/shared/regexp_test.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRegexp(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
|
||||
a.IsTrue(RegexpFloatNumber.MatchString("123"))
|
||||
a.IsTrue(RegexpFloatNumber.MatchString("123.456"))
|
||||
a.IsFalse(RegexpFloatNumber.MatchString(".456"))
|
||||
a.IsFalse(RegexpFloatNumber.MatchString("abc"))
|
||||
a.IsFalse(RegexpFloatNumber.MatchString("123."))
|
||||
a.IsFalse(RegexpFloatNumber.MatchString("123.456e7"))
|
||||
}
|
||||
41
internal/configs/serverconfigs/shared/request_call.go
Normal file
41
internal/configs/serverconfigs/shared/request_call.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// 请求调用
|
||||
type RequestCall struct {
|
||||
Formatter func(source string) string
|
||||
Request *http.Request
|
||||
ResponseCallbacks []func(resp http.ResponseWriter)
|
||||
Options maps.Map
|
||||
}
|
||||
|
||||
// 获取新对象
|
||||
func NewRequestCall() *RequestCall {
|
||||
return &RequestCall{
|
||||
Options: maps.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
// 重置
|
||||
func (this *RequestCall) Reset() {
|
||||
this.Formatter = nil
|
||||
this.Request = nil
|
||||
this.ResponseCallbacks = nil
|
||||
this.Options = maps.Map{}
|
||||
}
|
||||
|
||||
// 添加响应回调
|
||||
func (this *RequestCall) AddResponseCall(callback func(resp http.ResponseWriter)) {
|
||||
this.ResponseCallbacks = append(this.ResponseCallbacks, callback)
|
||||
}
|
||||
|
||||
// 执行响应回调
|
||||
func (this *RequestCall) CallResponseCallbacks(resp http.ResponseWriter) {
|
||||
for _, callback := range this.ResponseCallbacks {
|
||||
callback(resp)
|
||||
}
|
||||
}
|
||||
372
internal/configs/serverconfigs/shared/request_cond.go
Normal file
372
internal/configs/serverconfigs/shared/request_cond.go
Normal file
@@ -0,0 +1,372 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/iwind/TeaGo/utils/string"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// 重写条件定义
|
||||
type RequestCond struct {
|
||||
Id string `yaml:"id" json:"id"` // ID
|
||||
|
||||
// 要测试的字符串
|
||||
// 其中可以使用跟请求相关的参数,比如:
|
||||
// ${arg.name}, ${requestPath}
|
||||
Param string `yaml:"param" json:"param"`
|
||||
|
||||
// 运算符
|
||||
Operator RequestCondOperator `yaml:"operator" json:"operator"`
|
||||
|
||||
// 对比
|
||||
Value string `yaml:"value" json:"value"`
|
||||
|
||||
isInt bool
|
||||
isFloat bool
|
||||
isIP bool
|
||||
|
||||
regValue *regexp.Regexp
|
||||
floatValue float64
|
||||
ipValue net.IP
|
||||
arrayValue []string
|
||||
}
|
||||
|
||||
// 取得新对象
|
||||
func NewRequestCond() *RequestCond {
|
||||
return &RequestCond{
|
||||
Id: stringutil.Rand(16),
|
||||
}
|
||||
}
|
||||
|
||||
// 校验配置
|
||||
func (this *RequestCond) Validate() error {
|
||||
this.isInt = RegexpDigitNumber.MatchString(this.Value)
|
||||
this.isFloat = RegexpFloatNumber.MatchString(this.Value)
|
||||
|
||||
if lists.ContainsString([]string{
|
||||
RequestCondOperatorRegexp,
|
||||
RequestCondOperatorNotRegexp,
|
||||
}, this.Operator) {
|
||||
reg, err := regexp.Compile(this.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.regValue = reg
|
||||
} else if lists.ContainsString([]string{
|
||||
RequestCondOperatorEqFloat,
|
||||
RequestCondOperatorGtFloat,
|
||||
RequestCondOperatorGteFloat,
|
||||
RequestCondOperatorLtFloat,
|
||||
RequestCondOperatorLteFloat,
|
||||
}, this.Operator) {
|
||||
this.floatValue = types.Float64(this.Value)
|
||||
} else if lists.ContainsString([]string{
|
||||
RequestCondOperatorEqIP,
|
||||
RequestCondOperatorGtIP,
|
||||
RequestCondOperatorGteIP,
|
||||
RequestCondOperatorLtIP,
|
||||
RequestCondOperatorLteIP,
|
||||
}, this.Operator) {
|
||||
this.ipValue = net.ParseIP(this.Value)
|
||||
this.isIP = this.ipValue != nil
|
||||
|
||||
if !this.isIP {
|
||||
return errors.New("value should be a valid ip")
|
||||
}
|
||||
} else if lists.ContainsString([]string{
|
||||
RequestCondOperatorIPRange,
|
||||
}, this.Operator) {
|
||||
if strings.Contains(this.Value, ",") {
|
||||
ipList := strings.SplitN(this.Value, ",", 2)
|
||||
ipString1 := strings.TrimSpace(ipList[0])
|
||||
ipString2 := strings.TrimSpace(ipList[1])
|
||||
|
||||
if len(ipString1) > 0 {
|
||||
ip1 := net.ParseIP(ipString1)
|
||||
if ip1 == nil {
|
||||
return errors.New("start ip is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
if len(ipString2) > 0 {
|
||||
ip2 := net.ParseIP(ipString2)
|
||||
if ip2 == nil {
|
||||
return errors.New("end ip is invalid")
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(this.Value, "/") {
|
||||
_, _, err := net.ParseCIDR(this.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return errors.New("invalid ip range")
|
||||
}
|
||||
} else if lists.ContainsString([]string{
|
||||
RequestCondOperatorIn,
|
||||
RequestCondOperatorNotIn,
|
||||
RequestCondOperatorFileExt,
|
||||
}, this.Operator) {
|
||||
stringsValue := []string{}
|
||||
err := json.Unmarshal([]byte(this.Value), &stringsValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.arrayValue = stringsValue
|
||||
} else if lists.ContainsString([]string{
|
||||
RequestCondOperatorFileMimeType,
|
||||
}, this.Operator) {
|
||||
stringsValue := []string{}
|
||||
err := json.Unmarshal([]byte(this.Value), &stringsValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for k, v := range stringsValue {
|
||||
if strings.Contains(v, "*") {
|
||||
v = regexp.QuoteMeta(v)
|
||||
v = strings.Replace(v, `\*`, ".*", -1)
|
||||
stringsValue[k] = v
|
||||
}
|
||||
}
|
||||
this.arrayValue = stringsValue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 将此条件应用于请求,检查是否匹配
|
||||
func (this *RequestCond) Match(formatter func(source string) string) bool {
|
||||
paramValue := formatter(this.Param)
|
||||
switch this.Operator {
|
||||
case RequestCondOperatorRegexp:
|
||||
if this.regValue == nil {
|
||||
return false
|
||||
}
|
||||
return this.regValue.MatchString(paramValue)
|
||||
case RequestCondOperatorNotRegexp:
|
||||
if this.regValue == nil {
|
||||
return false
|
||||
}
|
||||
return !this.regValue.MatchString(paramValue)
|
||||
case RequestCondOperatorEqInt:
|
||||
return this.isInt && paramValue == this.Value
|
||||
case RequestCondOperatorEqFloat:
|
||||
return this.isFloat && types.Float64(paramValue) == this.floatValue
|
||||
case RequestCondOperatorGtFloat:
|
||||
return this.isFloat && types.Float64(paramValue) > this.floatValue
|
||||
case RequestCondOperatorGteFloat:
|
||||
return this.isFloat && types.Float64(paramValue) >= this.floatValue
|
||||
case RequestCondOperatorLtFloat:
|
||||
return this.isFloat && types.Float64(paramValue) < this.floatValue
|
||||
case RequestCondOperatorLteFloat:
|
||||
return this.isFloat && types.Float64(paramValue) <= this.floatValue
|
||||
case RequestCondOperatorMod:
|
||||
pieces := strings.SplitN(this.Value, ",", 2)
|
||||
if len(pieces) == 1 {
|
||||
rem := types.Int64(pieces[0])
|
||||
return types.Int64(paramValue)%10 == rem
|
||||
}
|
||||
div := types.Int64(pieces[0])
|
||||
if div == 0 {
|
||||
return false
|
||||
}
|
||||
rem := types.Int64(pieces[1])
|
||||
return types.Int64(paramValue)%div == rem
|
||||
case RequestCondOperatorMod10:
|
||||
return types.Int64(paramValue)%10 == types.Int64(this.Value)
|
||||
case RequestCondOperatorMod100:
|
||||
return types.Int64(paramValue)%100 == types.Int64(this.Value)
|
||||
case RequestCondOperatorEqString:
|
||||
return paramValue == this.Value
|
||||
case RequestCondOperatorNeqString:
|
||||
return paramValue != this.Value
|
||||
case RequestCondOperatorHasPrefix:
|
||||
return strings.HasPrefix(paramValue, this.Value)
|
||||
case RequestCondOperatorHasSuffix:
|
||||
return strings.HasSuffix(paramValue, this.Value)
|
||||
case RequestCondOperatorContainsString:
|
||||
return strings.Contains(paramValue, this.Value)
|
||||
case RequestCondOperatorNotContainsString:
|
||||
return !strings.Contains(paramValue, this.Value)
|
||||
case RequestCondOperatorEqIP:
|
||||
ip := net.ParseIP(paramValue)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return this.isIP && bytes.Compare(this.ipValue, ip) == 0
|
||||
case RequestCondOperatorGtIP:
|
||||
ip := net.ParseIP(paramValue)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return this.isIP && bytes.Compare(ip, this.ipValue) > 0
|
||||
case RequestCondOperatorGteIP:
|
||||
ip := net.ParseIP(paramValue)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return this.isIP && bytes.Compare(ip, this.ipValue) >= 0
|
||||
case RequestCondOperatorLtIP:
|
||||
ip := net.ParseIP(paramValue)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return this.isIP && bytes.Compare(ip, this.ipValue) < 0
|
||||
case RequestCondOperatorLteIP:
|
||||
ip := net.ParseIP(paramValue)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return this.isIP && bytes.Compare(ip, this.ipValue) <= 0
|
||||
case RequestCondOperatorIPRange:
|
||||
ip := net.ParseIP(paramValue)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查IP范围格式
|
||||
if strings.Contains(this.Value, ",") {
|
||||
ipList := strings.SplitN(this.Value, ",", 2)
|
||||
ipString1 := strings.TrimSpace(ipList[0])
|
||||
ipString2 := strings.TrimSpace(ipList[1])
|
||||
|
||||
if len(ipString1) > 0 {
|
||||
ip1 := net.ParseIP(ipString1)
|
||||
if ip1 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if bytes.Compare(ip, ip1) < 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if len(ipString2) > 0 {
|
||||
ip2 := net.ParseIP(ipString2)
|
||||
if ip2 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if bytes.Compare(ip, ip2) > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
} else if strings.Contains(this.Value, "/") {
|
||||
_, ipNet, err := net.ParseCIDR(this.Value)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return ipNet.Contains(ip)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
case RequestCondOperatorIn:
|
||||
return lists.ContainsString(this.arrayValue, paramValue)
|
||||
case RequestCondOperatorNotIn:
|
||||
return !lists.ContainsString(this.arrayValue, paramValue)
|
||||
case RequestCondOperatorFileExt:
|
||||
ext := filepath.Ext(paramValue)
|
||||
if len(ext) > 0 {
|
||||
ext = ext[1:] // remove dot
|
||||
}
|
||||
return lists.ContainsString(this.arrayValue, strings.ToLower(ext))
|
||||
case RequestCondOperatorFileMimeType:
|
||||
index := strings.Index(paramValue, ";")
|
||||
if index >= 0 {
|
||||
paramValue = strings.TrimSpace(paramValue[:index])
|
||||
}
|
||||
if len(this.arrayValue) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, v := range this.arrayValue {
|
||||
if strings.Contains(v, "*") {
|
||||
reg, err := stringutil.RegexpCompile("^" + v + "$")
|
||||
if err == nil && reg.MatchString(paramValue) {
|
||||
return true
|
||||
}
|
||||
} else if paramValue == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case RequestCondOperatorVersionRange:
|
||||
if strings.Contains(this.Value, ",") {
|
||||
versions := strings.SplitN(this.Value, ",", 2)
|
||||
version1 := strings.TrimSpace(versions[0])
|
||||
version2 := strings.TrimSpace(versions[1])
|
||||
if len(version1) > 0 && stringutil.VersionCompare(paramValue, version1) < 0 {
|
||||
return false
|
||||
}
|
||||
if len(version2) > 0 && stringutil.VersionCompare(paramValue, version2) > 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
} else {
|
||||
return stringutil.VersionCompare(paramValue, this.Value) >= 0
|
||||
}
|
||||
case RequestCondOperatorIPMod:
|
||||
pieces := strings.SplitN(this.Value, ",", 2)
|
||||
if len(pieces) == 1 {
|
||||
rem := types.Int64(pieces[0])
|
||||
return this.ipToInt64(net.ParseIP(paramValue))%10 == rem
|
||||
}
|
||||
div := types.Int64(pieces[0])
|
||||
if div == 0 {
|
||||
return false
|
||||
}
|
||||
rem := types.Int64(pieces[1])
|
||||
return this.ipToInt64(net.ParseIP(paramValue))%div == rem
|
||||
case RequestCondOperatorIPMod10:
|
||||
return this.ipToInt64(net.ParseIP(paramValue))%10 == types.Int64(this.Value)
|
||||
case RequestCondOperatorIPMod100:
|
||||
return this.ipToInt64(net.ParseIP(paramValue))%100 == types.Int64(this.Value)
|
||||
case RequestCondOperatorFileExist:
|
||||
index := strings.Index(paramValue, "?")
|
||||
if index > -1 {
|
||||
paramValue = paramValue[:index]
|
||||
}
|
||||
if len(paramValue) == 0 {
|
||||
return false
|
||||
}
|
||||
if !filepath.IsAbs(paramValue) {
|
||||
paramValue = Tea.Root + Tea.DS + paramValue
|
||||
}
|
||||
stat, err := os.Stat(paramValue)
|
||||
return err == nil && !stat.IsDir()
|
||||
case RequestCondOperatorFileNotExist:
|
||||
index := strings.Index(paramValue, "?")
|
||||
if index > -1 {
|
||||
paramValue = paramValue[:index]
|
||||
}
|
||||
if len(paramValue) == 0 {
|
||||
return true
|
||||
}
|
||||
if !filepath.IsAbs(paramValue) {
|
||||
paramValue = Tea.Root + Tea.DS + paramValue
|
||||
}
|
||||
stat, err := os.Stat(paramValue)
|
||||
return err != nil || stat.IsDir()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *RequestCond) ipToInt64(ip net.IP) int64 {
|
||||
if len(ip) == 0 {
|
||||
return 0
|
||||
}
|
||||
if len(ip) == 16 {
|
||||
return int64(binary.BigEndian.Uint32(ip[12:16]))
|
||||
}
|
||||
return int64(binary.BigEndian.Uint32(ip))
|
||||
}
|
||||
1000
internal/configs/serverconfigs/shared/request_cond_test.go
Normal file
1000
internal/configs/serverconfigs/shared/request_cond_test.go
Normal file
File diff suppressed because it is too large
Load Diff
226
internal/configs/serverconfigs/shared/request_operators.go
Normal file
226
internal/configs/serverconfigs/shared/request_operators.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package shared
|
||||
|
||||
import "github.com/iwind/TeaGo/maps"
|
||||
|
||||
// 运算符定义
|
||||
type RequestCondOperator = string
|
||||
|
||||
const (
|
||||
// 正则
|
||||
RequestCondOperatorRegexp RequestCondOperator = "regexp"
|
||||
RequestCondOperatorNotRegexp RequestCondOperator = "not regexp"
|
||||
|
||||
// 数字相关
|
||||
RequestCondOperatorEqInt RequestCondOperator = "eq int" // 整数等于
|
||||
RequestCondOperatorEqFloat RequestCondOperator = "eq float" // 浮点数等于
|
||||
RequestCondOperatorGtFloat RequestCondOperator = "gt"
|
||||
RequestCondOperatorGteFloat RequestCondOperator = "gte"
|
||||
RequestCondOperatorLtFloat RequestCondOperator = "lt"
|
||||
RequestCondOperatorLteFloat RequestCondOperator = "lte"
|
||||
|
||||
// 取模
|
||||
RequestCondOperatorMod10 RequestCondOperator = "mod 10"
|
||||
RequestCondOperatorMod100 RequestCondOperator = "mod 100"
|
||||
RequestCondOperatorMod RequestCondOperator = "mod"
|
||||
|
||||
// 字符串相关
|
||||
RequestCondOperatorEqString RequestCondOperator = "eq"
|
||||
RequestCondOperatorNeqString RequestCondOperator = "not"
|
||||
RequestCondOperatorHasPrefix RequestCondOperator = "prefix"
|
||||
RequestCondOperatorHasSuffix RequestCondOperator = "suffix"
|
||||
RequestCondOperatorContainsString RequestCondOperator = "contains"
|
||||
RequestCondOperatorNotContainsString RequestCondOperator = "not contains"
|
||||
RequestCondOperatorIn RequestCondOperator = "in"
|
||||
RequestCondOperatorNotIn RequestCondOperator = "not in"
|
||||
RequestCondOperatorFileExt RequestCondOperator = "file ext"
|
||||
RequestCondOperatorFileMimeType RequestCondOperator = "mime type"
|
||||
RequestCondOperatorVersionRange RequestCondOperator = "version range"
|
||||
|
||||
// IP相关
|
||||
RequestCondOperatorEqIP RequestCondOperator = "eq ip"
|
||||
RequestCondOperatorGtIP RequestCondOperator = "gt ip"
|
||||
RequestCondOperatorGteIP RequestCondOperator = "gte ip"
|
||||
RequestCondOperatorLtIP RequestCondOperator = "lt ip"
|
||||
RequestCondOperatorLteIP RequestCondOperator = "lte ip"
|
||||
RequestCondOperatorIPRange RequestCondOperator = "ip range"
|
||||
RequestCondOperatorIPMod10 RequestCondOperator = "ip mod 10"
|
||||
RequestCondOperatorIPMod100 RequestCondOperator = "ip mod 100"
|
||||
RequestCondOperatorIPMod RequestCondOperator = "ip mod"
|
||||
|
||||
// 文件相关
|
||||
RequestCondOperatorFileExist RequestCondOperator = "file exist"
|
||||
RequestCondOperatorFileNotExist RequestCondOperator = "file not exist"
|
||||
)
|
||||
|
||||
// 所有的运算符
|
||||
func AllRequestOperators() []maps.Map {
|
||||
return []maps.Map{
|
||||
{
|
||||
"name": "正则表达式匹配",
|
||||
"op": RequestCondOperatorRegexp,
|
||||
"description": "判断是否正则表达式匹配",
|
||||
},
|
||||
{
|
||||
"name": "正则表达式不匹配",
|
||||
"op": RequestCondOperatorNotRegexp,
|
||||
"description": "判断是否正则表达式不匹配",
|
||||
},
|
||||
{
|
||||
"name": "字符串等于",
|
||||
"op": RequestCondOperatorEqString,
|
||||
"description": "使用字符串对比参数值是否相等于某个值",
|
||||
},
|
||||
{
|
||||
"name": "字符串前缀",
|
||||
"op": RequestCondOperatorHasPrefix,
|
||||
"description": "参数值包含某个前缀",
|
||||
},
|
||||
{
|
||||
"name": "字符串后缀",
|
||||
"op": RequestCondOperatorHasSuffix,
|
||||
"description": "参数值包含某个后缀",
|
||||
},
|
||||
{
|
||||
"name": "字符串包含",
|
||||
"op": RequestCondOperatorContainsString,
|
||||
"description": "参数值包含另外一个字符串",
|
||||
},
|
||||
{
|
||||
"name": "字符串不包含",
|
||||
"op": RequestCondOperatorNotContainsString,
|
||||
"description": "参数值不包含另外一个字符串",
|
||||
},
|
||||
{
|
||||
"name": "字符串不等于",
|
||||
"op": RequestCondOperatorNeqString,
|
||||
"description": "使用字符串对比参数值是否不相等于某个值",
|
||||
},
|
||||
{
|
||||
"name": "在列表中",
|
||||
"op": RequestCondOperatorIn,
|
||||
"description": "判断参数值在某个列表中",
|
||||
},
|
||||
{
|
||||
"name": "不在列表中",
|
||||
"op": RequestCondOperatorNotIn,
|
||||
"description": "判断参数值不在某个列表中",
|
||||
},
|
||||
{
|
||||
"name": "扩展名",
|
||||
"op": RequestCondOperatorFileExt,
|
||||
"description": "判断小写的扩展名(不带点)在某个列表中",
|
||||
},
|
||||
{
|
||||
"name": "MimeType",
|
||||
"op": RequestCondOperatorFileMimeType,
|
||||
"description": "判断MimeType在某个列表中,支持类似于image/*的语法",
|
||||
},
|
||||
{
|
||||
"name": "版本号范围",
|
||||
"op": RequestCondOperatorVersionRange,
|
||||
"description": "判断版本号在某个范围内,格式为version1,version2",
|
||||
},
|
||||
{
|
||||
"name": "整数等于",
|
||||
"op": RequestCondOperatorEqInt,
|
||||
"description": "将参数转换为整数数字后进行对比",
|
||||
},
|
||||
{
|
||||
"name": "浮点数等于",
|
||||
"op": RequestCondOperatorEqFloat,
|
||||
"description": "将参数转换为可以有小数的浮点数字进行对比",
|
||||
},
|
||||
{
|
||||
"name": "数字大于",
|
||||
"op": RequestCondOperatorGtFloat,
|
||||
"description": "将参数转换为数字进行对比",
|
||||
},
|
||||
{
|
||||
"name": "数字大于等于",
|
||||
"op": RequestCondOperatorGteFloat,
|
||||
"description": "将参数转换为数字进行对比",
|
||||
},
|
||||
{
|
||||
"name": "数字小于",
|
||||
"op": RequestCondOperatorLtFloat,
|
||||
"description": "将参数转换为数字进行对比",
|
||||
},
|
||||
{
|
||||
"name": "数字小于等于",
|
||||
"op": RequestCondOperatorLteFloat,
|
||||
"description": "将参数转换为数字进行对比",
|
||||
},
|
||||
{
|
||||
"name": "整数取模10",
|
||||
"op": RequestCondOperatorMod10,
|
||||
"description": "对整数参数值取模,除数为10,对比值为余数",
|
||||
},
|
||||
{
|
||||
"name": "整数取模100",
|
||||
"op": RequestCondOperatorMod100,
|
||||
"description": "对整数参数值取模,除数为100,对比值为余数",
|
||||
},
|
||||
{
|
||||
"name": "整数取模",
|
||||
"op": RequestCondOperatorMod,
|
||||
"description": "对整数参数值取模,对比值格式为:除数,余数,比如10,1",
|
||||
},
|
||||
{
|
||||
"name": "IP等于",
|
||||
"op": RequestCondOperatorEqIP,
|
||||
"description": "将参数转换为IP进行对比",
|
||||
},
|
||||
{
|
||||
"name": "IP大于",
|
||||
"op": RequestCondOperatorGtIP,
|
||||
"description": "将参数转换为IP进行对比",
|
||||
},
|
||||
{
|
||||
"name": "IP大于等于",
|
||||
"op": RequestCondOperatorGteIP,
|
||||
"description": "将参数转换为IP进行对比",
|
||||
},
|
||||
{
|
||||
"name": "IP小于",
|
||||
"op": RequestCondOperatorLtIP,
|
||||
"description": "将参数转换为IP进行对比",
|
||||
},
|
||||
{
|
||||
"name": "IP小于等于",
|
||||
"op": RequestCondOperatorLteIP,
|
||||
"description": "将参数转换为IP进行对比",
|
||||
},
|
||||
{
|
||||
"name": "IP范围",
|
||||
"op": RequestCondOperatorIPRange,
|
||||
"description": "IP在某个范围之内,范围格式可以是英文逗号分隔的ip1,ip2,或者CIDR格式的ip/bits",
|
||||
},
|
||||
{
|
||||
"name": "IP取模10",
|
||||
"op": RequestCondOperatorIPMod10,
|
||||
"description": "对IP参数值取模,除数为10,对比值为余数",
|
||||
},
|
||||
{
|
||||
"name": "IP取模100",
|
||||
"op": RequestCondOperatorIPMod100,
|
||||
"description": "对IP参数值取模,除数为100,对比值为余数",
|
||||
},
|
||||
{
|
||||
"name": "IP取模",
|
||||
"op": RequestCondOperatorIPMod,
|
||||
"description": "对IP参数值取模,对比值格式为:除数,余数,比如10,1",
|
||||
},
|
||||
|
||||
{
|
||||
"name": "文件存在",
|
||||
"op": RequestCondOperatorFileExist,
|
||||
"description": "判断参数值解析后的文件是否存在",
|
||||
},
|
||||
|
||||
{
|
||||
"name": "文件不存在",
|
||||
"op": RequestCondOperatorFileNotExist,
|
||||
"description": "判断参数值解析后的文件是否不存在",
|
||||
},
|
||||
}
|
||||
}
|
||||
30
internal/configs/serverconfigs/shared/size_capacity.go
Normal file
30
internal/configs/serverconfigs/shared/size_capacity.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package shared
|
||||
|
||||
type SizeCapacityUnit = string
|
||||
|
||||
const (
|
||||
SizeCapacityUnitByte SizeCapacityUnit = "byte"
|
||||
SizeCapacityUnitKB SizeCapacityUnit = "kb"
|
||||
SizeCapacityUnitMB SizeCapacityUnit = "mb"
|
||||
SizeCapacityUnitGB SizeCapacityUnit = "gb"
|
||||
)
|
||||
|
||||
type SizeCapacity struct {
|
||||
Count int64 `json:"count" yaml:"count"`
|
||||
Unit SizeCapacityUnit `json:"unit" yaml:"unit"`
|
||||
}
|
||||
|
||||
func (this *SizeCapacity) Bytes() int64 {
|
||||
switch this.Unit {
|
||||
case SizeCapacityUnitByte:
|
||||
return this.Count
|
||||
case SizeCapacityUnitKB:
|
||||
return this.Count * 1024
|
||||
case SizeCapacityUnitMB:
|
||||
return this.Count * 1024 * 1024
|
||||
case SizeCapacityUnitGB:
|
||||
return this.Count * 1024 * 1024 * 1024
|
||||
default:
|
||||
return this.Count
|
||||
}
|
||||
}
|
||||
36
internal/configs/serverconfigs/shared/time_duration.go
Normal file
36
internal/configs/serverconfigs/shared/time_duration.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package shared
|
||||
|
||||
import "time"
|
||||
|
||||
type TimeDurationUnit = string
|
||||
|
||||
const (
|
||||
TimeDurationUnitMS TimeDurationUnit = "ms"
|
||||
TimeDurationUnitSecond TimeDurationUnit = "second"
|
||||
TimeDurationUnitMinute TimeDurationUnit = "minute"
|
||||
TimeDurationUnitHour TimeDurationUnit = "hour"
|
||||
TimeDurationUnitDay TimeDurationUnit = "day"
|
||||
)
|
||||
|
||||
// 时间间隔
|
||||
type TimeDuration struct {
|
||||
Count int64 `yaml:"count" json:"count"` // 数量
|
||||
Unit TimeDurationUnit `yaml:"unit" json:"unit"` // 单位
|
||||
}
|
||||
|
||||
func (this *TimeDuration) Duration() time.Duration {
|
||||
switch this.Unit {
|
||||
case TimeDurationUnitMS:
|
||||
return time.Duration(this.Count) * time.Millisecond
|
||||
case TimeDurationUnitSecond:
|
||||
return time.Duration(this.Count) * time.Second
|
||||
case TimeDurationUnitMinute:
|
||||
return time.Duration(this.Count) * time.Minute
|
||||
case TimeDurationUnitHour:
|
||||
return time.Duration(this.Count) * time.Hour
|
||||
case TimeDurationUnitDay:
|
||||
return time.Duration(this.Count) * 24 * time.Hour
|
||||
default:
|
||||
return time.Duration(this.Count) * time.Second
|
||||
}
|
||||
}
|
||||
@@ -97,7 +97,7 @@ func (this *SSLConfig) Init() error {
|
||||
if cert == nil {
|
||||
continue
|
||||
}
|
||||
if !cert.On {
|
||||
if !cert.IsOn {
|
||||
continue
|
||||
}
|
||||
data, err := ioutil.ReadFile(cert.FullCertPath())
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
// SSL证书
|
||||
type SSLCertConfig struct {
|
||||
Id string `yaml:"id" json:"id"`
|
||||
On bool `yaml:"on" json:"on"`
|
||||
IsOn bool `yaml:"isOn" json:"isOn"`
|
||||
Description string `yaml:"description" json:"description"` // 说明
|
||||
CertFile string `yaml:"certFile" json:"certFile"`
|
||||
KeyFile string `yaml:"keyFile" json:"keyFile"`
|
||||
@@ -39,7 +39,7 @@ type SSLCertConfig struct {
|
||||
// 获取新的SSL证书
|
||||
func NewSSLCertConfig(certFile string, keyFile string) *SSLCertConfig {
|
||||
return &SSLCertConfig{
|
||||
On: true,
|
||||
IsOn: true,
|
||||
Id: stringutil.Rand(16),
|
||||
CertFile: certFile,
|
||||
KeyFile: keyFile,
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
// HSTS设置
|
||||
// 参考: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Strict-Transport-Security
|
||||
type HSTSConfig struct {
|
||||
On bool `yaml:"on" json:"on"`
|
||||
IsOn bool `yaml:"isOn" json:"isOn"`
|
||||
MaxAge int `yaml:"maxAge" json:"maxAge"` // 单位秒
|
||||
IncludeSubDomains bool `yaml:"includeSubDomains" json:"includeSubDomains"`
|
||||
Preload bool `yaml:"preload" json:"preload"`
|
||||
|
||||
@@ -59,8 +59,10 @@ func (this *HTTPListener) Serve() error {
|
||||
}
|
||||
|
||||
func (this *HTTPListener) Close() error {
|
||||
// TODO
|
||||
return nil
|
||||
if this.httpServer != nil {
|
||||
_ = this.httpServer.Close()
|
||||
}
|
||||
return this.Listener.Close()
|
||||
}
|
||||
|
||||
func (this *HTTPListener) handleHTTP(writer http.ResponseWriter, req *http.Request) {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"net"
|
||||
)
|
||||
|
||||
@@ -13,11 +15,99 @@ type TCPListener struct {
|
||||
}
|
||||
|
||||
func (this *TCPListener) Serve() error {
|
||||
// TODO
|
||||
for {
|
||||
conn, err := this.Listener.Accept()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = this.handleConn(conn)
|
||||
if err != nil {
|
||||
logs.Println("[TCP_LISTENER]" + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
firstServer := this.Group.FirstServer()
|
||||
if firstServer == nil {
|
||||
return errors.New("no server available")
|
||||
}
|
||||
if firstServer.ReverseProxy == nil {
|
||||
return errors.New("no ReverseProxy configured for the server")
|
||||
}
|
||||
originConn, err := this.connectOrigin(firstServer.ReverseProxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var closer = func() {
|
||||
_ = conn.Close()
|
||||
_ = originConn.Close()
|
||||
}
|
||||
|
||||
go func() {
|
||||
originBuffer := make([]byte, 32*1024) // TODO 需要可以设置,并可以使用Pool
|
||||
for {
|
||||
n, err := originConn.Read(originBuffer)
|
||||
if n > 0 {
|
||||
_, err = conn.Write(originBuffer[:n])
|
||||
if err != nil {
|
||||
closer()
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
closer()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
clientBuffer := make([]byte, 32*1024) // TODO 需要可以设置,并可以使用Pool
|
||||
for {
|
||||
n, err := conn.Read(clientBuffer)
|
||||
if n > 0 {
|
||||
_, err = originConn.Write(clientBuffer[:n])
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭连接
|
||||
closer()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *TCPListener) Close() error {
|
||||
// TODO
|
||||
return nil
|
||||
return this.Listener.Close()
|
||||
}
|
||||
|
||||
func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig) (conn net.Conn, err error) {
|
||||
if reverseProxy == nil {
|
||||
return nil, errors.New("no reverse proxy config")
|
||||
}
|
||||
|
||||
retries := 3
|
||||
for i := 0; i < retries; i++ {
|
||||
origin := reverseProxy.NextOrigin(nil)
|
||||
if origin == nil {
|
||||
continue
|
||||
}
|
||||
conn, err = origin.Connect()
|
||||
if err != nil {
|
||||
logs.Println("[TCP_LISTENER]unable to connect origin: " + origin.Addr.Host + ":" + origin.Addr.PortRange + ": " + err.Error())
|
||||
continue
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
err = errors.New("no origin can be used")
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user