优化代码

This commit is contained in:
刘祥超
2023-08-08 15:12:28 +08:00
parent 3c5c961cd5
commit 6a2803187e
21 changed files with 61 additions and 78 deletions

View File

@@ -7,7 +7,7 @@ import (
var whitespaceReg = regexp.MustCompile(`\s+`) var whitespaceReg = regexp.MustCompile(`\s+`)
// 关键词匹配 // MatchKeyword 关键词匹配
func MatchKeyword(source, keyword string) bool { func MatchKeyword(source, keyword string) bool {
if len(keyword) == 0 { if len(keyword) == 0 {
return false return false
@@ -16,7 +16,7 @@ func MatchKeyword(source, keyword string) bool {
pieces := whitespaceReg.Split(keyword, -1) pieces := whitespaceReg.Split(keyword, -1)
source = strings.ToLower(source) source = strings.ToLower(source)
for _, piece := range pieces { for _, piece := range pieces {
if strings.Index(source, strings.ToLower(piece)) > -1 { if strings.Contains(source, strings.ToLower(piece)) {
return true return true
} }
} }

View File

@@ -46,7 +46,7 @@ func NewReader(reader io.Reader) (*Reader, error) {
// 从Reader中加载数据 // 从Reader中加载数据
func (this *Reader) load(reader io.Reader) error { func (this *Reader) load(reader io.Reader) error {
var buf = make([]byte, 1024) var buf = make([]byte, 1024)
var metaLine = []byte{} var metaLine []byte
var metaLineFound = false var metaLineFound = false
var dataBuf = []byte{} var dataBuf = []byte{}
for { for {

View File

@@ -13,7 +13,7 @@ import (
type FileReader struct { type FileReader struct {
rawReader *Reader rawReader *Reader
password string //password string
} }
func NewFileReader(path string, password string) (*FileReader, error) { func NewFileReader(path string, password string) (*FileReader, error) {

View File

@@ -9,15 +9,15 @@ import (
type UnicodeEncodeFilter struct { type UnicodeEncodeFilter struct {
} }
// 初始化 // Init 初始化
func (this *UnicodeEncodeFilter) Init() error { func (this *UnicodeEncodeFilter) Init() error {
return nil return nil
} }
// 执行过滤 // Do 执行过滤
func (this *UnicodeEncodeFilter) Do(input interface{}, options interface{}) (output interface{}, goNext bool, err error) { func (this *UnicodeEncodeFilter) Do(input interface{}, options interface{}) (output interface{}, goNext bool, err error) {
s := []rune(types.String(input)) var s = types.String(input)
result := strings.Builder{} var result = strings.Builder{}
for _, r := range s { for _, r := range s {
if r < 128 { if r < 128 {
result.WriteRune(r) result.WriteRune(r)

View File

@@ -304,7 +304,7 @@ var AllCheckpoints = []*HTTPFirewallCheckpointDefinition{
RightLabel: "秒", RightLabel: "秒",
MaxLength: 8, MaxLength: 8,
Validate: func(value string) (ok bool, message string) { Validate: func(value string) (ok bool, message string) {
if regexp.MustCompile("^\\d+$").MatchString(value) { if regexp.MustCompile(`^\d+$`).MatchString(value) {
ok = true ok = true
return return
} }

View File

@@ -71,29 +71,3 @@ func (this *HTTPAuthBaseMethod) MatchRequest(req *http.Request) bool {
return true return true
} }
// cleanPath 清理Path中的多余的字符
func (this *HTTPAuthBaseMethod) cleanPath(path string) string {
var l = len(path)
if l == 0 {
return "/"
}
var result = []byte{'/'}
var isSlash = true
for i := 0; i < l; i++ {
if path[i] == '?' {
result = append(result, path[i:]...)
break
}
if path[i] == '\\' || path[i] == '/' {
if !isSlash {
isSlash = true
result = append(result, '/')
}
} else {
isSlash = false
result = append(result, path[i])
}
}
return string(result)
}

View File

@@ -2,8 +2,10 @@ package serverconfigs
import ( import (
"errors" "errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/maps"
"net"
"path/filepath" "path/filepath"
"regexp" "regexp"
"time" "time"
@@ -60,11 +62,11 @@ func (this *HTTPFastcgiConfig) Init() error {
} }
// 校验地址 // 校验地址
if regexp.MustCompile("^\\d+$").MatchString(this.Address) { if regexp.MustCompile(`^\d+$`).MatchString(this.Address) {
this.network = "tcp" this.network = "tcp"
this.address = "127.0.0.1:" + this.Address this.address = "127.0.0.1:" + this.Address
} else if regexp.MustCompile("^(.*):(\\d+)$").MatchString(this.Address) { } else if regexp.MustCompile(`^(.*):(\d+)$`).MatchString(this.Address) {
matches := regexp.MustCompile("^(.*):(\\d+)$").FindStringSubmatch(this.Address) var matches = regexp.MustCompile(`^(.*):(\d+)$`).FindStringSubmatch(this.Address)
ip := matches[1] ip := matches[1]
port := matches[2] port := matches[2]
if len(ip) == 0 { if len(ip) == 0 {
@@ -72,9 +74,9 @@ func (this *HTTPFastcgiConfig) Init() error {
} }
this.network = "tcp" this.network = "tcp"
this.address = ip + ":" + port this.address = ip + ":" + port
} else if regexp.MustCompile("^\\d+\\.\\d+.\\d+.\\d+$").MatchString(this.Address) { } else if net.ParseIP(this.address) != nil {
this.network = "tcp" this.network = "tcp"
this.address = this.Address + ":9000" this.address = configutils.QuoteIP(this.Address) + ":9000"
} else if regexp.MustCompile("^unix:(.+)$").MatchString(this.Address) { } else if regexp.MustCompile("^unix:(.+)$").MatchString(this.Address) {
matches := regexp.MustCompile("^unix:(.+)$").FindStringSubmatch(this.Address) matches := regexp.MustCompile("^unix:(.+)$").FindStringSubmatch(this.Address)
path := matches[1] path := matches[1]

View File

@@ -301,13 +301,13 @@ func (this *HTTPLocationConfig) Match(path string, formatter func(source string)
if this.patternType == HTTPLocationPatternTypeExact { if this.patternType == HTTPLocationPatternTypeExact {
if this.reverse { if this.reverse {
if this.caseInsensitive { if this.caseInsensitive {
return nil, strings.ToLower(path) != strings.ToLower(this.path) return nil, !strings.EqualFold(path, this.path)
} else { } else {
return nil, path != this.path return nil, path != this.path
} }
} else { } else {
if this.caseInsensitive { if this.caseInsensitive {
return nil, strings.ToLower(path) == strings.ToLower(this.path) return nil, strings.EqualFold(path, this.path)
} else { } else {
return nil, path == this.path return nil, path == this.path
} }

View File

@@ -20,7 +20,7 @@ type MetricItemConfig struct {
Version int32 `yaml:"version" json:"version"` Version int32 `yaml:"version" json:"version"`
ExpiresPeriod int `yaml:"expiresPeriod" json:"expiresPeriod"` // 过期周期 ExpiresPeriod int `yaml:"expiresPeriod" json:"expiresPeriod"` // 过期周期
sumType string // 统计类型 //sumType string // 统计类型
baseTime time.Time // 基准时间 baseTime time.Time // 基准时间
hasHTTPConnectionValue bool // 是否有统计HTTP连接数的数值 hasHTTPConnectionValue bool // 是否有统计HTTP连接数的数值
} }

View File

@@ -66,8 +66,8 @@ type OriginConfig struct {
requestPath string requestPath string
requestArgs string requestArgs string
hasRequestHeaders bool //hasRequestHeaders bool
hasResponseHeaders bool //hasResponseHeaders bool
uniqueKey string uniqueKey string

View File

@@ -1,13 +1,16 @@
package serverconfigs package serverconfigs
import "testing" import (
"context"
"testing"
)
func TestOriginConfig_UniqueKey(t *testing.T) { func TestOriginConfig_UniqueKey(t *testing.T) {
origin := &OriginConfig{ origin := &OriginConfig{
Id: 1, Id: 1,
Version: 101, Version: 101,
} }
err := origin.Init(nil) err := origin.Init(context.TODO())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -53,9 +53,9 @@ type ReverseProxyConfig struct {
schedulingGroupMap map[string]*SchedulingGroup // domain => *SchedulingGroup schedulingGroupMap map[string]*SchedulingGroup // domain => *SchedulingGroup
schedulingLocker sync.RWMutex schedulingLocker sync.RWMutex
addXRealIPHeader bool addXRealIPHeader bool
addXForwardedForHeader bool addXForwardedForHeader bool
addForwardedHeader bool //addForwardedHeader bool
addXForwardedByHeader bool addXForwardedByHeader bool
addXForwardedHostHeader bool addXForwardedHostHeader bool
addXForwardedProtoHeader bool addXForwardedProtoHeader bool
@@ -131,12 +131,8 @@ func (this *ReverseProxyConfig) Init(ctx context.Context) error {
if domain == "" { if domain == "" {
continue continue
} }
for _, origin := range defaultGroup.PrimaryOrigins { group.PrimaryOrigins = append(group.PrimaryOrigins, defaultGroup.PrimaryOrigins...)
group.PrimaryOrigins = append(group.PrimaryOrigins, origin) group.BackupOrigins = append(group.BackupOrigins, defaultGroup.BackupOrigins...)
}
for _, origin := range defaultGroup.BackupOrigins {
group.BackupOrigins = append(group.BackupOrigins, origin)
}
} }
} }
} }

View File

@@ -3,6 +3,7 @@
package serverconfigs package serverconfigs
import ( import (
"context"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"testing" "testing"
) )
@@ -30,7 +31,7 @@ func TestReverseProxyConfig_Init(t *testing.T) {
Addr: &NetworkAddressConfig{Host: "127.0.0.4"}, Addr: &NetworkAddressConfig{Host: "127.0.0.4"},
IsOn: true, IsOn: true,
}) })
err := config.Init(nil) err := config.Init(context.TODO())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -52,7 +52,7 @@ func (this *RoundRobinScheduling) Next(call *shared.RequestCall) CandidateInterf
if this.index > this.count-1 { if this.index > this.count-1 {
this.index = 0 this.index = 0
} }
weight := this.currentWeights[this.index] var weight = this.currentWeights[this.index]
// 已经一轮了,则重置状态 // 已经一轮了,则重置状态
if weight == 0 { if weight == 0 {
@@ -60,11 +60,10 @@ func (this *RoundRobinScheduling) Next(call *shared.RequestCall) CandidateInterf
this.currentWeights = append([]uint{}, this.rawWeights...) this.currentWeights = append([]uint{}, this.rawWeights...)
} }
this.index = 0 this.index = 0
weight = this.currentWeights[this.index]
} }
c := this.Candidates[this.index] c := this.Candidates[this.index]
this.currentWeights[this.index] -- this.currentWeights[this.index]--
this.index++ this.index++
return c return c
} }

View File

@@ -271,7 +271,7 @@ func (this *ServerConfig) Init(ctx context.Context) (results []error) {
this.isOk = true this.isOk = true
return nil return
} }
func (this *ServerConfig) IsInitialized() bool { func (this *ServerConfig) IsInitialized() bool {

View File

@@ -1,6 +1,9 @@
package serverconfigs package serverconfigs
import "testing" import (
"context"
"testing"
)
func TestServerConfig_Protocols(t *testing.T) { func TestServerConfig_Protocols(t *testing.T) {
{ {
@@ -65,7 +68,7 @@ func TestServerConfig_Protocols(t *testing.T) {
}, },
}, },
}} }}
err := server.Init(nil) err := server.Init(context.TODO())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -214,12 +214,12 @@ func (this *HTTPRequestCond) match(formatter func(source string) string) bool {
return types.Int64(paramValue)%100 == types.Int64(this.Value) return types.Int64(paramValue)%100 == types.Int64(this.Value)
case RequestCondOperatorEqString: case RequestCondOperatorEqString:
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
return strings.ToUpper(paramValue) == strings.ToUpper(this.Value) return strings.EqualFold(paramValue, this.Value)
} }
return paramValue == this.Value return paramValue == this.Value
case RequestCondOperatorNeqString: case RequestCondOperatorNeqString:
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
return strings.ToUpper(paramValue) != strings.ToUpper(this.Value) return !strings.EqualFold(paramValue, this.Value)
} }
return paramValue != this.Value return paramValue != this.Value
case RequestCondOperatorHasPrefix: case RequestCondOperatorHasPrefix:
@@ -243,11 +243,11 @@ func (this *HTTPRequestCond) match(formatter func(source string) string) bool {
} }
return !strings.Contains(paramValue, this.Value) return !strings.Contains(paramValue, this.Value)
case RequestCondOperatorEqIP: case RequestCondOperatorEqIP:
ip := net.ParseIP(paramValue) var ip = net.ParseIP(paramValue)
if ip == nil { if ip == nil {
return false return false
} }
return this.isIP && bytes.Compare(this.ipValue, ip) == 0 return this.isIP && ip.Equal(this.ipValue)
case RequestCondOperatorGtIP: case RequestCondOperatorGtIP:
ip := net.ParseIP(paramValue) ip := net.ParseIP(paramValue)
if ip == nil { if ip == nil {

View File

@@ -9,5 +9,5 @@ var (
RegexpAllDigitNumber = regexp.MustCompile(`^[+-]?\d+$`) // 整数,支持正负数 RegexpAllDigitNumber = regexp.MustCompile(`^[+-]?\d+$`) // 整数,支持正负数
RegexpAllFloatNumber = regexp.MustCompile(`^[+-]?\d+(\.\d+)?$`) // 浮点数支持正负数不支持e RegexpAllFloatNumber = regexp.MustCompile(`^[+-]?\d+(\.\d+)?$`) // 浮点数支持正负数不支持e
RegexpExternalURL = regexp.MustCompile("(?i)^(http|https|ftp)://") // URL RegexpExternalURL = regexp.MustCompile("(?i)^(http|https|ftp)://") // URL
RegexpNamedVariable = regexp.MustCompile("\\${[\\w.-]+}") // 命名变量 RegexpNamedVariable = regexp.MustCompile(`\${[\w.-]+}`) // 命名变量
) )

View File

@@ -1,17 +1,21 @@
package shared package shared_test
import ( import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/assert"
"testing" "testing"
) )
func TestRegexp(t *testing.T) { func TestRegexp(t *testing.T) {
a := assert.NewAssertion(t) var a = assert.NewAssertion(t)
a.IsTrue(RegexpFloatNumber.MatchString("123")) a.IsTrue(shared.RegexpFloatNumber.MatchString("123"))
a.IsTrue(RegexpFloatNumber.MatchString("123.456")) a.IsTrue(shared.RegexpFloatNumber.MatchString("123.456"))
a.IsFalse(RegexpFloatNumber.MatchString(".456")) a.IsFalse(shared.RegexpFloatNumber.MatchString(".456"))
a.IsFalse(RegexpFloatNumber.MatchString("abc")) a.IsFalse(shared.RegexpFloatNumber.MatchString("abc"))
a.IsFalse(RegexpFloatNumber.MatchString("123.")) a.IsFalse(shared.RegexpFloatNumber.MatchString("123."))
a.IsFalse(RegexpFloatNumber.MatchString("123.456e7")) a.IsFalse(shared.RegexpFloatNumber.MatchString("123.456e7"))
a.IsTrue(shared.RegexpNamedVariable.MatchString("${abc.efg}"))
a.IsTrue(shared.RegexpNamedVariable.MatchString("${abc}"))
a.IsFalse(shared.RegexpNamedVariable.MatchString("{abc.efg}"))
} }

View File

@@ -252,7 +252,7 @@ func (this *SSLPolicy) certIsEqual(cert1 tls.Certificate, cert2 tls.Certificate)
} }
for index, b := range b1 { for index, b := range b1 {
if bytes.Compare(b, b2[index]) != 0 { if !bytes.Equal(b, b2[index]) {
return false return false
} }
} }

View File

@@ -3,6 +3,7 @@
package sslconfigs_test package sslconfigs_test
import ( import (
"context"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/assert"
"testing" "testing"
@@ -120,7 +121,7 @@ Z3NIV2eNt6YBwkC69DzdazXT
OCSPExpiresAt: nowTime + 2, OCSPExpiresAt: nowTime + 2,
}) })
err := policy.Init(nil) err := policy.Init(context.TODO())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }