diff --git a/internal/iplibrary/ip_list.go b/internal/iplibrary/ip_list.go index 4f04582..be7f86b 100644 --- a/internal/iplibrary/ip_list.go +++ b/internal/iplibrary/ip_list.go @@ -6,7 +6,7 @@ import ( "sync" ) -// IP名单 +// IPList IP名单 type IPList struct { itemsMap map[int64]*IPItem // id => item ipMap map[uint64][]int64 // ip => itemIds @@ -96,7 +96,7 @@ func (this *IPList) Delete(itemId int64) { this.isAll = len(this.ipMap[0]) > 0 } -// 判断是否包含某个IP +// Contains 判断是否包含某个IP func (this *IPList) Contains(ip uint64) bool { this.locker.RLock() if this.isAll { @@ -109,7 +109,7 @@ func (this *IPList) Contains(ip uint64) bool { return ok } -// 是否包含一组IP +// ContainsIPStrings 是否包含一组IP func (this *IPList) ContainsIPStrings(ipStrings []string) (found bool, item *IPItem) { if len(ipStrings) == 0 { return diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index 1b610a4..0ffae99 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -68,12 +68,15 @@ type HTTPRequest struct { cacheKey string // 缓存使用的Key isCached bool // 是否已经被缓存 isAttack bool // 是否是攻击请求 + bodyData []byte // 读取的Body内容 // WAF相关 firewallPolicyId int64 firewallRuleGroupId int64 firewallRuleSetId int64 firewallRuleId int64 + firewallActions []string + tags []string logAttrs map[string]string @@ -1197,5 +1200,10 @@ func (this *HTTPRequest) canIgnore(err error) bool { return true } + // HTTP内部错误 + if strings.HasPrefix(err.Error(), "http:") || strings.HasPrefix(err.Error(), "http2:") { + return true + } + return false } diff --git a/internal/nodes/http_request_log.go b/internal/nodes/http_request_log.go index 93ac503..df92eeb 100644 --- a/internal/nodes/http_request_log.go +++ b/internal/nodes/http_request_log.go @@ -128,6 +128,8 @@ func (this *HTTPRequest) log() { FirewallRuleGroupId: this.firewallRuleGroupId, FirewallRuleSetId: this.firewallRuleSetId, FirewallRuleId: this.firewallRuleId, + FirewallActions: this.firewallActions, + Tags: this.tags, Attrs: this.logAttrs, } diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index 559d2b1..c6100c5 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -1,6 +1,7 @@ package nodes import ( + "bytes" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" @@ -8,6 +9,8 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" + "io" + "io/ioutil" "net/http" ) @@ -152,27 +155,36 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir if w == nil { return } - goNext, ruleGroup, ruleSet, err := w.MatchRequest(this.RawReq, this.writer) + + w.OnAction(func(action waf.ActionInterface) (goNext bool) { + switch action.Code() { + case waf.ActionTag: + this.tags = action.(*waf.TagAction).Tags + } + return true + }) + + goNext, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer) if err != nil { remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error()) return } if ruleSet != nil { - if ruleSet.Action != waf.ActionAllow { + if ruleSet.HasSpecialActions() { this.firewallPolicyId = firewallPolicy.Id this.firewallRuleGroupId = types.Int64(ruleGroup.Id) this.firewallRuleSetId = types.Int64(ruleSet.Id) - if ruleSet.Action == waf.ActionBlock { + if ruleSet.HasAttackActions() { this.isAttack = true } // 添加统计 - stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Action) + stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions) } - this.logAttrs["waf.action"] = ruleSet.Action + this.firewallActions = ruleSet.ActionCodes() } return !goNext, false @@ -208,28 +220,79 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi return } - goNext, ruleGroup, ruleSet, err := w.MatchResponse(this.RawReq, resp, this.writer) + w.OnAction(func(action waf.ActionInterface) (goNext bool) { + switch action.Code() { + case waf.ActionTag: + this.tags = action.(*waf.TagAction).Tags + } + return true + }) + + goNext, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer) if err != nil { remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error()) return } if ruleSet != nil { - if ruleSet.Action != waf.ActionAllow { + if ruleSet.HasSpecialActions() { this.firewallPolicyId = firewallPolicy.Id this.firewallRuleGroupId = types.Int64(ruleGroup.Id) this.firewallRuleSetId = types.Int64(ruleSet.Id) - if ruleSet.Action == waf.ActionBlock { + if ruleSet.HasAttackActions() { this.isAttack = true } // 添加统计 - stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Action) + stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions) } - this.logAttrs["waf.action"] = ruleSet.Action + this.firewallActions = ruleSet.ActionCodes() } return !goNext } + +// WAFRaw 原始请求 +func (this *HTTPRequest) WAFRaw() *http.Request { + return this.RawReq +} + +// WAFRemoteIP 客户端IP +func (this *HTTPRequest) WAFRemoteIP() string { + return this.requestRemoteAddr() +} + +// WAFGetCacheBody 获取缓存中的Body +func (this *HTTPRequest) WAFGetCacheBody() []byte { + return this.bodyData +} + +// WAFSetCacheBody 设置Body +func (this *HTTPRequest) WAFSetCacheBody(body []byte) { + this.bodyData = body +} + +// WAFReadBody 读取Body +func (this *HTTPRequest) WAFReadBody(max int64) (data []byte, err error) { + if this.RawReq.ContentLength > 0 { + data, err = ioutil.ReadAll(io.LimitReader(this.RawReq.Body, max)) + } + return +} + +// WAFRestoreBody 恢复Body +func (this *HTTPRequest) WAFRestoreBody(data []byte) { + if len(data) > 0 { + rawReader := bytes.NewBuffer(data) + buf := make([]byte, 1024) + _, _ = io.CopyBuffer(rawReader, this.RawReq.Body, buf) + this.RawReq.Body = ioutil.NopCloser(rawReader) + } +} + +// WAFServerId 服务ID +func (this *HTTPRequest) WAFServerId() int64 { + return this.Server.Id +} diff --git a/internal/nodes/listener_base.go b/internal/nodes/listener_base.go index 910cf7e..d06e32f 100644 --- a/internal/nodes/listener_base.go +++ b/internal/nodes/listener_base.go @@ -7,7 +7,7 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" - http2 "golang.org/x/net/http2" + "golang.org/x/net/http2" "sync" ) diff --git a/internal/nodes/traffic_listener.go b/internal/nodes/traffic_listener.go index e934fbb..67d99d5 100644 --- a/internal/nodes/traffic_listener.go +++ b/internal/nodes/traffic_listener.go @@ -2,7 +2,10 @@ package nodes -import "net" +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf" + "net" +) // TrafficListener 用于统计流量的网络监听 type TrafficListener struct { @@ -18,6 +21,17 @@ func (this *TrafficListener) Accept() (net.Conn, error) { if err != nil { return nil, err } + // 是否在WAF名单中 + ip, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err == nil { + if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, ip) && waf.SharedIPBlackLIst.Contains(waf.IPTypeAll, ip) { + go func() { + _ = conn.Close() + }() + return conn, nil + } + } + return NewTrafficConn(conn), nil } diff --git a/internal/nodes/waf_manager.go b/internal/nodes/waf_manager.go index 91f542e..ac3ab19 100644 --- a/internal/nodes/waf_manager.go +++ b/internal/nodes/waf_manager.go @@ -11,20 +11,20 @@ import ( var sharedWAFManager = NewWAFManager() -// WAF管理器 +// WAFManager WAF管理器 type WAFManager struct { mapping map[int64]*waf.WAF // policyId => WAF locker sync.RWMutex } -// 获取新对象 +// NewWAFManager 获取新对象 func NewWAFManager() *WAFManager { return &WAFManager{ mapping: map[int64]*waf.WAF{}, } } -// 更新策略 +// UpdatePolicies 更新策略 func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallPolicy) { this.locker.Lock() defer this.locker.Unlock() @@ -44,7 +44,7 @@ func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallP this.mapping = m } -// 查找WAF +// FindWAF 查找WAF func (this *WAFManager) FindWAF(policyId int64) *waf.WAF { this.locker.RLock() w, _ := this.mapping[policyId] @@ -78,14 +78,15 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) ( // rule sets for _, set := range group.Sets { s := &waf.RuleSet{ - Id: strconv.FormatInt(set.Id, 10), - Code: set.Code, - IsOn: set.IsOn, - Name: set.Name, - Description: set.Description, - Connector: set.Connector, - Action: set.Action, - ActionOptions: set.ActionOptions, + Id: strconv.FormatInt(set.Id, 10), + Code: set.Code, + IsOn: set.IsOn, + Name: set.Name, + Description: set.Description, + Connector: set.Connector, + } + for _, a := range set.Actions { + s.AddAction(a.Code, a.Options) } // rules @@ -132,14 +133,16 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) ( // rule sets for _, set := range group.Sets { s := &waf.RuleSet{ - Id: strconv.FormatInt(set.Id, 10), - Code: set.Code, - IsOn: set.IsOn, - Name: set.Name, - Description: set.Description, - Connector: set.Connector, - Action: set.Action, - ActionOptions: set.ActionOptions, + Id: strconv.FormatInt(set.Id, 10), + Code: set.Code, + IsOn: set.IsOn, + Name: set.Name, + Description: set.Description, + Connector: set.Connector, + } + + for _, a := range set.Actions { + s.AddAction(a.Code, a.Options) } // rules @@ -164,10 +167,11 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) ( // action if policy.BlockOptions != nil { - w.ActionBlock = &waf.BlockAction{ + w.DefaultBlockAction = &waf.BlockAction{ StatusCode: policy.BlockOptions.StatusCode, Body: policy.BlockOptions.Body, - URL: "", + URL: policy.BlockOptions.URL, + Timeout: policy.BlockOptions.Timeout, } } diff --git a/internal/rpc/rpc_client.go b/internal/rpc/rpc_client.go index 37f1a32..072262f 100644 --- a/internal/rpc/rpc_client.go +++ b/internal/rpc/rpc_client.go @@ -113,6 +113,10 @@ func (this *RPCClient) MetricStatRPC() pb.MetricStatServiceClient { return pb.NewMetricStatServiceClient(this.pickConn()) } +func (this *RPCClient) FirewallService() pb.FirewallServiceClient { + return pb.NewFirewallServiceClient(this.pickConn()) +} + // Context 节点上下文信息 func (this *RPCClient) Context() context.Context { ctx := context.Background() diff --git a/internal/stats/http_request_stat_manager.go b/internal/stats/http_request_stat_manager.go index 1fc29db..c30b61e 100644 --- a/internal/stats/http_request_stat_manager.go +++ b/internal/stats/http_request_stat_manager.go @@ -8,6 +8,7 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/monitor" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/rpc" + "github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" @@ -132,17 +133,19 @@ func (this *HTTPRequestStatManager) AddUserAgent(serverId int64, userAgent strin } // AddFirewallRuleGroupId 添加防火墙拦截动作 -func (this *HTTPRequestStatManager) AddFirewallRuleGroupId(serverId int64, firewallRuleGroupId int64, action string) { +func (this *HTTPRequestStatManager) AddFirewallRuleGroupId(serverId int64, firewallRuleGroupId int64, actions []*waf.ActionConfig) { if firewallRuleGroupId <= 0 { return } - this.totalAttackRequests ++ + this.totalAttackRequests++ - select { - case this.firewallRuleGroupChan <- strconv.FormatInt(serverId, 10) + "@" + strconv.FormatInt(firewallRuleGroupId, 10) + "@" + action: - default: - // 超出容量我们就丢弃 + for _, action := range actions { + select { + case this.firewallRuleGroupChan <- strconv.FormatInt(serverId, 10) + "@" + strconv.FormatInt(firewallRuleGroupId, 10) + "@" + action.Code: + default: + // 超出容量我们就丢弃 + } } } diff --git a/internal/utils/encrypt.go b/internal/utils/encrypt.go new file mode 100644 index 0000000..2bf38e6 --- /dev/null +++ b/internal/utils/encrypt.go @@ -0,0 +1,159 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package utils + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "encoding/json" + "errors" + "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" + "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/iwind/TeaGo/logs" + "github.com/iwind/TeaGo/maps" + "github.com/iwind/TeaGo/rands" + stringutil "github.com/iwind/TeaGo/utils/string" +) + +var ( + simpleEncryptMagicKey = rands.HexString(32) +) + +func init() { + events.On(events.EventReload, func() { + nodeConfig, _ := nodeconfigs.SharedNodeConfig() + if nodeConfig != nil { + simpleEncryptMagicKey = stringutil.Md5(nodeConfig.NodeId + "@" + nodeConfig.Secret) + } + }) +} + +// SimpleEncrypt 加密特殊信息 +func SimpleEncrypt(data []byte) []byte { + var method = &AES256CFBMethod{} + err := method.Init([]byte(simpleEncryptMagicKey), []byte(simpleEncryptMagicKey[:16])) + if err != nil { + logs.Println("[SimpleEncrypt]" + err.Error()) + return data + } + + dst, err := method.Encrypt(data) + if err != nil { + logs.Println("[SimpleEncrypt]" + err.Error()) + return data + } + return dst +} + +// SimpleDecrypt 解密特殊信息 +func SimpleDecrypt(data []byte) []byte { + var method = &AES256CFBMethod{} + err := method.Init([]byte(simpleEncryptMagicKey), []byte(simpleEncryptMagicKey[:16])) + if err != nil { + logs.Println("[MagicKeyEncode]" + err.Error()) + return data + } + + src, err := method.Decrypt(data) + if err != nil { + logs.Println("[MagicKeyEncode]" + err.Error()) + return data + } + return src +} + +func SimpleEncryptMap(m maps.Map) (base64String string, err error) { + mJSON, err := json.Marshal(m) + if err != nil { + return "", err + } + data := SimpleEncrypt(mJSON) + return base64.StdEncoding.EncodeToString(data), nil +} + +func SimpleDecryptMap(base64String string) (maps.Map, error) { + data, err := base64.StdEncoding.DecodeString(base64String) + if err != nil { + return nil, err + } + mJSON := SimpleDecrypt(data) + var result = maps.Map{} + err = json.Unmarshal(mJSON, &result) + if err != nil { + return nil, err + } + return result, nil +} + +type AES256CFBMethod struct { + block cipher.Block + iv []byte +} + +func (this *AES256CFBMethod) Init(key, iv []byte) error { + // 判断key是否为32长度 + l := len(key) + if l > 32 { + key = key[:32] + } else if l < 32 { + key = append(key, bytes.Repeat([]byte{' '}, 32-l)...) + } + + block, err := aes.NewCipher(key) + if err != nil { + return err + } + this.block = block + + // 判断iv长度 + l2 := len(iv) + if l2 > aes.BlockSize { + iv = iv[:aes.BlockSize] + } else if l2 < aes.BlockSize { + iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...) + } + this.iv = iv + + return nil +} + +func (this *AES256CFBMethod) Encrypt(src []byte) (dst []byte, err error) { + if len(src) == 0 { + return + } + + defer func() { + r := recover() + if r != nil { + err = errors.New("encrypt failed") + } + }() + + dst = make([]byte, len(src)) + + encrypter := cipher.NewCFBEncrypter(this.block, this.iv) + encrypter.XORKeyStream(dst, src) + + return +} + +func (this *AES256CFBMethod) Decrypt(dst []byte) (src []byte, err error) { + if len(dst) == 0 { + return + } + + defer func() { + r := recover() + if r != nil { + err = errors.New("decrypt failed") + } + }() + + src = make([]byte, len(dst)) + decrypter := cipher.NewCFBDecrypter(this.block, this.iv) + decrypter.XORKeyStream(src, dst) + + return +} diff --git a/internal/utils/encrypt_test.go b/internal/utils/encrypt_test.go new file mode 100644 index 0000000..3a5d411 --- /dev/null +++ b/internal/utils/encrypt_test.go @@ -0,0 +1,52 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package utils + +import ( + "github.com/iwind/TeaGo/maps" + "sync" + "testing" +) + +func TestSimpleEncrypt(t *testing.T) { + var arr = []string{"Hello", "World", "People"} + for _, s := range arr { + var value = []byte(s) + encoded := SimpleEncrypt(value) + t.Log(encoded, string(encoded)) + decoded := SimpleDecrypt(encoded) + t.Log(decoded, string(decoded)) + } +} + +func TestSimpleEncrypt_Concurrent(t *testing.T) { + wg := sync.WaitGroup{} + var arr = []string{"Hello", "World", "People"} + wg.Add(len(arr)) + for _, s := range arr { + go func(s string) { + defer wg.Done() + t.Log(string(SimpleDecrypt(SimpleEncrypt([]byte(s))))) + }(s) + } + wg.Wait() +} + +func TestSimpleEncryptMap(t *testing.T) { + var m = maps.Map{ + "s": "Hello", + "i": 20, + "b": true, + } + encodedResult, err := SimpleEncryptMap(m) + if err != nil { + t.Fatal(err) + } + t.Log("result:", encodedResult) + + decodedResult, err := SimpleDecryptMap(encodedResult) + if err != nil { + t.Fatal(err) + } + t.Log(decodedResult) +} diff --git a/internal/utils/expires/list.go b/internal/utils/expires/list.go index 2e83fe0..98fb485 100644 --- a/internal/utils/expires/list.go +++ b/internal/utils/expires/list.go @@ -12,6 +12,7 @@ type List struct { itemsMap map[int64]int64 // itemId => timestamp locker sync.Mutex + ticker *time.Ticker } func NewList() *List { @@ -21,10 +22,7 @@ func NewList() *List { } } -func (this *List) Add(itemId int64, expiredAt int64) { - if expiredAt <= time.Now().Unix() { - return - } +func (this *List) Add(itemId int64, expiresAt int64) { this.locker.Lock() defer this.locker.Unlock() @@ -34,17 +32,17 @@ func (this *List) Add(itemId int64, expiredAt int64) { this.removeItem(itemId) } - expireItemMap, ok := this.expireMap[expiredAt] + expireItemMap, ok := this.expireMap[expiresAt] if ok { expireItemMap[itemId] = true } else { expireItemMap = ItemMap{ itemId: true, } - this.expireMap[expiredAt] = expireItemMap + this.expireMap[expiresAt] = expireItemMap } - this.itemsMap[itemId] = expiredAt + this.itemsMap[itemId] = expiresAt } func (this *List) Remove(itemId int64) { @@ -64,21 +62,22 @@ func (this *List) GC(timestamp int64, callback func(itemId int64)) { } func (this *List) StartGC(callback func(itemId int64)) { - ticker := time.NewTicker(1 * time.Second) + this.ticker = time.NewTicker(1 * time.Second) lastTimestamp := int64(0) - for range ticker.C { + for range this.ticker.C { timestamp := time.Now().Unix() if lastTimestamp == 0 { lastTimestamp = timestamp - 3600 } - // 防止死循环 - if lastTimestamp > timestamp { - continue - } - - for i := lastTimestamp; i <= timestamp; i++ { - this.GC(timestamp, callback) + if timestamp >= lastTimestamp { + for i := lastTimestamp; i <= timestamp; i++ { + this.GC(i, callback) + } + } else { + for i := timestamp; i <= lastTimestamp; i++ { + this.GC(i, callback) + } } // 这样做是为了防止系统时钟突变 diff --git a/internal/utils/expires/list_test.go b/internal/utils/expires/list_test.go index c4b06d3..bca42e9 100644 --- a/internal/utils/expires/list_test.go +++ b/internal/utils/expires/list_test.go @@ -58,6 +58,10 @@ func TestList_Start_GC(t *testing.T) { list.Add(2, time.Now().Unix()+1) list.Add(3, time.Now().Unix()+2) list.Add(4, time.Now().Unix()+5) + list.Add(5, time.Now().Unix()+5) + list.Add(6, time.Now().Unix()+6) + list.Add(7, time.Now().Unix()+6) + list.Add(8, time.Now().Unix()+6) go func() { list.StartGC(func(itemId int64) { @@ -66,7 +70,7 @@ func TestList_Start_GC(t *testing.T) { }) }() - time.Sleep(10 * time.Second) + time.Sleep(20 * time.Second) } func TestList_ManyItems(t *testing.T) { diff --git a/internal/utils/jsonutils/map.go b/internal/utils/jsonutils/map.go new file mode 100644 index 0000000..4986f3e --- /dev/null +++ b/internal/utils/jsonutils/map.go @@ -0,0 +1,35 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package jsonutils + +import ( + "encoding/json" + "github.com/iwind/TeaGo/maps" +) + +func MapToObject(m maps.Map, ptr interface{}) error { + if m == nil { + return nil + } + mJSON, err := json.Marshal(m) + if err != nil { + return err + } + return json.Unmarshal(mJSON, ptr) +} + +func ObjectToMap(ptr interface{}) (maps.Map, error) { + if ptr == nil { + return maps.Map{}, nil + } + ptrJSON, err := json.Marshal(ptr) + if err != nil { + return nil, err + } + var result = maps.Map{} + err = json.Unmarshal(ptrJSON, &result) + if err != nil { + return nil, err + } + return result, nil +} diff --git a/internal/utils/jsonutils/map_test.go b/internal/utils/jsonutils/map_test.go new file mode 100644 index 0000000..6bccfcf --- /dev/null +++ b/internal/utils/jsonutils/map_test.go @@ -0,0 +1,46 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package jsonutils + +import ( + "github.com/iwind/TeaGo/assert" + "github.com/iwind/TeaGo/maps" + "testing" +) + +func TestMapToObject(t *testing.T) { + a := assert.NewAssertion(t) + + type typeA struct { + B int `json:"b"` + C bool `json:"c"` + } + + { + var obj = &typeA{B: 1, C: true} + m, err := ObjectToMap(obj) + if err != nil { + t.Fatal(err) + } + PrintT(m, t) + a.IsTrue(m.GetInt("b") == 1) + a.IsTrue(m.GetBool("c") == true) + } + + { + var obj = &typeA{} + err := MapToObject(maps.Map{ + "b": 1024, + "c": true, + }, obj) + if err != nil { + t.Fatal(err) + } + if obj == nil { + t.Fatal("obj should not be nil") + } + a.IsTrue(obj.B == 1024) + a.IsTrue(obj.C == true) + PrintT(obj, t) + } +} diff --git a/internal/utils/jsonutils/utils.go b/internal/utils/jsonutils/utils.go new file mode 100644 index 0000000..5c37dfd --- /dev/null +++ b/internal/utils/jsonutils/utils.go @@ -0,0 +1,17 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package jsonutils + +import ( + "encoding/json" + "testing" +) + +func PrintT(obj interface{}, t *testing.T) { + data, err := json.MarshalIndent(obj, "", " ") + if err != nil { + t.Log(err) + } else { + t.Log(string(data)) + } +} diff --git a/internal/waf/action_allow.go b/internal/waf/action_allow.go index 35421ca..ea3b3a4 100644 --- a/internal/waf/action_allow.go +++ b/internal/waf/action_allow.go @@ -8,7 +8,23 @@ import ( type AllowAction struct { } -func (this *AllowAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) { +func (this *AllowAction) Init(waf *WAF) error { + return nil +} + +func (this *AllowAction) Code() string { + return ActionAllow +} + +func (this *AllowAction) IsAttack() bool { + return false +} + +func (this *AllowAction) WillChange() bool { + return false +} + +func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { // do nothing return true } diff --git a/internal/waf/action_base.go b/internal/waf/action_base.go new file mode 100644 index 0000000..e0e6bec --- /dev/null +++ b/internal/waf/action_base.go @@ -0,0 +1,21 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package waf + +import "net/http" + +type BaseAction struct { +} + +// CloseConn 关闭连接 +func (this *BaseAction) CloseConn(writer http.ResponseWriter) error { + // 断开连接 + hijack, ok := writer.(http.Hijacker) + if ok { + conn, _, err := hijack.Hijack() + if err == nil { + return conn.Close() + } + } + return nil +} diff --git a/internal/waf/action_block.go b/internal/waf/action_block.go index e6ba70f..4b91fe4 100644 --- a/internal/waf/action_block.go +++ b/internal/waf/action_block.go @@ -23,12 +23,48 @@ type BlockAction struct { StatusCode int `yaml:"statusCode" json:"statusCode"` Body string `yaml:"body" json:"body"` // supports HTML URL string `yaml:"url" json:"url"` + Timeout int32 `yaml:"timeout" json:"timeout"` } -func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) { +func (this *BlockAction) Init(waf *WAF) error { + if waf.DefaultBlockAction != nil { + if this.StatusCode <= 0 { + this.StatusCode = waf.DefaultBlockAction.StatusCode + } + if len(this.Body) == 0 { + this.Body = waf.DefaultBlockAction.Body + } + if len(this.URL) == 0 { + this.URL = waf.DefaultBlockAction.URL + } + if this.Timeout <= 0 { + this.Timeout = waf.DefaultBlockAction.Timeout + } + } + return nil +} + +func (this *BlockAction) Code() string { + return ActionBlock +} + +func (this *BlockAction) IsAttack() bool { + return true +} + +func (this *BlockAction) WillChange() bool { + return true +} + +func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { + if this.Timeout > 0 { + // 加入到黑名单 + SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), time.Now().Unix()+int64(this.Timeout)) + } + if writer != nil { - // if status code eq 444, we close the connection - if this.StatusCode == 444 { + // close the connection + defer func() { hijack, ok := writer.(http.Hijacker) if ok { conn, _, _ := hijack.Hijack() @@ -37,7 +73,7 @@ func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer htt return } } - } + }() // output response if this.StatusCode > 0 { diff --git a/internal/waf/action_captcha.go b/internal/waf/action_captcha.go index 120869d..4db432e 100644 --- a/internal/waf/action_captcha.go +++ b/internal/waf/action_captcha.go @@ -1,11 +1,14 @@ package waf import ( + "github.com/TeaOSLab/EdgeNode/internal/remotelogs" + "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/waf/requests" - "github.com/iwind/TeaGo/types" + "github.com/iwind/TeaGo/maps" stringutil "github.com/iwind/TeaGo/utils/string" "net/http" "net/url" + "strings" "time" ) @@ -13,27 +16,63 @@ var captchaSalt = stringutil.Rand(32) const ( CaptchaSeconds = 600 // 10 minutes + CaptchaPath = "/WAF/VERIFY/CAPTCHA" ) type CaptchaAction struct { + Life int32 `yaml:"life" json:"life"` + Language string `yaml:"language" json:"language"` // 语言,zh-CN, en-US ... + AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单 } -func (this *CaptchaAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) { - // TEAWEB_CAPTCHA: - cookie, err := request.Cookie("TEAWEB_WAF_CAPTCHA") - if err == nil && cookie != nil && len(cookie.Value) > 32 { - m := cookie.Value[:32] - timestamp := cookie.Value[32:] - if stringutil.Md5(captchaSalt+timestamp) == m && time.Now().Unix() < types.Int64(timestamp) { // verify md5 - return true +func (this *CaptchaAction) Init(waf *WAF) error { + return nil +} + +func (this *CaptchaAction) Code() string { + return ActionCaptcha +} + +func (this *CaptchaAction) IsAttack() bool { + return false +} + +func (this *CaptchaAction) WillChange() bool { + return true +} + +func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { + // 是否在白名单中 + if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) { + return true + } + + refURL := request.WAFRaw().URL.String() + + // 覆盖配置 + if strings.HasPrefix(refURL, CaptchaPath) { + info := request.WAFRaw().URL.Query().Get("info") + if len(info) > 0 { + m, err := utils.SimpleDecryptMap(info) + if err == nil && m != nil { + refURL = m.GetString("url") + } } } - refURL := request.URL.String() - if len(request.Referer()) > 0 { - refURL = request.Referer() + var captchaConfig = maps.Map{ + "action": this, + "timestamp": time.Now().Unix(), + "url": refURL, + "setId": set.Id, } - http.Redirect(writer, request.Raw(), "/WAFCAPTCHA?url="+url.QueryEscape(refURL), http.StatusTemporaryRedirect) + info, err := utils.SimpleEncryptMap(captchaConfig) + if err != nil { + remotelogs.Error("WAF_CAPTCHA_ACTION", "encode captcha config failed: "+err.Error()) + return true + } + + http.Redirect(writer, request.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect) return false } diff --git a/internal/waf/action_category.go b/internal/waf/action_category.go new file mode 100644 index 0000000..f5bf9c2 --- /dev/null +++ b/internal/waf/action_category.go @@ -0,0 +1,13 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package waf + +import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" + +type ActionCategory = string + +const ( + ActionCategoryAllow ActionCategory = firewallconfigs.HTTPFirewallActionCategoryAllow + ActionCategoryBlock ActionCategory = firewallconfigs.HTTPFirewallActionCategoryBlock + ActionCategoryVerify ActionCategory = firewallconfigs.HTTPFirewallActionCategoryVerify +) diff --git a/internal/waf/action_config.go b/internal/waf/action_config.go new file mode 100644 index 0000000..5cae9cf --- /dev/null +++ b/internal/waf/action_config.go @@ -0,0 +1,10 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package waf + +import "github.com/iwind/TeaGo/maps" + +type ActionConfig struct { + Code string `yaml:"code" json:"code"` + Options maps.Map `yaml:"options" json:"options"` +} diff --git a/internal/waf/action_definition.go b/internal/waf/action_definition.go index e268742..aff5bc4 100644 --- a/internal/waf/action_definition.go +++ b/internal/waf/action_definition.go @@ -2,11 +2,12 @@ package waf import "reflect" -// action definition +// ActionDefinition action definition type ActionDefinition struct { Name string Code ActionString Description string + Category string // category: block, verify, allow Instance ActionInterface Type reflect.Type } diff --git a/internal/waf/action_get_302.go b/internal/waf/action_get_302.go new file mode 100644 index 0000000..304d310 --- /dev/null +++ b/internal/waf/action_get_302.go @@ -0,0 +1,71 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/remotelogs" + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/maps" + "net/http" + "net/url" + "time" +) + +const ( + Get302Path = "/WAF/VERIFY/GET" +) + +// Get302Action +// 原理: origin url --> 302 verify url --> origin url +// TODO 将来支持meta refresh验证 +type Get302Action struct { + BaseAction + + Life int32 `yaml:"life" json:"life"` +} + +func (this *Get302Action) Init(waf *WAF) error { + return nil +} + +func (this *Get302Action) Code() string { + return ActionGet302 +} + +func (this *Get302Action) IsAttack() bool { + return false +} + +func (this *Get302Action) WillChange() bool { + return true +} + +func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { + // 仅限于Get + if request.WAFRaw().Method != http.MethodGet { + return true + } + + // 是否已经在白名单中 + if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) { + return true + } + + var m = maps.Map{ + "url": request.WAFRaw().URL.String(), + "timestamp": time.Now().Unix(), + "life": this.Life, + "setId": set.Id, + } + info, err := utils.SimpleEncryptMap(m) + if err != nil { + remotelogs.Error("WAF_GET_302_ACTION", "encode info failed: "+err.Error()) + return true + } + + http.Redirect(writer, request.WAFRaw(), Get302Path+"?info="+url.QueryEscape(info), http.StatusFound) + + // 关闭连接 + _ = this.CloseConn(writer) + + return true +} diff --git a/internal/waf/action_go_group.go b/internal/waf/action_go_group.go index 446bd0a..85f2f64 100644 --- a/internal/waf/action_go_group.go +++ b/internal/waf/action_go_group.go @@ -10,13 +10,29 @@ type GoGroupAction struct { GroupId string `yaml:"groupId" json:"groupId"` } -func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) { - group := waf.FindRuleGroup(this.GroupId) - if group == nil || !group.IsOn { +func (this *GoGroupAction) Init(waf *WAF) error { + return nil +} + +func (this *GoGroupAction) Code() string { + return ActionGoGroup +} + +func (this *GoGroupAction) IsAttack() bool { + return false +} + +func (this *GoGroupAction) WillChange() bool { + return true +} + +func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { + nextGroup := waf.FindRuleGroup(this.GroupId) + if nextGroup == nil || !nextGroup.IsOn { return true } - b, set, err := group.MatchRequest(request) + b, nextSet, err := nextGroup.MatchRequest(request) if err != nil { logs.Error(err) return true @@ -26,9 +42,5 @@ func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer h return true } - actionObject := FindActionInstance(set.Action, set.ActionOptions) - if actionObject == nil { - return true - } - return actionObject.Perform(waf, request, writer) + return nextSet.PerformActions(waf, nextGroup, request, writer) } diff --git a/internal/waf/action_go_set.go b/internal/waf/action_go_set.go index ad8b049..eadfd03 100644 --- a/internal/waf/action_go_set.go +++ b/internal/waf/action_go_set.go @@ -11,17 +11,33 @@ type GoSetAction struct { SetId string `yaml:"setId" json:"setId"` } -func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) { - group := waf.FindRuleGroup(this.GroupId) - if group == nil || !group.IsOn { +func (this *GoSetAction) Init(waf *WAF) error { + return nil +} + +func (this *GoSetAction) Code() string { + return ActionGoSet +} + +func (this *GoSetAction) IsAttack() bool { + return false +} + +func (this *GoSetAction) WillChange() bool { + return true +} + +func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { + nextGroup := waf.FindRuleGroup(this.GroupId) + if nextGroup == nil || !nextGroup.IsOn { return true } - set := group.FindRuleSet(this.SetId) - if set == nil || !set.IsOn { + nextSet := nextGroup.FindRuleSet(this.SetId) + if nextSet == nil || !nextSet.IsOn { return true } - b, err := set.MatchRequest(request) + b, err := nextSet.MatchRequest(request) if err != nil { logs.Error(err) return true @@ -29,9 +45,5 @@ func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer htt if !b { return true } - actionObject := FindActionInstance(set.Action, set.ActionOptions) - if actionObject == nil { - return true - } - return actionObject.Perform(waf, request, writer) + return nextSet.PerformActions(waf, nextGroup, request, writer) } diff --git a/internal/waf/action_interface.go b/internal/waf/action_interface.go new file mode 100644 index 0000000..256b58e --- /dev/null +++ b/internal/waf/action_interface.go @@ -0,0 +1,25 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" +) + +type ActionInterface interface { + // Init 初始化 + Init(waf *WAF) error + + // Code 代号 + Code() string + + // IsAttack 是否为拦截攻击动作 + IsAttack() bool + + // WillChange determine if the action will change the request + WillChange() bool + + // Perform perform the action + Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) +} diff --git a/internal/waf/action_log.go b/internal/waf/action_log.go index 8b8efcd..74c85ac 100644 --- a/internal/waf/action_log.go +++ b/internal/waf/action_log.go @@ -8,6 +8,22 @@ import ( type LogAction struct { } -func (this *LogAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) { +func (this *LogAction) Init(waf *WAF) error { + return nil +} + +func (this *LogAction) Code() string { + return ActionLog +} + +func (this *LogAction) IsAttack() bool { + return false +} + +func (this *LogAction) WillChange() bool { + return false +} + +func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { return true } diff --git a/internal/waf/action_notify.go b/internal/waf/action_notify.go new file mode 100644 index 0000000..1df7d4e --- /dev/null +++ b/internal/waf/action_notify.go @@ -0,0 +1,86 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package waf + +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/TeaOSLab/EdgeNode/internal/remotelogs" + "github.com/TeaOSLab/EdgeNode/internal/rpc" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/types" + "net/http" + "time" +) + +type notifyTask struct { + ServerId int64 + HttpFirewallPolicyId int64 + HttpFirewallRuleGroupId int64 + HttpFirewallRuleSetId int64 + CreatedAt int64 +} + +var notifyChan = make(chan *notifyTask, 128) + +func init() { + events.On(events.EventLoaded, func() { + go func() { + rpcClient, err := rpc.SharedRPC() + if err != nil { + remotelogs.Error("WAF_NOTIFY_ACTION", "create rpc client failed: "+err.Error()) + return + } + + for task := range notifyChan { + _, err = rpcClient.FirewallService().NotifyHTTPFirewallEvent(rpcClient.Context(), &pb.NotifyHTTPFirewallEventRequest{ + ServerId: task.ServerId, + HttpFirewallPolicyId: task.HttpFirewallPolicyId, + HttpFirewallRuleGroupId: task.HttpFirewallRuleGroupId, + HttpFirewallRuleSetId: task.HttpFirewallRuleSetId, + CreatedAt: task.CreatedAt, + }) + if err != nil { + remotelogs.Error("WAF_NOTIFY_ACTION", "notify failed: "+err.Error()) + } + } + }() + }) +} + +type NotifyAction struct { +} + +func (this *NotifyAction) Init(waf *WAF) error { + return nil +} + +func (this *NotifyAction) Code() string { + return ActionNotify +} + +func (this *NotifyAction) IsAttack() bool { + return false +} + +// WillChange determine if the action will change the request +func (this *NotifyAction) WillChange() bool { + return false +} + +// Perform perform the action +func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { + select { + case notifyChan <- ¬ifyTask{ + ServerId: request.WAFServerId(), + HttpFirewallPolicyId: types.Int64(waf.Id), + HttpFirewallRuleGroupId: types.Int64(group.Id), + HttpFirewallRuleSetId: types.Int64(set.Id), + CreatedAt: time.Now().Unix(), + }: + default: + + } + + return true +} diff --git a/internal/waf/action_post_307.go b/internal/waf/action_post_307.go new file mode 100644 index 0000000..22a4dc1 --- /dev/null +++ b/internal/waf/action_post_307.go @@ -0,0 +1,88 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/remotelogs" + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/maps" + "net/http" + "time" +) + +type Post307Action struct { + Life int32 `yaml:"life" json:"life"` + + BaseAction +} + +func (this *Post307Action) Init(waf *WAF) error { + return nil +} + +func (this *Post307Action) Code() string { + return ActionPost307 +} + +func (this *Post307Action) IsAttack() bool { + return false +} + +func (this *Post307Action) WillChange() bool { + return true +} + +func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { + var cookieName = "WAF_VALIDATOR_ID" + + // 仅限于POST + if request.WAFRaw().Method != http.MethodPost { + return true + } + + // 是否已经在白名单中 + if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) { + return true + } + + // 判断是否有Cookie + cookie, err := request.WAFRaw().Cookie(cookieName) + if err == nil && cookie != nil { + m, err := utils.SimpleDecryptMap(cookie.Value) + if err == nil && m.GetString("remoteIP") == request.WAFRemoteIP() && time.Now().Unix() < m.GetInt64("timestamp")+10 { + var life = m.GetInt64("life") + if life <= 0 { + life = 600 // 默认10分钟 + } + var setId = m.GetString("setId") + SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life) + return true + } + } + + var m = maps.Map{ + "timestamp": time.Now().Unix(), + "life": this.Life, + "setId": set.Id, + "remoteIP": request.WAFRemoteIP(), + } + info, err := utils.SimpleEncryptMap(m) + if err != nil { + remotelogs.Error("WAF_POST_302_ACTION", "encode info failed: "+err.Error()) + return true + } + + // 设置Cookie + http.SetCookie(writer, &http.Cookie{ + Name: cookieName, + Path: "/", + MaxAge: 10, + Value: info, + }) + + http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect) + + // 关闭连接 + _ = this.CloseConn(writer) + + return true +} diff --git a/internal/waf/action_record_ip.go b/internal/waf/action_record_ip.go new file mode 100644 index 0000000..8a34906 --- /dev/null +++ b/internal/waf/action_record_ip.go @@ -0,0 +1,120 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/TeaOSLab/EdgeNode/internal/remotelogs" + "github.com/TeaOSLab/EdgeNode/internal/rpc" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" + "strings" + "time" +) + +type recordIPTask struct { + ip string + listId int64 + expiredAt int64 + level string +} + +var recordIPTaskChan = make(chan *recordIPTask, 1024) + +func init() { + events.On(events.EventLoaded, func() { + go func() { + rpcClient, err := rpc.SharedRPC() + if err != nil { + remotelogs.Error("WAF_RECORD_IP_ACTION", "create rpc client failed: "+err.Error()) + return + } + + for task := range recordIPTaskChan { + ipType := "ipv4" + if strings.Contains(task.ip, ":") { + ipType = "ipv6" + } + _, err = rpcClient.IPItemRPC().CreateIPItem(rpcClient.Context(), &pb.CreateIPItemRequest{ + IpListId: task.listId, + IpFrom: task.ip, + IpTo: "", + ExpiredAt: task.expiredAt, + Reason: "触发WAF规则自动加入", + Type: ipType, + EventLevel: task.level, + }) + if err != nil { + remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error()) + } + } + }() + }) +} + +type RecordIPAction struct { + BaseAction + + Type string `yaml:"type" json:"type"` + IPListId int64 `yaml:"ipListId" json:"ipListId"` + Level string `yaml:"level" json:"level"` + Timeout int32 `yaml:"timeout" json:"timeout"` +} + +func (this *RecordIPAction) Init(waf *WAF) error { + return nil +} + +func (this *RecordIPAction) Code() string { + return ActionRecordIP +} + +func (this *RecordIPAction) IsAttack() bool { + return this.Type == "black" +} + +func (this *RecordIPAction) WillChange() bool { + return this.Type == "black" +} + +func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { + // 是否在本地白名单中 + if SharedIPWhiteList.Contains("set:"+set.Id, set.Id) { + return true + } + + // 先加入本地的黑名单 + timeout := this.Timeout + if timeout <= 0 { + timeout = 86400 // 1天 + } + expiredAt := time.Now().Unix() + int64(timeout) + + if this.Type == "black" { + _ = this.CloseConn(writer) + + SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), expiredAt) + } else { + // 加入本地白名单 + timeout := this.Timeout + if timeout <= 0 { + timeout = 86400 // 1天 + } + SharedIPWhiteList.Add("set:"+set.Id, request.WAFRemoteIP(), expiredAt) + } + + // 上报 + if this.IPListId > 0 { + select { + case recordIPTaskChan <- &recordIPTask{ + ip: request.WAFRemoteIP(), + listId: this.IPListId, + expiredAt: expiredAt, + level: this.Level, + }: + default: + + } + } + + return this.Type != "black" +} diff --git a/internal/waf/action_tag.go b/internal/waf/action_tag.go new file mode 100644 index 0000000..b39794f --- /dev/null +++ b/internal/waf/action_tag.go @@ -0,0 +1,30 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" +) + +type TagAction struct { + Tags []string `yaml:"tags" json:"tags"` +} + +func (this *TagAction) Init(waf *WAF) error { + return nil +} + +func (this *TagAction) Code() string { + return ActionTag +} + +func (this *TagAction) IsAttack() bool { + return false +} + +func (this *TagAction) WillChange() bool { + return false +} + +func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { + return true +} diff --git a/internal/waf/action_type.go b/internal/waf/action_type.go deleted file mode 100644 index 221226d..0000000 --- a/internal/waf/action_type.go +++ /dev/null @@ -1,21 +0,0 @@ -package waf - -import ( - "github.com/TeaOSLab/EdgeNode/internal/waf/requests" - "net/http" -) - -type ActionString = string - -const ( - ActionLog = "log" // allow and log - ActionBlock = "block" // block - ActionCaptcha = "captcha" // block and show captcha - ActionAllow = "allow" // allow - ActionGoGroup = "go_group" // go to next rule group - ActionGoSet = "go_set" // go to next rule set -) - -type ActionInterface interface { - Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) -} diff --git a/internal/waf/action_types.go b/internal/waf/action_types.go new file mode 100644 index 0000000..3a9d512 --- /dev/null +++ b/internal/waf/action_types.go @@ -0,0 +1,88 @@ +package waf + +import "reflect" + +type ActionString = string + +const ( + ActionLog ActionString = "log" // allow and log + ActionBlock ActionString = "block" // block + ActionCaptcha ActionString = "captcha" // block and show captcha + ActionNotify ActionString = "notify" // 告警 + ActionGet302 ActionString = "get_302" // 针对GET的302重定向认证 + ActionPost307 ActionString = "post_307" // 针对POST的307重定向认证 + ActionRecordIP ActionString = "record_ip" // 记录IP + ActionTag ActionString = "tag" // 标签 + ActionAllow ActionString = "allow" // allow + ActionGoGroup ActionString = "go_group" // go to next rule group + ActionGoSet ActionString = "go_set" // go to next rule set +) + +var AllActions = []*ActionDefinition{ + { + Name: "阻止", + Code: ActionBlock, + Instance: new(BlockAction), + Type: reflect.TypeOf(new(BlockAction)).Elem(), + }, + { + Name: "允许通过", + Code: ActionAllow, + Instance: new(AllowAction), + Type: reflect.TypeOf(new(AllowAction)).Elem(), + }, + { + Name: "允许并记录日志", + Code: ActionLog, + Instance: new(LogAction), + Type: reflect.TypeOf(new(LogAction)).Elem(), + }, + { + Name: "Captcha验证码", + Code: ActionCaptcha, + Instance: new(CaptchaAction), + Type: reflect.TypeOf(new(CaptchaAction)).Elem(), + }, + { + Name: "告警", + Code: ActionNotify, + Instance: new(NotifyAction), + Type: reflect.TypeOf(new(NotifyAction)).Elem(), + }, + { + Name: "GET 302", + Code: ActionGet302, + Instance: new(Get302Action), + Type: reflect.TypeOf(new(Get302Action)).Elem(), + }, + { + Name: "POST 307", + Code: ActionPost307, + Instance: new(Post307Action), + Type: reflect.TypeOf(new(Post307Action)).Elem(), + }, + { + Name: "记录IP", + Code: ActionRecordIP, + Instance: new(RecordIPAction), + Type: reflect.TypeOf(new(RecordIPAction)).Elem(), + }, + { + Name: "标签", + Code: ActionTag, + Instance: new(TagAction), + Type: reflect.TypeOf(new(TagAction)).Elem(), + }, + { + Name: "跳到下一个规则分组", + Code: ActionGoGroup, + Instance: new(GoGroupAction), + Type: reflect.TypeOf(new(GoGroupAction)).Elem(), + }, + { + Name: "跳到下一个规则集", + Code: ActionGoSet, + Instance: new(GoSetAction), + Type: reflect.TypeOf(new(GoSetAction)).Elem(), + }, +} diff --git a/internal/waf/action_utils.go b/internal/waf/action_utils.go index d2178e9..39f1259 100644 --- a/internal/waf/action_utils.go +++ b/internal/waf/action_utils.go @@ -1,45 +1,12 @@ package waf import ( + "encoding/json" + "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/iwind/TeaGo/maps" "reflect" ) -var AllActions = []*ActionDefinition{ - { - Name: "阻止", - Code: ActionBlock, - Instance: new(BlockAction), - }, - { - Name: "允许通过", - Code: ActionAllow, - Instance: new(AllowAction), - }, - { - Name: "允许并记录日志", - Code: ActionLog, - Instance: new(LogAction), - }, - { - Name: "Captcha验证码", - Code: ActionCaptcha, - Instance: new(CaptchaAction), - }, - { - Name: "跳到下一个规则分组", - Code: ActionGoGroup, - Instance: new(GoGroupAction), - Type: reflect.TypeOf(new(GoGroupAction)).Elem(), - }, - { - Name: "跳到下一个规则集", - Code: ActionGoSet, - Instance: new(GoSetAction), - Type: reflect.TypeOf(new(GoSetAction)).Elem(), - }, -} - func FindActionInstance(action ActionString, options maps.Map) ActionInterface { for _, def := range AllActions { if def.Code == action { @@ -49,15 +16,13 @@ func FindActionInstance(action ActionString, options maps.Map) ActionInterface { instance := ptrValue.Interface().(ActionInterface) if len(options) > 0 { - count := def.Type.NumField() - for i := 0; i < count; i++ { - field := def.Type.Field(i) - tag, ok := field.Tag.Lookup("yaml") - if ok { - v, ok := options[tag] - if ok && reflect.TypeOf(v) == field.Type { - ptrValue.Elem().FieldByName(field.Name).Set(reflect.ValueOf(v)) - } + optionsJSON, err := json.Marshal(options) + if err != nil { + remotelogs.Error("WAF_FindActionInstance", "encode options to json failed: "+err.Error()) + } else { + err = json.Unmarshal(optionsJSON, instance) + if err != nil { + remotelogs.Error("WAF_FindActionInstance", "decode options from json failed: "+err.Error()) } } } diff --git a/internal/waf/action_utils_test.go b/internal/waf/action_utils_test.go index e219f55..735fe32 100644 --- a/internal/waf/action_utils_test.go +++ b/internal/waf/action_utils_test.go @@ -2,6 +2,7 @@ package waf import ( "github.com/iwind/TeaGo/assert" + "github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/maps" "runtime" "testing" @@ -16,11 +17,20 @@ func TestFindActionInstance(t *testing.T) { t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil)) t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil)) t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil)) - t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b",})) + t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b"})) a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil)) } +func TestFindActionInstance_Options(t *testing.T) { + //t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{})) + //t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{})) + //logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{}), t) + logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{ + "timeout": 3600, + }), t) +} + func BenchmarkFindActionInstance(b *testing.B) { runtime.GOMAXPROCS(1) for i := 0; i < b.N; i++ { diff --git a/internal/waf/captcha_validator.go b/internal/waf/captcha_validator.go index e945e27..954bbdf 100644 --- a/internal/waf/captcha_validator.go +++ b/internal/waf/captcha_validator.go @@ -3,29 +3,64 @@ package waf import ( "bytes" "encoding/base64" - "fmt" + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/TeaOSLab/EdgeNode/internal/utils/jsonutils" "github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/dchest/captcha" "github.com/iwind/TeaGo/logs" - stringutil "github.com/iwind/TeaGo/utils/string" + "github.com/iwind/TeaGo/types" "net/http" + "strconv" + "strings" "time" ) -var captchaValidator = &CaptchaValidator{} +var captchaValidator = NewCaptchaValidator() type CaptchaValidator struct { } -func (this *CaptchaValidator) Run(request *requests.Request, writer http.ResponseWriter) { - if request.Method == http.MethodPost && len(request.FormValue("TEAWEB_WAF_CAPTCHA_ID")) > 0 { - this.validate(request, writer) +func NewCaptchaValidator() *CaptchaValidator { + return &CaptchaValidator{} +} + +func (this *CaptchaValidator) Run(request requests.Request, writer http.ResponseWriter) { + var info = request.WAFRaw().URL.Query().Get("info") + if len(info) == 0 { + writer.WriteHeader(http.StatusBadRequest) + _, _ = writer.Write([]byte("invalid request")) + return + } + m, err := utils.SimpleDecryptMap(info) + if err != nil { + _, _ = writer.Write([]byte("invalid request")) + return + } + + timestamp := m.GetInt64("timestamp") + if timestamp < time.Now().Unix()-600 { // 10分钟之后信息过期 + http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect) + return + } + + var actionConfig = &CaptchaAction{} + err = jsonutils.MapToObject(m.GetMap("action"), actionConfig) + if err != nil { + http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect) + return + } + + var setId = m.GetInt64("setId") + var originURL = m.GetString("url") + + if request.WAFRaw().Method == http.MethodPost && len(request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 { + this.validate(actionConfig, setId, originURL, request, writer) } else { - this.show(request, writer) + this.show(actionConfig, request, writer) } } -func (this *CaptchaValidator) show(request *requests.Request, writer http.ResponseWriter) { +func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests.Request, writer http.ResponseWriter) { // show captcha captchaId := captcha.NewLen(6) buf := bytes.NewBuffer([]byte{}) @@ -35,48 +70,86 @@ func (this *CaptchaValidator) show(request *requests.Request, writer http.Respon return } + var lang = actionConfig.Language + if len(lang) == 0 { + acceptLanguage := request.WAFRaw().Header.Get("Accept-Language") + if len(acceptLanguage) > 0 { + langIndex := strings.Index(acceptLanguage, ",") + if langIndex > 0 { + lang = acceptLanguage[:langIndex] + } + } + } + if len(lang) == 0 { + lang = "en-US" + } + + var msgTitle = "" + var msgPrompt = "" + var msgButtonTitle = "" + + switch lang { + case "en-US": + msgTitle = "Verify Yourself" + msgPrompt = "Input verify code above:" + msgButtonTitle = "Verify Yourself" + case "zh-CN": + msgTitle = "身份验证" + msgPrompt = "请输入上面的验证码" + msgButtonTitle = "提交验证" + default: + msgTitle = "Verify Yourself" + msgPrompt = "Input verify code above:" + msgButtonTitle = "Verify Yourself" + } + + writer.Header().Set("Content-Type", "text/html; charset=utf-8") _, _ = writer.Write([]byte(` - Verify Yourself + ` + msgTitle + ` +
- + ` + `
-

Input verify code above:

- +

` + msgPrompt + `

+
- +
`)) } -func (this *CaptchaValidator) validate(request *requests.Request, writer http.ResponseWriter) (allow bool) { - captchaId := request.FormValue("TEAWEB_WAF_CAPTCHA_ID") +func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, setId int64, originURL string, request requests.Request, writer http.ResponseWriter) (allow bool) { + captchaId := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID") if len(captchaId) > 0 { - captchaCode := request.FormValue("TEAWEB_WAF_CAPTCHA_CODE") + captchaCode := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_CODE") if captcha.VerifyString(captchaId, captchaCode) { - // set cookie - timestamp := fmt.Sprintf("%d", time.Now().Unix()+CaptchaSeconds) - m := stringutil.Md5(captchaSalt + timestamp) - http.SetCookie(writer, &http.Cookie{ - Name: "TEAWEB_WAF_CAPTCHA", - Value: m + timestamp, - MaxAge: CaptchaSeconds, // TODO 这个时间可以设置 - Path: "/", // all of dirs - }) + var life = CaptchaSeconds + if actionConfig.Life > 0 { + life = types.Int(actionConfig.Life) + } - rawURL := request.URL.Query().Get("url") - http.Redirect(writer, request.Raw(), rawURL, http.StatusSeeOther) + // 加入到白名单 + SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), request.WAFRemoteIP(), time.Now().Unix()+int64(life)) // TODO + + http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther) return false } else { - http.Redirect(writer, request.Raw(), request.URL.String(), http.StatusSeeOther) + http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusSeeOther) } } diff --git a/internal/waf/checkpoints/cc.go b/internal/waf/checkpoints/cc.go index 8de7ef9..de6c3c9 100644 --- a/internal/waf/checkpoints/cc.go +++ b/internal/waf/checkpoints/cc.go @@ -5,14 +5,12 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" - "net" "regexp" - "strings" "sync" "time" ) -// ${cc.arg} +// CCCheckpoint ${cc.arg} // TODO implement more traffic rules type CCCheckpoint struct { Checkpoint @@ -32,7 +30,7 @@ func (this *CCCheckpoint) Start() { this.cache = ttlcache.NewCache() } -func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *CCCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { value = 0 if this.cache == nil { @@ -66,12 +64,12 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti var key = "" switch userType { case "ip": - key = this.ip(req) + key = req.WAFRemoteIP() case "cookie": if len(userField) == 0 { - key = this.ip(req) + key = req.WAFRemoteIP() } else { - cookie, _ := req.Cookie(userField) + cookie, _ := req.WAFRaw().Cookie(userField) if cookie != nil { v := cookie.Value if userIndex > 0 && len(v) > userIndex { @@ -82,9 +80,9 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti } case "get": if len(userField) == 0 { - key = this.ip(req) + key = req.WAFRemoteIP() } else { - v := req.URL.Query().Get(userField) + v := req.WAFRaw().URL.Query().Get(userField) if userIndex > 0 && len(v) > userIndex { v = v[userIndex:] } @@ -92,9 +90,9 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti } case "post": if len(userField) == 0 { - key = this.ip(req) + key = req.WAFRemoteIP() } else { - v := req.PostFormValue(userField) + v := req.WAFRaw().PostFormValue(userField) if userIndex > 0 && len(v) > userIndex { v = v[userIndex:] } @@ -102,19 +100,19 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti } case "header": if len(userField) == 0 { - key = this.ip(req) + key = req.WAFRemoteIP() } else { - v := req.Header.Get(userField) + v := req.WAFRaw().Header.Get(userField) if userIndex > 0 && len(v) > userIndex { v = v[userIndex:] } key = "USER@" + userType + "@" + userField + "@" + v } default: - key = this.ip(req) + key = req.WAFRemoteIP() } if len(key) == 0 { - key = this.ip(req) + key = req.WAFRemoteIP() } value = this.cache.IncreaseInt64(key, int64(1), time.Now().Unix()+period) } @@ -122,7 +120,7 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti return } -func (this *CCCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *CCCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } @@ -210,38 +208,3 @@ func (this *CCCheckpoint) Stop() { this.cache = nil } } - -func (this *CCCheckpoint) ip(req *requests.Request) string { - // X-Forwarded-For - forwardedFor := req.Header.Get("X-Forwarded-For") - if len(forwardedFor) > 0 { - commaIndex := strings.Index(forwardedFor, ",") - if commaIndex > 0 { - return forwardedFor[:commaIndex] - } - return forwardedFor - } - - // Real-IP - { - realIP, ok := req.Header["X-Real-IP"] - if ok && len(realIP) > 0 { - return realIP[0] - } - } - - // Real-Ip - { - realIP, ok := req.Header["X-Real-Ip"] - if ok && len(realIP) > 0 { - return realIP[0] - } - } - - // Remote-Addr - host, _, err := net.SplitHostPort(req.RemoteAddr) - if err == nil { - return host - } - return req.RemoteAddr -} diff --git a/internal/waf/checkpoints/cc_test.go b/internal/waf/checkpoints/cc_test.go index 6245798..249b477 100644 --- a/internal/waf/checkpoints/cc_test.go +++ b/internal/waf/checkpoints/cc_test.go @@ -2,6 +2,7 @@ package checkpoints import ( "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/maps" "net/http" "testing" ) @@ -12,31 +13,31 @@ func TestCCCheckpoint_RequestValue(t *testing.T) { t.Fatal(err) } - req := requests.NewRequest(raw) - req.RemoteAddr = "127.0.0.1" + req := requests.NewTestRequest(raw) + req.WAFRaw().RemoteAddr = "127.0.0.1" checkpoint := new(CCCheckpoint) checkpoint.Init() checkpoint.Start() - options := map[string]string{ + options := maps.Map{ "period": "5", } t.Log(checkpoint.RequestValue(req, "requests", options)) t.Log(checkpoint.RequestValue(req, "requests", options)) - req.RemoteAddr = "127.0.0.2" + req.WAFRaw().RemoteAddr = "127.0.0.2" t.Log(checkpoint.RequestValue(req, "requests", options)) - req.RemoteAddr = "127.0.0.1" + req.WAFRaw().RemoteAddr = "127.0.0.1" t.Log(checkpoint.RequestValue(req, "requests", options)) - req.RemoteAddr = "127.0.0.2" + req.WAFRaw().RemoteAddr = "127.0.0.2" t.Log(checkpoint.RequestValue(req, "requests", options)) - req.RemoteAddr = "127.0.0.2" + req.WAFRaw().RemoteAddr = "127.0.0.2" t.Log(checkpoint.RequestValue(req, "requests", options)) - req.RemoteAddr = "127.0.0.2" + req.WAFRaw().RemoteAddr = "127.0.0.2" t.Log(checkpoint.RequestValue(req, "requests", options)) } diff --git a/internal/waf/checkpoints/checkpoint_interface.go b/internal/waf/checkpoints/checkpoint_interface.go index 0a8ac8d..532ae62 100644 --- a/internal/waf/checkpoints/checkpoint_interface.go +++ b/internal/waf/checkpoints/checkpoint_interface.go @@ -5,32 +5,32 @@ import ( "github.com/iwind/TeaGo/maps" ) -// Check Point +// CheckpointInterface Check Point type CheckpointInterface interface { - // initialize + // Init initialize Init() - // is request? + // IsRequest is request? IsRequest() bool - // is composed? + // IsComposed is composed? IsComposed() bool - // get request value - RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) + // RequestValue get request value + RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) - // get response value - ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) + // ResponseValue get response value + ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) - // param option list + // ParamOptions param option list ParamOptions() *ParamOptions - // options + // Options options Options() []OptionInterface - // start + // Start start Start() - // stop + // Stop stop Stop() } diff --git a/internal/waf/checkpoints/request_all.go b/internal/waf/checkpoints/request_all.go index 64664a6..30a5f98 100644 --- a/internal/waf/checkpoints/request_all.go +++ b/internal/waf/checkpoints/request_all.go @@ -5,32 +5,34 @@ import ( "github.com/iwind/TeaGo/maps" ) -// ${requestAll} +// RequestAllCheckpoint ${requestAll} type RequestAllCheckpoint struct { Checkpoint } -func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { valueBytes := []byte{} - if len(req.RequestURI) > 0 { - valueBytes = append(valueBytes, req.RequestURI...) - } else if req.URL != nil { - valueBytes = append(valueBytes, req.URL.RequestURI()...) + if len(req.WAFRaw().RequestURI) > 0 { + valueBytes = append(valueBytes, req.WAFRaw().RequestURI...) + } else if req.WAFRaw().URL != nil { + valueBytes = append(valueBytes, req.WAFRaw().URL.RequestURI()...) } - if req.Body != nil { + if req.WAFRaw().Body != nil { valueBytes = append(valueBytes, ' ') - if len(req.BodyData) == 0 { - data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes + var bodyData = req.WAFGetCacheBody() + if len(bodyData) == 0 { + data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes if err != nil { return "", err, nil } - req.BodyData = data - req.RestoreBody(data) + bodyData = data + req.WAFSetCacheBody(data) + req.WAFRestoreBody(data) } - valueBytes = append(valueBytes, req.BodyData...) + valueBytes = append(valueBytes, bodyData...) } value = valueBytes @@ -38,7 +40,7 @@ func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param stri return } -func (this *RequestAllCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestAllCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { value = "" if this.IsRequest() { return this.RequestValue(req, param, options) diff --git a/internal/waf/checkpoints/request_all_test.go b/internal/waf/checkpoints/request_all_test.go index d8a12a0..5ee81d5 100644 --- a/internal/waf/checkpoints/request_all_test.go +++ b/internal/waf/checkpoints/request_all_test.go @@ -18,7 +18,7 @@ func TestRequestAllCheckpoint_RequestValue(t *testing.T) { } checkpoint := new(RequestAllCheckpoint) - v, sysErr, userErr := checkpoint.RequestValue(requests.NewRequest(req), "", nil) + v, sysErr, userErr := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil) if sysErr != nil { t.Fatal(sysErr) } @@ -42,7 +42,7 @@ func TestRequestAllCheckpoint_RequestValue_Max(t *testing.T) { } checkpoint := new(RequestBodyCheckpoint) - value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil) + value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil) if err != nil { t.Fatal(err) } @@ -65,6 +65,6 @@ func BenchmarkRequestAllCheckpoint_RequestValue(b *testing.B) { checkpoint := new(RequestAllCheckpoint) for i := 0; i < b.N; i++ { - _, _, _ = checkpoint.RequestValue(requests.NewRequest(req), "", nil) + _, _, _ = checkpoint.RequestValue(requests.NewTestRequest(req), "", nil) } } diff --git a/internal/waf/checkpoints/request_arg.go b/internal/waf/checkpoints/request_arg.go index 813fc51..a9c51a5 100644 --- a/internal/waf/checkpoints/request_arg.go +++ b/internal/waf/checkpoints/request_arg.go @@ -9,11 +9,11 @@ type RequestArgCheckpoint struct { Checkpoint } -func (this *RequestArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - return req.URL.Query().Get(param), nil, nil +func (this *RequestArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + return req.WAFRaw().URL.Query().Get(param), nil, nil } -func (this *RequestArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_arg_test.go b/internal/waf/checkpoints/request_arg_test.go index a7cdaf3..6ac84f6 100644 --- a/internal/waf/checkpoints/request_arg_test.go +++ b/internal/waf/checkpoints/request_arg_test.go @@ -12,7 +12,7 @@ func TestArgParam_RequestValue(t *testing.T) { t.Fatal(err) } - req := requests.NewRequest(rawReq) + req := requests.NewTestRequest(rawReq) checkpoint := new(RequestArgCheckpoint) t.Log(checkpoint.RequestValue(req, "name", nil)) diff --git a/internal/waf/checkpoints/request_args.go b/internal/waf/checkpoints/request_args.go index 9a3883c..a83dc3f 100644 --- a/internal/waf/checkpoints/request_args.go +++ b/internal/waf/checkpoints/request_args.go @@ -9,12 +9,12 @@ type RequestArgsCheckpoint struct { Checkpoint } -func (this *RequestArgsCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - value = req.URL.RawQuery +func (this *RequestArgsCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.WAFRaw().URL.RawQuery return } -func (this *RequestArgsCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestArgsCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_body.go b/internal/waf/checkpoints/request_body.go index d6e54ca..a04d0ff 100644 --- a/internal/waf/checkpoints/request_body.go +++ b/internal/waf/checkpoints/request_body.go @@ -5,31 +5,33 @@ import ( "github.com/iwind/TeaGo/maps" ) -// ${requestBody} +// RequestBodyCheckpoint ${requestBody} type RequestBodyCheckpoint struct { Checkpoint } -func (this *RequestBodyCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - if req.Body == nil { +func (this *RequestBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + if req.WAFRaw().Body == nil { value = "" return } - if len(req.BodyData) == 0 { - data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes + var bodyData = req.WAFGetCacheBody() + if len(bodyData) == 0 { + data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes if err != nil { return "", err, nil } - req.BodyData = data - req.RestoreBody(data) + bodyData = data + req.WAFSetCacheBody(data) + req.WAFRestoreBody(data) } - return req.BodyData, nil, nil + return bodyData, nil, nil } -func (this *RequestBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_body_test.go b/internal/waf/checkpoints/request_body_test.go index b1c982d..8bdb0d2 100644 --- a/internal/waf/checkpoints/request_body_test.go +++ b/internal/waf/checkpoints/request_body_test.go @@ -11,19 +11,20 @@ import ( ) func TestRequestBodyCheckpoint_RequestValue(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456"))) + rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456"))) if err != nil { t.Fatal(err) } - + var req = requests.NewTestRequest(rawReq) checkpoint := new(RequestBodyCheckpoint) - t.Log(checkpoint.RequestValue(requests.NewRequest(req), "", nil)) + t.Log(checkpoint.RequestValue(req, "", nil)) - body, err := ioutil.ReadAll(req.Body) + body, err := ioutil.ReadAll(rawReq.Body) if err != nil { t.Fatal(err) } t.Log(string(body)) + t.Log(string(req.WAFGetCacheBody())) } func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) { @@ -33,7 +34,7 @@ func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) { } checkpoint := new(RequestBodyCheckpoint) - value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil) + value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil) if err != nil { t.Fatal(err) } diff --git a/internal/waf/checkpoints/request_content_type.go b/internal/waf/checkpoints/request_content_type.go index 6a11132..6ff04cd 100644 --- a/internal/waf/checkpoints/request_content_type.go +++ b/internal/waf/checkpoints/request_content_type.go @@ -9,12 +9,12 @@ type RequestContentTypeCheckpoint struct { Checkpoint } -func (this *RequestContentTypeCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - value = req.Header.Get("Content-Type") +func (this *RequestContentTypeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.WAFRaw().Header.Get("Content-Type") return } -func (this *RequestContentTypeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestContentTypeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_cookie.go b/internal/waf/checkpoints/request_cookie.go index eb1ba91..33fd968 100644 --- a/internal/waf/checkpoints/request_cookie.go +++ b/internal/waf/checkpoints/request_cookie.go @@ -9,8 +9,8 @@ type RequestCookieCheckpoint struct { Checkpoint } -func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - cookie, err := req.Cookie(param) +func (this *RequestCookieCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + cookie, err := req.WAFRaw().Cookie(param) if err != nil { value = "" return @@ -20,7 +20,7 @@ func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param s return } -func (this *RequestCookieCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestCookieCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_cookies.go b/internal/waf/checkpoints/request_cookies.go index f284788..a9f1035 100644 --- a/internal/waf/checkpoints/request_cookies.go +++ b/internal/waf/checkpoints/request_cookies.go @@ -11,16 +11,16 @@ type RequestCookiesCheckpoint struct { Checkpoint } -func (this *RequestCookiesCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestCookiesCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { var cookies = []string{} - for _, cookie := range req.Cookies() { + for _, cookie := range req.WAFRaw().Cookies() { cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value)) } value = strings.Join(cookies, "&") return } -func (this *RequestCookiesCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestCookiesCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_form_arg.go b/internal/waf/checkpoints/request_form_arg.go index 92fa0ab..1862784 100644 --- a/internal/waf/checkpoints/request_form_arg.go +++ b/internal/waf/checkpoints/request_form_arg.go @@ -6,33 +6,35 @@ import ( "net/url" ) -// ${requestForm.arg} +// RequestFormArgCheckpoint ${requestForm.arg} type RequestFormArgCheckpoint struct { Checkpoint } -func (this *RequestFormArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - if req.Body == nil { +func (this *RequestFormArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + if req.WAFRaw().Body == nil { value = "" return } - if len(req.BodyData) == 0 { - data, err := req.ReadBody(32 * 1024 * 1024) // read 32m bytes + var bodyData = req.WAFGetCacheBody() + if len(bodyData) == 0 { + data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes if err != nil { return "", err, nil } - req.BodyData = data - req.RestoreBody(data) + bodyData = data + req.WAFSetCacheBody(data) + req.WAFRestoreBody(data) } // TODO improve performance - values, _ := url.ParseQuery(string(req.BodyData)) + values, _ := url.ParseQuery(string(bodyData)) return values.Get(param), nil, nil } -func (this *RequestFormArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestFormArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_form_arg_test.go b/internal/waf/checkpoints/request_form_arg_test.go index 01c0396..5da0624 100644 --- a/internal/waf/checkpoints/request_form_arg_test.go +++ b/internal/waf/checkpoints/request_form_arg_test.go @@ -15,8 +15,8 @@ func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) { t.Fatal(err) } - req := requests.NewRequest(rawReq) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req := requests.NewTestRequest(rawReq) + req.WAFRaw().Header.Set("Content-Type", "application/x-www-form-urlencoded") checkpoint := new(RequestFormArgCheckpoint) t.Log(checkpoint.RequestValue(req, "name", nil)) @@ -24,7 +24,7 @@ func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) { t.Log(checkpoint.RequestValue(req, "Hello", nil)) t.Log(checkpoint.RequestValue(req, "encoded", nil)) - body, err := ioutil.ReadAll(req.Body) + body, err := ioutil.ReadAll(req.WAFRaw().Body) if err != nil { t.Fatal(err) } diff --git a/internal/waf/checkpoints/request_general_header_length.go b/internal/waf/checkpoints/request_general_header_length.go index 4f2a430..50e8251 100644 --- a/internal/waf/checkpoints/request_general_header_length.go +++ b/internal/waf/checkpoints/request_general_header_length.go @@ -14,7 +14,7 @@ func (this *RequestGeneralHeaderLengthCheckpoint) IsComposed() bool { return true } -func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { value = false headers := options.GetSlice("headers") @@ -25,7 +25,7 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Req length := options.GetInt("length") for _, header := range headers { - v := req.Header.Get(types.String(header)) + v := req.WAFRaw().Header.Get(types.String(header)) if len(v) > length { value = true break @@ -35,6 +35,6 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Req return } -func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { return } diff --git a/internal/waf/checkpoints/request_header.go b/internal/waf/checkpoints/request_header.go index 029def4..8b206d0 100644 --- a/internal/waf/checkpoints/request_header.go +++ b/internal/waf/checkpoints/request_header.go @@ -10,8 +10,8 @@ type RequestHeaderCheckpoint struct { Checkpoint } -func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - v, found := req.Header[param] +func (this *RequestHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + v, found := req.WAFRaw().Header[param] if !found { value = "" return @@ -20,7 +20,7 @@ func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param s return } -func (this *RequestHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_headers.go b/internal/waf/checkpoints/request_headers.go index c5ef280..0fdb225 100644 --- a/internal/waf/checkpoints/request_headers.go +++ b/internal/waf/checkpoints/request_headers.go @@ -11,9 +11,9 @@ type RequestHeadersCheckpoint struct { Checkpoint } -func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestHeadersCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { var headers = []string{} - for k, v := range req.Header { + for k, v := range req.WAFRaw().Header { for _, subV := range v { headers = append(headers, k+": "+subV) } @@ -23,7 +23,7 @@ func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param return } -func (this *RequestHeadersCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestHeadersCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_host.go b/internal/waf/checkpoints/request_host.go index 60174d0..105f4a7 100644 --- a/internal/waf/checkpoints/request_host.go +++ b/internal/waf/checkpoints/request_host.go @@ -9,12 +9,12 @@ type RequestHostCheckpoint struct { Checkpoint } -func (this *RequestHostCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - value = req.Host +func (this *RequestHostCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.WAFRaw().Host return } -func (this *RequestHostCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestHostCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_host_test.go b/internal/waf/checkpoints/request_host_test.go index fc1b449..b9274a7 100644 --- a/internal/waf/checkpoints/request_host_test.go +++ b/internal/waf/checkpoints/request_host_test.go @@ -12,8 +12,8 @@ func TestRequestHostCheckpoint_RequestValue(t *testing.T) { t.Fatal(err) } - req := requests.NewRequest(rawReq) - req.Header.Set("Host", "cloud.teaos.cn") + req := requests.NewTestRequest(rawReq) + req.WAFRaw().Header.Set("Host", "cloud.teaos.cn") checkpoint := new(RequestHostCheckpoint) t.Log(checkpoint.RequestValue(req, "", nil)) diff --git a/internal/waf/checkpoints/request_json_arg.go b/internal/waf/checkpoints/request_json_arg.go index 1a6414e..341db2d 100644 --- a/internal/waf/checkpoints/request_json_arg.go +++ b/internal/waf/checkpoints/request_json_arg.go @@ -8,24 +8,27 @@ import ( "strings" ) -// ${requestJSON.arg} +// RequestJSONArgCheckpoint ${requestJSON.arg} type RequestJSONArgCheckpoint struct { Checkpoint } -func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - if len(req.BodyData) == 0 { - data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes +func (this *RequestJSONArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + var bodyData = req.WAFGetCacheBody() + if len(bodyData) == 0 { + data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes if err != nil { return "", err, nil } - req.BodyData = data - defer req.RestoreBody(data) + + bodyData = data + req.WAFSetCacheBody(data) + defer req.WAFRestoreBody(data) } // TODO improve performance var m interface{} = nil - err := json.Unmarshal(req.BodyData, &m) + err := json.Unmarshal(bodyData, &m) if err != nil || m == nil { return "", nil, err } @@ -37,7 +40,7 @@ func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param return "", nil, nil } -func (this *RequestJSONArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestJSONArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_json_arg_test.go b/internal/waf/checkpoints/request_json_arg_test.go index 00708be..63fae0b 100644 --- a/internal/waf/checkpoints/request_json_arg_test.go +++ b/internal/waf/checkpoints/request_json_arg_test.go @@ -20,7 +20,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) { t.Fatal(err) } - req := requests.NewRequest(rawReq) + req := requests.NewTestRequest(rawReq) //req.Header.Set("Content-Type", "application/x-www-form-urlencoded") checkpoint := new(RequestJSONArgCheckpoint) @@ -31,7 +31,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) { t.Log(checkpoint.RequestValue(req, "books", nil)) t.Log(checkpoint.RequestValue(req, "books.1", nil)) - body, err := ioutil.ReadAll(req.Body) + body, err := ioutil.ReadAll(req.WAFRaw().Body) if err != nil { t.Fatal(err) } @@ -50,7 +50,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) { t.Fatal(err) } - req := requests.NewRequest(rawReq) + req := requests.NewTestRequest(rawReq) //req.Header.Set("Content-Type", "application/x-www-form-urlencoded") checkpoint := new(RequestJSONArgCheckpoint) @@ -61,7 +61,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) { t.Log(checkpoint.RequestValue(req, "0.books", nil)) t.Log(checkpoint.RequestValue(req, "0.books.1", nil)) - body, err := ioutil.ReadAll(req.Body) + body, err := ioutil.ReadAll(req.WAFRaw().Body) if err != nil { t.Fatal(err) } @@ -80,7 +80,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) { t.Fatal(err) } - req := requests.NewRequest(rawReq) + req := requests.NewTestRequest(rawReq) //req.Header.Set("Content-Type", "application/x-www-form-urlencoded") checkpoint := new(RequestJSONArgCheckpoint) @@ -91,7 +91,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) { t.Log(checkpoint.RequestValue(req, "0.books", nil)) t.Log(checkpoint.RequestValue(req, "0.books.1", nil)) - body, err := ioutil.ReadAll(req.Body) + body, err := ioutil.ReadAll(req.WAFRaw().Body) if err != nil { t.Fatal(err) } diff --git a/internal/waf/checkpoints/request_length.go b/internal/waf/checkpoints/request_length.go index 9a09556..e26a18b 100644 --- a/internal/waf/checkpoints/request_length.go +++ b/internal/waf/checkpoints/request_length.go @@ -9,12 +9,12 @@ type RequestLengthCheckpoint struct { Checkpoint } -func (this *RequestLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - value = req.ContentLength +func (this *RequestLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.WAFRaw().ContentLength return } -func (this *RequestLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_method.go b/internal/waf/checkpoints/request_method.go index b27deb0..3b85fc0 100644 --- a/internal/waf/checkpoints/request_method.go +++ b/internal/waf/checkpoints/request_method.go @@ -9,12 +9,12 @@ type RequestMethodCheckpoint struct { Checkpoint } -func (this *RequestMethodCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - value = req.Method +func (this *RequestMethodCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.WAFRaw().Method return } -func (this *RequestMethodCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestMethodCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_path.go b/internal/waf/checkpoints/request_path.go index 7c934de..5e757bb 100644 --- a/internal/waf/checkpoints/request_path.go +++ b/internal/waf/checkpoints/request_path.go @@ -9,11 +9,11 @@ type RequestPathCheckpoint struct { Checkpoint } -func (this *RequestPathCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - return req.URL.Path, nil, nil +func (this *RequestPathCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + return req.WAFRaw().URL.Path, nil, nil } -func (this *RequestPathCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestPathCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_path_test.go b/internal/waf/checkpoints/request_path_test.go index e100602..88f47cb 100644 --- a/internal/waf/checkpoints/request_path_test.go +++ b/internal/waf/checkpoints/request_path_test.go @@ -12,7 +12,7 @@ func TestRequestPathCheckpoint_RequestValue(t *testing.T) { t.Fatal(err) } - req := requests.NewRequest(rawReq) + req := requests.NewTestRequest(rawReq) checkpoint := new(RequestPathCheckpoint) t.Log(checkpoint.RequestValue(req, "", nil)) } diff --git a/internal/waf/checkpoints/request_proto.go b/internal/waf/checkpoints/request_proto.go index f3cd372..235b2db 100644 --- a/internal/waf/checkpoints/request_proto.go +++ b/internal/waf/checkpoints/request_proto.go @@ -9,12 +9,12 @@ type RequestProtoCheckpoint struct { Checkpoint } -func (this *RequestProtoCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - value = req.Proto +func (this *RequestProtoCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.WAFRaw().Proto return } -func (this *RequestProtoCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestProtoCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_raw_remote_addr.go b/internal/waf/checkpoints/request_raw_remote_addr.go index 9da7f11..7886c44 100644 --- a/internal/waf/checkpoints/request_raw_remote_addr.go +++ b/internal/waf/checkpoints/request_raw_remote_addr.go @@ -10,17 +10,17 @@ type RequestRawRemoteAddrCheckpoint struct { Checkpoint } -func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - host, _, err := net.SplitHostPort(req.RemoteAddr) +func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + host, _, err := net.SplitHostPort(req.WAFRaw().RemoteAddr) if err == nil { value = host } else { - value = req.RemoteAddr + value = req.WAFRaw().RemoteAddr } return } -func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_referer.go b/internal/waf/checkpoints/request_referer.go index 0160579..775c084 100644 --- a/internal/waf/checkpoints/request_referer.go +++ b/internal/waf/checkpoints/request_referer.go @@ -9,12 +9,12 @@ type RequestRefererCheckpoint struct { Checkpoint } -func (this *RequestRefererCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - value = req.Referer() +func (this *RequestRefererCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.WAFRaw().Referer() return } -func (this *RequestRefererCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRefererCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_remote_addr.go b/internal/waf/checkpoints/request_remote_addr.go index b80e42b..dc26a10 100644 --- a/internal/waf/checkpoints/request_remote_addr.go +++ b/internal/waf/checkpoints/request_remote_addr.go @@ -3,56 +3,18 @@ package checkpoints import ( "github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/iwind/TeaGo/maps" - "net" - "strings" ) type RequestRemoteAddrCheckpoint struct { Checkpoint } -func (this *RequestRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - // X-Forwarded-For - forwardedFor := req.Header.Get("X-Forwarded-For") - if len(forwardedFor) > 0 { - commaIndex := strings.Index(forwardedFor, ",") - if commaIndex > 0 { - value = forwardedFor[:commaIndex] - return - } - value = forwardedFor - return - } - - // Real-IP - { - realIP, ok := req.Header["X-Real-IP"] - if ok && len(realIP) > 0 { - value = realIP[0] - return - } - } - - // Real-Ip - { - realIP, ok := req.Header["X-Real-Ip"] - if ok && len(realIP) > 0 { - value = realIP[0] - return - } - } - - // Remote-Addr - host, _, err := net.SplitHostPort(req.RemoteAddr) - if err == nil { - value = host - } else { - value = req.RemoteAddr - } +func (this *RequestRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.WAFRemoteIP() return } -func (this *RequestRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_remote_port.go b/internal/waf/checkpoints/request_remote_port.go index 5f65c7e..f5aa158 100644 --- a/internal/waf/checkpoints/request_remote_port.go +++ b/internal/waf/checkpoints/request_remote_port.go @@ -11,8 +11,8 @@ type RequestRemotePortCheckpoint struct { Checkpoint } -func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - _, port, err := net.SplitHostPort(req.RemoteAddr) +func (this *RequestRemotePortCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + _, port, err := net.SplitHostPort(req.WAFRaw().RemoteAddr) if err == nil { value = types.Int(port) } else { @@ -21,7 +21,7 @@ func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, par return } -func (this *RequestRemotePortCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRemotePortCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_remote_user.go b/internal/waf/checkpoints/request_remote_user.go index 1c27cc8..a2d1e20 100644 --- a/internal/waf/checkpoints/request_remote_user.go +++ b/internal/waf/checkpoints/request_remote_user.go @@ -9,8 +9,8 @@ type RequestRemoteUserCheckpoint struct { Checkpoint } -func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - username, _, ok := req.BasicAuth() +func (this *RequestRemoteUserCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + username, _, ok := req.WAFRaw().BasicAuth() if !ok { value = "" return @@ -19,7 +19,7 @@ func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, par return } -func (this *RequestRemoteUserCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRemoteUserCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_scheme.go b/internal/waf/checkpoints/request_scheme.go index 11e27e1..05f98c6 100644 --- a/internal/waf/checkpoints/request_scheme.go +++ b/internal/waf/checkpoints/request_scheme.go @@ -9,12 +9,12 @@ type RequestSchemeCheckpoint struct { Checkpoint } -func (this *RequestSchemeCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - value = req.URL.Scheme +func (this *RequestSchemeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.WAFRaw().URL.Scheme return } -func (this *RequestSchemeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestSchemeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_scheme_test.go b/internal/waf/checkpoints/request_scheme_test.go index 461cf23..8738a3d 100644 --- a/internal/waf/checkpoints/request_scheme_test.go +++ b/internal/waf/checkpoints/request_scheme_test.go @@ -12,7 +12,7 @@ func TestRequestSchemeCheckpoint_RequestValue(t *testing.T) { t.Fatal(err) } - req := requests.NewRequest(rawReq) + req := requests.NewTestRequest(rawReq) checkpoint := new(RequestSchemeCheckpoint) t.Log(checkpoint.RequestValue(req, "", nil)) } diff --git a/internal/waf/checkpoints/request_upload.go b/internal/waf/checkpoints/request_upload.go index d76656d..ed43dcf 100644 --- a/internal/waf/checkpoints/request_upload.go +++ b/internal/waf/checkpoints/request_upload.go @@ -11,63 +11,65 @@ import ( "strings" ) -// ${requestUpload.arg} +// RequestUploadCheckpoint ${requestUpload.arg} type RequestUploadCheckpoint struct { Checkpoint } -func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestUploadCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { value = "" if param == "minSize" || param == "maxSize" { value = 0 } - if req.Method != http.MethodPost { + if req.WAFRaw().Method != http.MethodPost { return } - if req.Body == nil { + if req.WAFRaw().Body == nil { return } - if req.MultipartForm == nil { - if len(req.BodyData) == 0 { - data, err := req.ReadBody(32 * 1024 * 1024) + if req.WAFRaw().MultipartForm == nil { + var bodyData = req.WAFGetCacheBody() + if len(bodyData) == 0 { + data, err := req.WAFReadBody(32 * 1024 * 1024) if err != nil { sysErr = err return } - req.BodyData = data - defer req.RestoreBody(data) + bodyData = data + req.WAFSetCacheBody(data) + defer req.WAFRestoreBody(data) } - oldBody := req.Body - req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyData)) + oldBody := req.WAFRaw().Body + req.WAFRaw().Body = ioutil.NopCloser(bytes.NewBuffer(bodyData)) - err := req.ParseMultipartForm(32 * 1024 * 1024) + err := req.WAFRaw().ParseMultipartForm(32 * 1024 * 1024) // 还原 - req.Body = oldBody + req.WAFRaw().Body = oldBody if err != nil { userErr = err return } - if req.MultipartForm == nil { + if req.WAFRaw().MultipartForm == nil { return } } if param == "field" { // field fields := []string{} - for field := range req.MultipartForm.File { + for field := range req.WAFRaw().MultipartForm.File { fields = append(fields, field) } value = strings.Join(fields, ",") } else if param == "minSize" { // minSize minSize := int64(0) - for _, files := range req.MultipartForm.File { + for _, files := range req.WAFRaw().MultipartForm.File { for _, file := range files { if minSize == 0 || minSize > file.Size { minSize = file.Size @@ -77,7 +79,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s value = minSize } else if param == "maxSize" { // maxSize maxSize := int64(0) - for _, files := range req.MultipartForm.File { + for _, files := range req.WAFRaw().MultipartForm.File { for _, file := range files { if maxSize < file.Size { maxSize = file.Size @@ -87,7 +89,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s value = maxSize } else if param == "name" { // name names := []string{} - for _, files := range req.MultipartForm.File { + for _, files := range req.WAFRaw().MultipartForm.File { for _, file := range files { if !lists.ContainsString(names, file.Filename) { names = append(names, file.Filename) @@ -97,7 +99,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s value = strings.Join(names, ",") } else if param == "ext" { // ext extensions := []string{} - for _, files := range req.MultipartForm.File { + for _, files := range req.WAFRaw().MultipartForm.File { for _, file := range files { if len(file.Filename) > 0 { exit := strings.ToLower(filepath.Ext(file.Filename)) @@ -113,7 +115,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s return } -func (this *RequestUploadCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestUploadCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_upload_test.go b/internal/waf/checkpoints/request_upload_test.go index cc1ab64..0bc03a5 100644 --- a/internal/waf/checkpoints/request_upload_test.go +++ b/internal/waf/checkpoints/request_upload_test.go @@ -63,8 +63,8 @@ func TestRequestUploadCheckpoint_RequestValue(t *testing.T) { t.Fatal() } - req := requests.NewRequest(rawReq) - req.Header.Add("Content-Type", writer.FormDataContentType()) + req := requests.NewTestRequest(rawReq) + req.WAFRaw().Header.Add("Content-Type", writer.FormDataContentType()) checkpoint := new(RequestUploadCheckpoint) t.Log(checkpoint.RequestValue(req, "field", nil)) @@ -73,7 +73,7 @@ func TestRequestUploadCheckpoint_RequestValue(t *testing.T) { t.Log(checkpoint.RequestValue(req, "name", nil)) t.Log(checkpoint.RequestValue(req, "ext", nil)) - data, err := ioutil.ReadAll(req.Body) + data, err := ioutil.ReadAll(req.WAFRaw().Body) if err != nil { t.Fatal(err) } diff --git a/internal/waf/checkpoints/request_uri.go b/internal/waf/checkpoints/request_uri.go index d927baf..bfe72fd 100644 --- a/internal/waf/checkpoints/request_uri.go +++ b/internal/waf/checkpoints/request_uri.go @@ -9,16 +9,16 @@ type RequestURICheckpoint struct { Checkpoint } -func (this *RequestURICheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - if len(req.RequestURI) > 0 { - value = req.RequestURI - } else if req.URL != nil { - value = req.URL.RequestURI() +func (this *RequestURICheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + if len(req.WAFRaw().RequestURI) > 0 { + value = req.WAFRaw().RequestURI + } else if req.WAFRaw().URL != nil { + value = req.WAFRaw().URL.RequestURI() } return } -func (this *RequestURICheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestURICheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_user_agent.go b/internal/waf/checkpoints/request_user_agent.go index 407fe50..a9c1bec 100644 --- a/internal/waf/checkpoints/request_user_agent.go +++ b/internal/waf/checkpoints/request_user_agent.go @@ -9,12 +9,12 @@ type RequestUserAgentCheckpoint struct { Checkpoint } -func (this *RequestUserAgentCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - value = req.UserAgent() +func (this *RequestUserAgentCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.WAFRaw().UserAgent() return } -func (this *RequestUserAgentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestUserAgentCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/response_body.go b/internal/waf/checkpoints/response_body.go index 4e48dac..a39fc67 100644 --- a/internal/waf/checkpoints/response_body.go +++ b/internal/waf/checkpoints/response_body.go @@ -16,12 +16,12 @@ func (this *ResponseBodyCheckpoint) IsRequest() bool { return false } -func (this *ResponseBodyCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { value = "" return } -func (this *ResponseBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { value = "" if resp != nil && resp.Body != nil { if len(resp.BodyData) > 0 { diff --git a/internal/waf/checkpoints/response_bytes_sent.go b/internal/waf/checkpoints/response_bytes_sent.go index 9461f97..75a719a 100644 --- a/internal/waf/checkpoints/response_bytes_sent.go +++ b/internal/waf/checkpoints/response_bytes_sent.go @@ -14,12 +14,12 @@ func (this *ResponseBytesSentCheckpoint) IsRequest() bool { return false } -func (this *ResponseBytesSentCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseBytesSentCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { value = 0 return } -func (this *ResponseBytesSentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseBytesSentCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { value = 0 if resp != nil { value = resp.ContentLength diff --git a/internal/waf/checkpoints/response_general_header_length.go b/internal/waf/checkpoints/response_general_header_length.go index f1ef6ff..00404fd 100644 --- a/internal/waf/checkpoints/response_general_header_length.go +++ b/internal/waf/checkpoints/response_general_header_length.go @@ -18,12 +18,12 @@ func (this *ResponseGeneralHeaderLengthCheckpoint) IsComposed() bool { return true } -func (this *ResponseGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { return } -func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { value = false headers := options.GetSlice("headers") @@ -34,7 +34,7 @@ func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.R length := options.GetInt("length") for _, header := range headers { - v := req.Header.Get(types.String(header)) + v := req.WAFRaw().Header.Get(types.String(header)) if len(v) > length { value = true break diff --git a/internal/waf/checkpoints/response_header.go b/internal/waf/checkpoints/response_header.go index 5d23df5..839e657 100644 --- a/internal/waf/checkpoints/response_header.go +++ b/internal/waf/checkpoints/response_header.go @@ -14,12 +14,12 @@ func (this *ResponseHeaderCheckpoint) IsRequest() bool { return false } -func (this *ResponseHeaderCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { value = "" return } -func (this *ResponseHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if resp != nil && resp.Header != nil { value = resp.Header.Get(param) } else { diff --git a/internal/waf/checkpoints/response_status.go b/internal/waf/checkpoints/response_status.go index fb63a79..eb9a9bd 100644 --- a/internal/waf/checkpoints/response_status.go +++ b/internal/waf/checkpoints/response_status.go @@ -5,7 +5,7 @@ import ( "github.com/iwind/TeaGo/maps" ) -// ${bytesSent} +// ResponseStatusCheckpoint ${bytesSent} type ResponseStatusCheckpoint struct { Checkpoint } @@ -14,12 +14,12 @@ func (this *ResponseStatusCheckpoint) IsRequest() bool { return false } -func (this *ResponseStatusCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseStatusCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { value = 0 return } -func (this *ResponseStatusCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseStatusCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if resp != nil { value = resp.StatusCode } diff --git a/internal/waf/checkpoints/sample_request.go b/internal/waf/checkpoints/sample_request.go index e542eda..1aa1197 100644 --- a/internal/waf/checkpoints/sample_request.go +++ b/internal/waf/checkpoints/sample_request.go @@ -5,16 +5,16 @@ import ( "github.com/iwind/TeaGo/maps" ) -// just a sample checkpoint, copy and change it for your new checkpoint +// SampleRequestCheckpoint just a sample checkpoint, copy and change it for your new checkpoint type SampleRequestCheckpoint struct { Checkpoint } -func (this *SampleRequestCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *SampleRequestCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { return } -func (this *SampleRequestCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *SampleRequestCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/utils.go b/internal/waf/checkpoints/utils.go index e3393d7..6d79e13 100644 --- a/internal/waf/checkpoints/utils.go +++ b/internal/waf/checkpoints/utils.go @@ -1,6 +1,6 @@ package checkpoints -// all check points list +// AllCheckpoints all check points list var AllCheckpoints = []*CheckpointDefinition{ { Name: "通用请求Header长度限制", diff --git a/internal/waf/get302_validator.go b/internal/waf/get302_validator.go new file mode 100644 index 0000000..fbb562b --- /dev/null +++ b/internal/waf/get302_validator.go @@ -0,0 +1,52 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" + "time" +) + +var get302Validator = NewGet302Validator() + +type Get302Validator struct { +} + +func NewGet302Validator() *Get302Validator { + return &Get302Validator{} +} + +func (this *Get302Validator) Run(request requests.Request, writer http.ResponseWriter) { + var info = request.WAFRaw().URL.Query().Get("info") + if len(info) == 0 { + writer.WriteHeader(http.StatusBadRequest) + _, _ = writer.Write([]byte("invalid request")) + return + } + m, err := utils.SimpleDecryptMap(info) + if err != nil { + _, _ = writer.Write([]byte("invalid request")) + return + } + + var timestamp = m.GetInt64("timestamp") + if time.Now().Unix()-timestamp > 5 { // 超过5秒认为失效 + writer.WriteHeader(http.StatusBadRequest) + _, _ = writer.Write([]byte("invalid request")) + return + } + + // 加入白名单 + life := m.GetInt64("life") + if life <= 0 { + life = 600 // 默认10分钟 + } + setId := m.GetString("setId") + SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life) + + // 返回原始URL + var url = m.GetString("url") + http.Redirect(writer, request.WAFRaw(), url, http.StatusFound) +} diff --git a/internal/waf/ip_list.go b/internal/waf/ip_list.go new file mode 100644 index 0000000..5c53624 --- /dev/null +++ b/internal/waf/ip_list.go @@ -0,0 +1,82 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/utils/expires" + "sync" + "sync/atomic" +) + +var SharedIPWhiteList = NewIPList() +var SharedIPBlackLIst = NewIPList() + +const IPTypeAll = "*" + +// IPList IP列表管理 +type IPList struct { + expireList *expires.List + ipMap map[string]int64 // ip => id + idMap map[int64]string // id => ip + + id int64 + locker sync.RWMutex +} + +// NewIPList 获取新对象 +func NewIPList() *IPList { + var list = &IPList{ + ipMap: map[string]int64{}, + idMap: map[int64]string{}, + } + + e := expires.NewList() + list.expireList = e + + go func() { + e.StartGC(func(itemId int64) { + list.remove(itemId) + }) + }() + + return list +} + +// Add 添加IP +func (this *IPList) Add(ipType string, ip string, expiresAt int64) { + ip = ip + "@" + ipType + + var id = this.nextId() + this.expireList.Add(id, expiresAt) + this.locker.Lock() + this.ipMap[ip] = id + this.idMap[id] = ip + this.locker.Unlock() +} + +// Contains 判断是否有某个IP +func (this *IPList) Contains(ipType string, ip string) bool { + ip = ip + "@" + ipType + + this.locker.RLock() + defer this.locker.RUnlock() + _, ok := this.ipMap[ip] + return ok +} + +func (this *IPList) remove(id int64) { + this.locker.Lock() + ip, ok := this.idMap[id] + if ok { + ipId, ok := this.ipMap[ip] + if ok && ipId == id { + delete(this.ipMap, ip) + } + delete(this.idMap, id) + } + this.locker.Unlock() +} + +func (this *IPList) nextId() int64 { + return atomic.AddInt64(&this.id, 1) +} diff --git a/internal/waf/ip_list_test.go b/internal/waf/ip_list_test.go new file mode 100644 index 0000000..c175f43 --- /dev/null +++ b/internal/waf/ip_list_test.go @@ -0,0 +1,67 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package waf + +import ( + "github.com/iwind/TeaGo/assert" + "github.com/iwind/TeaGo/logs" + "runtime" + "strconv" + "testing" + "time" +) + +func TestNewIPList(t *testing.T) { + list := NewIPList() + list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix()) + list.Add(IPTypeAll, "127.0.0.2", time.Now().Unix()+1) + list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix()+2) + list.Add(IPTypeAll, "127.0.0.3", time.Now().Unix()+3) + list.Add(IPTypeAll, "127.0.0.10", time.Now().Unix()+10) + + var ticker = time.NewTicker(1 * time.Second) + for range ticker.C { + t.Log("====") + logs.PrintAsJSON(list.ipMap, t) + logs.PrintAsJSON(list.idMap, t) + if len(list.idMap) == 0 { + break + } + } +} + +func TestIPList_Contains(t *testing.T) { + a := assert.NewAssertion(t) + + list := NewIPList() + + for i := 0; i < 1_0000; i++ { + list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) + } + a.IsTrue(list.Contains(IPTypeAll, "192.168.1.100")) + a.IsFalse(list.Contains(IPTypeAll, "192.168.2.100")) +} + +func BenchmarkIPList_Add(b *testing.B) { + runtime.GOMAXPROCS(1) + + list := NewIPList() + for i := 0; i < b.N; i++ { + list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) + } + b.Log(len(list.ipMap)) +} + +func BenchmarkIPList_Has(b *testing.B) { + runtime.GOMAXPROCS(1) + + list := NewIPList() + + for i := 0; i < 1_0000; i++ { + list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) + } + + for i := 0; i < b.N; i++ { + list.Contains(IPTypeAll, "192.168.1.100") + } +} diff --git a/internal/waf/ip_table.go b/internal/waf/ip_table.go deleted file mode 100644 index 7367b84..0000000 --- a/internal/waf/ip_table.go +++ /dev/null @@ -1,154 +0,0 @@ -package waf - -import ( - "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" - "github.com/iwind/TeaGo/lists" - "github.com/iwind/TeaGo/types" - stringutil "github.com/iwind/TeaGo/utils/string" - "regexp" - "strings" - "time" -) - -type IPAction = string - -var RegexpDigitNumber = regexp.MustCompile("^\\d+$") - -const ( - IPActionAccept IPAction = "accept" - IPActionReject IPAction = "reject" -) - -// ip table -type IPTable struct { - Id string `yaml:"id" json:"id"` - On bool `yaml:"on" json:"on"` - IP string `yaml:"ip" json:"ip"` // single ip, cidr, ip range, TODO support * - Port string `yaml:"port" json:"port"` // single port, range, * - Action IPAction `yaml:"action" json:"action"` // accept, reject - TimeFrom int64 `yaml:"timeFrom" json:"timeFrom"` // from timestamp - TimeTo int64 `yaml:"timeTo" json:"timeTo"` // zero means forever - Remark string `yaml:"remark" json:"remark"` - - // port - minPort int - maxPort int - - minPortWildcard bool - maxPortWildcard bool - - ports []int - - // ip - ipRange *shared.IPRangeConfig -} - -func NewIPTable() *IPTable { - return &IPTable{ - On: true, - Id: stringutil.Rand(16), - } -} - -func (this *IPTable) Init() error { - // parse port - if RegexpDigitNumber.MatchString(this.Port) { - this.minPort = types.Int(this.Port) - this.maxPort = types.Int(this.Port) - } else if regexp.MustCompile(`[:-]`).MatchString(this.Port) { - pieces := regexp.MustCompile(`[:-]`).Split(this.Port, 2) - if pieces[0] == "*" { - this.minPortWildcard = true - } else { - this.minPort = types.Int(pieces[0]) - } - if pieces[1] == "*" { - this.maxPortWildcard = true - } else { - this.maxPort = types.Int(pieces[1]) - } - } else if strings.Contains(this.Port, ",") { - pieces := strings.Split(this.Port, ",") - for _, piece := range pieces { - piece = strings.TrimSpace(piece) - if len(piece) > 0 { - this.ports = append(this.ports, types.Int(piece)) - } - } - } else if this.Port == "*" { - this.minPortWildcard = true - this.maxPortWildcard = true - } - - // parse ip - if len(this.IP) > 0 { - ipRange, err := shared.ParseIPRange(this.IP) - if err != nil { - return err - } - this.ipRange = ipRange - } - - return nil -} - -// check ip -func (this *IPTable) Match(ip string, port int) (isMatched bool) { - if !this.On { - return - } - - now := time.Now().Unix() - if this.TimeFrom > 0 && now < this.TimeFrom { - return - } - if this.TimeTo > 0 && now > this.TimeTo { - return - } - - if !this.matchPort(port) { - return - } - - if !this.matchIP(ip) { - return - } - - return true -} - -func (this *IPTable) matchPort(port int) bool { - if port == 0 { - return false - } - if this.minPortWildcard { - if this.maxPortWildcard { - return true - } - if this.maxPort >= port { - return true - } - } - if this.maxPortWildcard { - if this.minPortWildcard { - return true - } - if this.minPort <= port { - return true - } - } - if (this.minPort > 0 || this.maxPort > 0) && this.minPort <= port && this.maxPort >= port { - return true - } - if len(this.ports) > 0 { - return lists.ContainsInt(this.ports, port) - } - return false -} - -func (this *IPTable) matchIP(ip string) bool { - if this.ipRange == nil { - return false - } - return this.ipRange.Contains(ip) -} diff --git a/internal/waf/ip_table_test.go b/internal/waf/ip_table_test.go deleted file mode 100644 index 5fbd3e8..0000000 --- a/internal/waf/ip_table_test.go +++ /dev/null @@ -1,142 +0,0 @@ -package waf - -import ( - "github.com/iwind/TeaGo/assert" - "testing" - "time" -) - -func TestIPTable_MatchIP(t *testing.T) { - a := assert.NewAssertion(t) - - { - table := NewIPTable() - err := table.Init() - if err != nil { - t.Fatal(err) - } - a.IsFalse(table.Match("192.168.1.100", 8080)) - } - - { - table := NewIPTable() - table.IP = "*" - table.Port = "8080" - err := table.Init() - if err != nil { - t.Fatal(err) - } - t.Log("port:", table.minPort, table.maxPort) - a.IsTrue(table.Match("192.168.1.100", 8080)) - a.IsFalse(table.Match("192.168.1.100", 8081)) - } - - { - table := NewIPTable() - table.IP = "*" - table.Port = "8080-8082" - err := table.Init() - if err != nil { - t.Fatal(err) - } - t.Log("port:", table.minPort, table.maxPort) - a.IsTrue(table.Match("192.168.1.100", 8080)) - a.IsTrue(table.Match("192.168.1.100", 8081)) - a.IsFalse(table.Match("192.168.1.100", 8083)) - } - - { - table := NewIPTable() - table.IP = "*" - table.Port = "*-8082" - err := table.Init() - if err != nil { - t.Fatal(err) - } - t.Log("port:", table.minPort, table.maxPort) - a.IsTrue(table.Match("192.168.1.100", 8079)) - a.IsTrue(table.Match("192.168.1.100", 8080)) - a.IsTrue(table.Match("192.168.1.100", 8081)) - a.IsFalse(table.Match("192.168.1.100", 8083)) - } - - { - table := NewIPTable() - table.IP = "*" - table.Port = "8080-*" - err := table.Init() - if err != nil { - t.Fatal(err) - } - t.Log("port:", table.minPort, table.maxPort) - a.IsFalse(table.Match("192.168.1.100", 8079)) - a.IsTrue(table.Match("192.168.1.100", 8080)) - a.IsTrue(table.Match("192.168.1.100", 8081)) - a.IsTrue(table.Match("192.168.1.100", 8083)) - } - - { - table := NewIPTable() - table.IP = "*" - table.Port = "*" - err := table.Init() - if err != nil { - t.Fatal(err) - } - t.Log("port:", table.minPort, table.maxPort) - a.IsTrue(table.Match("192.168.1.100", 8079)) - a.IsTrue(table.Match("192.168.1.100", 8080)) - a.IsTrue(table.Match("192.168.1.100", 8081)) - a.IsTrue(table.Match("192.168.1.100", 8083)) - } - - { - table := NewIPTable() - table.IP = "192.168.1.100" - table.Port = "*" - err := table.Init() - if err != nil { - t.Fatal(err) - } - t.Log("port:", table.minPort, table.maxPort) - a.IsTrue(table.Match("192.168.1.100", 8080)) - } - - { - table := NewIPTable() - table.IP = "192.168.1.99-192.168.1.101" - table.Port = "*" - err := table.Init() - if err != nil { - t.Fatal(err) - } - t.Log("port:", table.minPort, table.maxPort) - a.IsTrue(table.Match("192.168.1.100", 8080)) - } - - { - table := NewIPTable() - table.IP = "192.168.1.99/24" - table.Port = "*" - err := table.Init() - if err != nil { - t.Fatal(err) - } - t.Log("ip:", table.ipRange) - a.IsTrue(table.Match("192.168.1.100", 8080)) - a.IsFalse(table.Match("192.168.2.100", 8080)) - } - - { - table := NewIPTable() - table.IP = "192.168.1.99/24" - table.TimeTo = time.Now().Unix() - 10 - table.Port = "*" - err := table.Init() - if err != nil { - t.Fatal(err) - } - a.IsFalse(table.Match("192.168.1.100", 8080)) - a.IsFalse(table.Match("192.168.2.100", 8080)) - } -} diff --git a/internal/waf/requests/request.go b/internal/waf/requests/request.go index 28c29b9..b2f35a9 100644 --- a/internal/waf/requests/request.go +++ b/internal/waf/requests/request.go @@ -1,39 +1,28 @@ package requests import ( - "bytes" - "io" - "io/ioutil" "net/http" ) -type Request struct { - *http.Request - BodyData []byte -} +type Request interface { + // WAFRaw 原始请求 + WAFRaw() *http.Request -func NewRequest(raw *http.Request) *Request { - return &Request{ - Request: raw, - } -} + // WAFRemoteIP 客户端IP + WAFRemoteIP() string -func (this *Request) Raw() *http.Request { - return this.Request -} + // WAFGetCacheBody 获取缓存中的Body + WAFGetCacheBody() []byte -func (this *Request) ReadBody(max int64) (data []byte, err error) { - if this.Request.ContentLength > 0 { - data, err = ioutil.ReadAll(io.LimitReader(this.Request.Body, max)) - } - return -} + // WAFSetCacheBody 设置Body + WAFSetCacheBody(body []byte) -func (this *Request) RestoreBody(data []byte) { - if len(data) > 0 { - rawReader := bytes.NewBuffer(data) - buf := make([]byte, 1024) - _, _ = io.CopyBuffer(rawReader, this.Request.Body, buf) - this.Request.Body = ioutil.NopCloser(rawReader) - } + // WAFReadBody 读取Body + WAFReadBody(max int64) (data []byte, err error) + + // WAFRestoreBody 恢复Body + WAFRestoreBody(data []byte) + + // WAFServerId 服务ID + WAFServerId() int64 } diff --git a/internal/waf/requests/test_request.go b/internal/waf/requests/test_request.go new file mode 100644 index 0000000..4682d37 --- /dev/null +++ b/internal/waf/requests/test_request.go @@ -0,0 +1,67 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package requests + +import ( + "bytes" + "io" + "io/ioutil" + "net" + "net/http" +) + +type TestRequest struct { + req *http.Request + BodyData []byte +} + +func NewTestRequest(raw *http.Request) *TestRequest { + return &TestRequest{ + req: raw, + } +} + +func (this *TestRequest) WAFSetCacheBody(bodyData []byte) { + this.BodyData = bodyData +} + +func (this *TestRequest) WAFGetCacheBody() []byte { + return this.BodyData +} + +func (this *TestRequest) WAFRaw() *http.Request { + return this.req +} + +func (this *TestRequest) WAFRemoteAddr() string { + return this.req.RemoteAddr +} + +func (this *TestRequest) WAFRemoteIP() string { + host, _, err := net.SplitHostPort(this.req.RemoteAddr) + if err != nil { + return this.req.RemoteAddr + } else { + return host + } +} + +func (this *TestRequest) WAFReadBody(max int64) (data []byte, err error) { + if this.req.ContentLength > 0 { + data, err = ioutil.ReadAll(io.LimitReader(this.req.Body, max)) + } + return +} + +func (this *TestRequest) WAFRestoreBody(data []byte) { + if len(data) > 0 { + rawReader := bytes.NewBuffer(data) + buf := make([]byte, 1024) + _, _ = io.CopyBuffer(rawReader, this.req.Body, buf) + this.req.Body = ioutil.NopCloser(rawReader) + } +} + +func (this *TestRequest) WAFServerId() int64 { + return 0 +} diff --git a/internal/waf/rule.go b/internal/waf/rule.go index 7455df1..a9ad080 100644 --- a/internal/waf/rule.go +++ b/internal/waf/rule.go @@ -183,7 +183,7 @@ func (this *Rule) Init() error { return err } -func (this *Rule) MatchRequest(req *requests.Request) (b bool, err error) { +func (this *Rule) MatchRequest(req requests.Request) (b bool, err error) { if this.singleCheckpoint != nil { value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions) if err != nil { @@ -233,7 +233,7 @@ func (this *Rule) MatchRequest(req *requests.Request) (b bool, err error) { return this.Test(value), nil } -func (this *Rule) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, err error) { +func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (b bool, err error) { if this.singleCheckpoint != nil { // if is request param if this.singleCheckpoint.IsRequest() { diff --git a/internal/waf/rule_group.go b/internal/waf/rule_group.go index 1c7bea5..19577c4 100644 --- a/internal/waf/rule_group.go +++ b/internal/waf/rule_group.go @@ -23,12 +23,12 @@ func NewRuleGroup() *RuleGroup { } } -func (this *RuleGroup) Init() error { +func (this *RuleGroup) Init(waf *WAF) error { this.hasRuleSets = len(this.RuleSets) > 0 if this.hasRuleSets { for _, set := range this.RuleSets { - err := set.Init() + err := set.Init(waf) if err != nil { return err } @@ -79,7 +79,7 @@ func (this *RuleGroup) RemoveRuleSet(id string) { this.RuleSets = result } -func (this *RuleGroup) MatchRequest(req *requests.Request) (b bool, set *RuleSet, err error) { +func (this *RuleGroup) MatchRequest(req requests.Request) (b bool, set *RuleSet, err error) { if !this.hasRuleSets { return } @@ -98,7 +98,7 @@ func (this *RuleGroup) MatchRequest(req *requests.Request) (b bool, set *RuleSet return } -func (this *RuleGroup) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, set *RuleSet, err error) { +func (this *RuleGroup) MatchResponse(req requests.Request, resp *requests.Response) (b bool, set *RuleSet, err error) { if !this.hasRuleSets { return } diff --git a/internal/waf/rule_set.go b/internal/waf/rule_set.go index 431a504..38b6e15 100644 --- a/internal/waf/rule_set.go +++ b/internal/waf/rule_set.go @@ -1,9 +1,13 @@ package waf import ( + "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/lists" + "github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/utils/string" + "net/http" ) type RuleConnector = string @@ -14,16 +18,17 @@ const ( ) type RuleSet struct { - Id string `yaml:"id" json:"id"` - Code string `yaml:"code" json:"code"` - IsOn bool `yaml:"isOn" json:"isOn"` - Name string `yaml:"name" json:"name"` - Description string `yaml:"description" json:"description"` - Rules []*Rule `yaml:"rules" json:"rules"` - Connector RuleConnector `yaml:"connector" json:"connector"` // rules connector + Id string `yaml:"id" json:"id"` + Code string `yaml:"code" json:"code"` + IsOn bool `yaml:"isOn" json:"isOn"` + Name string `yaml:"name" json:"name"` + Description string `yaml:"description" json:"description"` + Rules []*Rule `yaml:"rules" json:"rules"` + Connector RuleConnector `yaml:"connector" json:"connector"` // rules connector + Actions []*ActionConfig `yaml:"actions" json:"actions"` - Action ActionString `yaml:"action" json:"action"` - ActionOptions maps.Map `yaml:"actionOptions" json:"actionOptions"` // TODO TO BE IMPLEMENTED + actionCodes []string + actionInstances []ActionInterface hasRules bool } @@ -35,7 +40,7 @@ func NewRuleSet() *RuleSet { } } -func (this *RuleSet) Init() error { +func (this *RuleSet) Init(waf *WAF) error { this.hasRules = len(this.Rules) > 0 if this.hasRules { for _, rule := range this.Rules { @@ -45,6 +50,31 @@ func (this *RuleSet) Init() error { } } } + + // action codes + var actionCodes = []string{} + for _, action := range this.Actions { + if !lists.ContainsString(actionCodes, action.Code) { + actionCodes = append(actionCodes, action.Code) + } + } + this.actionCodes = actionCodes + + // action instances + this.actionInstances = []ActionInterface{} + for _, action := range this.Actions { + instance := FindActionInstance(action.Code, action.Options) + if instance == nil { + remotelogs.Error("WAF_RULE_SET", "can not find instance for action '"+action.Code+"'") + } else { + this.actionInstances = append(this.actionInstances, instance) + } + err := instance.Init(waf) + if err != nil { + remotelogs.Error("WAF_RULE_SET", "init action '"+action.Code+"' failed: "+err.Error()) + } + } + return nil } @@ -52,7 +82,75 @@ func (this *RuleSet) AddRule(rule ...*Rule) { this.Rules = append(this.Rules, rule...) } -func (this *RuleSet) MatchRequest(req *requests.Request) (b bool, err error) { +// AddAction 添加动作 +func (this *RuleSet) AddAction(code string, options maps.Map) { + if options == nil { + options = maps.Map{} + } + this.Actions = append(this.Actions, &ActionConfig{ + Code: code, + Options: options, + }) +} + +// HasSpecialActions 除了Allow之外是否还有别的动作 +func (this *RuleSet) HasSpecialActions() bool { + for _, action := range this.Actions { + if action.Code != ActionAllow { + return true + } + } + return false +} + +// HasAttackActions 检查是否含有攻击防御动作 +func (this *RuleSet) HasAttackActions() bool { + for _, action := range this.actionInstances { + if action.IsAttack() { + return true + } + } + return false +} + +func (this *RuleSet) ActionCodes() []string { + return this.actionCodes +} + +func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Request, writer http.ResponseWriter) bool { + // 先执行allow + for _, instance := range this.actionInstances { + if !instance.WillChange() { + if waf.onActionCallback != nil { + goNext := waf.onActionCallback(instance) + if !goNext { + return false + } + } + logs.Printf("perform1: %#v", instance) // TODO + instance.Perform(waf, group, this, req, writer) + } + } + + // 再执行block|verify + for _, instance := range this.actionInstances { + // 只执行第一个可能改变请求的动作,其余的都会被忽略 + if instance.WillChange() { + if waf.onActionCallback != nil { + goNext := waf.onActionCallback(instance) + if !goNext { + return false + } + } + logs.Printf("perform2: %#v", instance) // TODO + return instance.Perform(waf, group, this, req, writer) + } + } + + return true +} + +func (this *RuleSet) MatchRequest(req requests.Request) (b bool, err error) { if !this.hasRules { return false, nil } @@ -93,7 +191,7 @@ func (this *RuleSet) MatchRequest(req *requests.Request) (b bool, err error) { return } -func (this *RuleSet) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, err error) { +func (this *RuleSet) MatchResponse(req requests.Request, resp *requests.Response) (b bool, err error) { if !this.hasRules { return false, nil } diff --git a/internal/waf/rule_set_test.go b/internal/waf/rule_set_test.go index 9d39377..5317643 100644 --- a/internal/waf/rule_set_test.go +++ b/internal/waf/rule_set_test.go @@ -28,7 +28,7 @@ func TestRuleSet_MatchRequest(t *testing.T) { }, } - err := set.Init() + err := set.Init(nil) if err != nil { t.Fatal(err) } @@ -37,7 +37,7 @@ func TestRuleSet_MatchRequest(t *testing.T) { if err != nil { t.Fatal(err) } - req := requests.NewRequest(rawReq) + req := requests.NewTestRequest(rawReq) t.Log(set.MatchRequest(req)) } @@ -60,7 +60,7 @@ func TestRuleSet_MatchRequest2(t *testing.T) { }, } - err := set.Init() + err := set.Init(nil) if err != nil { t.Fatal(err) } @@ -69,7 +69,7 @@ func TestRuleSet_MatchRequest2(t *testing.T) { if err != nil { t.Fatal(err) } - req := requests.NewRequest(rawReq) + req := requests.NewTestRequest(rawReq) a.IsTrue(set.MatchRequest(req)) } @@ -102,7 +102,7 @@ func BenchmarkRuleSet_MatchRequest(b *testing.B) { }, } - err := set.Init() + err := set.Init(nil) if err != nil { b.Fatal(err) } @@ -111,7 +111,7 @@ func BenchmarkRuleSet_MatchRequest(b *testing.B) { if err != nil { b.Fatal(err) } - req := requests.NewRequest(rawReq) + req := requests.NewTestRequest(rawReq) for i := 0; i < b.N; i++ { _, _ = set.MatchRequest(req) } @@ -132,7 +132,7 @@ func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) { }, } - err := set.Init() + err := set.Init(nil) if err != nil { b.Fatal(err) } @@ -141,7 +141,7 @@ func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) { if err != nil { b.Fatal(err) } - req := requests.NewRequest(rawReq) + req := requests.NewTestRequest(rawReq) for i := 0; i < b.N; i++ { _, _ = set.MatchRequest(req) } diff --git a/internal/waf/rule_test.go b/internal/waf/rule_test.go index 6a7731c..e9597b3 100644 --- a/internal/waf/rule_test.go +++ b/internal/waf/rule_test.go @@ -25,7 +25,7 @@ func TestRule_Init_Single(t *testing.T) { t.Fatal(err) } - req := requests.NewRequest(rawReq) + req := requests.NewTestRequest(rawReq) t.Log(rule.MatchRequest(req)) } @@ -44,7 +44,7 @@ func TestRule_Init_Composite(t *testing.T) { if err != nil { t.Fatal(err) } - req := requests.NewRequest(rawReq) + req := requests.NewTestRequest(rawReq) t.Log(rule.MatchRequest(req)) } diff --git a/internal/waf/template.go b/internal/waf/template.go index 83ffe83..5bd06c3 100644 --- a/internal/waf/template.go +++ b/internal/waf/template.go @@ -20,7 +20,7 @@ func Template() *WAF { set.Name = "Javascript事件" set.Code = "1001" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${requestURI}", Operator: RuleOperatorMatch, @@ -36,7 +36,7 @@ func Template() *WAF { set.Name = "Javascript函数" set.Code = "1002" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${requestURI}", Operator: RuleOperatorMatch, @@ -52,7 +52,7 @@ func Template() *WAF { set.Name = "HTML标签" set.Code = "1003" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${requestURI}", Operator: RuleOperatorMatch, @@ -80,7 +80,7 @@ func Template() *WAF { set.Name = "上传文件扩展名" set.Code = "2001" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${requestUpload.ext}", Operator: RuleOperatorMatch, @@ -108,7 +108,7 @@ func Template() *WAF { set.Name = "Web Shell" set.Code = "3001" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${requestAll}", Operator: RuleOperatorMatch, @@ -135,7 +135,7 @@ func Template() *WAF { set.Name = "命令注入" set.Code = "4001" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${requestURI}", Operator: RuleOperatorMatch, @@ -169,7 +169,7 @@ func Template() *WAF { set.Name = "路径穿越" set.Code = "5001" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${requestURI}", Operator: RuleOperatorMatch, @@ -197,7 +197,7 @@ func Template() *WAF { set.Name = "特殊目录" set.Code = "6001" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${requestPath}", Operator: RuleOperatorMatch, @@ -225,7 +225,7 @@ func Template() *WAF { set.Name = "Union SQL Injection" set.Code = "7001" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${requestAll}", @@ -243,7 +243,7 @@ func Template() *WAF { set.Name = "SQL注释" set.Code = "7002" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${requestAll}", @@ -261,7 +261,7 @@ func Template() *WAF { set.Name = "SQL条件" set.Code = "7003" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${requestAll}", @@ -297,7 +297,7 @@ func Template() *WAF { set.Name = "SQL函数" set.Code = "7004" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${requestAll}", @@ -315,7 +315,7 @@ func Template() *WAF { set.Name = "SQL附加语句" set.Code = "7005" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${requestAll}", @@ -345,7 +345,7 @@ func Template() *WAF { set.Name = "常见网络爬虫" set.Code = "20001" set.Connector = RuleConnectorOr - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${userAgent}", @@ -376,7 +376,7 @@ func Template() *WAF { set.Description = "限制单IP在一定时间内的请求数" set.Code = "8001" set.Connector = RuleConnectorAnd - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) set.AddRule(&Rule{ Param: "${cc.requests}", Operator: RuleOperatorGt, diff --git a/internal/waf/template_test.go b/internal/waf/template_test.go index c2074b6..e10d691 100644 --- a/internal/waf/template_test.go +++ b/internal/waf/template_test.go @@ -2,6 +2,7 @@ package waf import ( "bytes" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/logs" @@ -22,8 +23,8 @@ func Test_Template(t *testing.T) { t.Fatal(err) } - template.OnAction(func(action ActionString) (goNext bool) { - return action != ActionBlock + template.OnAction(func(action ActionInterface) (goNext bool) { + return action.Code() != ActionBlock }) testTemplate1001(a, t, template) @@ -40,7 +41,7 @@ func Test_Template(t *testing.T) { func Test_Template2(t *testing.T) { reader := bytes.NewReader([]byte(strings.Repeat("HELLO", 1024))) - req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=123", reader) + req, err := http.NewRequest(http.MethodGet, "https://example.com/index.php?id=123", reader) if err != nil { t.Fatal(err) } @@ -52,7 +53,7 @@ func Test_Template2(t *testing.T) { } now := time.Now() - goNext, _, set, err := waf.MatchRequest(req, nil) + goNext, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -80,7 +81,7 @@ func BenchmarkTemplate(b *testing.B) { b.Fatal(err) } - _, _, _, _ = waf.MatchRequest(req, nil) + _, _, _, _ = waf.MatchRequest(requests.NewTestRequest(req), nil) } } @@ -89,7 +90,7 @@ func testTemplate1001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(req, nil) + _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -104,7 +105,7 @@ func testTemplate1002(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(req, nil) + _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -119,7 +120,7 @@ func testTemplate1003(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(req, nil) + _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -185,7 +186,7 @@ func testTemplate2001(a *assert.Assertion, t *testing.T, template *WAF) { req.Header.Add("Content-Type", writer.FormDataContentType()) - _, _, result, err := template.MatchRequest(req, nil) + _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -200,7 +201,7 @@ func testTemplate3001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(req, nil) + _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -215,7 +216,7 @@ func testTemplate4001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(req, nil) + _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -231,7 +232,7 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(req, nil) + _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -246,7 +247,7 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(req, nil) + _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -263,7 +264,7 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(req, nil) + _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -278,7 +279,7 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(req, nil) + _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -301,7 +302,7 @@ func testTemplate7001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(req, nil) + _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -338,7 +339,7 @@ func testTemplate20001(a *assert.Assertion, t *testing.T, template *WAF) { t.Fatal(err) } req.Header.Set("User-Agent", bot) - _, _, result, err := template.MatchRequest(req, nil) + _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } diff --git a/internal/waf/waf.go b/internal/waf/waf.go index 7d6fc8f..08c762c 100644 --- a/internal/waf/waf.go +++ b/internal/waf/waf.go @@ -22,13 +22,11 @@ type WAF struct { Outbound []*RuleGroup `yaml:"outbound" json:"outbound"` CreatedVersion string `yaml:"createdVersion" json:"createdVersion"` - ActionBlock *BlockAction `yaml:"actionBlock" json:"actionBlock"` // action block config - - IPTables []*IPTable `yaml:"ipTables" json:"ipTables"` // IP table list + DefaultBlockAction *BlockAction hasInboundRules bool hasOutboundRules bool - onActionCallback func(action ActionString) (goNext bool) + onActionCallback func(action ActionInterface) (goNext bool) checkpointsMap map[string]checkpoints.CheckpointInterface // prefix => checkpoint } @@ -87,7 +85,7 @@ func (this *WAF) Init() error { } } - err := group.Init() + err := group.Init(this) if err != nil { return err } @@ -103,7 +101,7 @@ func (this *WAF) Init() error { } } - err := group.Init() + err := group.Init(this) if err != nil { return err } @@ -241,19 +239,24 @@ func (this *WAF) MoveOutboundRuleGroup(fromIndex int, toIndex int) { this.Outbound = result } -func (this *WAF) MatchRequest(rawReq *http.Request, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) { +func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) { if !this.hasInboundRules { return true, nil, nil, nil } - req := requests.NewRequest(rawReq) - // validate captcha - if rawReq.URL.Path == "/WAFCAPTCHA" { + var rawPath = req.WAFRaw().URL.Path + if rawPath == CaptchaPath { captchaValidator.Run(req, writer) return } + // Get 302验证 + if rawPath == Get302Path { + get302Validator.Run(req, writer) + return + } + // match rules for _, group := range this.Inbound { if !group.IsOn { @@ -264,31 +267,17 @@ func (this *WAF) MatchRequest(rawReq *http.Request, writer http.ResponseWriter) return true, nil, nil, err } if b { - if this.onActionCallback == nil { - if set.Action == ActionBlock && this.ActionBlock != nil { - return this.ActionBlock.Perform(this, req, writer), group, set, nil - } else { - actionObject := FindActionInstance(set.Action, set.ActionOptions) - if actionObject == nil { - return true, group, set, errors.New("no action called '" + set.Action + "'") - } - goNext := actionObject.Perform(this, req, writer) - return goNext, group, set, nil - } - } else { - goNext = this.onActionCallback(set.Action) - } + goNext := set.PerformActions(this, group, req, writer) return goNext, group, set, nil } } return true, nil, nil, nil } -func (this *WAF) MatchResponse(rawReq *http.Request, rawResp *http.Response, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) { +func (this *WAF) MatchResponse(req requests.Request, rawResp *http.Response, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) { if !this.hasOutboundRules { return true, nil, nil, nil } - req := requests.NewRequest(rawReq) resp := requests.NewResponse(rawResp) for _, group := range this.Outbound { if !group.IsOn { @@ -299,27 +288,14 @@ func (this *WAF) MatchResponse(rawReq *http.Request, rawResp *http.Response, wri return true, nil, nil, err } if b { - if this.onActionCallback == nil { - if set.Action == ActionBlock && this.ActionBlock != nil { - return this.ActionBlock.Perform(this, req, writer), group, set, nil - } else { - actionObject := FindActionInstance(set.Action, set.ActionOptions) - if actionObject == nil { - return true, group, set, errors.New("no action called '" + set.Action + "'") - } - goNext := actionObject.Perform(this, req, writer) - return goNext, group, set, nil - } - } else { - goNext = this.onActionCallback(set.Action) - } + goNext := set.PerformActions(this, group, req, writer) return goNext, group, set, nil } } return true, nil, nil, nil } -// save to file path +// Save save to file path func (this *WAF) Save(path string) error { if len(path) == 0 { return errors.New("path should not be empty") @@ -378,7 +354,7 @@ func (this *WAF) CountOutboundRuleSets() int { return count } -func (this *WAF) OnAction(onActionCallback func(action ActionString) (goNext bool)) { +func (this *WAF) OnAction(onActionCallback func(action ActionInterface) (goNext bool)) { this.onActionCallback = onActionCallback } @@ -390,21 +366,21 @@ func (this *WAF) FindCheckpointInstance(prefix string) checkpoints.CheckpointInt return nil } -// start +// Start start func (this *WAF) Start() { for _, checkpoint := range this.checkpointsMap { checkpoint.Start() } } -// call stop() when the waf was deleted +// Stop call stop() when the waf was deleted func (this *WAF) Stop() { for _, checkpoint := range this.checkpointsMap { checkpoint.Stop() } } -// merge with template +// MergeTemplate merge with template func (this *WAF) MergeTemplate() (changedItems []string) { changedItems = []string{} diff --git a/internal/waf/waf_test.go b/internal/waf/waf_test.go index acca3d1..5395eb8 100644 --- a/internal/waf/waf_test.go +++ b/internal/waf/waf_test.go @@ -1,6 +1,7 @@ package waf import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/iwind/TeaGo/assert" "net/http" "testing" @@ -24,7 +25,7 @@ func TestWAF_MatchRequest(t *testing.T) { Value: "20", }, } - set.Action = ActionBlock + set.AddAction(ActionBlock, nil) group := NewRuleGroup() group.AddRuleSet(set) @@ -37,15 +38,15 @@ func TestWAF_MatchRequest(t *testing.T) { t.Fatal(err) } - waf.OnAction(func(action ActionString) (goNext bool) { - return action != ActionBlock + waf.OnAction(func(action ActionInterface) (goNext bool) { + return action.Code() != ActionBlock }) req, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil) if err != nil { t.Fatal(err) } - goNext, _, set, err := waf.MatchRequest(req, nil) + goNext, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) }