diff --git a/internal/configs/serverconfigs/cache_policy.go b/internal/configs/serverconfigs/cache_policy.go new file mode 100644 index 0000000..02bc2a7 --- /dev/null +++ b/internal/configs/serverconfigs/cache_policy.go @@ -0,0 +1,155 @@ +package serverconfigs + +import ( + "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/configutils" + "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/files" + "github.com/iwind/TeaGo/lists" + "github.com/iwind/TeaGo/logs" + "strconv" + "strings" + "time" +) + +var DefaultSkippedResponseCacheControlValues = []string{"private", "no-cache", "no-store"} + +// 缓存策略配置 +type CachePolicy struct { + Id int `yaml:"id" json:"id"` + IsOn bool `yaml:"isOn" json:"isOn"` // 是否开启 TODO + Name string `yaml:"name" json:"name"` // 名称 + + Key string `yaml:"key" json:"key"` // 每个缓存的Key规则,里面可以有变量 + Capacity shared.SizeCapacity `yaml:"capacity" json:"capacity"` // 最大内容容量 + Life shared.TimeDuration `yaml:"life" json:"life"` // 时间 + Status []int `yaml:"status" json:"status"` // 缓存的状态码列表 + MaxSize shared.SizeCapacity `yaml:"maxSize" json:"maxSize"` // 能够请求的最大尺寸 + + SkipResponseCacheControlValues []string `yaml:"skipCacheControlValues" json:"skipCacheControlValues"` // 可以跳过的响应的Cache-Control值 + SkipResponseSetCookie bool `yaml:"skipSetCookie" json:"skipSetCookie"` // 是否跳过响应的Set-Cookie Header + EnableRequestCachePragma bool `yaml:"enableRequestCachePragma" json:"enableRequestCachePragma"` // 是否支持客户端的Pragma: no-cache + + Cond []*shared.RequestCond `yaml:"cond" json:"cond"` + + life time.Duration + maxSize int64 + capacity int64 + + uppercaseSkipCacheControlValues []string + + Type string `yaml:"type" json:"type"` // 类型 + Options map[string]interface{} `yaml:"options" json:"options"` // 选项 +} + +// 获取新对象 +func NewCachePolicy() *CachePolicy { + return &CachePolicy{ + SkipResponseCacheControlValues: DefaultSkippedResponseCacheControlValues, + SkipResponseSetCookie: true, + } +} + +// 从文件中读取缓存策略 +func NewCachePolicyFromFile(file string) *CachePolicy { + if len(file) == 0 { + return nil + } + reader, err := files.NewReader(Tea.ConfigFile(file)) + if err != nil { + logs.Error(err) + return nil + } + defer func() { + _ = reader.Close() + }() + + p := NewCachePolicy() + err = reader.ReadYAML(p) + if err != nil { + logs.Error(err) + return nil + } + + return p +} + +// 校验 +func (this *CachePolicy) Validate() error { + var err error + this.maxSize = this.MaxSize.Bytes() + this.life = this.Life.Duration() + this.capacity = this.Capacity.Bytes() + + this.uppercaseSkipCacheControlValues = []string{} + for _, value := range this.SkipResponseCacheControlValues { + this.uppercaseSkipCacheControlValues = append(this.uppercaseSkipCacheControlValues, strings.ToUpper(value)) + } + + // cond + if len(this.Cond) > 0 { + for _, cond := range this.Cond { + err := cond.Validate() + if err != nil { + return err + } + } + } + + return err +} + +// 最大数据尺寸 +func (this *CachePolicy) MaxDataSize() int64 { + return this.maxSize +} + +// 容量 +func (this *CachePolicy) CapacitySize() int64 { + return this.capacity +} + +// 生命周期 +func (this *CachePolicy) LifeDuration() time.Duration { + return this.life +} + +// 保存 +func (this *CachePolicy) Save() error { + shared.Locker.Lock() + defer shared.Locker.Unlock() + + filename := "cache.policy." + strconv.Itoa(this.Id) + ".conf" + writer, err := files.NewWriter(Tea.ConfigFile(filename)) + if err != nil { + return err + } + defer func() { + _ = writer.Close() + }() + _, err = writer.WriteYAML(this) + return err +} + +// 删除 +func (this *CachePolicy) Delete() error { + filename := "cache.policy." + strconv.Itoa(this.Id) + ".conf" + return files.NewFile(Tea.ConfigFile(filename)).Delete() +} + +// 是否包含某个Cache-Control值 +func (this *CachePolicy) ContainsCacheControl(value string) bool { + return lists.ContainsString(this.uppercaseSkipCacheControlValues, strings.ToUpper(value)) +} + +// 检查是否匹配关键词 +func (this *CachePolicy) MatchKeyword(keyword string) (matched bool, name string, tags []string) { + if configutils.MatchKeyword(this.Name, keyword) || configutils.MatchKeyword(this.Type, keyword) { + matched = true + name = this.Name + if len(this.Type) > 0 { + tags = []string{"类型:" + this.Type} + } + } + return +} diff --git a/internal/configs/serverconfigs/global_config.go b/internal/configs/serverconfigs/global_config.go index 63dc101..481d4a9 100644 --- a/internal/configs/serverconfigs/global_config.go +++ b/internal/configs/serverconfigs/global_config.go @@ -11,15 +11,15 @@ var globalConfigFile = "global.yaml" // 全局设置 type GlobalConfig struct { HTTPAll struct { - MatchDomainStrictly bool `yaml:"matchDomainStrictly"` - } `yaml:"httpAll"` - HTTP struct{} `yaml:"http"` - HTTPS struct{} `yaml:"https"` - TCPAll struct{} `yaml:"tcpAll"` - TCP struct{} `yaml:"tcp"` - TLS struct{} `yaml:"tls"` - Unix struct{} `yaml:"unix"` - UDP struct{} `yaml:"udp"` + MatchDomainStrictly bool `yaml:"matchDomainStrictly" json:"matchDomainStrictly"` + } `yaml:"httpAll" json:"httpAll"` + HTTP struct{} `yaml:"http" json:"http"` + HTTPS struct{} `yaml:"https" json:"https"` + TCPAll struct{} `yaml:"tcpAll" json:"tcpAll"` + TCP struct{} `yaml:"tcp" json:"tcp"` + TLS struct{} `yaml:"tls" json:"tls"` + Unix struct{} `yaml:"unix" json:"unix"` + UDP struct{} `yaml:"udp" json:"udp"` } func SharedGlobalConfig() *GlobalConfig { diff --git a/internal/configs/serverconfigs/origin_server_config.go b/internal/configs/serverconfigs/origin_server_config.go index 3b03966..56d1630 100644 --- a/internal/configs/serverconfigs/origin_server_config.go +++ b/internal/configs/serverconfigs/origin_server_config.go @@ -1,10 +1,181 @@ package serverconfigs +import ( + "crypto/tls" + "errors" + "fmt" + "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared" + "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/sslconfigs" + "net" + "strconv" + "strings" + "time" +) + // 源站服务配置 type OriginServerConfig struct { - Id string `yaml:"id" json:"id"` // ID - IsOn bool `yaml:"isOn" json:"isOn"` // 是否启用 + HeaderList *shared.HeaderList `yaml:"headers" json:"headers"` + + Id int64 `yaml:"id" json:"id"` // ID + IsOn bool `yaml:"isOn" json:"isOn"` // 是否启用 TODO + Version int `yaml:"version" json:"version"` // 版本 Name string `yaml:"name" json:"name"` // 名称 TODO Addr *NetworkAddressConfig `yaml:"addr" json:"addr"` // 地址 Description string `yaml:"description" json:"description"` // 描述 TODO + Code string `yaml:"code" json:"code"` // 代号 TODO + Scheme string `yaml:"scheme" json:"scheme"` // 协议 TODO + + Weight uint `yaml:"weight" json:"weight"` // 权重 TODO + IsBackup bool `yaml:"backup" json:"isBackup"` // 是否为备份 TODO + FailTimeout string `yaml:"failTimeout" json:"failTimeout"` // 连接失败超时 TODO + ReadTimeout string `yaml:"readTimeout" json:"readTimeout"` // 读取超时时间 TODO + IdleTimeout string `yaml:"idleTimeout" json:"idleTimeout"` // 空闲连接超时时间 TODO + MaxFails int32 `yaml:"maxFails" json:"maxFails"` // 最多失败次数 TODO + CurrentFails int32 `yaml:"currentFails" json:"currentFails"` // 当前已失败次数 TODO + MaxConns int32 `yaml:"maxConns" json:"maxConns"` // 最大并发连接数 TODO + CurrentConns int32 `yaml:"currentConns" json:"currentConns"` // 当前连接数 TODO + IdleConns int32 `yaml:"idleConns" json:"idleConns"` // 最大空闲连接数 TODO + + IsDown bool `yaml:"down" json:"isDown"` // 是否下线 TODO + DownTime time.Time `yaml:"downTime,omitempty" json:"downTime,omitempty"` // 下线时间 TODO + + RequestURI string `yaml:"requestURI" json:"requestURI"` // 转发后的请求URI TODO + ResponseHeaders []*shared.HeaderConfig `yaml:"responseHeaders" json:"responseHeaders"` // 响应Header TODO + Host string `yaml:"host" json:"host"` // 自定义主机名 TODO + + // 健康检查URL,目前支持: + // - http|https 返回2xx-3xx认为成功 + HealthCheck struct { + IsOn bool `yaml:"isOn" json:"isOn"` // 是否开启 TODO + URL string `yaml:"url" json:"url"` // TODO + Interval int `yaml:"interval" json:"interval"` // TODO + StatusCodes []int `yaml:"statusCodes" json:"statusCodes"` // TODO + Timeout shared.TimeDuration `yaml:"timeout" json:"timeout"` // 超时时间 TODO + } `yaml:"healthCheck" json:"healthCheck"` + + Cert *sslconfigs.SSLCertConfig `yaml:"cert" json:"cert"` // 请求源服务器用的证书 + + // ftp + FTP *OriginServerFTPConfig `yaml:"ftp" json:"ftp"` + + failTimeoutDuration time.Duration + readTimeoutDuration time.Duration + idleTimeoutDuration time.Duration + + hasRequestURI bool + requestPath string + requestArgs string + + hasRequestHeaders bool + hasResponseHeaders bool + + hasHost bool + + uniqueKey string + + hasAddrVariables bool // 地址中是否含有变量 +} + +// 校验 +func (this *OriginServerConfig) Init() error { + // 证书 + if this.Cert != nil { + err := this.Cert.Init() + if err != nil { + return err + } + } + + // unique key + this.uniqueKey = strconv.FormatInt(this.Id, 10) + "@" + fmt.Sprintf("%d", this.Version) + + // failTimeout + if len(this.FailTimeout) > 0 { + this.failTimeoutDuration, _ = time.ParseDuration(this.FailTimeout) + } + + // readTimeout + if len(this.ReadTimeout) > 0 { + this.readTimeoutDuration, _ = time.ParseDuration(this.ReadTimeout) + } + + // idleTimeout + if len(this.IdleTimeout) > 0 { + this.idleTimeoutDuration, _ = time.ParseDuration(this.IdleTimeout) + } + + // Headers + if this.HeaderList != nil { + err := this.HeaderList.Init() + if err != nil { + return err + } + } + + // request uri + if len(this.RequestURI) == 0 || this.RequestURI == "${requestURI}" { + this.hasRequestURI = false + } else { + this.hasRequestURI = true + + if strings.Contains(this.RequestURI, "?") { + pieces := strings.SplitN(this.RequestURI, "?", -1) + this.requestPath = pieces[0] + this.requestArgs = pieces[1] + } else { + this.requestPath = this.RequestURI + } + } + + // TODO init health check + + // headers + if this.HeaderList != nil { + this.hasRequestHeaders = len(this.HeaderList.RequestHeaders) > 0 + } + this.hasResponseHeaders = len(this.ResponseHeaders) > 0 + + // host + this.hasHost = len(this.Host) > 0 + + // variables + // TODO 在host和port中支持变量 + this.hasAddrVariables = false + + return nil +} + +// 候选对象代号 +func (this *OriginServerConfig) CandidateCodes() []string { + codes := []string{strconv.FormatInt(this.Id, 10)} + if len(this.Code) > 0 { + codes = append(codes, this.Code) + } + return codes +} + +// 候选对象权重 +func (this *OriginServerConfig) CandidateWeight() uint { + return this.Weight +} + +// 连接源站 +func (this *OriginServerConfig) Connect() (net.Conn, error) { + switch this.Scheme { + case "", ProtocolTCP: + // TODO 支持TCP4/TCP6 + // TODO 支持指定特定网卡 + // TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用 + return net.DialTimeout("tcp", this.Addr.Host+":"+this.Addr.PortRange, this.failTimeoutDuration) + case ProtocolTLS: + // TODO 支持TCP4/TCP6 + // TODO 支持指定特定网卡 + // TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用 + // TODO 支持使用证书 + return tls.Dial("tcp", this.Addr.Host+":"+this.Addr.PortRange, &tls.Config{}) + } + + // TODO 支持从Unix、Pipe、HTTP、HTTPS中读取数据 + + return nil, errors.New("invalid scheme '" + this.Scheme + "'") } diff --git a/internal/configs/serverconfigs/origin_server_ftp.go b/internal/configs/serverconfigs/origin_server_ftp.go new file mode 100644 index 0000000..4d4f309 --- /dev/null +++ b/internal/configs/serverconfigs/origin_server_ftp.go @@ -0,0 +1,8 @@ +package serverconfigs + +// FTP源站配置 +type OriginServerFTPConfig struct { + Username string `yaml:"username" json:"username"` // 用户名 + Password string `yaml:"password" json:"password"` // 密码 + Dir string `yaml:"dir" json:"dir"` // 目录 +} diff --git a/internal/configs/serverconfigs/protocol_tls_config.go b/internal/configs/serverconfigs/protocol_tls_config.go index b90f0cc..5880a5c 100644 --- a/internal/configs/serverconfigs/protocol_tls_config.go +++ b/internal/configs/serverconfigs/protocol_tls_config.go @@ -5,9 +5,10 @@ import "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/sslconfigs" type TLSProtocolConfig struct { BaseProtocol `yaml:",inline"` - SSL *sslconfigs.SSLConfig `yaml:"ssl"` + SSL *sslconfigs.SSLConfig `yaml:"ssl" json:"ssl"` } +// 初始化 func (this *TLSProtocolConfig) Init() error { err := this.InitBase() if err != nil { diff --git a/internal/configs/serverconfigs/reverse_proxy_config.go b/internal/configs/serverconfigs/reverse_proxy_config.go index dc9b4d7..49b43de 100644 --- a/internal/configs/serverconfigs/reverse_proxy_config.go +++ b/internal/configs/serverconfigs/reverse_proxy_config.go @@ -1,6 +1,105 @@ package serverconfigs +import ( + "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/scheduling" + "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared" + "sync" +) + +// 反向代理设置 type ReverseProxyConfig struct { - IsOn bool `yaml:"isOn" json:"isOn"` // 是否启用 - Origins []*OriginServerConfig `yaml:"origins" json:"origins"` // 源站列表 + IsOn bool `yaml:"isOn" json:"isOn"` // 是否启用 TODO + Origins []*OriginServerConfig `yaml:"origins" json:"origins"` // 源站列表 + Scheduling *SchedulingConfig `yaml:"scheduling" json:"scheduling"` // 调度算法选项 + + hasOrigins bool + schedulingIsBackup bool + schedulingObject scheduling.SchedulingInterface + schedulingLocker sync.Mutex +} + +// 初始化 +func (this *ReverseProxyConfig) Init() error { + this.hasOrigins = len(this.Origins) > 0 + + for _, origin := range this.Origins { + err := origin.Init() + if err != nil { + return err + } + } + + // scheduling + this.SetupScheduling(false) + + return nil +} + +// 取得下一个可用的后端服务 +func (this *ReverseProxyConfig) NextOrigin(call *shared.RequestCall) *OriginServerConfig { + this.schedulingLocker.Lock() + defer this.schedulingLocker.Unlock() + + if this.schedulingObject == nil { + return nil + } + + if this.Scheduling != nil && call != nil && call.Options != nil { + for k, v := range this.Scheduling.Options { + call.Options[k] = v + } + } + + candidate := this.schedulingObject.Next(call) + if candidate == nil { + // 启用备用服务器 + if !this.schedulingIsBackup { + this.SetupScheduling(true) + + candidate = this.schedulingObject.Next(call) + if candidate == nil { + return nil + } + } + + if candidate == nil { + return nil + } + } + + return candidate.(*OriginServerConfig) +} + +// 设置调度算法 +func (this *ReverseProxyConfig) SetupScheduling(isBackup bool) { + if !isBackup { + this.schedulingLocker.Lock() + defer this.schedulingLocker.Unlock() + } + this.schedulingIsBackup = isBackup + + if this.Scheduling == nil { + this.schedulingObject = &scheduling.RandomScheduling{} + } else { + typeCode := this.Scheduling.Code + s := scheduling.FindSchedulingType(typeCode) + if s == nil { + this.Scheduling = nil + this.schedulingObject = &scheduling.RandomScheduling{} + } else { + this.schedulingObject = s["instance"].(scheduling.SchedulingInterface) + } + } + + for _, origin := range this.Origins { + if origin.IsOn && !origin.IsDown { + if isBackup && origin.IsBackup { + this.schedulingObject.Add(origin) + } else if !isBackup && !origin.IsBackup { + this.schedulingObject.Add(origin) + } + } + } + + this.schedulingObject.Start() } diff --git a/internal/configs/serverconfigs/scheduling/candidate.go b/internal/configs/serverconfigs/scheduling/candidate.go new file mode 100644 index 0000000..701e3c0 --- /dev/null +++ b/internal/configs/serverconfigs/scheduling/candidate.go @@ -0,0 +1,10 @@ +package scheduling + +// 候选对象接口 +type CandidateInterface interface { + // 权重 + CandidateWeight() uint + + // 代号 + CandidateCodes() []string +} diff --git a/internal/configs/serverconfigs/scheduling/scheduling.go b/internal/configs/serverconfigs/scheduling/scheduling.go new file mode 100644 index 0000000..f86cb19 --- /dev/null +++ b/internal/configs/serverconfigs/scheduling/scheduling.go @@ -0,0 +1,39 @@ +package scheduling + +import ( + "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared" + "github.com/iwind/TeaGo/maps" +) + +// 调度算法接口 +type SchedulingInterface interface { + // 是否有候选对象 + HasCandidates() bool + + // 添加候选对象 + Add(candidate ...CandidateInterface) + + // 启动 + Start() + + // 查找下一个候选对象 + Next(call *shared.RequestCall) CandidateInterface + + // 获取简要信息 + Summary() maps.Map +} + +// 调度算法基础类 +type Scheduling struct { + Candidates []CandidateInterface +} + +// 判断是否有候选对象 +func (this *Scheduling) HasCandidates() bool { + return len(this.Candidates) > 0 +} + +// 添加候选对象 +func (this *Scheduling) Add(candidate ...CandidateInterface) { + this.Candidates = append(this.Candidates, candidate...) +} diff --git a/internal/configs/serverconfigs/scheduling/scheduling_hash.go b/internal/configs/serverconfigs/scheduling/scheduling_hash.go new file mode 100644 index 0000000..9f64069 --- /dev/null +++ b/internal/configs/serverconfigs/scheduling/scheduling_hash.go @@ -0,0 +1,45 @@ +package scheduling + +import ( + "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared" + "github.com/iwind/TeaGo/maps" + "hash/crc32" +) + +// Hash调度算法 +type HashScheduling struct { + Scheduling + + count uint32 +} + +// 启动 +func (this *HashScheduling) Start() { + this.count = uint32(len(this.Candidates)) +} + +// 获取下一个候选对象 +func (this *HashScheduling) Next(call *shared.RequestCall) CandidateInterface { + if this.count == 0 { + return nil + } + + key := call.Options.GetString("key") + + if call.Formatter != nil { + key = call.Formatter(key) + } + + sum := crc32.ChecksumIEEE([]byte(key)) + return this.Candidates[sum%this.count] +} + +// 获取简要信息 +func (this *HashScheduling) Summary() maps.Map { + return maps.Map{ + "code": "hash", + "name": "Hash算法", + "description": "根据自定义的键值的Hash值分配后端服务器", + "networks": []string{"http"}, + } +} diff --git a/internal/configs/serverconfigs/scheduling/scheduling_hash_test.go b/internal/configs/serverconfigs/scheduling/scheduling_hash_test.go new file mode 100644 index 0000000..290131b --- /dev/null +++ b/internal/configs/serverconfigs/scheduling/scheduling_hash_test.go @@ -0,0 +1,45 @@ +package scheduling + +import ( + "fmt" + "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared" + "math/rand" + "testing" + "time" +) + +func TestHashScheduling_Next(t *testing.T) { + s := &HashScheduling{} + s.Add(&TestCandidate{ + Name: "a", + Weight: 10, + }) + s.Add(&TestCandidate{ + Name: "b", + Weight: 10, + }) + s.Add(&TestCandidate{ + Name: "c", + Weight: 10, + }) + s.Add(&TestCandidate{ + Name: "d", + Weight: 30, + }) + s.Start() + + hits := map[string]uint{} + for _, c := range s.Candidates { + hits[c.(*TestCandidate).Name] = 0 + } + + rand.Seed(time.Now().UnixNano()) + for i := 0; i < 1000000; i ++ { + call := shared.NewRequestCall() + call.Options["key"] = "192.168.1." + fmt.Sprintf("%d", rand.Int()) + + c := s.Next(call) + hits[c.(*TestCandidate).Name] ++ + } + t.Log(hits) +} diff --git a/internal/configs/serverconfigs/scheduling/scheduling_random.go b/internal/configs/serverconfigs/scheduling/scheduling_random.go new file mode 100644 index 0000000..bd9165c --- /dev/null +++ b/internal/configs/serverconfigs/scheduling/scheduling_random.go @@ -0,0 +1,78 @@ +package scheduling + +import ( + "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared" + "github.com/iwind/TeaGo/maps" + "math" + "math/rand" + "time" +) + +// 随机调度算法 +type RandomScheduling struct { + Scheduling + + array []CandidateInterface + count uint // 实际总的服务器数 +} + +// 启动 +func (this *RandomScheduling) Start() { + sumWeight := uint(0) + for _, c := range this.Candidates { + weight := c.CandidateWeight() + if weight == 0 { + weight = 1 + } else if weight > 10000 { + weight = 10000 + } + sumWeight += weight + } + + if sumWeight == 0 { + return + } + + for _, c := range this.Candidates { + weight := c.CandidateWeight() + if weight == 0 { + weight = 1 + } else if weight > 10000 { + weight = 10000 + } + count := uint(0) + if sumWeight <= 1000 { + count = weight + } else { + count = uint(math.Round(float64(weight*10000) / float64(sumWeight))) // 1% 产生 100个数据,最多支持10000个服务器 + } + for i := uint(0); i < count; i++ { + this.array = append(this.array, c) + } + this.count += count + } + + rand.Seed(time.Now().UnixNano()) +} + +// 获取下一个候选对象 +func (this *RandomScheduling) Next(call *shared.RequestCall) CandidateInterface { + if this.count == 0 { + return nil + } + if this.count == 1 { + return this.array[0] + } + index := rand.Int() % int(this.count) + return this.array[index] +} + +// 获取简要信息 +func (this *RandomScheduling) Summary() maps.Map { + return maps.Map{ + "code": "random", + "name": "Random随机算法", + "description": "根据权重设置随机分配后端服务器", + "networks": []string{"http", "tcp"}, + } +} diff --git a/internal/configs/serverconfigs/scheduling/scheduling_random_test.go b/internal/configs/serverconfigs/scheduling/scheduling_random_test.go new file mode 100644 index 0000000..4b89e90 --- /dev/null +++ b/internal/configs/serverconfigs/scheduling/scheduling_random_test.go @@ -0,0 +1,79 @@ +package scheduling + +import ( + "sync" + "testing" +) + +type TestCandidate struct { + Name string + Weight uint +} + +func (this *TestCandidate) CandidateWeight() uint { + return this.Weight +} + +func (this *TestCandidate) CandidateCodes() []string { + return []string{this.Name} +} + +func TestRandomScheduling_Next(t *testing.T) { + s := &RandomScheduling{} + s.Add(&TestCandidate{ + Name: "a", + Weight: 10, + }) + s.Add(&TestCandidate{ + Name: "b", + Weight: 10, + }) + s.Add(&TestCandidate{ + Name: "c", + Weight: 10, + }) + s.Add(&TestCandidate{ + Name: "d", + Weight: 30, + }) + s.Start() + + /**for _, c := range s.array { + t.Log(c.(*TestCandidate).Name, ":", c.CandidateWeight()) + }**/ + + hits := map[string]uint{} + for _, c := range s.array { + hits[c.(*TestCandidate).Name] = 0 + } + + t.Log("count:", s.count, "array length:", len(s.array)) + + var locker sync.Mutex + var wg = sync.WaitGroup{} + wg.Add(100 * 10000) + for i := 0; i < 100*10000; i ++ { + go func() { + defer wg.Done() + + c := s.Next(nil) + + locker.Lock() + defer locker.Unlock() + hits[c.(*TestCandidate).Name] ++ + }() + } + wg.Wait() + + t.Log(hits) +} + +func TestRandomScheduling_NextZero(t *testing.T) { + s := &RandomScheduling{} + s.Add(&TestCandidate{ + Name: "a", + Weight: 0, + }) + s.Start() + t.Log(s.Next(nil)) +} diff --git a/internal/configs/serverconfigs/scheduling/scheduling_round_robin.go b/internal/configs/serverconfigs/scheduling/scheduling_round_robin.go new file mode 100644 index 0000000..de9de53 --- /dev/null +++ b/internal/configs/serverconfigs/scheduling/scheduling_round_robin.go @@ -0,0 +1,80 @@ +package scheduling + +import ( + "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared" + "github.com/iwind/TeaGo/lists" + "github.com/iwind/TeaGo/maps" + "sync" +) + +// 轮询调度算法 +type RoundRobinScheduling struct { + Scheduling + + rawWeights []uint + currentWeights []uint + count uint + index uint + + locker sync.Mutex +} + +// 启动 +func (this *RoundRobinScheduling) Start() { + lists.Sort(this.Candidates, func(i int, j int) bool { + c1 := this.Candidates[i] + c2 := this.Candidates[j] + return c1.CandidateWeight() > c2.CandidateWeight() + }) + + for _, c := range this.Candidates { + weight := c.CandidateWeight() + if weight == 0 { + weight = 1 + } else if weight > 10000 { + weight = 10000 + } + this.rawWeights = append(this.rawWeights, weight) + } + + this.currentWeights = append([]uint{}, this.rawWeights...) + this.count = uint(len(this.Candidates)) +} + +// 获取下一个候选对象 +func (this *RoundRobinScheduling) Next(call *shared.RequestCall) CandidateInterface { + if this.count == 0 { + return nil + } + this.locker.Lock() + defer this.locker.Unlock() + + if this.index > this.count-1 { + this.index = 0 + } + weight := this.currentWeights[this.index] + + // 已经一轮了,则重置状态 + if weight == 0 { + if this.currentWeights[0] == 0 { + this.currentWeights = append([]uint{}, this.rawWeights...) + } + this.index = 0 + weight = this.currentWeights[this.index] + } + + c := this.Candidates[this.index] + this.currentWeights[this.index] -- + this.index++ + return c +} + +// 获取简要信息 +func (this *RoundRobinScheduling) Summary() maps.Map { + return maps.Map{ + "code": "roundRobin", + "name": "RoundRobin轮询算法", + "description": "根据权重,依次分配后端服务器", + "networks": []string{"http", "tcp"}, + } +} diff --git a/internal/configs/serverconfigs/scheduling/scheduling_round_robin_test.go b/internal/configs/serverconfigs/scheduling/scheduling_round_robin_test.go new file mode 100644 index 0000000..20c8152 --- /dev/null +++ b/internal/configs/serverconfigs/scheduling/scheduling_round_robin_test.go @@ -0,0 +1,101 @@ +package scheduling + +import "testing" + +func TestRoundRobinScheduling_Next(t *testing.T) { + s := &RoundRobinScheduling{} + s.Add(&TestCandidate{ + Name: "a", + Weight: 5, + }) + s.Add(&TestCandidate{ + Name: "b", + Weight: 10, + }) + s.Add(&TestCandidate{ + Name: "c", + Weight: 20, + }) + s.Add(&TestCandidate{ + Name: "d", + Weight: 30, + }) + s.Start() + + for _, c := range s.Candidates { + t.Log(c.(*TestCandidate).Name, c.CandidateWeight()) + } + + t.Log(s.currentWeights) + + for i := 0; i < 100; i ++ { + t.Log("===", "round", i, "===") + t.Log(s.Next(nil)) + t.Log(s.currentWeights) + t.Log(s.rawWeights) + } +} + +func TestRoundRobinScheduling_Two(t *testing.T) { + s := &RoundRobinScheduling{} + s.Add(&TestCandidate{ + Name: "a", + Weight: 10, + }) + s.Add(&TestCandidate{ + Name: "b", + Weight: 10, + }) + s.Start() + + for _, c := range s.Candidates { + t.Log(c.(*TestCandidate).Name, c.CandidateWeight()) + } + + t.Log(s.currentWeights) + + for i := 0; i < 100; i ++ { + t.Log("===", "round", i, "===") + t.Log(s.Next(nil)) + t.Log(s.currentWeights) + t.Log(s.rawWeights) + } +} + +func TestRoundRobinScheduling_NextPerformance(t *testing.T) { + s := &RoundRobinScheduling{} + s.Add(&TestCandidate{ + Name: "a", + Weight: 1, + }) + s.Add(&TestCandidate{ + Name: "b", + Weight: 2, + }) + s.Add(&TestCandidate{ + Name: "c", + Weight: 3, + }) + s.Add(&TestCandidate{ + Name: "d", + Weight: 6, + }) + s.Start() + + for _, c := range s.Candidates { + t.Log(c.(*TestCandidate).Name, c.CandidateWeight()) + } + + t.Log(s.currentWeights) + + hits := map[string]uint{} + for _, c := range s.Candidates { + hits[c.(*TestCandidate).Name] = 0 + } + for i := 0; i < 100*10000; i ++ { + c := s.Next(nil) + hits[c.(*TestCandidate).Name] ++ + } + + t.Log(hits) +} diff --git a/internal/configs/serverconfigs/scheduling/scheduling_sticky.go b/internal/configs/serverconfigs/scheduling/scheduling_sticky.go new file mode 100644 index 0000000..205bb61 --- /dev/null +++ b/internal/configs/serverconfigs/scheduling/scheduling_sticky.go @@ -0,0 +1,106 @@ +package scheduling + +import ( + "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared" + "github.com/iwind/TeaGo/maps" + "math/rand" + "net/http" + "time" +) + +// Sticky调度算法 +type StickyScheduling struct { + Scheduling + + count uint32 + mapping map[string]CandidateInterface // code => candidate +} + +// 启动 +func (this *StickyScheduling) Start() { + this.mapping = map[string]CandidateInterface{} + for _, c := range this.Candidates { + for _, code := range c.CandidateCodes() { + this.mapping[code] = c + } + } + + this.count = uint32(len(this.Candidates)) + rand.Seed(time.Now().UnixNano()) +} + +// 获取下一个候选对象 +func (this *StickyScheduling) Next(call *shared.RequestCall) CandidateInterface { + if this.count == 0 { + return nil + } + typeCode := call.Options.GetString("type") + param := call.Options.GetString("param") + + if call.Request == nil { + return this.Candidates[uint32(rand.Int())%this.count] + } + + code := "" + if typeCode == "cookie" { + cookie, err := call.Request.Cookie(param) + if err == nil { + code = cookie.Value + } + } else if typeCode == "header" { + code = call.Request.Header.Get(param) + } else if typeCode == "argument" { + code = call.Request.URL.Query().Get(param) + } + + matched := false + var c CandidateInterface = nil + + defer func() { + if !matched && c != nil { + codes := c.CandidateCodes() + if len(codes) == 0 { + return + } + if typeCode == "cookie" { + call.AddResponseCall(func(resp http.ResponseWriter) { + http.SetCookie(resp, &http.Cookie{ + Name: param, + Value: codes[0], + Path: "/", + Expires: time.Now().AddDate(0, 1, 0), + }) + }) + } else { + call.AddResponseCall(func(resp http.ResponseWriter) { + resp.Header().Set(param, codes[0]) + }) + } + } + }() + + if len(code) == 0 { + c = this.Candidates[uint32(rand.Int())%this.count] + return c + } + + found := false + c, found = this.mapping[code] + if !found { + c = this.Candidates[uint32(rand.Int())%this.count] + return c + } + + matched = true + return c +} + +// 获取简要信息 +func (this *StickyScheduling) Summary() maps.Map { + return maps.Map{ + "code": "sticky", + "name": "Sticky算法", + "description": "利用Cookie、URL参数或者HTTP Header来指定后端服务器", + "networks": []string{"http"}, + } +} diff --git a/internal/configs/serverconfigs/scheduling/scheduling_sticky_test.go b/internal/configs/serverconfigs/scheduling/scheduling_sticky_test.go new file mode 100644 index 0000000..a0d8ba2 --- /dev/null +++ b/internal/configs/serverconfigs/scheduling/scheduling_sticky_test.go @@ -0,0 +1,128 @@ +package scheduling + +import ( + "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs/shared" + "github.com/iwind/TeaGo/maps" + "net/http" + "testing" +) + +func TestStickyScheduling_NextArgument(t *testing.T) { + s := &StickyScheduling{} + s.Add(&TestCandidate{ + Name: "a", + Weight: 1, + }) + s.Add(&TestCandidate{ + Name: "b", + Weight: 2, + }) + s.Add(&TestCandidate{ + Name: "c", + Weight: 3, + }) + s.Add(&TestCandidate{ + Name: "d", + Weight: 6, + }) + s.Start() + + t.Log(s.mapping) + + req, err := http.NewRequest("GET", "http://www.example.com/?backend=c", nil) + if err != nil { + t.Fatal(err) + } + + options := maps.Map{ + "type": "argument", + "param": "backend", + } + call := shared.NewRequestCall() + call.Request = req + call.Options = options + t.Log(s.Next(call)) + t.Log(options) +} + +func TestStickyScheduling_NextCookie(t *testing.T) { + s := &StickyScheduling{} + s.Add(&TestCandidate{ + Name: "a", + Weight: 1, + }) + s.Add(&TestCandidate{ + Name: "b", + Weight: 2, + }) + s.Add(&TestCandidate{ + Name: "c", + Weight: 3, + }) + s.Add(&TestCandidate{ + Name: "d", + Weight: 6, + }) + s.Start() + + t.Log(s.mapping) + + req, err := http.NewRequest("GET", "http://www.example.com/?backend=c", nil) + if err != nil { + t.Fatal(err) + } + + req.AddCookie(&http.Cookie{ + Name: "backend", + Value: "c", + }) + + options := maps.Map{ + "type": "cookie", + "param": "backend", + } + call := shared.NewRequestCall() + call.Request = req + call.Options = options + t.Log(s.Next(call)) + t.Log(options) +} + +func TestStickyScheduling_NextHeader(t *testing.T) { + s := &StickyScheduling{} + s.Add(&TestCandidate{ + Name: "a", + Weight: 1, + }) + s.Add(&TestCandidate{ + Name: "b", + Weight: 2, + }) + s.Add(&TestCandidate{ + Name: "c", + Weight: 3, + }) + s.Add(&TestCandidate{ + Name: "d", + Weight: 6, + }) + s.Start() + + t.Log(s.mapping) + + req, err := http.NewRequest("GET", "http://www.example.com/?backend=c", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("backend", "c") + + options := maps.Map{ + "type": "header", + "param": "backend", + } + call := shared.NewRequestCall() + call.Request = req + call.Options = options + t.Log(s.Next(call)) + t.Log(options) +} diff --git a/internal/configs/serverconfigs/scheduling/utils.go b/internal/configs/serverconfigs/scheduling/utils.go new file mode 100644 index 0000000..aa44d98 --- /dev/null +++ b/internal/configs/serverconfigs/scheduling/utils.go @@ -0,0 +1,28 @@ +package scheduling + +import "github.com/iwind/TeaGo/maps" + +// 所有请求类型 +func AllSchedulingTypes() []maps.Map { + types := []maps.Map{} + for _, s := range []SchedulingInterface{ + new(RandomScheduling), + new(RoundRobinScheduling), + new(HashScheduling), + new(StickyScheduling), + } { + summary := s.Summary() + summary["instance"] = s + types = append(types, summary) + } + return types +} + +func FindSchedulingType(code string) maps.Map { + for _, summary := range AllSchedulingTypes() { + if summary["code"] == code { + return summary + } + } + return nil +} diff --git a/internal/configs/serverconfigs/scheduling_config.go b/internal/configs/serverconfigs/scheduling_config.go new file mode 100644 index 0000000..d091b1b --- /dev/null +++ b/internal/configs/serverconfigs/scheduling_config.go @@ -0,0 +1,14 @@ +package serverconfigs + +import "github.com/iwind/TeaGo/maps" + +// 调度算法配置 +type SchedulingConfig struct { + Code string `yaml:"code" json:"code"` // 类型 + Options maps.Map `yaml:"options" json:"options"` // 选项 +} + +// 获取新对象 +func NewSchedulingConfig() *SchedulingConfig { + return &SchedulingConfig{} +} diff --git a/internal/configs/serverconfigs/server_config.go b/internal/configs/serverconfigs/server_config.go index 170c583..e0e1c6b 100644 --- a/internal/configs/serverconfigs/server_config.go +++ b/internal/configs/serverconfigs/server_config.go @@ -76,6 +76,13 @@ func (this *ServerConfig) Init() error { } } + if this.ReverseProxy != nil { + err := this.ReverseProxy.Init() + if err != nil { + return err + } + } + return nil } diff --git a/internal/configs/serverconfigs/shared/header.go b/internal/configs/serverconfigs/shared/header.go new file mode 100644 index 0000000..70049a6 --- /dev/null +++ b/internal/configs/serverconfigs/shared/header.go @@ -0,0 +1,68 @@ +package shared + +import ( + "github.com/iwind/TeaGo/utils/string" + "regexp" +) + +var regexpNamedVariable = regexp.MustCompile("\\${[\\w.-]+}") + +// 头部信息定义 +type HeaderConfig struct { + IsOn bool `yaml:"isOn" json:"isOn"` // 是否开启 + Id string `yaml:"id" json:"id"` // ID + Name string `yaml:"name" json:"name"` // Name + Value string `yaml:"value" json:"value"` // Value + Always bool `yaml:"always" json:"always"` // 是否忽略状态码 + Status []int `yaml:"status" json:"status"` // 支持的状态码 + + statusMap map[int]bool + hasVariables bool +} + +// 获取新Header对象 +func NewHeaderConfig() *HeaderConfig { + return &HeaderConfig{ + IsOn: true, + Id: stringutil.Rand(16), + } +} + +// 校验 +func (this *HeaderConfig) Validate() error { + this.statusMap = map[int]bool{} + this.hasVariables = regexpNamedVariable.MatchString(this.Value) + + if this.Status == nil { + this.Status = []int{} + } + + for _, status := range this.Status { + this.statusMap[status] = true + } + + return nil +} + +// 判断是否匹配状态码 +func (this *HeaderConfig) Match(statusCode int) bool { + if !this.IsOn { + return false + } + + if this.Always { + return true + } + + if this.statusMap != nil { + _, found := this.statusMap[statusCode] + return found + } + + return false +} + +// 是否有变量 +func (this *HeaderConfig) HasVariables() bool { + return this.hasVariables +} diff --git a/internal/configs/serverconfigs/shared/header_list.go b/internal/configs/serverconfigs/shared/header_list.go new file mode 100644 index 0000000..b04a963 --- /dev/null +++ b/internal/configs/serverconfigs/shared/header_list.go @@ -0,0 +1,245 @@ +package shared + +import ( + "github.com/iwind/TeaGo/lists" + "strings" +) + +// HeaderList相关操作接口 +type HeaderListInterface interface { + // 校验 + ValidateHeaders() error + + // 取得所有的IgnoreHeader + AllIgnoreResponseHeaders() []string + + // 添加IgnoreHeader + AddIgnoreResponseHeader(name string) + + // 判断是否包含IgnoreHeader + ContainsIgnoreResponseHeader(name string) bool + + // 移除IgnoreHeader + RemoveIgnoreResponseHeader(name string) + + // 修改IgnoreHeader + UpdateIgnoreResponseHeader(oldName string, newName string) + + // 取得所有的Header + AllResponseHeaders() []*HeaderConfig + + // 添加Header + AddResponseHeader(header *HeaderConfig) + + // 判断是否包含Header + ContainsResponseHeader(name string) bool + + // 查找Header + FindResponseHeader(headerId string) *HeaderConfig + + // 移除Header + RemoveResponseHeader(headerId string) + + // 取得所有的请求Header + AllRequestHeaders() []*HeaderConfig + + // 添加请求Header + AddRequestHeader(header *HeaderConfig) + + // 查找请求Header + FindRequestHeader(headerId string) *HeaderConfig + + // 移除请求Header + RemoveRequestHeader(headerId string) +} + +// HeaderList定义 +type HeaderList struct { + // 添加的响应Headers + Headers []*HeaderConfig `yaml:"headers" json:"headers"` + + // 忽略的响应Headers + IgnoreHeaders []string `yaml:"ignoreHeaders" json:"ignoreHeaders"` + + // 自定义请求Headers + RequestHeaders []*HeaderConfig `yaml:"requestHeaders" json:"requestHeaders"` + + hasResponseHeaders bool + hasRequestHeaders bool + + hasIgnoreHeaders bool + uppercaseIgnoreHeaders []string +} + +// 校验 +func (this *HeaderList) Init() error { + this.hasResponseHeaders = len(this.Headers) > 0 + this.hasRequestHeaders = len(this.RequestHeaders) > 0 + + for _, h := range this.Headers { + err := h.Validate() + if err != nil { + return err + } + } + + for _, h := range this.RequestHeaders { + err := h.Validate() + if err != nil { + return err + } + } + + this.hasIgnoreHeaders = len(this.IgnoreHeaders) > 0 + this.uppercaseIgnoreHeaders = []string{} + for _, headerKey := range this.IgnoreHeaders { + this.uppercaseIgnoreHeaders = append(this.uppercaseIgnoreHeaders, strings.ToUpper(headerKey)) + } + + return nil +} + +// 是否有Headers +func (this *HeaderList) HasResponseHeaders() bool { + return this.hasResponseHeaders +} + +// 取得所有的IgnoreHeader +func (this *HeaderList) AllIgnoreResponseHeaders() []string { + if this.IgnoreHeaders == nil { + return []string{} + } + return this.IgnoreHeaders +} + +// 添加IgnoreHeader +func (this *HeaderList) AddIgnoreResponseHeader(name string) { + if !lists.ContainsString(this.IgnoreHeaders, name) { + this.IgnoreHeaders = append(this.IgnoreHeaders, name) + } +} + +// 判断是否包含IgnoreHeader +func (this *HeaderList) ContainsIgnoreResponseHeader(name string) bool { + if len(this.IgnoreHeaders) == 0 { + return false + } + return lists.ContainsString(this.IgnoreHeaders, name) +} + +// 修改IgnoreHeader +func (this *HeaderList) UpdateIgnoreResponseHeader(oldName string, newName string) { + result := []string{} + for _, h := range this.IgnoreHeaders { + if h == oldName { + result = append(result, newName) + } else { + result = append(result, h) + } + } + this.IgnoreHeaders = result +} + +// 移除IgnoreHeader +func (this *HeaderList) RemoveIgnoreResponseHeader(name string) { + result := []string{} + for _, n := range this.IgnoreHeaders { + if n == name { + continue + } + result = append(result, n) + } + this.IgnoreHeaders = result +} + +// 取得所有的Header +func (this *HeaderList) AllResponseHeaders() []*HeaderConfig { + if this.Headers == nil { + return []*HeaderConfig{} + } + return this.Headers +} + +// 添加Header +func (this *HeaderList) AddResponseHeader(header *HeaderConfig) { + this.Headers = append(this.Headers, header) +} + +// 判断是否包含Header +func (this *HeaderList) ContainsResponseHeader(name string) bool { + for _, h := range this.Headers { + if h.Name == name { + return true + } + } + return false +} + +// 查找Header +func (this *HeaderList) FindResponseHeader(headerId string) *HeaderConfig { + for _, h := range this.Headers { + if h.Id == headerId { + return h + } + } + return nil +} + +// 移除Header +func (this *HeaderList) RemoveResponseHeader(headerId string) { + result := []*HeaderConfig{} + for _, h := range this.Headers { + if h.Id == headerId { + continue + } + result = append(result, h) + } + this.Headers = result +} + +// 添加请求Header +func (this *HeaderList) AddRequestHeader(header *HeaderConfig) { + this.RequestHeaders = append(this.RequestHeaders, header) +} + +// 判断是否有请求Header +func (this *HeaderList) HasRequestHeaders() bool { + return this.hasRequestHeaders +} + +// 取得所有的请求Header +func (this *HeaderList) AllRequestHeaders() []*HeaderConfig { + return this.RequestHeaders +} + +// 查找请求Header +func (this *HeaderList) FindRequestHeader(headerId string) *HeaderConfig { + for _, h := range this.RequestHeaders { + if h.Id == headerId { + return h + } + } + return nil +} + +// 移除请求Header +func (this *HeaderList) RemoveRequestHeader(headerId string) { + result := []*HeaderConfig{} + for _, h := range this.RequestHeaders { + if h.Id == headerId { + continue + } + result = append(result, h) + } + this.RequestHeaders = result +} + +// 判断是否有Ignore Headers +func (this *HeaderList) HasIgnoreHeaders() bool { + return this.hasIgnoreHeaders +} + +// 查找大写的Ignore Headers +func (this *HeaderList) UppercaseIgnoreHeaders() []string { + return this.uppercaseIgnoreHeaders +} diff --git a/internal/configs/serverconfigs/shared/header_list_test.go b/internal/configs/serverconfigs/shared/header_list_test.go new file mode 100644 index 0000000..940dc7b --- /dev/null +++ b/internal/configs/serverconfigs/shared/header_list_test.go @@ -0,0 +1,23 @@ +package shared + +import ( + "fmt" + "testing" +) + +func TestHeaderList_FormatHeaders(t *testing.T) { + list := &HeaderList{} + + for i := 0; i < 5; i++ { + list.AddRequestHeader(&HeaderConfig{ + IsOn: true, + Name: "A" + fmt.Sprintf("%d", i), + Value: "ABCDEFGHIJ${name}KLM${hello}NEFGHIJILKKKk", + }) + } + + err := list.Init() + if err != nil { + t.Fatal(err) + } +} diff --git a/internal/configs/serverconfigs/shared/header_test.go b/internal/configs/serverconfigs/shared/header_test.go new file mode 100644 index 0000000..823f9c7 --- /dev/null +++ b/internal/configs/serverconfigs/shared/header_test.go @@ -0,0 +1,28 @@ +package shared + +import ( + "github.com/iwind/TeaGo/assert" + "testing" +) + +func TestHeaderConfig_Match(t *testing.T) { + a := assert.NewAssertion(t) + h := NewHeaderConfig() + err := h.Validate() + if err != nil { + t.Fatal(err) + } + a.IsFalse(h.Match(200)) + a.IsFalse(h.Match(400)) + + h.Status = []int{200, 201, 400} + err = h.Validate() + if err != nil { + t.Fatal(err) + } + a.IsTrue(h.Match(400)) + a.IsFalse(h.Match(500)) + + h.Always = true + a.IsTrue(h.Match(500)) +} diff --git a/internal/configs/serverconfigs/shared/regexp.go b/internal/configs/serverconfigs/shared/regexp.go new file mode 100644 index 0000000..ab49949 --- /dev/null +++ b/internal/configs/serverconfigs/shared/regexp.go @@ -0,0 +1,13 @@ +package shared + +import "regexp" + +// 常用的正则表达式 +var ( + RegexpDigitNumber = regexp.MustCompile(`^\d+$`) // 正整数 + RegexpFloatNumber = regexp.MustCompile(`^\d+(\.\d+)?$`) // 正浮点数,不支持e + RegexpAllDigitNumber = regexp.MustCompile(`^[+-]?\d+$`) // 整数,支持正负数 + RegexpAllFloatNumber = regexp.MustCompile(`^[+-]?\d+(\.\d+)?$`) // 浮点数,支持正负数,不支持e + RegexpExternalURL = regexp.MustCompile("(?i)^(http|https|ftp)://") // URL + RegexpNamedVariable = regexp.MustCompile("\\${[\\w.-]+}") // 命名变量 +) diff --git a/internal/configs/serverconfigs/shared/regexp_test.go b/internal/configs/serverconfigs/shared/regexp_test.go new file mode 100644 index 0000000..44f6228 --- /dev/null +++ b/internal/configs/serverconfigs/shared/regexp_test.go @@ -0,0 +1,17 @@ +package shared + +import ( + "github.com/iwind/TeaGo/assert" + "testing" +) + +func TestRegexp(t *testing.T) { + a := assert.NewAssertion(t) + + a.IsTrue(RegexpFloatNumber.MatchString("123")) + a.IsTrue(RegexpFloatNumber.MatchString("123.456")) + a.IsFalse(RegexpFloatNumber.MatchString(".456")) + a.IsFalse(RegexpFloatNumber.MatchString("abc")) + a.IsFalse(RegexpFloatNumber.MatchString("123.")) + a.IsFalse(RegexpFloatNumber.MatchString("123.456e7")) +} diff --git a/internal/configs/serverconfigs/shared/request_call.go b/internal/configs/serverconfigs/shared/request_call.go new file mode 100644 index 0000000..95d5756 --- /dev/null +++ b/internal/configs/serverconfigs/shared/request_call.go @@ -0,0 +1,41 @@ +package shared + +import ( + "github.com/iwind/TeaGo/maps" + "net/http" +) + +// 请求调用 +type RequestCall struct { + Formatter func(source string) string + Request *http.Request + ResponseCallbacks []func(resp http.ResponseWriter) + Options maps.Map +} + +// 获取新对象 +func NewRequestCall() *RequestCall { + return &RequestCall{ + Options: maps.Map{}, + } +} + +// 重置 +func (this *RequestCall) Reset() { + this.Formatter = nil + this.Request = nil + this.ResponseCallbacks = nil + this.Options = maps.Map{} +} + +// 添加响应回调 +func (this *RequestCall) AddResponseCall(callback func(resp http.ResponseWriter)) { + this.ResponseCallbacks = append(this.ResponseCallbacks, callback) +} + +// 执行响应回调 +func (this *RequestCall) CallResponseCallbacks(resp http.ResponseWriter) { + for _, callback := range this.ResponseCallbacks { + callback(resp) + } +} diff --git a/internal/configs/serverconfigs/shared/request_cond.go b/internal/configs/serverconfigs/shared/request_cond.go new file mode 100644 index 0000000..11f9bd2 --- /dev/null +++ b/internal/configs/serverconfigs/shared/request_cond.go @@ -0,0 +1,372 @@ +package shared + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "errors" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/lists" + "github.com/iwind/TeaGo/types" + "github.com/iwind/TeaGo/utils/string" + "net" + "os" + "path/filepath" + "regexp" + "strings" +) + +// 重写条件定义 +type RequestCond struct { + Id string `yaml:"id" json:"id"` // ID + + // 要测试的字符串 + // 其中可以使用跟请求相关的参数,比如: + // ${arg.name}, ${requestPath} + Param string `yaml:"param" json:"param"` + + // 运算符 + Operator RequestCondOperator `yaml:"operator" json:"operator"` + + // 对比 + Value string `yaml:"value" json:"value"` + + isInt bool + isFloat bool + isIP bool + + regValue *regexp.Regexp + floatValue float64 + ipValue net.IP + arrayValue []string +} + +// 取得新对象 +func NewRequestCond() *RequestCond { + return &RequestCond{ + Id: stringutil.Rand(16), + } +} + +// 校验配置 +func (this *RequestCond) Validate() error { + this.isInt = RegexpDigitNumber.MatchString(this.Value) + this.isFloat = RegexpFloatNumber.MatchString(this.Value) + + if lists.ContainsString([]string{ + RequestCondOperatorRegexp, + RequestCondOperatorNotRegexp, + }, this.Operator) { + reg, err := regexp.Compile(this.Value) + if err != nil { + return err + } + this.regValue = reg + } else if lists.ContainsString([]string{ + RequestCondOperatorEqFloat, + RequestCondOperatorGtFloat, + RequestCondOperatorGteFloat, + RequestCondOperatorLtFloat, + RequestCondOperatorLteFloat, + }, this.Operator) { + this.floatValue = types.Float64(this.Value) + } else if lists.ContainsString([]string{ + RequestCondOperatorEqIP, + RequestCondOperatorGtIP, + RequestCondOperatorGteIP, + RequestCondOperatorLtIP, + RequestCondOperatorLteIP, + }, this.Operator) { + this.ipValue = net.ParseIP(this.Value) + this.isIP = this.ipValue != nil + + if !this.isIP { + return errors.New("value should be a valid ip") + } + } else if lists.ContainsString([]string{ + RequestCondOperatorIPRange, + }, this.Operator) { + if strings.Contains(this.Value, ",") { + ipList := strings.SplitN(this.Value, ",", 2) + ipString1 := strings.TrimSpace(ipList[0]) + ipString2 := strings.TrimSpace(ipList[1]) + + if len(ipString1) > 0 { + ip1 := net.ParseIP(ipString1) + if ip1 == nil { + return errors.New("start ip is invalid") + } + } + + if len(ipString2) > 0 { + ip2 := net.ParseIP(ipString2) + if ip2 == nil { + return errors.New("end ip is invalid") + } + } + } else if strings.Contains(this.Value, "/") { + _, _, err := net.ParseCIDR(this.Value) + if err != nil { + return err + } + } else { + return errors.New("invalid ip range") + } + } else if lists.ContainsString([]string{ + RequestCondOperatorIn, + RequestCondOperatorNotIn, + RequestCondOperatorFileExt, + }, this.Operator) { + stringsValue := []string{} + err := json.Unmarshal([]byte(this.Value), &stringsValue) + if err != nil { + return err + } + this.arrayValue = stringsValue + } else if lists.ContainsString([]string{ + RequestCondOperatorFileMimeType, + }, this.Operator) { + stringsValue := []string{} + err := json.Unmarshal([]byte(this.Value), &stringsValue) + if err != nil { + return err + } + for k, v := range stringsValue { + if strings.Contains(v, "*") { + v = regexp.QuoteMeta(v) + v = strings.Replace(v, `\*`, ".*", -1) + stringsValue[k] = v + } + } + this.arrayValue = stringsValue + } + return nil +} + +// 将此条件应用于请求,检查是否匹配 +func (this *RequestCond) Match(formatter func(source string) string) bool { + paramValue := formatter(this.Param) + switch this.Operator { + case RequestCondOperatorRegexp: + if this.regValue == nil { + return false + } + return this.regValue.MatchString(paramValue) + case RequestCondOperatorNotRegexp: + if this.regValue == nil { + return false + } + return !this.regValue.MatchString(paramValue) + case RequestCondOperatorEqInt: + return this.isInt && paramValue == this.Value + case RequestCondOperatorEqFloat: + return this.isFloat && types.Float64(paramValue) == this.floatValue + case RequestCondOperatorGtFloat: + return this.isFloat && types.Float64(paramValue) > this.floatValue + case RequestCondOperatorGteFloat: + return this.isFloat && types.Float64(paramValue) >= this.floatValue + case RequestCondOperatorLtFloat: + return this.isFloat && types.Float64(paramValue) < this.floatValue + case RequestCondOperatorLteFloat: + return this.isFloat && types.Float64(paramValue) <= this.floatValue + case RequestCondOperatorMod: + pieces := strings.SplitN(this.Value, ",", 2) + if len(pieces) == 1 { + rem := types.Int64(pieces[0]) + return types.Int64(paramValue)%10 == rem + } + div := types.Int64(pieces[0]) + if div == 0 { + return false + } + rem := types.Int64(pieces[1]) + return types.Int64(paramValue)%div == rem + case RequestCondOperatorMod10: + return types.Int64(paramValue)%10 == types.Int64(this.Value) + case RequestCondOperatorMod100: + return types.Int64(paramValue)%100 == types.Int64(this.Value) + case RequestCondOperatorEqString: + return paramValue == this.Value + case RequestCondOperatorNeqString: + return paramValue != this.Value + case RequestCondOperatorHasPrefix: + return strings.HasPrefix(paramValue, this.Value) + case RequestCondOperatorHasSuffix: + return strings.HasSuffix(paramValue, this.Value) + case RequestCondOperatorContainsString: + return strings.Contains(paramValue, this.Value) + case RequestCondOperatorNotContainsString: + return !strings.Contains(paramValue, this.Value) + case RequestCondOperatorEqIP: + ip := net.ParseIP(paramValue) + if ip == nil { + return false + } + return this.isIP && bytes.Compare(this.ipValue, ip) == 0 + case RequestCondOperatorGtIP: + ip := net.ParseIP(paramValue) + if ip == nil { + return false + } + return this.isIP && bytes.Compare(ip, this.ipValue) > 0 + case RequestCondOperatorGteIP: + ip := net.ParseIP(paramValue) + if ip == nil { + return false + } + return this.isIP && bytes.Compare(ip, this.ipValue) >= 0 + case RequestCondOperatorLtIP: + ip := net.ParseIP(paramValue) + if ip == nil { + return false + } + return this.isIP && bytes.Compare(ip, this.ipValue) < 0 + case RequestCondOperatorLteIP: + ip := net.ParseIP(paramValue) + if ip == nil { + return false + } + return this.isIP && bytes.Compare(ip, this.ipValue) <= 0 + case RequestCondOperatorIPRange: + ip := net.ParseIP(paramValue) + if ip == nil { + return false + } + + // 检查IP范围格式 + if strings.Contains(this.Value, ",") { + ipList := strings.SplitN(this.Value, ",", 2) + ipString1 := strings.TrimSpace(ipList[0]) + ipString2 := strings.TrimSpace(ipList[1]) + + if len(ipString1) > 0 { + ip1 := net.ParseIP(ipString1) + if ip1 == nil { + return false + } + + if bytes.Compare(ip, ip1) < 0 { + return false + } + } + + if len(ipString2) > 0 { + ip2 := net.ParseIP(ipString2) + if ip2 == nil { + return false + } + + if bytes.Compare(ip, ip2) > 0 { + return false + } + } + + return true + } else if strings.Contains(this.Value, "/") { + _, ipNet, err := net.ParseCIDR(this.Value) + if err != nil { + return false + } + return ipNet.Contains(ip) + } else { + return false + } + case RequestCondOperatorIn: + return lists.ContainsString(this.arrayValue, paramValue) + case RequestCondOperatorNotIn: + return !lists.ContainsString(this.arrayValue, paramValue) + case RequestCondOperatorFileExt: + ext := filepath.Ext(paramValue) + if len(ext) > 0 { + ext = ext[1:] // remove dot + } + return lists.ContainsString(this.arrayValue, strings.ToLower(ext)) + case RequestCondOperatorFileMimeType: + index := strings.Index(paramValue, ";") + if index >= 0 { + paramValue = strings.TrimSpace(paramValue[:index]) + } + if len(this.arrayValue) == 0 { + return false + } + for _, v := range this.arrayValue { + if strings.Contains(v, "*") { + reg, err := stringutil.RegexpCompile("^" + v + "$") + if err == nil && reg.MatchString(paramValue) { + return true + } + } else if paramValue == v { + return true + } + } + case RequestCondOperatorVersionRange: + if strings.Contains(this.Value, ",") { + versions := strings.SplitN(this.Value, ",", 2) + version1 := strings.TrimSpace(versions[0]) + version2 := strings.TrimSpace(versions[1]) + if len(version1) > 0 && stringutil.VersionCompare(paramValue, version1) < 0 { + return false + } + if len(version2) > 0 && stringutil.VersionCompare(paramValue, version2) > 0 { + return false + } + return true + } else { + return stringutil.VersionCompare(paramValue, this.Value) >= 0 + } + case RequestCondOperatorIPMod: + pieces := strings.SplitN(this.Value, ",", 2) + if len(pieces) == 1 { + rem := types.Int64(pieces[0]) + return this.ipToInt64(net.ParseIP(paramValue))%10 == rem + } + div := types.Int64(pieces[0]) + if div == 0 { + return false + } + rem := types.Int64(pieces[1]) + return this.ipToInt64(net.ParseIP(paramValue))%div == rem + case RequestCondOperatorIPMod10: + return this.ipToInt64(net.ParseIP(paramValue))%10 == types.Int64(this.Value) + case RequestCondOperatorIPMod100: + return this.ipToInt64(net.ParseIP(paramValue))%100 == types.Int64(this.Value) + case RequestCondOperatorFileExist: + index := strings.Index(paramValue, "?") + if index > -1 { + paramValue = paramValue[:index] + } + if len(paramValue) == 0 { + return false + } + if !filepath.IsAbs(paramValue) { + paramValue = Tea.Root + Tea.DS + paramValue + } + stat, err := os.Stat(paramValue) + return err == nil && !stat.IsDir() + case RequestCondOperatorFileNotExist: + index := strings.Index(paramValue, "?") + if index > -1 { + paramValue = paramValue[:index] + } + if len(paramValue) == 0 { + return true + } + if !filepath.IsAbs(paramValue) { + paramValue = Tea.Root + Tea.DS + paramValue + } + stat, err := os.Stat(paramValue) + return err != nil || stat.IsDir() + } + + return false +} + +func (this *RequestCond) ipToInt64(ip net.IP) int64 { + if len(ip) == 0 { + return 0 + } + if len(ip) == 16 { + return int64(binary.BigEndian.Uint32(ip[12:16])) + } + return int64(binary.BigEndian.Uint32(ip)) +} diff --git a/internal/configs/serverconfigs/shared/request_cond_test.go b/internal/configs/serverconfigs/shared/request_cond_test.go new file mode 100644 index 0000000..b62ecbf --- /dev/null +++ b/internal/configs/serverconfigs/shared/request_cond_test.go @@ -0,0 +1,1000 @@ +package shared + +import ( + "bytes" + "fmt" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/assert" + "net" + "regexp" + "testing" +) + +func TestRequestCond_Compare1(t *testing.T) { + a := assert.NewAssertion(t) + + { + cond := RequestCond{ + Param: "/hello", + Operator: RequestCondOperatorRegexp, + Value: "abc", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(format string) string { + return format + })) + } + + { + cond := RequestCond{ + Param: "/hello", + Operator: RequestCondOperatorRegexp, + Value: "/\\w+", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(format string) string { + return format + })) + } + + { + cond := RequestCond{ + Param: "/article/123.html", + Operator: RequestCondOperatorRegexp, + Value: `^/article/\d+\.html$`, + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(format string) string { + return format + })) + } + + { + cond := RequestCond{ + Param: "/hello", + Operator: RequestCondOperatorRegexp, + Value: "[", + } + a.IsNotNil(cond.Validate()) + a.IsFalse(cond.Match(func(format string) string { + return format + })) + } + + { + cond := RequestCond{ + Param: "/hello", + Operator: RequestCondOperatorNotRegexp, + Value: "abc", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(format string) string { + return format + })) + } + + { + cond := RequestCond{ + Param: "/hello", + Operator: RequestCondOperatorNotRegexp, + Value: "/\\w+", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(format string) string { + return format + })) + } + + { + cond := RequestCond{ + Param: "123.123", + Operator: RequestCondOperatorEqInt, + Value: "123", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "123", + Operator: RequestCondOperatorEqInt, + Value: "123", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "abc", + Operator: RequestCondOperatorEqInt, + Value: "abc", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "123", + Operator: RequestCondOperatorEqFloat, + Value: "123", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "123.0", + Operator: RequestCondOperatorEqFloat, + Value: "123", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "123.123", + Operator: RequestCondOperatorEqFloat, + Value: "123.12", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "123", + Operator: RequestCondOperatorGtFloat, + Value: "1", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "123", + Operator: RequestCondOperatorGtFloat, + Value: "125", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "125", + Operator: RequestCondOperatorGteFloat, + Value: "125", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "125", + Operator: RequestCondOperatorLtFloat, + Value: "127", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "125", + Operator: RequestCondOperatorLteFloat, + Value: "127", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "125", + Operator: RequestCondOperatorEqString, + Value: "125", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "125", + Operator: RequestCondOperatorNeqString, + Value: "125", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "125", + Operator: RequestCondOperatorNeqString, + Value: "127", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "/hello/world", + Operator: RequestCondOperatorHasPrefix, + Value: "/hello", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "/hello/world", + Operator: RequestCondOperatorHasPrefix, + Value: "/hello2", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "/hello/world", + Operator: RequestCondOperatorHasSuffix, + Value: "world", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "/hello/world", + Operator: RequestCondOperatorHasSuffix, + Value: "world/", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "/hello/world", + Operator: RequestCondOperatorContainsString, + Value: "wo", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "/hello/world", + Operator: RequestCondOperatorContainsString, + Value: "wr", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "/hello/world", + Operator: RequestCondOperatorNotContainsString, + Value: "HELLO", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "/hello/world", + Operator: RequestCondOperatorNotContainsString, + Value: "hello", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } +} + +func TestRequestCond_IP(t *testing.T) { + a := assert.NewAssertion(t) + + { + cond := RequestCond{ + Param: "hello", + Operator: RequestCondOperatorEqIP, + Value: "hello", + } + a.IsNotNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.1.100", + Operator: RequestCondOperatorEqIP, + Value: "hello", + } + a.IsNotNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.1.100", + Operator: RequestCondOperatorEqIP, + Value: "192.168.1.100", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.1.100", + Operator: RequestCondOperatorGtIP, + Value: "192.168.1.90", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.1.100", + Operator: RequestCondOperatorGteIP, + Value: "192.168.1.90", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.1.80", + Operator: RequestCondOperatorLtIP, + Value: "192.168.1.90", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.0.100", + Operator: RequestCondOperatorLteIP, + Value: "192.168.1.90", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.0.100", + Operator: RequestCondOperatorIPRange, + Value: "192.168.0.90,", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.0.100", + Operator: RequestCondOperatorIPRange, + Value: "192.168.0.90,192.168.1.100", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.0.100", + Operator: RequestCondOperatorIPRange, + Value: ",192.168.1.100", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.1.100", + Operator: RequestCondOperatorIPRange, + Value: "192.168.0.90,192.168.1.99", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.1.100", + Operator: RequestCondOperatorIPRange, + Value: "192.168.0.90/24", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.1.100", + Operator: RequestCondOperatorIPRange, + Value: "192.168.0.90/18", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.1.100", + Operator: RequestCondOperatorIPRange, + Value: "a/18", + } + a.IsNotNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.1.100", + Operator: RequestCondOperatorIPMod10, + Value: "6", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.1.100", + Operator: RequestCondOperatorIPMod100, + Value: "76", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "192.168.1.100", + Operator: RequestCondOperatorIPMod, + Value: "10,6", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } +} + +func TestRequestCondIPCompare(t *testing.T) { + { + ip1 := net.ParseIP("192.168.3.100") + ip2 := net.ParseIP("192.168.2.100") + t.Log(bytes.Compare(ip1, ip2)) + } + + { + ip1 := net.ParseIP("192.168.3.100") + ip2 := net.ParseIP("a") + t.Log(bytes.Compare(ip1, ip2)) + } + + { + ip1 := net.ParseIP("b") + ip2 := net.ParseIP("192.168.2.100") + t.Log(bytes.Compare(ip1, ip2)) + } + + { + ip1 := net.ParseIP("b") + ip2 := net.ParseIP("a") + t.Log(ip1 == nil) + t.Log(bytes.Compare(ip1, ip2)) + } + + { + cond := RequestCond{} + t.Log(cond.ipToInt64(net.ParseIP("192.168.1.100"))) + t.Log(cond.ipToInt64(net.ParseIP("192.168.1.99"))) + t.Log(cond.ipToInt64(net.ParseIP("0.0.0.0"))) + t.Log(cond.ipToInt64(net.ParseIP("127.0.0.1"))) + t.Log(cond.ipToInt64(net.ParseIP("abc"))) + t.Log(cond.ipToInt64(net.ParseIP("192.168"))) + t.Log(cond.ipToInt64(net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329"))) + } +} + +func TestRequestCond_In(t *testing.T) { + a := assert.NewAssertion(t) + + { + cond := RequestCond{ + Param: "a", + Operator: RequestCondOperatorIn, + Value: `a`, + } + a.IsNotNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "a", + Operator: RequestCondOperatorIn, + Value: `["a", "b"]`, + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "c", + Operator: RequestCondOperatorNotIn, + Value: `["a", "b"]`, + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "a", + Operator: RequestCondOperatorNotIn, + Value: `["a", "b"]`, + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } +} + +func TestRequestCond_File(t *testing.T) { + a := assert.NewAssertion(t) + + { + cond := RequestCond{ + Param: "a", + Operator: RequestCondOperatorFileExt, + Value: `["jpeg", "jpg", "png"]`, + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "a.gif", + Operator: RequestCondOperatorFileExt, + Value: `["jpeg", "jpg", "png"]`, + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "a.png", + Operator: RequestCondOperatorFileExt, + Value: `["jpeg", "jpg", "png"]`, + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "a.png", + Operator: RequestCondOperatorFileExist, + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: Tea.Root + "/README.md", + Operator: RequestCondOperatorFileExist, + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: Tea.Root + "/README.md?v=1", + Operator: RequestCondOperatorFileExist, + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: Tea.Root, + Operator: RequestCondOperatorFileExist, + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: Tea.Root, + Operator: RequestCondOperatorFileExist, + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "a.png", + Operator: RequestCondOperatorFileNotExist, + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: Tea.Root + "/README.md", + Operator: RequestCondOperatorFileNotExist, + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } +} + +func TestRequestCond_MimeType(t *testing.T) { + a := assert.NewAssertion(t) + + { + cond := RequestCond{ + Param: "text/html; charset=utf-8", + Operator: RequestCondOperatorFileMimeType, + Value: `["text/html"]`, + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "text/html; charset=utf-8", + Operator: RequestCondOperatorFileMimeType, + Value: `["text/*"]`, + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "text/html; charset=utf-8", + Operator: RequestCondOperatorFileMimeType, + Value: `["image/*"]`, + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "text/plain; charset=utf-8", + Operator: RequestCondOperatorFileMimeType, + Value: `["text/html", "image/jpeg", "image/png"]`, + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } +} + +func TestRequestCond_Version(t *testing.T) { + a := assert.NewAssertion(t) + + { + cond := RequestCond{ + Param: "1.0", + Operator: RequestCondOperatorVersionRange, + Value: `1.0,1.1`, + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "1.0", + Operator: RequestCondOperatorVersionRange, + Value: `1.0,`, + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "1.0", + Operator: RequestCondOperatorVersionRange, + Value: `,1.1`, + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "0.9", + Operator: RequestCondOperatorVersionRange, + Value: `1.0,1.1`, + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "0.9", + Operator: RequestCondOperatorVersionRange, + Value: `1.0`, + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "1.1", + Operator: RequestCondOperatorVersionRange, + Value: `1.0`, + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } +} + +func TestRequestCond_RegexpQuote(t *testing.T) { + t.Log(regexp.QuoteMeta("a")) + t.Log(regexp.QuoteMeta("*")) + t.Log(regexp.QuoteMeta("([\\d]).*")) +} + +func TestRequestCond_Mod(t *testing.T) { + a := assert.NewAssertion(t) + + { + cond := RequestCond{ + Param: "1", + Operator: RequestCondOperatorMod, + Value: "1", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "1", + Operator: RequestCondOperatorMod, + Value: "2", + } + a.IsNil(cond.Validate()) + a.IsFalse(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "3", + Operator: RequestCondOperatorMod, + Value: "3", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "1", + Operator: RequestCondOperatorMod, + Value: "11,1", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "3", + Operator: RequestCondOperatorMod, + Value: "11,3", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + { + cond := RequestCond{ + Param: "4", + Operator: RequestCondOperatorMod, + Value: "2,0", + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + for i := 0; i < 100; i++ { + cond := RequestCond{ + Param: fmt.Sprintf("%d", i), + Operator: RequestCondOperatorMod10, + Value: fmt.Sprintf("%d", i%10), + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } + + for i := 0; i < 2000; i++ { + cond := RequestCond{ + Param: fmt.Sprintf("%d", i), + Operator: RequestCondOperatorMod100, + Value: fmt.Sprintf("%d", i%100), + } + a.IsNil(cond.Validate()) + a.IsTrue(cond.Match(func(source string) string { + return source + })) + } +} diff --git a/internal/configs/serverconfigs/shared/request_operators.go b/internal/configs/serverconfigs/shared/request_operators.go new file mode 100644 index 0000000..ce88cea --- /dev/null +++ b/internal/configs/serverconfigs/shared/request_operators.go @@ -0,0 +1,226 @@ +package shared + +import "github.com/iwind/TeaGo/maps" + +// 运算符定义 +type RequestCondOperator = string + +const ( + // 正则 + RequestCondOperatorRegexp RequestCondOperator = "regexp" + RequestCondOperatorNotRegexp RequestCondOperator = "not regexp" + + // 数字相关 + RequestCondOperatorEqInt RequestCondOperator = "eq int" // 整数等于 + RequestCondOperatorEqFloat RequestCondOperator = "eq float" // 浮点数等于 + RequestCondOperatorGtFloat RequestCondOperator = "gt" + RequestCondOperatorGteFloat RequestCondOperator = "gte" + RequestCondOperatorLtFloat RequestCondOperator = "lt" + RequestCondOperatorLteFloat RequestCondOperator = "lte" + + // 取模 + RequestCondOperatorMod10 RequestCondOperator = "mod 10" + RequestCondOperatorMod100 RequestCondOperator = "mod 100" + RequestCondOperatorMod RequestCondOperator = "mod" + + // 字符串相关 + RequestCondOperatorEqString RequestCondOperator = "eq" + RequestCondOperatorNeqString RequestCondOperator = "not" + RequestCondOperatorHasPrefix RequestCondOperator = "prefix" + RequestCondOperatorHasSuffix RequestCondOperator = "suffix" + RequestCondOperatorContainsString RequestCondOperator = "contains" + RequestCondOperatorNotContainsString RequestCondOperator = "not contains" + RequestCondOperatorIn RequestCondOperator = "in" + RequestCondOperatorNotIn RequestCondOperator = "not in" + RequestCondOperatorFileExt RequestCondOperator = "file ext" + RequestCondOperatorFileMimeType RequestCondOperator = "mime type" + RequestCondOperatorVersionRange RequestCondOperator = "version range" + + // IP相关 + RequestCondOperatorEqIP RequestCondOperator = "eq ip" + RequestCondOperatorGtIP RequestCondOperator = "gt ip" + RequestCondOperatorGteIP RequestCondOperator = "gte ip" + RequestCondOperatorLtIP RequestCondOperator = "lt ip" + RequestCondOperatorLteIP RequestCondOperator = "lte ip" + RequestCondOperatorIPRange RequestCondOperator = "ip range" + RequestCondOperatorIPMod10 RequestCondOperator = "ip mod 10" + RequestCondOperatorIPMod100 RequestCondOperator = "ip mod 100" + RequestCondOperatorIPMod RequestCondOperator = "ip mod" + + // 文件相关 + RequestCondOperatorFileExist RequestCondOperator = "file exist" + RequestCondOperatorFileNotExist RequestCondOperator = "file not exist" +) + +// 所有的运算符 +func AllRequestOperators() []maps.Map { + return []maps.Map{ + { + "name": "正则表达式匹配", + "op": RequestCondOperatorRegexp, + "description": "判断是否正则表达式匹配", + }, + { + "name": "正则表达式不匹配", + "op": RequestCondOperatorNotRegexp, + "description": "判断是否正则表达式不匹配", + }, + { + "name": "字符串等于", + "op": RequestCondOperatorEqString, + "description": "使用字符串对比参数值是否相等于某个值", + }, + { + "name": "字符串前缀", + "op": RequestCondOperatorHasPrefix, + "description": "参数值包含某个前缀", + }, + { + "name": "字符串后缀", + "op": RequestCondOperatorHasSuffix, + "description": "参数值包含某个后缀", + }, + { + "name": "字符串包含", + "op": RequestCondOperatorContainsString, + "description": "参数值包含另外一个字符串", + }, + { + "name": "字符串不包含", + "op": RequestCondOperatorNotContainsString, + "description": "参数值不包含另外一个字符串", + }, + { + "name": "字符串不等于", + "op": RequestCondOperatorNeqString, + "description": "使用字符串对比参数值是否不相等于某个值", + }, + { + "name": "在列表中", + "op": RequestCondOperatorIn, + "description": "判断参数值在某个列表中", + }, + { + "name": "不在列表中", + "op": RequestCondOperatorNotIn, + "description": "判断参数值不在某个列表中", + }, + { + "name": "扩展名", + "op": RequestCondOperatorFileExt, + "description": "判断小写的扩展名(不带点)在某个列表中", + }, + { + "name": "MimeType", + "op": RequestCondOperatorFileMimeType, + "description": "判断MimeType在某个列表中,支持类似于image/*的语法", + }, + { + "name": "版本号范围", + "op": RequestCondOperatorVersionRange, + "description": "判断版本号在某个范围内,格式为version1,version2", + }, + { + "name": "整数等于", + "op": RequestCondOperatorEqInt, + "description": "将参数转换为整数数字后进行对比", + }, + { + "name": "浮点数等于", + "op": RequestCondOperatorEqFloat, + "description": "将参数转换为可以有小数的浮点数字进行对比", + }, + { + "name": "数字大于", + "op": RequestCondOperatorGtFloat, + "description": "将参数转换为数字进行对比", + }, + { + "name": "数字大于等于", + "op": RequestCondOperatorGteFloat, + "description": "将参数转换为数字进行对比", + }, + { + "name": "数字小于", + "op": RequestCondOperatorLtFloat, + "description": "将参数转换为数字进行对比", + }, + { + "name": "数字小于等于", + "op": RequestCondOperatorLteFloat, + "description": "将参数转换为数字进行对比", + }, + { + "name": "整数取模10", + "op": RequestCondOperatorMod10, + "description": "对整数参数值取模,除数为10,对比值为余数", + }, + { + "name": "整数取模100", + "op": RequestCondOperatorMod100, + "description": "对整数参数值取模,除数为100,对比值为余数", + }, + { + "name": "整数取模", + "op": RequestCondOperatorMod, + "description": "对整数参数值取模,对比值格式为:除数,余数,比如10,1", + }, + { + "name": "IP等于", + "op": RequestCondOperatorEqIP, + "description": "将参数转换为IP进行对比", + }, + { + "name": "IP大于", + "op": RequestCondOperatorGtIP, + "description": "将参数转换为IP进行对比", + }, + { + "name": "IP大于等于", + "op": RequestCondOperatorGteIP, + "description": "将参数转换为IP进行对比", + }, + { + "name": "IP小于", + "op": RequestCondOperatorLtIP, + "description": "将参数转换为IP进行对比", + }, + { + "name": "IP小于等于", + "op": RequestCondOperatorLteIP, + "description": "将参数转换为IP进行对比", + }, + { + "name": "IP范围", + "op": RequestCondOperatorIPRange, + "description": "IP在某个范围之内,范围格式可以是英文逗号分隔的ip1,ip2,或者CIDR格式的ip/bits", + }, + { + "name": "IP取模10", + "op": RequestCondOperatorIPMod10, + "description": "对IP参数值取模,除数为10,对比值为余数", + }, + { + "name": "IP取模100", + "op": RequestCondOperatorIPMod100, + "description": "对IP参数值取模,除数为100,对比值为余数", + }, + { + "name": "IP取模", + "op": RequestCondOperatorIPMod, + "description": "对IP参数值取模,对比值格式为:除数,余数,比如10,1", + }, + + { + "name": "文件存在", + "op": RequestCondOperatorFileExist, + "description": "判断参数值解析后的文件是否存在", + }, + + { + "name": "文件不存在", + "op": RequestCondOperatorFileNotExist, + "description": "判断参数值解析后的文件是否不存在", + }, + } +} diff --git a/internal/configs/serverconfigs/shared/size_capacity.go b/internal/configs/serverconfigs/shared/size_capacity.go new file mode 100644 index 0000000..be2bc02 --- /dev/null +++ b/internal/configs/serverconfigs/shared/size_capacity.go @@ -0,0 +1,30 @@ +package shared + +type SizeCapacityUnit = string + +const ( + SizeCapacityUnitByte SizeCapacityUnit = "byte" + SizeCapacityUnitKB SizeCapacityUnit = "kb" + SizeCapacityUnitMB SizeCapacityUnit = "mb" + SizeCapacityUnitGB SizeCapacityUnit = "gb" +) + +type SizeCapacity struct { + Count int64 `json:"count" yaml:"count"` + Unit SizeCapacityUnit `json:"unit" yaml:"unit"` +} + +func (this *SizeCapacity) Bytes() int64 { + switch this.Unit { + case SizeCapacityUnitByte: + return this.Count + case SizeCapacityUnitKB: + return this.Count * 1024 + case SizeCapacityUnitMB: + return this.Count * 1024 * 1024 + case SizeCapacityUnitGB: + return this.Count * 1024 * 1024 * 1024 + default: + return this.Count + } +} diff --git a/internal/configs/serverconfigs/shared/time_duration.go b/internal/configs/serverconfigs/shared/time_duration.go new file mode 100644 index 0000000..af866d1 --- /dev/null +++ b/internal/configs/serverconfigs/shared/time_duration.go @@ -0,0 +1,36 @@ +package shared + +import "time" + +type TimeDurationUnit = string + +const ( + TimeDurationUnitMS TimeDurationUnit = "ms" + TimeDurationUnitSecond TimeDurationUnit = "second" + TimeDurationUnitMinute TimeDurationUnit = "minute" + TimeDurationUnitHour TimeDurationUnit = "hour" + TimeDurationUnitDay TimeDurationUnit = "day" +) + +// 时间间隔 +type TimeDuration struct { + Count int64 `yaml:"count" json:"count"` // 数量 + Unit TimeDurationUnit `yaml:"unit" json:"unit"` // 单位 +} + +func (this *TimeDuration) Duration() time.Duration { + switch this.Unit { + case TimeDurationUnitMS: + return time.Duration(this.Count) * time.Millisecond + case TimeDurationUnitSecond: + return time.Duration(this.Count) * time.Second + case TimeDurationUnitMinute: + return time.Duration(this.Count) * time.Minute + case TimeDurationUnitHour: + return time.Duration(this.Count) * time.Hour + case TimeDurationUnitDay: + return time.Duration(this.Count) * 24 * time.Hour + default: + return time.Duration(this.Count) * time.Second + } +} diff --git a/internal/configs/serverconfigs/sslconfigs/ssl.go b/internal/configs/serverconfigs/sslconfigs/ssl.go index e953ce1..6f607ef 100644 --- a/internal/configs/serverconfigs/sslconfigs/ssl.go +++ b/internal/configs/serverconfigs/sslconfigs/ssl.go @@ -97,7 +97,7 @@ func (this *SSLConfig) Init() error { if cert == nil { continue } - if !cert.On { + if !cert.IsOn { continue } data, err := ioutil.ReadFile(cert.FullCertPath()) diff --git a/internal/configs/serverconfigs/sslconfigs/ssl_cert.go b/internal/configs/serverconfigs/sslconfigs/ssl_cert.go index 90ea334..089b58a 100644 --- a/internal/configs/serverconfigs/sslconfigs/ssl_cert.go +++ b/internal/configs/serverconfigs/sslconfigs/ssl_cert.go @@ -19,7 +19,7 @@ import ( // SSL证书 type SSLCertConfig struct { Id string `yaml:"id" json:"id"` - On bool `yaml:"on" json:"on"` + IsOn bool `yaml:"isOn" json:"isOn"` Description string `yaml:"description" json:"description"` // 说明 CertFile string `yaml:"certFile" json:"certFile"` KeyFile string `yaml:"keyFile" json:"keyFile"` @@ -39,7 +39,7 @@ type SSLCertConfig struct { // 获取新的SSL证书 func NewSSLCertConfig(certFile string, keyFile string) *SSLCertConfig { return &SSLCertConfig{ - On: true, + IsOn: true, Id: stringutil.Rand(16), CertFile: certFile, KeyFile: keyFile, diff --git a/internal/configs/serverconfigs/sslconfigs/ssl_hsts.go b/internal/configs/serverconfigs/sslconfigs/ssl_hsts.go index d54303a..98e1f7b 100644 --- a/internal/configs/serverconfigs/sslconfigs/ssl_hsts.go +++ b/internal/configs/serverconfigs/sslconfigs/ssl_hsts.go @@ -9,7 +9,7 @@ import ( // HSTS设置 // 参考: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Strict-Transport-Security type HSTSConfig struct { - On bool `yaml:"on" json:"on"` + IsOn bool `yaml:"isOn" json:"isOn"` MaxAge int `yaml:"maxAge" json:"maxAge"` // 单位秒 IncludeSubDomains bool `yaml:"includeSubDomains" json:"includeSubDomains"` Preload bool `yaml:"preload" json:"preload"` diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index 7f9984b..b0c561c 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -59,8 +59,10 @@ func (this *HTTPListener) Serve() error { } func (this *HTTPListener) Close() error { - // TODO - return nil + if this.httpServer != nil { + _ = this.httpServer.Close() + } + return this.Listener.Close() } func (this *HTTPListener) handleHTTP(writer http.ResponseWriter, req *http.Request) { diff --git a/internal/nodes/listener_tcp.go b/internal/nodes/listener_tcp.go index f0eb92b..6743dbb 100644 --- a/internal/nodes/listener_tcp.go +++ b/internal/nodes/listener_tcp.go @@ -1,7 +1,9 @@ package nodes import ( + "errors" "github.com/TeaOSLab/EdgeNode/internal/configs/serverconfigs" + "github.com/iwind/TeaGo/logs" "net" ) @@ -13,11 +15,99 @@ type TCPListener struct { } func (this *TCPListener) Serve() error { - // TODO + for { + conn, err := this.Listener.Accept() + if err != nil { + break + } + err = this.handleConn(conn) + if err != nil { + logs.Println("[TCP_LISTENER]" + err.Error()) + } + } + + return nil +} + +func (this *TCPListener) handleConn(conn net.Conn) error { + firstServer := this.Group.FirstServer() + if firstServer == nil { + return errors.New("no server available") + } + if firstServer.ReverseProxy == nil { + return errors.New("no ReverseProxy configured for the server") + } + originConn, err := this.connectOrigin(firstServer.ReverseProxy) + if err != nil { + return err + } + + var closer = func() { + _ = conn.Close() + _ = originConn.Close() + } + + go func() { + originBuffer := make([]byte, 32*1024) // TODO 需要可以设置,并可以使用Pool + for { + n, err := originConn.Read(originBuffer) + if n > 0 { + _, err = conn.Write(originBuffer[:n]) + if err != nil { + closer() + break + } + } + if err != nil { + closer() + break + } + } + }() + + clientBuffer := make([]byte, 32*1024) // TODO 需要可以设置,并可以使用Pool + for { + n, err := conn.Read(clientBuffer) + if n > 0 { + _, err = originConn.Write(clientBuffer[:n]) + if err != nil { + break + } + } + if err != nil { + break + } + } + + // 关闭连接 + closer() + return nil } func (this *TCPListener) Close() error { - // TODO - return nil + return this.Listener.Close() +} + +func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig) (conn net.Conn, err error) { + if reverseProxy == nil { + return nil, errors.New("no reverse proxy config") + } + + retries := 3 + for i := 0; i < retries; i++ { + origin := reverseProxy.NextOrigin(nil) + if origin == nil { + continue + } + conn, err = origin.Connect() + if err != nil { + logs.Println("[TCP_LISTENER]unable to connect origin: " + origin.Addr.Host + ":" + origin.Addr.PortRange + ": " + err.Error()) + continue + } else { + return + } + } + err = errors.New("no origin can be used") + return }