mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 16:00:25 +08:00 
			
		
		
		
	WAF增加多个动作
This commit is contained in:
		@@ -6,7 +6,7 @@ import (
 | 
				
			|||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// IP名单
 | 
					// IPList IP名单
 | 
				
			||||||
type IPList struct {
 | 
					type IPList struct {
 | 
				
			||||||
	itemsMap   map[int64]*IPItem  // id => item
 | 
						itemsMap   map[int64]*IPItem  // id => item
 | 
				
			||||||
	ipMap      map[uint64][]int64 // ip => itemIds
 | 
						ipMap      map[uint64][]int64 // ip => itemIds
 | 
				
			||||||
@@ -96,7 +96,7 @@ func (this *IPList) Delete(itemId int64) {
 | 
				
			|||||||
	this.isAll = len(this.ipMap[0]) > 0
 | 
						this.isAll = len(this.ipMap[0]) > 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 判断是否包含某个IP
 | 
					// Contains 判断是否包含某个IP
 | 
				
			||||||
func (this *IPList) Contains(ip uint64) bool {
 | 
					func (this *IPList) Contains(ip uint64) bool {
 | 
				
			||||||
	this.locker.RLock()
 | 
						this.locker.RLock()
 | 
				
			||||||
	if this.isAll {
 | 
						if this.isAll {
 | 
				
			||||||
@@ -109,7 +109,7 @@ func (this *IPList) Contains(ip uint64) bool {
 | 
				
			|||||||
	return ok
 | 
						return ok
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 是否包含一组IP
 | 
					// ContainsIPStrings 是否包含一组IP
 | 
				
			||||||
func (this *IPList) ContainsIPStrings(ipStrings []string) (found bool, item *IPItem) {
 | 
					func (this *IPList) ContainsIPStrings(ipStrings []string) (found bool, item *IPItem) {
 | 
				
			||||||
	if len(ipStrings) == 0 {
 | 
						if len(ipStrings) == 0 {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -68,12 +68,15 @@ type HTTPRequest struct {
 | 
				
			|||||||
	cacheKey             string                            // 缓存使用的Key
 | 
						cacheKey             string                            // 缓存使用的Key
 | 
				
			||||||
	isCached             bool                              // 是否已经被缓存
 | 
						isCached             bool                              // 是否已经被缓存
 | 
				
			||||||
	isAttack             bool                              // 是否是攻击请求
 | 
						isAttack             bool                              // 是否是攻击请求
 | 
				
			||||||
 | 
						bodyData             []byte                            // 读取的Body内容
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// WAF相关
 | 
						// WAF相关
 | 
				
			||||||
	firewallPolicyId    int64
 | 
						firewallPolicyId    int64
 | 
				
			||||||
	firewallRuleGroupId int64
 | 
						firewallRuleGroupId int64
 | 
				
			||||||
	firewallRuleSetId   int64
 | 
						firewallRuleSetId   int64
 | 
				
			||||||
	firewallRuleId      int64
 | 
						firewallRuleId      int64
 | 
				
			||||||
 | 
						firewallActions     []string
 | 
				
			||||||
 | 
						tags                []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	logAttrs map[string]string
 | 
						logAttrs map[string]string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1197,5 +1200,10 @@ func (this *HTTPRequest) canIgnore(err error) bool {
 | 
				
			|||||||
		return true
 | 
							return true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// HTTP内部错误
 | 
				
			||||||
 | 
						if strings.HasPrefix(err.Error(), "http:")  || strings.HasPrefix(err.Error(), "http2:") {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -128,6 +128,8 @@ func (this *HTTPRequest) log() {
 | 
				
			|||||||
		FirewallRuleGroupId: this.firewallRuleGroupId,
 | 
							FirewallRuleGroupId: this.firewallRuleGroupId,
 | 
				
			||||||
		FirewallRuleSetId:   this.firewallRuleSetId,
 | 
							FirewallRuleSetId:   this.firewallRuleSetId,
 | 
				
			||||||
		FirewallRuleId:      this.firewallRuleId,
 | 
							FirewallRuleId:      this.firewallRuleId,
 | 
				
			||||||
 | 
							FirewallActions:     this.firewallActions,
 | 
				
			||||||
 | 
							Tags:                this.tags,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		Attrs: this.logAttrs,
 | 
							Attrs: this.logAttrs,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,7 @@
 | 
				
			|||||||
package nodes
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
				
			||||||
@@ -8,6 +9,8 @@ import (
 | 
				
			|||||||
	"github.com/TeaOSLab/EdgeNode/internal/waf"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf"
 | 
				
			||||||
	"github.com/iwind/TeaGo/lists"
 | 
						"github.com/iwind/TeaGo/lists"
 | 
				
			||||||
	"github.com/iwind/TeaGo/types"
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -152,27 +155,36 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
 | 
				
			|||||||
	if w == nil {
 | 
						if w == nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	goNext, ruleGroup, ruleSet, err := w.MatchRequest(this.RawReq, this.writer)
 | 
					
 | 
				
			||||||
 | 
						w.OnAction(func(action waf.ActionInterface) (goNext bool) {
 | 
				
			||||||
 | 
							switch action.Code() {
 | 
				
			||||||
 | 
							case waf.ActionTag:
 | 
				
			||||||
 | 
								this.tags = action.(*waf.TagAction).Tags
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						goNext, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
 | 
							remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if ruleSet != nil {
 | 
						if ruleSet != nil {
 | 
				
			||||||
		if ruleSet.Action != waf.ActionAllow {
 | 
							if ruleSet.HasSpecialActions() {
 | 
				
			||||||
			this.firewallPolicyId = firewallPolicy.Id
 | 
								this.firewallPolicyId = firewallPolicy.Id
 | 
				
			||||||
			this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
 | 
								this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
 | 
				
			||||||
			this.firewallRuleSetId = types.Int64(ruleSet.Id)
 | 
								this.firewallRuleSetId = types.Int64(ruleSet.Id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if ruleSet.Action == waf.ActionBlock {
 | 
								if ruleSet.HasAttackActions() {
 | 
				
			||||||
				this.isAttack = true
 | 
									this.isAttack = true
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// 添加统计
 | 
								// 添加统计
 | 
				
			||||||
			stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Action)
 | 
								stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		this.logAttrs["waf.action"] = ruleSet.Action
 | 
							this.firewallActions = ruleSet.ActionCodes()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return !goNext, false
 | 
						return !goNext, false
 | 
				
			||||||
@@ -208,28 +220,79 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	goNext, ruleGroup, ruleSet, err := w.MatchResponse(this.RawReq, resp, this.writer)
 | 
						w.OnAction(func(action waf.ActionInterface) (goNext bool) {
 | 
				
			||||||
 | 
							switch action.Code() {
 | 
				
			||||||
 | 
							case waf.ActionTag:
 | 
				
			||||||
 | 
								this.tags = action.(*waf.TagAction).Tags
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						goNext, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
 | 
							remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if ruleSet != nil {
 | 
						if ruleSet != nil {
 | 
				
			||||||
		if ruleSet.Action != waf.ActionAllow {
 | 
							if ruleSet.HasSpecialActions() {
 | 
				
			||||||
			this.firewallPolicyId = firewallPolicy.Id
 | 
								this.firewallPolicyId = firewallPolicy.Id
 | 
				
			||||||
			this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
 | 
								this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
 | 
				
			||||||
			this.firewallRuleSetId = types.Int64(ruleSet.Id)
 | 
								this.firewallRuleSetId = types.Int64(ruleSet.Id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if ruleSet.Action == waf.ActionBlock {
 | 
								if ruleSet.HasAttackActions() {
 | 
				
			||||||
				this.isAttack = true
 | 
									this.isAttack = true
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// 添加统计
 | 
								// 添加统计
 | 
				
			||||||
			stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Action)
 | 
								stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		this.logAttrs["waf.action"] = ruleSet.Action
 | 
							this.firewallActions = ruleSet.ActionCodes()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return !goNext
 | 
						return !goNext
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WAFRaw 原始请求
 | 
				
			||||||
 | 
					func (this *HTTPRequest) WAFRaw() *http.Request {
 | 
				
			||||||
 | 
						return this.RawReq
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WAFRemoteIP 客户端IP
 | 
				
			||||||
 | 
					func (this *HTTPRequest) WAFRemoteIP() string {
 | 
				
			||||||
 | 
						return this.requestRemoteAddr()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WAFGetCacheBody 获取缓存中的Body
 | 
				
			||||||
 | 
					func (this *HTTPRequest) WAFGetCacheBody() []byte {
 | 
				
			||||||
 | 
						return this.bodyData
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WAFSetCacheBody 设置Body
 | 
				
			||||||
 | 
					func (this *HTTPRequest) WAFSetCacheBody(body []byte) {
 | 
				
			||||||
 | 
						this.bodyData = body
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WAFReadBody 读取Body
 | 
				
			||||||
 | 
					func (this *HTTPRequest) WAFReadBody(max int64) (data []byte, err error) {
 | 
				
			||||||
 | 
						if this.RawReq.ContentLength > 0 {
 | 
				
			||||||
 | 
							data, err = ioutil.ReadAll(io.LimitReader(this.RawReq.Body, max))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WAFRestoreBody 恢复Body
 | 
				
			||||||
 | 
					func (this *HTTPRequest) WAFRestoreBody(data []byte) {
 | 
				
			||||||
 | 
						if len(data) > 0 {
 | 
				
			||||||
 | 
							rawReader := bytes.NewBuffer(data)
 | 
				
			||||||
 | 
							buf := make([]byte, 1024)
 | 
				
			||||||
 | 
							_, _ = io.CopyBuffer(rawReader, this.RawReq.Body, buf)
 | 
				
			||||||
 | 
							this.RawReq.Body = ioutil.NopCloser(rawReader)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WAFServerId 服务ID
 | 
				
			||||||
 | 
					func (this *HTTPRequest) WAFServerId() int64 {
 | 
				
			||||||
 | 
						return this.Server.Id
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,7 +7,7 @@ import (
 | 
				
			|||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
 | 
				
			||||||
	"github.com/iwind/TeaGo/lists"
 | 
						"github.com/iwind/TeaGo/lists"
 | 
				
			||||||
	"github.com/iwind/TeaGo/types"
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
	http2 "golang.org/x/net/http2"
 | 
						"golang.org/x/net/http2"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,7 +2,10 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
package nodes
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import "net"
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TrafficListener 用于统计流量的网络监听
 | 
					// TrafficListener 用于统计流量的网络监听
 | 
				
			||||||
type TrafficListener struct {
 | 
					type TrafficListener struct {
 | 
				
			||||||
@@ -18,6 +21,17 @@ func (this *TrafficListener) Accept() (net.Conn, error) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						// 是否在WAF名单中
 | 
				
			||||||
 | 
						ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, ip) && waf.SharedIPBlackLIst.Contains(waf.IPTypeAll, ip) {
 | 
				
			||||||
 | 
								go func() {
 | 
				
			||||||
 | 
									_ = conn.Close()
 | 
				
			||||||
 | 
								}()
 | 
				
			||||||
 | 
								return conn, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return NewTrafficConn(conn), nil
 | 
						return NewTrafficConn(conn), nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,20 +11,20 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
var sharedWAFManager = NewWAFManager()
 | 
					var sharedWAFManager = NewWAFManager()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// WAF管理器
 | 
					// WAFManager WAF管理器
 | 
				
			||||||
type WAFManager struct {
 | 
					type WAFManager struct {
 | 
				
			||||||
	mapping map[int64]*waf.WAF // policyId => WAF
 | 
						mapping map[int64]*waf.WAF // policyId => WAF
 | 
				
			||||||
	locker  sync.RWMutex
 | 
						locker  sync.RWMutex
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 获取新对象
 | 
					// NewWAFManager 获取新对象
 | 
				
			||||||
func NewWAFManager() *WAFManager {
 | 
					func NewWAFManager() *WAFManager {
 | 
				
			||||||
	return &WAFManager{
 | 
						return &WAFManager{
 | 
				
			||||||
		mapping: map[int64]*waf.WAF{},
 | 
							mapping: map[int64]*waf.WAF{},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 更新策略
 | 
					// UpdatePolicies 更新策略
 | 
				
			||||||
func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallPolicy) {
 | 
					func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallPolicy) {
 | 
				
			||||||
	this.locker.Lock()
 | 
						this.locker.Lock()
 | 
				
			||||||
	defer this.locker.Unlock()
 | 
						defer this.locker.Unlock()
 | 
				
			||||||
@@ -44,7 +44,7 @@ func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallP
 | 
				
			|||||||
	this.mapping = m
 | 
						this.mapping = m
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 查找WAF
 | 
					// FindWAF 查找WAF
 | 
				
			||||||
func (this *WAFManager) FindWAF(policyId int64) *waf.WAF {
 | 
					func (this *WAFManager) FindWAF(policyId int64) *waf.WAF {
 | 
				
			||||||
	this.locker.RLock()
 | 
						this.locker.RLock()
 | 
				
			||||||
	w, _ := this.mapping[policyId]
 | 
						w, _ := this.mapping[policyId]
 | 
				
			||||||
@@ -78,14 +78,15 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
 | 
				
			|||||||
			// rule sets
 | 
								// rule sets
 | 
				
			||||||
			for _, set := range group.Sets {
 | 
								for _, set := range group.Sets {
 | 
				
			||||||
				s := &waf.RuleSet{
 | 
									s := &waf.RuleSet{
 | 
				
			||||||
					Id:            strconv.FormatInt(set.Id, 10),
 | 
										Id:          strconv.FormatInt(set.Id, 10),
 | 
				
			||||||
					Code:          set.Code,
 | 
										Code:        set.Code,
 | 
				
			||||||
					IsOn:          set.IsOn,
 | 
										IsOn:        set.IsOn,
 | 
				
			||||||
					Name:          set.Name,
 | 
										Name:        set.Name,
 | 
				
			||||||
					Description:   set.Description,
 | 
										Description: set.Description,
 | 
				
			||||||
					Connector:     set.Connector,
 | 
										Connector:   set.Connector,
 | 
				
			||||||
					Action:        set.Action,
 | 
									}
 | 
				
			||||||
					ActionOptions: set.ActionOptions,
 | 
									for _, a := range set.Actions {
 | 
				
			||||||
 | 
										s.AddAction(a.Code, a.Options)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				// rules
 | 
									// rules
 | 
				
			||||||
@@ -132,14 +133,16 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
 | 
				
			|||||||
			// rule sets
 | 
								// rule sets
 | 
				
			||||||
			for _, set := range group.Sets {
 | 
								for _, set := range group.Sets {
 | 
				
			||||||
				s := &waf.RuleSet{
 | 
									s := &waf.RuleSet{
 | 
				
			||||||
					Id:            strconv.FormatInt(set.Id, 10),
 | 
										Id:          strconv.FormatInt(set.Id, 10),
 | 
				
			||||||
					Code:          set.Code,
 | 
										Code:        set.Code,
 | 
				
			||||||
					IsOn:          set.IsOn,
 | 
										IsOn:        set.IsOn,
 | 
				
			||||||
					Name:          set.Name,
 | 
										Name:        set.Name,
 | 
				
			||||||
					Description:   set.Description,
 | 
										Description: set.Description,
 | 
				
			||||||
					Connector:     set.Connector,
 | 
										Connector:   set.Connector,
 | 
				
			||||||
					Action:        set.Action,
 | 
									}
 | 
				
			||||||
					ActionOptions: set.ActionOptions,
 | 
					
 | 
				
			||||||
 | 
									for _, a := range set.Actions {
 | 
				
			||||||
 | 
										s.AddAction(a.Code, a.Options)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				// rules
 | 
									// rules
 | 
				
			||||||
@@ -164,10 +167,11 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// action
 | 
						// action
 | 
				
			||||||
	if policy.BlockOptions != nil {
 | 
						if policy.BlockOptions != nil {
 | 
				
			||||||
		w.ActionBlock = &waf.BlockAction{
 | 
							w.DefaultBlockAction = &waf.BlockAction{
 | 
				
			||||||
			StatusCode: policy.BlockOptions.StatusCode,
 | 
								StatusCode: policy.BlockOptions.StatusCode,
 | 
				
			||||||
			Body:       policy.BlockOptions.Body,
 | 
								Body:       policy.BlockOptions.Body,
 | 
				
			||||||
			URL:        "",
 | 
								URL:        policy.BlockOptions.URL,
 | 
				
			||||||
 | 
								Timeout:    policy.BlockOptions.Timeout,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -113,6 +113,10 @@ func (this *RPCClient) MetricStatRPC() pb.MetricStatServiceClient {
 | 
				
			|||||||
	return pb.NewMetricStatServiceClient(this.pickConn())
 | 
						return pb.NewMetricStatServiceClient(this.pickConn())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RPCClient) FirewallService() pb.FirewallServiceClient {
 | 
				
			||||||
 | 
						return pb.NewFirewallServiceClient(this.pickConn())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Context 节点上下文信息
 | 
					// Context 节点上下文信息
 | 
				
			||||||
func (this *RPCClient) Context() context.Context {
 | 
					func (this *RPCClient) Context() context.Context {
 | 
				
			||||||
	ctx := context.Background()
 | 
						ctx := context.Background()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,6 +8,7 @@ import (
 | 
				
			|||||||
	"github.com/TeaOSLab/EdgeNode/internal/monitor"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/monitor"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/rpc"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/rpc"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf"
 | 
				
			||||||
	"github.com/iwind/TeaGo/Tea"
 | 
						"github.com/iwind/TeaGo/Tea"
 | 
				
			||||||
	"github.com/iwind/TeaGo/maps"
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
	"github.com/iwind/TeaGo/types"
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
@@ -132,17 +133,19 @@ func (this *HTTPRequestStatManager) AddUserAgent(serverId int64, userAgent strin
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// AddFirewallRuleGroupId 添加防火墙拦截动作
 | 
					// AddFirewallRuleGroupId 添加防火墙拦截动作
 | 
				
			||||||
func (this *HTTPRequestStatManager) AddFirewallRuleGroupId(serverId int64, firewallRuleGroupId int64, action string) {
 | 
					func (this *HTTPRequestStatManager) AddFirewallRuleGroupId(serverId int64, firewallRuleGroupId int64, actions []*waf.ActionConfig) {
 | 
				
			||||||
	if firewallRuleGroupId <= 0 {
 | 
						if firewallRuleGroupId <= 0 {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	this.totalAttackRequests ++
 | 
						this.totalAttackRequests++
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	select {
 | 
						for _, action := range actions {
 | 
				
			||||||
	case this.firewallRuleGroupChan <- strconv.FormatInt(serverId, 10) + "@" + strconv.FormatInt(firewallRuleGroupId, 10) + "@" + action:
 | 
							select {
 | 
				
			||||||
	default:
 | 
							case this.firewallRuleGroupChan <- strconv.FormatInt(serverId, 10) + "@" + strconv.FormatInt(firewallRuleGroupId, 10) + "@" + action.Code:
 | 
				
			||||||
		// 超出容量我们就丢弃
 | 
							default:
 | 
				
			||||||
 | 
								// 超出容量我们就丢弃
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										159
									
								
								internal/utils/encrypt.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								internal/utils/encrypt.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,159 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"crypto/aes"
 | 
				
			||||||
 | 
						"crypto/cipher"
 | 
				
			||||||
 | 
						"encoding/base64"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/rands"
 | 
				
			||||||
 | 
						stringutil "github.com/iwind/TeaGo/utils/string"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						simpleEncryptMagicKey = rands.HexString(32)
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						events.On(events.EventReload, func() {
 | 
				
			||||||
 | 
							nodeConfig, _ := nodeconfigs.SharedNodeConfig()
 | 
				
			||||||
 | 
							if nodeConfig != nil {
 | 
				
			||||||
 | 
								simpleEncryptMagicKey = stringutil.Md5(nodeConfig.NodeId + "@" + nodeConfig.Secret)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SimpleEncrypt 加密特殊信息
 | 
				
			||||||
 | 
					func SimpleEncrypt(data []byte) []byte {
 | 
				
			||||||
 | 
						var method = &AES256CFBMethod{}
 | 
				
			||||||
 | 
						err := method.Init([]byte(simpleEncryptMagicKey), []byte(simpleEncryptMagicKey[:16]))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logs.Println("[SimpleEncrypt]" + err.Error())
 | 
				
			||||||
 | 
							return data
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						dst, err := method.Encrypt(data)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logs.Println("[SimpleEncrypt]" + err.Error())
 | 
				
			||||||
 | 
							return data
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return dst
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SimpleDecrypt 解密特殊信息
 | 
				
			||||||
 | 
					func SimpleDecrypt(data []byte) []byte {
 | 
				
			||||||
 | 
						var method = &AES256CFBMethod{}
 | 
				
			||||||
 | 
						err := method.Init([]byte(simpleEncryptMagicKey), []byte(simpleEncryptMagicKey[:16]))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logs.Println("[MagicKeyEncode]" + err.Error())
 | 
				
			||||||
 | 
							return data
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						src, err := method.Decrypt(data)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logs.Println("[MagicKeyEncode]" + err.Error())
 | 
				
			||||||
 | 
							return data
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return src
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func SimpleEncryptMap(m maps.Map) (base64String string, err error) {
 | 
				
			||||||
 | 
						mJSON, err := json.Marshal(m)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return "", err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						data := SimpleEncrypt(mJSON)
 | 
				
			||||||
 | 
						return base64.StdEncoding.EncodeToString(data), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func SimpleDecryptMap(base64String string) (maps.Map, error) {
 | 
				
			||||||
 | 
						data, err := base64.StdEncoding.DecodeString(base64String)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						mJSON := SimpleDecrypt(data)
 | 
				
			||||||
 | 
						var result = maps.Map{}
 | 
				
			||||||
 | 
						err = json.Unmarshal(mJSON, &result)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return result, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AES256CFBMethod struct {
 | 
				
			||||||
 | 
						block cipher.Block
 | 
				
			||||||
 | 
						iv    []byte
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *AES256CFBMethod) Init(key, iv []byte) error {
 | 
				
			||||||
 | 
						// 判断key是否为32长度
 | 
				
			||||||
 | 
						l := len(key)
 | 
				
			||||||
 | 
						if l > 32 {
 | 
				
			||||||
 | 
							key = key[:32]
 | 
				
			||||||
 | 
						} else if l < 32 {
 | 
				
			||||||
 | 
							key = append(key, bytes.Repeat([]byte{' '}, 32-l)...)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						block, err := aes.NewCipher(key)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.block = block
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 判断iv长度
 | 
				
			||||||
 | 
						l2 := len(iv)
 | 
				
			||||||
 | 
						if l2 > aes.BlockSize {
 | 
				
			||||||
 | 
							iv = iv[:aes.BlockSize]
 | 
				
			||||||
 | 
						} else if l2 < aes.BlockSize {
 | 
				
			||||||
 | 
							iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.iv = iv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *AES256CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
 | 
				
			||||||
 | 
						if len(src) == 0 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer func() {
 | 
				
			||||||
 | 
							r := recover()
 | 
				
			||||||
 | 
							if r != nil {
 | 
				
			||||||
 | 
								err = errors.New("encrypt failed")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						dst = make([]byte, len(src))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
 | 
				
			||||||
 | 
						encrypter.XORKeyStream(dst, src)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *AES256CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
 | 
				
			||||||
 | 
						if len(dst) == 0 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer func() {
 | 
				
			||||||
 | 
							r := recover()
 | 
				
			||||||
 | 
							if r != nil {
 | 
				
			||||||
 | 
								err = errors.New("decrypt failed")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						src = make([]byte, len(dst))
 | 
				
			||||||
 | 
						decrypter := cipher.NewCFBDecrypter(this.block, this.iv)
 | 
				
			||||||
 | 
						decrypter.XORKeyStream(src, dst)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										52
									
								
								internal/utils/encrypt_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								internal/utils/encrypt_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSimpleEncrypt(t *testing.T) {
 | 
				
			||||||
 | 
						var arr = []string{"Hello", "World", "People"}
 | 
				
			||||||
 | 
						for _, s := range arr {
 | 
				
			||||||
 | 
							var value = []byte(s)
 | 
				
			||||||
 | 
							encoded := SimpleEncrypt(value)
 | 
				
			||||||
 | 
							t.Log(encoded, string(encoded))
 | 
				
			||||||
 | 
							decoded := SimpleDecrypt(encoded)
 | 
				
			||||||
 | 
							t.Log(decoded, string(decoded))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSimpleEncrypt_Concurrent(t *testing.T) {
 | 
				
			||||||
 | 
						wg := sync.WaitGroup{}
 | 
				
			||||||
 | 
						var arr = []string{"Hello", "World", "People"}
 | 
				
			||||||
 | 
						wg.Add(len(arr))
 | 
				
			||||||
 | 
						for _, s := range arr {
 | 
				
			||||||
 | 
							go func(s string) {
 | 
				
			||||||
 | 
								defer wg.Done()
 | 
				
			||||||
 | 
								t.Log(string(SimpleDecrypt(SimpleEncrypt([]byte(s)))))
 | 
				
			||||||
 | 
							}(s)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						wg.Wait()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSimpleEncryptMap(t *testing.T) {
 | 
				
			||||||
 | 
						var m = maps.Map{
 | 
				
			||||||
 | 
							"s": "Hello",
 | 
				
			||||||
 | 
							"i": 20,
 | 
				
			||||||
 | 
							"b": true,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						encodedResult, err := SimpleEncryptMap(m)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log("result:", encodedResult)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						decodedResult, err := SimpleDecryptMap(encodedResult)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(decodedResult)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -12,6 +12,7 @@ type List struct {
 | 
				
			|||||||
	itemsMap  map[int64]int64   // itemId => timestamp
 | 
						itemsMap  map[int64]int64   // itemId => timestamp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	locker sync.Mutex
 | 
						locker sync.Mutex
 | 
				
			||||||
 | 
						ticker *time.Ticker
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewList() *List {
 | 
					func NewList() *List {
 | 
				
			||||||
@@ -21,10 +22,7 @@ func NewList() *List {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *List) Add(itemId int64, expiredAt int64) {
 | 
					func (this *List) Add(itemId int64, expiresAt int64) {
 | 
				
			||||||
	if expiredAt <= time.Now().Unix() {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	this.locker.Lock()
 | 
						this.locker.Lock()
 | 
				
			||||||
	defer this.locker.Unlock()
 | 
						defer this.locker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -34,17 +32,17 @@ func (this *List) Add(itemId int64, expiredAt int64) {
 | 
				
			|||||||
		this.removeItem(itemId)
 | 
							this.removeItem(itemId)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	expireItemMap, ok := this.expireMap[expiredAt]
 | 
						expireItemMap, ok := this.expireMap[expiresAt]
 | 
				
			||||||
	if ok {
 | 
						if ok {
 | 
				
			||||||
		expireItemMap[itemId] = true
 | 
							expireItemMap[itemId] = true
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		expireItemMap = ItemMap{
 | 
							expireItemMap = ItemMap{
 | 
				
			||||||
			itemId: true,
 | 
								itemId: true,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		this.expireMap[expiredAt] = expireItemMap
 | 
							this.expireMap[expiresAt] = expireItemMap
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	this.itemsMap[itemId] = expiredAt
 | 
						this.itemsMap[itemId] = expiresAt
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *List) Remove(itemId int64) {
 | 
					func (this *List) Remove(itemId int64) {
 | 
				
			||||||
@@ -64,21 +62,22 @@ func (this *List) GC(timestamp int64, callback func(itemId int64)) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *List) StartGC(callback func(itemId int64)) {
 | 
					func (this *List) StartGC(callback func(itemId int64)) {
 | 
				
			||||||
	ticker := time.NewTicker(1 * time.Second)
 | 
						this.ticker = time.NewTicker(1 * time.Second)
 | 
				
			||||||
	lastTimestamp := int64(0)
 | 
						lastTimestamp := int64(0)
 | 
				
			||||||
	for range ticker.C {
 | 
						for range this.ticker.C {
 | 
				
			||||||
		timestamp := time.Now().Unix()
 | 
							timestamp := time.Now().Unix()
 | 
				
			||||||
		if lastTimestamp == 0 {
 | 
							if lastTimestamp == 0 {
 | 
				
			||||||
			lastTimestamp = timestamp - 3600
 | 
								lastTimestamp = timestamp - 3600
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 防止死循环
 | 
							if timestamp >= lastTimestamp {
 | 
				
			||||||
		if lastTimestamp > timestamp {
 | 
								for i := lastTimestamp; i <= timestamp; i++ {
 | 
				
			||||||
			continue
 | 
									this.GC(i, callback)
 | 
				
			||||||
		}
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
		for i := lastTimestamp; i <= timestamp; i++ {
 | 
								for i := timestamp; i <= lastTimestamp; i++ {
 | 
				
			||||||
			this.GC(timestamp, callback)
 | 
									this.GC(i, callback)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 这样做是为了防止系统时钟突变
 | 
							// 这样做是为了防止系统时钟突变
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -58,6 +58,10 @@ func TestList_Start_GC(t *testing.T) {
 | 
				
			|||||||
	list.Add(2, time.Now().Unix()+1)
 | 
						list.Add(2, time.Now().Unix()+1)
 | 
				
			||||||
	list.Add(3, time.Now().Unix()+2)
 | 
						list.Add(3, time.Now().Unix()+2)
 | 
				
			||||||
	list.Add(4, time.Now().Unix()+5)
 | 
						list.Add(4, time.Now().Unix()+5)
 | 
				
			||||||
 | 
						list.Add(5, time.Now().Unix()+5)
 | 
				
			||||||
 | 
						list.Add(6, time.Now().Unix()+6)
 | 
				
			||||||
 | 
						list.Add(7, time.Now().Unix()+6)
 | 
				
			||||||
 | 
						list.Add(8, time.Now().Unix()+6)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
		list.StartGC(func(itemId int64) {
 | 
							list.StartGC(func(itemId int64) {
 | 
				
			||||||
@@ -66,7 +70,7 @@ func TestList_Start_GC(t *testing.T) {
 | 
				
			|||||||
		})
 | 
							})
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	time.Sleep(10 * time.Second)
 | 
						time.Sleep(20 * time.Second)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestList_ManyItems(t *testing.T) {
 | 
					func TestList_ManyItems(t *testing.T) {
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										35
									
								
								internal/utils/jsonutils/map.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								internal/utils/jsonutils/map.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,35 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package jsonutils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func MapToObject(m maps.Map, ptr interface{}) error {
 | 
				
			||||||
 | 
						if m == nil {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						mJSON, err := json.Marshal(m)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return json.Unmarshal(mJSON, ptr)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ObjectToMap(ptr interface{}) (maps.Map, error) {
 | 
				
			||||||
 | 
						if ptr == nil {
 | 
				
			||||||
 | 
							return maps.Map{}, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						ptrJSON, err := json.Marshal(ptr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var result = maps.Map{}
 | 
				
			||||||
 | 
						err = json.Unmarshal(ptrJSON, &result)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return result, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										46
									
								
								internal/utils/jsonutils/map_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								internal/utils/jsonutils/map_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,46 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package jsonutils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/assert"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMapToObject(t *testing.T) {
 | 
				
			||||||
 | 
						a := assert.NewAssertion(t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						type typeA struct {
 | 
				
			||||||
 | 
							B int  `json:"b"`
 | 
				
			||||||
 | 
							C bool `json:"c"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							var obj = &typeA{B: 1, C: true}
 | 
				
			||||||
 | 
							m, err := ObjectToMap(obj)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							PrintT(m, t)
 | 
				
			||||||
 | 
							a.IsTrue(m.GetInt("b") == 1)
 | 
				
			||||||
 | 
							a.IsTrue(m.GetBool("c") == true)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							var obj = &typeA{}
 | 
				
			||||||
 | 
							err := MapToObject(maps.Map{
 | 
				
			||||||
 | 
								"b": 1024,
 | 
				
			||||||
 | 
								"c": true,
 | 
				
			||||||
 | 
							}, obj)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if obj == nil {
 | 
				
			||||||
 | 
								t.Fatal("obj should not be nil")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							a.IsTrue(obj.B == 1024)
 | 
				
			||||||
 | 
							a.IsTrue(obj.C == true)
 | 
				
			||||||
 | 
							PrintT(obj, t)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										17
									
								
								internal/utils/jsonutils/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								internal/utils/jsonutils/utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package jsonutils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func PrintT(obj interface{}, t *testing.T) {
 | 
				
			||||||
 | 
						data, err := json.MarshalIndent(obj, "", "  ")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Log(err)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							t.Log(string(data))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -8,7 +8,23 @@ import (
 | 
				
			|||||||
type AllowAction struct {
 | 
					type AllowAction struct {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *AllowAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
					func (this *AllowAction) Init(waf *WAF) error {
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *AllowAction) Code() string {
 | 
				
			||||||
 | 
						return ActionAllow
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *AllowAction) IsAttack() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *AllowAction) WillChange() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
	// do nothing
 | 
						// do nothing
 | 
				
			||||||
	return true
 | 
						return true
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										21
									
								
								internal/waf/action_base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/action_base.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "net/http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type BaseAction struct {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CloseConn 关闭连接
 | 
				
			||||||
 | 
					func (this *BaseAction) CloseConn(writer http.ResponseWriter) error {
 | 
				
			||||||
 | 
						// 断开连接
 | 
				
			||||||
 | 
						hijack, ok := writer.(http.Hijacker)
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							conn, _, err := hijack.Hijack()
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								return conn.Close()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -23,12 +23,48 @@ type BlockAction struct {
 | 
				
			|||||||
	StatusCode int    `yaml:"statusCode" json:"statusCode"`
 | 
						StatusCode int    `yaml:"statusCode" json:"statusCode"`
 | 
				
			||||||
	Body       string `yaml:"body" json:"body"` // supports HTML
 | 
						Body       string `yaml:"body" json:"body"` // supports HTML
 | 
				
			||||||
	URL        string `yaml:"url" json:"url"`
 | 
						URL        string `yaml:"url" json:"url"`
 | 
				
			||||||
 | 
						Timeout    int32  `yaml:"timeout" json:"timeout"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
					func (this *BlockAction) Init(waf *WAF) error {
 | 
				
			||||||
 | 
						if waf.DefaultBlockAction != nil {
 | 
				
			||||||
 | 
							if this.StatusCode <= 0 {
 | 
				
			||||||
 | 
								this.StatusCode = waf.DefaultBlockAction.StatusCode
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if len(this.Body) == 0 {
 | 
				
			||||||
 | 
								this.Body = waf.DefaultBlockAction.Body
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if len(this.URL) == 0 {
 | 
				
			||||||
 | 
								this.URL = waf.DefaultBlockAction.URL
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if this.Timeout <= 0 {
 | 
				
			||||||
 | 
								this.Timeout = waf.DefaultBlockAction.Timeout
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *BlockAction) Code() string {
 | 
				
			||||||
 | 
						return ActionBlock
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *BlockAction) IsAttack() bool {
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *BlockAction) WillChange() bool {
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						if this.Timeout > 0 {
 | 
				
			||||||
 | 
							// 加入到黑名单
 | 
				
			||||||
 | 
							SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), time.Now().Unix()+int64(this.Timeout))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if writer != nil {
 | 
						if writer != nil {
 | 
				
			||||||
		// if status code eq 444, we close the connection
 | 
							// close the connection
 | 
				
			||||||
		if this.StatusCode == 444 {
 | 
							defer func() {
 | 
				
			||||||
			hijack, ok := writer.(http.Hijacker)
 | 
								hijack, ok := writer.(http.Hijacker)
 | 
				
			||||||
			if ok {
 | 
								if ok {
 | 
				
			||||||
				conn, _, _ := hijack.Hijack()
 | 
									conn, _, _ := hijack.Hijack()
 | 
				
			||||||
@@ -37,7 +73,7 @@ func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer htt
 | 
				
			|||||||
					return
 | 
										return
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// output response
 | 
							// output response
 | 
				
			||||||
		if this.StatusCode > 0 {
 | 
							if this.StatusCode > 0 {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,11 +1,14 @@
 | 
				
			|||||||
package waf
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
	"github.com/iwind/TeaGo/types"
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
	stringutil "github.com/iwind/TeaGo/utils/string"
 | 
						stringutil "github.com/iwind/TeaGo/utils/string"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -13,27 +16,63 @@ var captchaSalt = stringutil.Rand(32)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	CaptchaSeconds = 600 // 10 minutes
 | 
						CaptchaSeconds = 600 // 10 minutes
 | 
				
			||||||
 | 
						CaptchaPath    = "/WAF/VERIFY/CAPTCHA"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type CaptchaAction struct {
 | 
					type CaptchaAction struct {
 | 
				
			||||||
 | 
						Life           int32  `yaml:"life" json:"life"`
 | 
				
			||||||
 | 
						Language       string `yaml:"language" json:"language"`             // 语言,zh-CN, en-US ...
 | 
				
			||||||
 | 
						AddToWhiteList bool   `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *CaptchaAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
					func (this *CaptchaAction) Init(waf *WAF) error {
 | 
				
			||||||
	// TEAWEB_CAPTCHA:
 | 
						return nil
 | 
				
			||||||
	cookie, err := request.Cookie("TEAWEB_WAF_CAPTCHA")
 | 
					}
 | 
				
			||||||
	if err == nil && cookie != nil && len(cookie.Value) > 32 {
 | 
					
 | 
				
			||||||
		m := cookie.Value[:32]
 | 
					func (this *CaptchaAction) Code() string {
 | 
				
			||||||
		timestamp := cookie.Value[32:]
 | 
						return ActionCaptcha
 | 
				
			||||||
		if stringutil.Md5(captchaSalt+timestamp) == m && time.Now().Unix() < types.Int64(timestamp) { // verify md5
 | 
					}
 | 
				
			||||||
			return true
 | 
					
 | 
				
			||||||
 | 
					func (this *CaptchaAction) IsAttack() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CaptchaAction) WillChange() bool {
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						// 是否在白名单中
 | 
				
			||||||
 | 
						if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						refURL := request.WAFRaw().URL.String()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 覆盖配置
 | 
				
			||||||
 | 
						if strings.HasPrefix(refURL, CaptchaPath) {
 | 
				
			||||||
 | 
							info := request.WAFRaw().URL.Query().Get("info")
 | 
				
			||||||
 | 
							if len(info) > 0 {
 | 
				
			||||||
 | 
								m, err := utils.SimpleDecryptMap(info)
 | 
				
			||||||
 | 
								if err == nil && m != nil {
 | 
				
			||||||
 | 
									refURL = m.GetString("url")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	refURL := request.URL.String()
 | 
						var captchaConfig = maps.Map{
 | 
				
			||||||
	if len(request.Referer()) > 0 {
 | 
							"action":    this,
 | 
				
			||||||
		refURL = request.Referer()
 | 
							"timestamp": time.Now().Unix(),
 | 
				
			||||||
 | 
							"url":       refURL,
 | 
				
			||||||
 | 
							"setId":     set.Id,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	http.Redirect(writer, request.Raw(), "/WAFCAPTCHA?url="+url.QueryEscape(refURL), http.StatusTemporaryRedirect)
 | 
						info, err := utils.SimpleEncryptMap(captchaConfig)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							remotelogs.Error("WAF_CAPTCHA_ACTION", "encode captcha config failed: "+err.Error())
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						http.Redirect(writer, request.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										13
									
								
								internal/waf/action_category.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								internal/waf/action_category.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,13 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ActionCategory = string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						ActionCategoryAllow  ActionCategory = firewallconfigs.HTTPFirewallActionCategoryAllow
 | 
				
			||||||
 | 
						ActionCategoryBlock  ActionCategory = firewallconfigs.HTTPFirewallActionCategoryBlock
 | 
				
			||||||
 | 
						ActionCategoryVerify ActionCategory = firewallconfigs.HTTPFirewallActionCategoryVerify
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
							
								
								
									
										10
									
								
								internal/waf/action_config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								internal/waf/action_config.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "github.com/iwind/TeaGo/maps"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ActionConfig struct {
 | 
				
			||||||
 | 
						Code    string   `yaml:"code" json:"code"`
 | 
				
			||||||
 | 
						Options maps.Map `yaml:"options" json:"options"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -2,11 +2,12 @@ package waf
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import "reflect"
 | 
					import "reflect"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// action definition
 | 
					// ActionDefinition action definition
 | 
				
			||||||
type ActionDefinition struct {
 | 
					type ActionDefinition struct {
 | 
				
			||||||
	Name        string
 | 
						Name        string
 | 
				
			||||||
	Code        ActionString
 | 
						Code        ActionString
 | 
				
			||||||
	Description string
 | 
						Description string
 | 
				
			||||||
 | 
						Category    string // category: block, verify, allow
 | 
				
			||||||
	Instance    ActionInterface
 | 
						Instance    ActionInterface
 | 
				
			||||||
	Type        reflect.Type
 | 
						Type        reflect.Type
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										71
									
								
								internal/waf/action_get_302.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								internal/waf/action_get_302.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,71 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						Get302Path = "/WAF/VERIFY/GET"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Get302Action
 | 
				
			||||||
 | 
					// 原理:  origin url --> 302 verify url --> origin url
 | 
				
			||||||
 | 
					// TODO 将来支持meta refresh验证
 | 
				
			||||||
 | 
					type Get302Action struct {
 | 
				
			||||||
 | 
						BaseAction
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						Life int32 `yaml:"life" json:"life"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Get302Action) Init(waf *WAF) error {
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Get302Action) Code() string {
 | 
				
			||||||
 | 
						return ActionGet302
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Get302Action) IsAttack() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Get302Action) WillChange() bool {
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						// 仅限于Get
 | 
				
			||||||
 | 
						if request.WAFRaw().Method != http.MethodGet {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 是否已经在白名单中
 | 
				
			||||||
 | 
						if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var m = maps.Map{
 | 
				
			||||||
 | 
							"url":       request.WAFRaw().URL.String(),
 | 
				
			||||||
 | 
							"timestamp": time.Now().Unix(),
 | 
				
			||||||
 | 
							"life":      this.Life,
 | 
				
			||||||
 | 
							"setId":     set.Id,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						info, err := utils.SimpleEncryptMap(m)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							remotelogs.Error("WAF_GET_302_ACTION", "encode info failed: "+err.Error())
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						http.Redirect(writer, request.WAFRaw(), Get302Path+"?info="+url.QueryEscape(info), http.StatusFound)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 关闭连接
 | 
				
			||||||
 | 
						_ = this.CloseConn(writer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -10,13 +10,29 @@ type GoGroupAction struct {
 | 
				
			|||||||
	GroupId string `yaml:"groupId" json:"groupId"`
 | 
						GroupId string `yaml:"groupId" json:"groupId"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
					func (this *GoGroupAction) Init(waf *WAF) error {
 | 
				
			||||||
	group := waf.FindRuleGroup(this.GroupId)
 | 
						return nil
 | 
				
			||||||
	if group == nil || !group.IsOn {
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *GoGroupAction) Code() string {
 | 
				
			||||||
 | 
						return ActionGoGroup
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *GoGroupAction) IsAttack() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *GoGroupAction) WillChange() bool {
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						nextGroup := waf.FindRuleGroup(this.GroupId)
 | 
				
			||||||
 | 
						if nextGroup == nil || !nextGroup.IsOn {
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	b, set, err := group.MatchRequest(request)
 | 
						b, nextSet, err := nextGroup.MatchRequest(request)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logs.Error(err)
 | 
							logs.Error(err)
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
@@ -26,9 +42,5 @@ func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer h
 | 
				
			|||||||
		return true
 | 
							return true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	actionObject := FindActionInstance(set.Action, set.ActionOptions)
 | 
						return nextSet.PerformActions(waf, nextGroup, request, writer)
 | 
				
			||||||
	if actionObject == nil {
 | 
					 | 
				
			||||||
		return true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return actionObject.Perform(waf, request, writer)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,17 +11,33 @@ type GoSetAction struct {
 | 
				
			|||||||
	SetId   string `yaml:"setId" json:"setId"`
 | 
						SetId   string `yaml:"setId" json:"setId"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
					func (this *GoSetAction) Init(waf *WAF) error {
 | 
				
			||||||
	group := waf.FindRuleGroup(this.GroupId)
 | 
						return nil
 | 
				
			||||||
	if group == nil || !group.IsOn {
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *GoSetAction) Code() string {
 | 
				
			||||||
 | 
						return ActionGoSet
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *GoSetAction) IsAttack() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *GoSetAction) WillChange() bool {
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						nextGroup := waf.FindRuleGroup(this.GroupId)
 | 
				
			||||||
 | 
						if nextGroup == nil || !nextGroup.IsOn {
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	set := group.FindRuleSet(this.SetId)
 | 
						nextSet := nextGroup.FindRuleSet(this.SetId)
 | 
				
			||||||
	if set == nil || !set.IsOn {
 | 
						if nextSet == nil || !nextSet.IsOn {
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	b, err := set.MatchRequest(request)
 | 
						b, err := nextSet.MatchRequest(request)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logs.Error(err)
 | 
							logs.Error(err)
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
@@ -29,9 +45,5 @@ func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer htt
 | 
				
			|||||||
	if !b {
 | 
						if !b {
 | 
				
			||||||
		return true
 | 
							return true
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	actionObject := FindActionInstance(set.Action, set.ActionOptions)
 | 
						return nextSet.PerformActions(waf, nextGroup, request, writer)
 | 
				
			||||||
	if actionObject == nil {
 | 
					 | 
				
			||||||
		return true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return actionObject.Perform(waf, request, writer)
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										25
									
								
								internal/waf/action_interface.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								internal/waf/action_interface.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ActionInterface interface {
 | 
				
			||||||
 | 
						// Init 初始化
 | 
				
			||||||
 | 
						Init(waf *WAF) error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Code 代号
 | 
				
			||||||
 | 
						Code() string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// IsAttack 是否为拦截攻击动作
 | 
				
			||||||
 | 
						IsAttack() bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// WillChange determine if the action will change the request
 | 
				
			||||||
 | 
						WillChange() bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Perform perform the action
 | 
				
			||||||
 | 
						Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -8,6 +8,22 @@ import (
 | 
				
			|||||||
type LogAction struct {
 | 
					type LogAction struct {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *LogAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
					func (this *LogAction) Init(waf *WAF) error {
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *LogAction) Code() string {
 | 
				
			||||||
 | 
						return ActionLog
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *LogAction) IsAttack() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *LogAction) WillChange() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
	return true
 | 
						return true
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										86
									
								
								internal/waf/action_notify.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								internal/waf/action_notify.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,86 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/rpc"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type notifyTask struct {
 | 
				
			||||||
 | 
						ServerId                int64
 | 
				
			||||||
 | 
						HttpFirewallPolicyId    int64
 | 
				
			||||||
 | 
						HttpFirewallRuleGroupId int64
 | 
				
			||||||
 | 
						HttpFirewallRuleSetId   int64
 | 
				
			||||||
 | 
						CreatedAt               int64
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var notifyChan = make(chan *notifyTask, 128)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						events.On(events.EventLoaded, func() {
 | 
				
			||||||
 | 
							go func() {
 | 
				
			||||||
 | 
								rpcClient, err := rpc.SharedRPC()
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									remotelogs.Error("WAF_NOTIFY_ACTION", "create rpc client failed: "+err.Error())
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for task := range notifyChan {
 | 
				
			||||||
 | 
									_, err = rpcClient.FirewallService().NotifyHTTPFirewallEvent(rpcClient.Context(), &pb.NotifyHTTPFirewallEventRequest{
 | 
				
			||||||
 | 
										ServerId:                task.ServerId,
 | 
				
			||||||
 | 
										HttpFirewallPolicyId:    task.HttpFirewallPolicyId,
 | 
				
			||||||
 | 
										HttpFirewallRuleGroupId: task.HttpFirewallRuleGroupId,
 | 
				
			||||||
 | 
										HttpFirewallRuleSetId:   task.HttpFirewallRuleSetId,
 | 
				
			||||||
 | 
										CreatedAt:               task.CreatedAt,
 | 
				
			||||||
 | 
									})
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										remotelogs.Error("WAF_NOTIFY_ACTION", "notify failed: "+err.Error())
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type NotifyAction struct {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *NotifyAction) Init(waf *WAF) error {
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *NotifyAction) Code() string {
 | 
				
			||||||
 | 
						return ActionNotify
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *NotifyAction) IsAttack() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WillChange determine if the action will change the request
 | 
				
			||||||
 | 
					func (this *NotifyAction) WillChange() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Perform perform the action
 | 
				
			||||||
 | 
					func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case notifyChan <- ¬ifyTask{
 | 
				
			||||||
 | 
							ServerId:                request.WAFServerId(),
 | 
				
			||||||
 | 
							HttpFirewallPolicyId:    types.Int64(waf.Id),
 | 
				
			||||||
 | 
							HttpFirewallRuleGroupId: types.Int64(group.Id),
 | 
				
			||||||
 | 
							HttpFirewallRuleSetId:   types.Int64(set.Id),
 | 
				
			||||||
 | 
							CreatedAt:               time.Now().Unix(),
 | 
				
			||||||
 | 
						}:
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										88
									
								
								internal/waf/action_post_307.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								internal/waf/action_post_307.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,88 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Post307Action struct {
 | 
				
			||||||
 | 
						Life int32 `yaml:"life" json:"life"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						BaseAction
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Post307Action) Init(waf *WAF) error {
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Post307Action) Code() string {
 | 
				
			||||||
 | 
						return ActionPost307
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Post307Action) IsAttack() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Post307Action) WillChange() bool {
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						var cookieName = "WAF_VALIDATOR_ID"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 仅限于POST
 | 
				
			||||||
 | 
						if request.WAFRaw().Method != http.MethodPost {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 是否已经在白名单中
 | 
				
			||||||
 | 
						if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 判断是否有Cookie
 | 
				
			||||||
 | 
						cookie, err := request.WAFRaw().Cookie(cookieName)
 | 
				
			||||||
 | 
						if err == nil && cookie != nil {
 | 
				
			||||||
 | 
							m, err := utils.SimpleDecryptMap(cookie.Value)
 | 
				
			||||||
 | 
							if err == nil && m.GetString("remoteIP") == request.WAFRemoteIP() && time.Now().Unix() < m.GetInt64("timestamp")+10 {
 | 
				
			||||||
 | 
								var life = m.GetInt64("life")
 | 
				
			||||||
 | 
								if life <= 0 {
 | 
				
			||||||
 | 
									life = 600 // 默认10分钟
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								var setId = m.GetString("setId")
 | 
				
			||||||
 | 
								SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life)
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var m = maps.Map{
 | 
				
			||||||
 | 
							"timestamp": time.Now().Unix(),
 | 
				
			||||||
 | 
							"life":      this.Life,
 | 
				
			||||||
 | 
							"setId":     set.Id,
 | 
				
			||||||
 | 
							"remoteIP":  request.WAFRemoteIP(),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						info, err := utils.SimpleEncryptMap(m)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							remotelogs.Error("WAF_POST_302_ACTION", "encode info failed: "+err.Error())
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 设置Cookie
 | 
				
			||||||
 | 
						http.SetCookie(writer, &http.Cookie{
 | 
				
			||||||
 | 
							Name:   cookieName,
 | 
				
			||||||
 | 
							Path:   "/",
 | 
				
			||||||
 | 
							MaxAge: 10,
 | 
				
			||||||
 | 
							Value:  info,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 关闭连接
 | 
				
			||||||
 | 
						_ = this.CloseConn(writer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										120
									
								
								internal/waf/action_record_ip.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								internal/waf/action_record_ip.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,120 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/rpc"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type recordIPTask struct {
 | 
				
			||||||
 | 
						ip        string
 | 
				
			||||||
 | 
						listId    int64
 | 
				
			||||||
 | 
						expiredAt int64
 | 
				
			||||||
 | 
						level     string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var recordIPTaskChan = make(chan *recordIPTask, 1024)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						events.On(events.EventLoaded, func() {
 | 
				
			||||||
 | 
							go func() {
 | 
				
			||||||
 | 
								rpcClient, err := rpc.SharedRPC()
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									remotelogs.Error("WAF_RECORD_IP_ACTION", "create rpc client failed: "+err.Error())
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for task := range recordIPTaskChan {
 | 
				
			||||||
 | 
									ipType := "ipv4"
 | 
				
			||||||
 | 
									if strings.Contains(task.ip, ":") {
 | 
				
			||||||
 | 
										ipType = "ipv6"
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									_, err = rpcClient.IPItemRPC().CreateIPItem(rpcClient.Context(), &pb.CreateIPItemRequest{
 | 
				
			||||||
 | 
										IpListId:   task.listId,
 | 
				
			||||||
 | 
										IpFrom:     task.ip,
 | 
				
			||||||
 | 
										IpTo:       "",
 | 
				
			||||||
 | 
										ExpiredAt:  task.expiredAt,
 | 
				
			||||||
 | 
										Reason:     "触发WAF规则自动加入",
 | 
				
			||||||
 | 
										Type:       ipType,
 | 
				
			||||||
 | 
										EventLevel: task.level,
 | 
				
			||||||
 | 
									})
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RecordIPAction struct {
 | 
				
			||||||
 | 
						BaseAction
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						Type     string `yaml:"type" json:"type"`
 | 
				
			||||||
 | 
						IPListId int64  `yaml:"ipListId" json:"ipListId"`
 | 
				
			||||||
 | 
						Level    string `yaml:"level" json:"level"`
 | 
				
			||||||
 | 
						Timeout  int32  `yaml:"timeout" json:"timeout"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RecordIPAction) Init(waf *WAF) error {
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RecordIPAction) Code() string {
 | 
				
			||||||
 | 
						return ActionRecordIP
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RecordIPAction) IsAttack() bool {
 | 
				
			||||||
 | 
						return this.Type == "black"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RecordIPAction) WillChange() bool {
 | 
				
			||||||
 | 
						return this.Type == "black"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						// 是否在本地白名单中
 | 
				
			||||||
 | 
						if SharedIPWhiteList.Contains("set:"+set.Id, set.Id) {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 先加入本地的黑名单
 | 
				
			||||||
 | 
						timeout := this.Timeout
 | 
				
			||||||
 | 
						if timeout <= 0 {
 | 
				
			||||||
 | 
							timeout = 86400 // 1天
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						expiredAt := time.Now().Unix() + int64(timeout)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if this.Type == "black" {
 | 
				
			||||||
 | 
							_ = this.CloseConn(writer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), expiredAt)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							// 加入本地白名单
 | 
				
			||||||
 | 
							timeout := this.Timeout
 | 
				
			||||||
 | 
							if timeout <= 0 {
 | 
				
			||||||
 | 
								timeout = 86400 // 1天
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							SharedIPWhiteList.Add("set:"+set.Id, request.WAFRemoteIP(), expiredAt)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 上报
 | 
				
			||||||
 | 
						if this.IPListId > 0 {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case recordIPTaskChan <- &recordIPTask{
 | 
				
			||||||
 | 
								ip:        request.WAFRemoteIP(),
 | 
				
			||||||
 | 
								listId:    this.IPListId,
 | 
				
			||||||
 | 
								expiredAt: expiredAt,
 | 
				
			||||||
 | 
								level:     this.Level,
 | 
				
			||||||
 | 
							}:
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return this.Type != "black"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										30
									
								
								internal/waf/action_tag.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								internal/waf/action_tag.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,30 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type TagAction struct {
 | 
				
			||||||
 | 
						Tags []string `yaml:"tags" json:"tags"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *TagAction) Init(waf *WAF) error {
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *TagAction) Code() string {
 | 
				
			||||||
 | 
						return ActionTag
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *TagAction) IsAttack() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *TagAction) WillChange() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,21 +0,0 @@
 | 
				
			|||||||
package waf
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
					 | 
				
			||||||
	"net/http"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ActionString = string
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const (
 | 
					 | 
				
			||||||
	ActionLog     = "log"      // allow and log
 | 
					 | 
				
			||||||
	ActionBlock   = "block"    // block
 | 
					 | 
				
			||||||
	ActionCaptcha = "captcha"  // block and show captcha
 | 
					 | 
				
			||||||
	ActionAllow   = "allow"    // allow
 | 
					 | 
				
			||||||
	ActionGoGroup = "go_group" // go to next rule group
 | 
					 | 
				
			||||||
	ActionGoSet   = "go_set"   // go to next rule set
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ActionInterface interface {
 | 
					 | 
				
			||||||
	Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
							
								
								
									
										88
									
								
								internal/waf/action_types.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								internal/waf/action_types.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,88 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "reflect"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ActionString = string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						ActionLog      ActionString = "log"       // allow and log
 | 
				
			||||||
 | 
						ActionBlock    ActionString = "block"     // block
 | 
				
			||||||
 | 
						ActionCaptcha  ActionString = "captcha"   // block and show captcha
 | 
				
			||||||
 | 
						ActionNotify   ActionString = "notify"    // 告警
 | 
				
			||||||
 | 
						ActionGet302   ActionString = "get_302"   // 针对GET的302重定向认证
 | 
				
			||||||
 | 
						ActionPost307  ActionString = "post_307"  // 针对POST的307重定向认证
 | 
				
			||||||
 | 
						ActionRecordIP ActionString = "record_ip" // 记录IP
 | 
				
			||||||
 | 
						ActionTag      ActionString = "tag"       // 标签
 | 
				
			||||||
 | 
						ActionAllow    ActionString = "allow"     // allow
 | 
				
			||||||
 | 
						ActionGoGroup  ActionString = "go_group"  // go to next rule group
 | 
				
			||||||
 | 
						ActionGoSet    ActionString = "go_set"    // go to next rule set
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var AllActions = []*ActionDefinition{
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "阻止",
 | 
				
			||||||
 | 
							Code:     ActionBlock,
 | 
				
			||||||
 | 
							Instance: new(BlockAction),
 | 
				
			||||||
 | 
							Type:     reflect.TypeOf(new(BlockAction)).Elem(),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "允许通过",
 | 
				
			||||||
 | 
							Code:     ActionAllow,
 | 
				
			||||||
 | 
							Instance: new(AllowAction),
 | 
				
			||||||
 | 
							Type:     reflect.TypeOf(new(AllowAction)).Elem(),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "允许并记录日志",
 | 
				
			||||||
 | 
							Code:     ActionLog,
 | 
				
			||||||
 | 
							Instance: new(LogAction),
 | 
				
			||||||
 | 
							Type:     reflect.TypeOf(new(LogAction)).Elem(),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "Captcha验证码",
 | 
				
			||||||
 | 
							Code:     ActionCaptcha,
 | 
				
			||||||
 | 
							Instance: new(CaptchaAction),
 | 
				
			||||||
 | 
							Type:     reflect.TypeOf(new(CaptchaAction)).Elem(),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "告警",
 | 
				
			||||||
 | 
							Code:     ActionNotify,
 | 
				
			||||||
 | 
							Instance: new(NotifyAction),
 | 
				
			||||||
 | 
							Type:     reflect.TypeOf(new(NotifyAction)).Elem(),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "GET 302",
 | 
				
			||||||
 | 
							Code:     ActionGet302,
 | 
				
			||||||
 | 
							Instance: new(Get302Action),
 | 
				
			||||||
 | 
							Type:     reflect.TypeOf(new(Get302Action)).Elem(),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "POST 307",
 | 
				
			||||||
 | 
							Code:     ActionPost307,
 | 
				
			||||||
 | 
							Instance: new(Post307Action),
 | 
				
			||||||
 | 
							Type:     reflect.TypeOf(new(Post307Action)).Elem(),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "记录IP",
 | 
				
			||||||
 | 
							Code:     ActionRecordIP,
 | 
				
			||||||
 | 
							Instance: new(RecordIPAction),
 | 
				
			||||||
 | 
							Type:     reflect.TypeOf(new(RecordIPAction)).Elem(),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "标签",
 | 
				
			||||||
 | 
							Code:     ActionTag,
 | 
				
			||||||
 | 
							Instance: new(TagAction),
 | 
				
			||||||
 | 
							Type:     reflect.TypeOf(new(TagAction)).Elem(),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "跳到下一个规则分组",
 | 
				
			||||||
 | 
							Code:     ActionGoGroup,
 | 
				
			||||||
 | 
							Instance: new(GoGroupAction),
 | 
				
			||||||
 | 
							Type:     reflect.TypeOf(new(GoGroupAction)).Elem(),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "跳到下一个规则集",
 | 
				
			||||||
 | 
							Code:     ActionGoSet,
 | 
				
			||||||
 | 
							Instance: new(GoSetAction),
 | 
				
			||||||
 | 
							Type:     reflect.TypeOf(new(GoSetAction)).Elem(),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,45 +1,12 @@
 | 
				
			|||||||
package waf
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
				
			||||||
	"github.com/iwind/TeaGo/maps"
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var AllActions = []*ActionDefinition{
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		Name:     "阻止",
 | 
					 | 
				
			||||||
		Code:     ActionBlock,
 | 
					 | 
				
			||||||
		Instance: new(BlockAction),
 | 
					 | 
				
			||||||
	},
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		Name:     "允许通过",
 | 
					 | 
				
			||||||
		Code:     ActionAllow,
 | 
					 | 
				
			||||||
		Instance: new(AllowAction),
 | 
					 | 
				
			||||||
	},
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		Name:     "允许并记录日志",
 | 
					 | 
				
			||||||
		Code:     ActionLog,
 | 
					 | 
				
			||||||
		Instance: new(LogAction),
 | 
					 | 
				
			||||||
	},
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		Name:     "Captcha验证码",
 | 
					 | 
				
			||||||
		Code:     ActionCaptcha,
 | 
					 | 
				
			||||||
		Instance: new(CaptchaAction),
 | 
					 | 
				
			||||||
	},
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		Name:     "跳到下一个规则分组",
 | 
					 | 
				
			||||||
		Code:     ActionGoGroup,
 | 
					 | 
				
			||||||
		Instance: new(GoGroupAction),
 | 
					 | 
				
			||||||
		Type:     reflect.TypeOf(new(GoGroupAction)).Elem(),
 | 
					 | 
				
			||||||
	},
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		Name:     "跳到下一个规则集",
 | 
					 | 
				
			||||||
		Code:     ActionGoSet,
 | 
					 | 
				
			||||||
		Instance: new(GoSetAction),
 | 
					 | 
				
			||||||
		Type:     reflect.TypeOf(new(GoSetAction)).Elem(),
 | 
					 | 
				
			||||||
	},
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
 | 
					func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
 | 
				
			||||||
	for _, def := range AllActions {
 | 
						for _, def := range AllActions {
 | 
				
			||||||
		if def.Code == action {
 | 
							if def.Code == action {
 | 
				
			||||||
@@ -49,15 +16,13 @@ func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
 | 
				
			|||||||
				instance := ptrValue.Interface().(ActionInterface)
 | 
									instance := ptrValue.Interface().(ActionInterface)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				if len(options) > 0 {
 | 
									if len(options) > 0 {
 | 
				
			||||||
					count := def.Type.NumField()
 | 
										optionsJSON, err := json.Marshal(options)
 | 
				
			||||||
					for i := 0; i < count; i++ {
 | 
										if err != nil {
 | 
				
			||||||
						field := def.Type.Field(i)
 | 
											remotelogs.Error("WAF_FindActionInstance", "encode options to json failed: "+err.Error())
 | 
				
			||||||
						tag, ok := field.Tag.Lookup("yaml")
 | 
										} else {
 | 
				
			||||||
						if ok {
 | 
											err = json.Unmarshal(optionsJSON, instance)
 | 
				
			||||||
							v, ok := options[tag]
 | 
											if err != nil {
 | 
				
			||||||
							if ok && reflect.TypeOf(v) == field.Type {
 | 
												remotelogs.Error("WAF_FindActionInstance", "decode options from json failed: "+err.Error())
 | 
				
			||||||
								ptrValue.Elem().FieldByName(field.Name).Set(reflect.ValueOf(v))
 | 
					 | 
				
			||||||
							}
 | 
					 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@ package waf
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/iwind/TeaGo/assert"
 | 
						"github.com/iwind/TeaGo/assert"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
	"github.com/iwind/TeaGo/maps"
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
	"runtime"
 | 
						"runtime"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
@@ -16,11 +17,20 @@ func TestFindActionInstance(t *testing.T) {
 | 
				
			|||||||
	t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
 | 
						t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
 | 
				
			||||||
	t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
 | 
						t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
 | 
				
			||||||
	t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
 | 
						t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
 | 
				
			||||||
	t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b",}))
 | 
						t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b"}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil))
 | 
						a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestFindActionInstance_Options(t *testing.T) {
 | 
				
			||||||
 | 
						//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
 | 
				
			||||||
 | 
						//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
 | 
				
			||||||
 | 
						//logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{}), t)
 | 
				
			||||||
 | 
						logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{
 | 
				
			||||||
 | 
							"timeout": 3600,
 | 
				
			||||||
 | 
						}), t)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func BenchmarkFindActionInstance(b *testing.B) {
 | 
					func BenchmarkFindActionInstance(b *testing.B) {
 | 
				
			||||||
	runtime.GOMAXPROCS(1)
 | 
						runtime.GOMAXPROCS(1)
 | 
				
			||||||
	for i := 0; i < b.N; i++ {
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,29 +3,64 @@ package waf
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
	"encoding/base64"
 | 
						"encoding/base64"
 | 
				
			||||||
	"fmt"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/utils/jsonutils"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
	"github.com/dchest/captcha"
 | 
						"github.com/dchest/captcha"
 | 
				
			||||||
	"github.com/iwind/TeaGo/logs"
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
	stringutil "github.com/iwind/TeaGo/utils/string"
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var captchaValidator = &CaptchaValidator{}
 | 
					var captchaValidator = NewCaptchaValidator()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type CaptchaValidator struct {
 | 
					type CaptchaValidator struct {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *CaptchaValidator) Run(request *requests.Request, writer http.ResponseWriter) {
 | 
					func NewCaptchaValidator() *CaptchaValidator {
 | 
				
			||||||
	if request.Method == http.MethodPost && len(request.FormValue("TEAWEB_WAF_CAPTCHA_ID")) > 0 {
 | 
						return &CaptchaValidator{}
 | 
				
			||||||
		this.validate(request, writer)
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CaptchaValidator) Run(request requests.Request, writer http.ResponseWriter) {
 | 
				
			||||||
 | 
						var info = request.WAFRaw().URL.Query().Get("info")
 | 
				
			||||||
 | 
						if len(info) == 0 {
 | 
				
			||||||
 | 
							writer.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
 | 
							_, _ = writer.Write([]byte("invalid request"))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						m, err := utils.SimpleDecryptMap(info)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							_, _ = writer.Write([]byte("invalid request"))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						timestamp := m.GetInt64("timestamp")
 | 
				
			||||||
 | 
						if timestamp < time.Now().Unix()-600 { // 10分钟之后信息过期
 | 
				
			||||||
 | 
							http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var actionConfig = &CaptchaAction{}
 | 
				
			||||||
 | 
						err = jsonutils.MapToObject(m.GetMap("action"), actionConfig)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var setId = m.GetInt64("setId")
 | 
				
			||||||
 | 
						var originURL = m.GetString("url")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if request.WAFRaw().Method == http.MethodPost && len(request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 {
 | 
				
			||||||
 | 
							this.validate(actionConfig, setId, originURL, request, writer)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		this.show(request, writer)
 | 
							this.show(actionConfig, request, writer)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *CaptchaValidator) show(request *requests.Request, writer http.ResponseWriter) {
 | 
					func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests.Request, writer http.ResponseWriter) {
 | 
				
			||||||
	// show captcha
 | 
						// show captcha
 | 
				
			||||||
	captchaId := captcha.NewLen(6)
 | 
						captchaId := captcha.NewLen(6)
 | 
				
			||||||
	buf := bytes.NewBuffer([]byte{})
 | 
						buf := bytes.NewBuffer([]byte{})
 | 
				
			||||||
@@ -35,48 +70,86 @@ func (this *CaptchaValidator) show(request *requests.Request, writer http.Respon
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var lang = actionConfig.Language
 | 
				
			||||||
 | 
						if len(lang) == 0 {
 | 
				
			||||||
 | 
							acceptLanguage := request.WAFRaw().Header.Get("Accept-Language")
 | 
				
			||||||
 | 
							if len(acceptLanguage) > 0 {
 | 
				
			||||||
 | 
								langIndex := strings.Index(acceptLanguage, ",")
 | 
				
			||||||
 | 
								if langIndex > 0 {
 | 
				
			||||||
 | 
									lang = acceptLanguage[:langIndex]
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(lang) == 0 {
 | 
				
			||||||
 | 
							lang = "en-US"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var msgTitle = ""
 | 
				
			||||||
 | 
						var msgPrompt = ""
 | 
				
			||||||
 | 
						var msgButtonTitle = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						switch lang {
 | 
				
			||||||
 | 
						case "en-US":
 | 
				
			||||||
 | 
							msgTitle = "Verify Yourself"
 | 
				
			||||||
 | 
							msgPrompt = "Input verify code above:"
 | 
				
			||||||
 | 
							msgButtonTitle = "Verify Yourself"
 | 
				
			||||||
 | 
						case "zh-CN":
 | 
				
			||||||
 | 
							msgTitle = "身份验证"
 | 
				
			||||||
 | 
							msgPrompt = "请输入上面的验证码"
 | 
				
			||||||
 | 
							msgButtonTitle = "提交验证"
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							msgTitle = "Verify Yourself"
 | 
				
			||||||
 | 
							msgPrompt = "Input verify code above:"
 | 
				
			||||||
 | 
							msgButtonTitle = "Verify Yourself"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						writer.Header().Set("Content-Type", "text/html; charset=utf-8")
 | 
				
			||||||
	_, _ = writer.Write([]byte(`<!DOCTYPE html>
 | 
						_, _ = writer.Write([]byte(`<!DOCTYPE html>
 | 
				
			||||||
<html>
 | 
					<html>
 | 
				
			||||||
<head>
 | 
					<head>
 | 
				
			||||||
	<title>Verify Yourself</title>
 | 
						<title>` + msgTitle + `</title>
 | 
				
			||||||
 | 
						<script type="text/javascript">
 | 
				
			||||||
 | 
						if (window.addEventListener != null) {
 | 
				
			||||||
 | 
							window.addEventListener("load", function () {
 | 
				
			||||||
 | 
								document.getElementById("GOEDGE_WAF_CAPTCHA_CODE").focus()
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						</script>
 | 
				
			||||||
</head>
 | 
					</head>
 | 
				
			||||||
<body>
 | 
					<body>
 | 
				
			||||||
<form method="POST">
 | 
					<form method="POST">
 | 
				
			||||||
	<input type="hidden" name="TEAWEB_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
 | 
						<input type="hidden" name="GOEDGE_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
 | 
				
			||||||
	<img src="data:image/png;base64, ` + base64.StdEncoding.EncodeToString(buf.Bytes()) + `"/>` + `
 | 
						<img src="data:image/png;base64, ` + base64.StdEncoding.EncodeToString(buf.Bytes()) + `"/>` + `
 | 
				
			||||||
	<div>
 | 
						<div>
 | 
				
			||||||
		<p>Input verify code above:</p>
 | 
							<p>` + msgPrompt + `</p>
 | 
				
			||||||
		<input type="text" name="TEAWEB_WAF_CAPTCHA_CODE" maxlength="6" size="18" autocomplete="off" z-index="1" style="font-size:16px;line-height:24px; letter-spacing: 15px; padding-left: 4px"/>
 | 
							<input type="text" name="GOEDGE_WAF_CAPTCHA_CODE" id="GOEDGE_WAF_CAPTCHA_CODE" maxlength="6" autocomplete="off" z-index="1" style="font-size:16px;line-height:24px; letter-spacing: 15px; padding-left: 4px; width: 160px"/>
 | 
				
			||||||
	</div>
 | 
						</div>
 | 
				
			||||||
	<div>
 | 
						<div>
 | 
				
			||||||
		<button type="submit" onclick="window.location = '/webhook'" style="line-height:24px;margin-top:10px">Verify Yourself</button>
 | 
							<button type="submit" style="line-height:24px;margin-top:10px">` + msgButtonTitle + `</button>
 | 
				
			||||||
	</div>
 | 
						</div>
 | 
				
			||||||
</form>
 | 
					</form>
 | 
				
			||||||
</body>
 | 
					</body>
 | 
				
			||||||
</html>`))
 | 
					</html>`))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *CaptchaValidator) validate(request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
					func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, setId int64, originURL string, request requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
	captchaId := request.FormValue("TEAWEB_WAF_CAPTCHA_ID")
 | 
						captchaId := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")
 | 
				
			||||||
	if len(captchaId) > 0 {
 | 
						if len(captchaId) > 0 {
 | 
				
			||||||
		captchaCode := request.FormValue("TEAWEB_WAF_CAPTCHA_CODE")
 | 
							captchaCode := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_CODE")
 | 
				
			||||||
		if captcha.VerifyString(captchaId, captchaCode) {
 | 
							if captcha.VerifyString(captchaId, captchaCode) {
 | 
				
			||||||
			// set cookie
 | 
								var life = CaptchaSeconds
 | 
				
			||||||
			timestamp := fmt.Sprintf("%d", time.Now().Unix()+CaptchaSeconds)
 | 
								if actionConfig.Life > 0 {
 | 
				
			||||||
			m := stringutil.Md5(captchaSalt + timestamp)
 | 
									life = types.Int(actionConfig.Life)
 | 
				
			||||||
			http.SetCookie(writer, &http.Cookie{
 | 
								}
 | 
				
			||||||
				Name:   "TEAWEB_WAF_CAPTCHA",
 | 
					 | 
				
			||||||
				Value:  m + timestamp,
 | 
					 | 
				
			||||||
				MaxAge: CaptchaSeconds, // TODO 这个时间可以设置
 | 
					 | 
				
			||||||
				Path:   "/", // all of dirs
 | 
					 | 
				
			||||||
			})
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
			rawURL := request.URL.Query().Get("url")
 | 
								// 加入到白名单
 | 
				
			||||||
			http.Redirect(writer, request.Raw(), rawURL, http.StatusSeeOther)
 | 
								SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), request.WAFRemoteIP(), time.Now().Unix()+int64(life)) // TODO
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return false
 | 
								return false
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			http.Redirect(writer, request.Raw(), request.URL.String(), http.StatusSeeOther)
 | 
								http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusSeeOther)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,14 +5,12 @@ import (
 | 
				
			|||||||
	"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
	"github.com/iwind/TeaGo/maps"
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
	"github.com/iwind/TeaGo/types"
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
	"net"
 | 
					 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ${cc.arg}
 | 
					// CCCheckpoint ${cc.arg}
 | 
				
			||||||
// TODO implement more traffic rules
 | 
					// TODO implement more traffic rules
 | 
				
			||||||
type CCCheckpoint struct {
 | 
					type CCCheckpoint struct {
 | 
				
			||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
@@ -32,7 +30,7 @@ func (this *CCCheckpoint) Start() {
 | 
				
			|||||||
	this.cache = ttlcache.NewCache()
 | 
						this.cache = ttlcache.NewCache()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *CCCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = 0
 | 
						value = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if this.cache == nil {
 | 
						if this.cache == nil {
 | 
				
			||||||
@@ -66,12 +64,12 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
 | 
				
			|||||||
		var key = ""
 | 
							var key = ""
 | 
				
			||||||
		switch userType {
 | 
							switch userType {
 | 
				
			||||||
		case "ip":
 | 
							case "ip":
 | 
				
			||||||
			key = this.ip(req)
 | 
								key = req.WAFRemoteIP()
 | 
				
			||||||
		case "cookie":
 | 
							case "cookie":
 | 
				
			||||||
			if len(userField) == 0 {
 | 
								if len(userField) == 0 {
 | 
				
			||||||
				key = this.ip(req)
 | 
									key = req.WAFRemoteIP()
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				cookie, _ := req.Cookie(userField)
 | 
									cookie, _ := req.WAFRaw().Cookie(userField)
 | 
				
			||||||
				if cookie != nil {
 | 
									if cookie != nil {
 | 
				
			||||||
					v := cookie.Value
 | 
										v := cookie.Value
 | 
				
			||||||
					if userIndex > 0 && len(v) > userIndex {
 | 
										if userIndex > 0 && len(v) > userIndex {
 | 
				
			||||||
@@ -82,9 +80,9 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		case "get":
 | 
							case "get":
 | 
				
			||||||
			if len(userField) == 0 {
 | 
								if len(userField) == 0 {
 | 
				
			||||||
				key = this.ip(req)
 | 
									key = req.WAFRemoteIP()
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				v := req.URL.Query().Get(userField)
 | 
									v := req.WAFRaw().URL.Query().Get(userField)
 | 
				
			||||||
				if userIndex > 0 && len(v) > userIndex {
 | 
									if userIndex > 0 && len(v) > userIndex {
 | 
				
			||||||
					v = v[userIndex:]
 | 
										v = v[userIndex:]
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
@@ -92,9 +90,9 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		case "post":
 | 
							case "post":
 | 
				
			||||||
			if len(userField) == 0 {
 | 
								if len(userField) == 0 {
 | 
				
			||||||
				key = this.ip(req)
 | 
									key = req.WAFRemoteIP()
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				v := req.PostFormValue(userField)
 | 
									v := req.WAFRaw().PostFormValue(userField)
 | 
				
			||||||
				if userIndex > 0 && len(v) > userIndex {
 | 
									if userIndex > 0 && len(v) > userIndex {
 | 
				
			||||||
					v = v[userIndex:]
 | 
										v = v[userIndex:]
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
@@ -102,19 +100,19 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		case "header":
 | 
							case "header":
 | 
				
			||||||
			if len(userField) == 0 {
 | 
								if len(userField) == 0 {
 | 
				
			||||||
				key = this.ip(req)
 | 
									key = req.WAFRemoteIP()
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				v := req.Header.Get(userField)
 | 
									v := req.WAFRaw().Header.Get(userField)
 | 
				
			||||||
				if userIndex > 0 && len(v) > userIndex {
 | 
									if userIndex > 0 && len(v) > userIndex {
 | 
				
			||||||
					v = v[userIndex:]
 | 
										v = v[userIndex:]
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				key = "USER@" + userType + "@" + userField + "@" + v
 | 
									key = "USER@" + userType + "@" + userField + "@" + v
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			key = this.ip(req)
 | 
								key = req.WAFRemoteIP()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if len(key) == 0 {
 | 
							if len(key) == 0 {
 | 
				
			||||||
			key = this.ip(req)
 | 
								key = req.WAFRemoteIP()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		value = this.cache.IncreaseInt64(key, int64(1), time.Now().Unix()+period)
 | 
							value = this.cache.IncreaseInt64(key, int64(1), time.Now().Unix()+period)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -122,7 +120,7 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *CCCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *CCCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -210,38 +208,3 @@ func (this *CCCheckpoint) Stop() {
 | 
				
			|||||||
		this.cache = nil
 | 
							this.cache = nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func (this *CCCheckpoint) ip(req *requests.Request) string {
 | 
					 | 
				
			||||||
	// X-Forwarded-For
 | 
					 | 
				
			||||||
	forwardedFor := req.Header.Get("X-Forwarded-For")
 | 
					 | 
				
			||||||
	if len(forwardedFor) > 0 {
 | 
					 | 
				
			||||||
		commaIndex := strings.Index(forwardedFor, ",")
 | 
					 | 
				
			||||||
		if commaIndex > 0 {
 | 
					 | 
				
			||||||
			return forwardedFor[:commaIndex]
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return forwardedFor
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Real-IP
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		realIP, ok := req.Header["X-Real-IP"]
 | 
					 | 
				
			||||||
		if ok && len(realIP) > 0 {
 | 
					 | 
				
			||||||
			return realIP[0]
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Real-Ip
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		realIP, ok := req.Header["X-Real-Ip"]
 | 
					 | 
				
			||||||
		if ok && len(realIP) > 0 {
 | 
					 | 
				
			||||||
			return realIP[0]
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Remote-Addr
 | 
					 | 
				
			||||||
	host, _, err := net.SplitHostPort(req.RemoteAddr)
 | 
					 | 
				
			||||||
	if err == nil {
 | 
					 | 
				
			||||||
		return host
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return req.RemoteAddr
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@ package checkpoints
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -12,31 +13,31 @@ func TestCCCheckpoint_RequestValue(t *testing.T) {
 | 
				
			|||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := requests.NewRequest(raw)
 | 
						req := requests.NewTestRequest(raw)
 | 
				
			||||||
	req.RemoteAddr = "127.0.0.1"
 | 
						req.WAFRaw().RemoteAddr = "127.0.0.1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	checkpoint := new(CCCheckpoint)
 | 
						checkpoint := new(CCCheckpoint)
 | 
				
			||||||
	checkpoint.Init()
 | 
						checkpoint.Init()
 | 
				
			||||||
	checkpoint.Start()
 | 
						checkpoint.Start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	options := map[string]string{
 | 
						options := maps.Map{
 | 
				
			||||||
		"period": "5",
 | 
							"period": "5",
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req.RemoteAddr = "127.0.0.2"
 | 
						req.WAFRaw().RemoteAddr = "127.0.0.2"
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req.RemoteAddr = "127.0.0.1"
 | 
						req.WAFRaw().RemoteAddr = "127.0.0.1"
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req.RemoteAddr = "127.0.0.2"
 | 
						req.WAFRaw().RemoteAddr = "127.0.0.2"
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req.RemoteAddr = "127.0.0.2"
 | 
						req.WAFRaw().RemoteAddr = "127.0.0.2"
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req.RemoteAddr = "127.0.0.2"
 | 
						req.WAFRaw().RemoteAddr = "127.0.0.2"
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,32 +5,32 @@ import (
 | 
				
			|||||||
	"github.com/iwind/TeaGo/maps"
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Check Point
 | 
					// CheckpointInterface Check Point
 | 
				
			||||||
type CheckpointInterface interface {
 | 
					type CheckpointInterface interface {
 | 
				
			||||||
	// initialize
 | 
						// Init initialize
 | 
				
			||||||
	Init()
 | 
						Init()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// is request?
 | 
						// IsRequest is request?
 | 
				
			||||||
	IsRequest() bool
 | 
						IsRequest() bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// is composed?
 | 
						// IsComposed is composed?
 | 
				
			||||||
	IsComposed() bool
 | 
						IsComposed() bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// get request value
 | 
						// RequestValue get request value
 | 
				
			||||||
	RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
 | 
						RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// get response value
 | 
						// ResponseValue get response value
 | 
				
			||||||
	ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
 | 
						ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// param option list
 | 
						// ParamOptions param option list
 | 
				
			||||||
	ParamOptions() *ParamOptions
 | 
						ParamOptions() *ParamOptions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// options
 | 
						// Options options
 | 
				
			||||||
	Options() []OptionInterface
 | 
						Options() []OptionInterface
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// start
 | 
						// Start start
 | 
				
			||||||
	Start()
 | 
						Start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// stop
 | 
						// Stop stop
 | 
				
			||||||
	Stop()
 | 
						Stop()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,32 +5,34 @@ import (
 | 
				
			|||||||
	"github.com/iwind/TeaGo/maps"
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ${requestAll}
 | 
					// RequestAllCheckpoint ${requestAll}
 | 
				
			||||||
type RequestAllCheckpoint struct {
 | 
					type RequestAllCheckpoint struct {
 | 
				
			||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	valueBytes := []byte{}
 | 
						valueBytes := []byte{}
 | 
				
			||||||
	if len(req.RequestURI) > 0 {
 | 
						if len(req.WAFRaw().RequestURI) > 0 {
 | 
				
			||||||
		valueBytes = append(valueBytes, req.RequestURI...)
 | 
							valueBytes = append(valueBytes, req.WAFRaw().RequestURI...)
 | 
				
			||||||
	} else if req.URL != nil {
 | 
						} else if req.WAFRaw().URL != nil {
 | 
				
			||||||
		valueBytes = append(valueBytes, req.URL.RequestURI()...)
 | 
							valueBytes = append(valueBytes, req.WAFRaw().URL.RequestURI()...)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if req.Body != nil {
 | 
						if req.WAFRaw().Body != nil {
 | 
				
			||||||
		valueBytes = append(valueBytes, ' ')
 | 
							valueBytes = append(valueBytes, ' ')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if len(req.BodyData) == 0 {
 | 
							var bodyData = req.WAFGetCacheBody()
 | 
				
			||||||
			data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
 | 
							if len(bodyData) == 0 {
 | 
				
			||||||
 | 
								data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return "", err, nil
 | 
									return "", err, nil
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			req.BodyData = data
 | 
								bodyData = data
 | 
				
			||||||
			req.RestoreBody(data)
 | 
								req.WAFSetCacheBody(data)
 | 
				
			||||||
 | 
								req.WAFRestoreBody(data)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		valueBytes = append(valueBytes, req.BodyData...)
 | 
							valueBytes = append(valueBytes, bodyData...)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	value = valueBytes
 | 
						value = valueBytes
 | 
				
			||||||
@@ -38,7 +40,7 @@ func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param stri
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestAllCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestAllCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = ""
 | 
						value = ""
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,7 +18,7 @@ func TestRequestAllCheckpoint_RequestValue(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	checkpoint := new(RequestAllCheckpoint)
 | 
						checkpoint := new(RequestAllCheckpoint)
 | 
				
			||||||
	v, sysErr, userErr := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
 | 
						v, sysErr, userErr := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
 | 
				
			||||||
	if sysErr != nil {
 | 
						if sysErr != nil {
 | 
				
			||||||
		t.Fatal(sysErr)
 | 
							t.Fatal(sysErr)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -42,7 +42,7 @@ func TestRequestAllCheckpoint_RequestValue_Max(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	checkpoint := new(RequestBodyCheckpoint)
 | 
						checkpoint := new(RequestBodyCheckpoint)
 | 
				
			||||||
	value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
 | 
						value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -65,6 +65,6 @@ func BenchmarkRequestAllCheckpoint_RequestValue(b *testing.B) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	checkpoint := new(RequestAllCheckpoint)
 | 
						checkpoint := new(RequestAllCheckpoint)
 | 
				
			||||||
	for i := 0; i < b.N; i++ {
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
		_, _, _ = checkpoint.RequestValue(requests.NewRequest(req), "", nil)
 | 
							_, _, _ = checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,11 +9,11 @@ type RequestArgCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	return req.URL.Query().Get(param), nil, nil
 | 
						return req.WAFRaw().URL.Query().Get(param), nil, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,7 +12,7 @@ func TestArgParam_RequestValue(t *testing.T) {
 | 
				
			|||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	checkpoint := new(RequestArgCheckpoint)
 | 
						checkpoint := new(RequestArgCheckpoint)
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "name", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "name", nil))
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,12 +9,12 @@ type RequestArgsCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestArgsCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestArgsCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = req.URL.RawQuery
 | 
						value = req.WAFRaw().URL.RawQuery
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestArgsCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestArgsCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,31 +5,33 @@ import (
 | 
				
			|||||||
	"github.com/iwind/TeaGo/maps"
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ${requestBody}
 | 
					// RequestBodyCheckpoint ${requestBody}
 | 
				
			||||||
type RequestBodyCheckpoint struct {
 | 
					type RequestBodyCheckpoint struct {
 | 
				
			||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestBodyCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if req.Body == nil {
 | 
						if req.WAFRaw().Body == nil {
 | 
				
			||||||
		value = ""
 | 
							value = ""
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(req.BodyData) == 0 {
 | 
						var bodyData = req.WAFGetCacheBody()
 | 
				
			||||||
		data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
 | 
						if len(bodyData) == 0 {
 | 
				
			||||||
 | 
							data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return "", err, nil
 | 
								return "", err, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		req.BodyData = data
 | 
							bodyData = data
 | 
				
			||||||
		req.RestoreBody(data)
 | 
							req.WAFSetCacheBody(data)
 | 
				
			||||||
 | 
							req.WAFRestoreBody(data)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return req.BodyData, nil, nil
 | 
						return bodyData, nil, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,19 +11,20 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestRequestBodyCheckpoint_RequestValue(t *testing.T) {
 | 
					func TestRequestBodyCheckpoint_RequestValue(t *testing.T) {
 | 
				
			||||||
	req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456")))
 | 
						rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456")))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						var req = requests.NewTestRequest(rawReq)
 | 
				
			||||||
	checkpoint := new(RequestBodyCheckpoint)
 | 
						checkpoint := new(RequestBodyCheckpoint)
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(requests.NewRequest(req), "", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "", nil))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	body, err := ioutil.ReadAll(req.Body)
 | 
						body, err := ioutil.ReadAll(rawReq.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	t.Log(string(body))
 | 
						t.Log(string(body))
 | 
				
			||||||
 | 
						t.Log(string(req.WAFGetCacheBody()))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
 | 
					func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
 | 
				
			||||||
@@ -33,7 +34,7 @@ func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	checkpoint := new(RequestBodyCheckpoint)
 | 
						checkpoint := new(RequestBodyCheckpoint)
 | 
				
			||||||
	value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
 | 
						value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,12 +9,12 @@ type RequestContentTypeCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestContentTypeCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestContentTypeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = req.Header.Get("Content-Type")
 | 
						value = req.WAFRaw().Header.Get("Content-Type")
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestContentTypeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestContentTypeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,8 +9,8 @@ type RequestCookieCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestCookieCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	cookie, err := req.Cookie(param)
 | 
						cookie, err := req.WAFRaw().Cookie(param)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		value = ""
 | 
							value = ""
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
@@ -20,7 +20,7 @@ func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param s
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestCookieCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestCookieCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,16 +11,16 @@ type RequestCookiesCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestCookiesCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestCookiesCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	var cookies = []string{}
 | 
						var cookies = []string{}
 | 
				
			||||||
	for _, cookie := range req.Cookies() {
 | 
						for _, cookie := range req.WAFRaw().Cookies() {
 | 
				
			||||||
		cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value))
 | 
							cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	value = strings.Join(cookies, "&")
 | 
						value = strings.Join(cookies, "&")
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestCookiesCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestCookiesCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,33 +6,35 @@ import (
 | 
				
			|||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ${requestForm.arg}
 | 
					// RequestFormArgCheckpoint ${requestForm.arg}
 | 
				
			||||||
type RequestFormArgCheckpoint struct {
 | 
					type RequestFormArgCheckpoint struct {
 | 
				
			||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestFormArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestFormArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if req.Body == nil {
 | 
						if req.WAFRaw().Body == nil {
 | 
				
			||||||
		value = ""
 | 
							value = ""
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(req.BodyData) == 0 {
 | 
						var bodyData = req.WAFGetCacheBody()
 | 
				
			||||||
		data, err := req.ReadBody(32 * 1024 * 1024) // read 32m bytes
 | 
						if len(bodyData) == 0 {
 | 
				
			||||||
 | 
							data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return "", err, nil
 | 
								return "", err, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		req.BodyData = data
 | 
							bodyData = data
 | 
				
			||||||
		req.RestoreBody(data)
 | 
							req.WAFSetCacheBody(data)
 | 
				
			||||||
 | 
							req.WAFRestoreBody(data)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TODO improve performance
 | 
						// TODO improve performance
 | 
				
			||||||
	values, _ := url.ParseQuery(string(req.BodyData))
 | 
						values, _ := url.ParseQuery(string(bodyData))
 | 
				
			||||||
	return values.Get(param), nil, nil
 | 
						return values.Get(param), nil, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestFormArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestFormArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,8 +15,8 @@ func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
 | 
				
			|||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
						req.WAFRaw().Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	checkpoint := new(RequestFormArgCheckpoint)
 | 
						checkpoint := new(RequestFormArgCheckpoint)
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "name", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "name", nil))
 | 
				
			||||||
@@ -24,7 +24,7 @@ func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
 | 
				
			|||||||
	t.Log(checkpoint.RequestValue(req, "Hello", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "Hello", nil))
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "encoded", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "encoded", nil))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	body, err := ioutil.ReadAll(req.Body)
 | 
						body, err := ioutil.ReadAll(req.WAFRaw().Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -14,7 +14,7 @@ func (this *RequestGeneralHeaderLengthCheckpoint) IsComposed() bool {
 | 
				
			|||||||
	return true
 | 
						return true
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = false
 | 
						value = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	headers := options.GetSlice("headers")
 | 
						headers := options.GetSlice("headers")
 | 
				
			||||||
@@ -25,7 +25,7 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Req
 | 
				
			|||||||
	length := options.GetInt("length")
 | 
						length := options.GetInt("length")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, header := range headers {
 | 
						for _, header := range headers {
 | 
				
			||||||
		v := req.Header.Get(types.String(header))
 | 
							v := req.WAFRaw().Header.Get(types.String(header))
 | 
				
			||||||
		if len(v) > length {
 | 
							if len(v) > length {
 | 
				
			||||||
			value = true
 | 
								value = true
 | 
				
			||||||
			break
 | 
								break
 | 
				
			||||||
@@ -35,6 +35,6 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Req
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,8 +10,8 @@ type RequestHeaderCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	v, found := req.Header[param]
 | 
						v, found := req.WAFRaw().Header[param]
 | 
				
			||||||
	if !found {
 | 
						if !found {
 | 
				
			||||||
		value = ""
 | 
							value = ""
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
@@ -20,7 +20,7 @@ func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param s
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,9 +11,9 @@ type RequestHeadersCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestHeadersCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	var headers = []string{}
 | 
						var headers = []string{}
 | 
				
			||||||
	for k, v := range req.Header {
 | 
						for k, v := range req.WAFRaw().Header {
 | 
				
			||||||
		for _, subV := range v {
 | 
							for _, subV := range v {
 | 
				
			||||||
			headers = append(headers, k+": "+subV)
 | 
								headers = append(headers, k+": "+subV)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -23,7 +23,7 @@ func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestHeadersCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestHeadersCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,12 +9,12 @@ type RequestHostCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestHostCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestHostCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = req.Host
 | 
						value = req.WAFRaw().Host
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestHostCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestHostCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,8 +12,8 @@ func TestRequestHostCheckpoint_RequestValue(t *testing.T) {
 | 
				
			|||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	req.Header.Set("Host", "cloud.teaos.cn")
 | 
						req.WAFRaw().Header.Set("Host", "cloud.teaos.cn")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	checkpoint := new(RequestHostCheckpoint)
 | 
						checkpoint := new(RequestHostCheckpoint)
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "", nil))
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,24 +8,27 @@ import (
 | 
				
			|||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ${requestJSON.arg}
 | 
					// RequestJSONArgCheckpoint ${requestJSON.arg}
 | 
				
			||||||
type RequestJSONArgCheckpoint struct {
 | 
					type RequestJSONArgCheckpoint struct {
 | 
				
			||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestJSONArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if len(req.BodyData) == 0 {
 | 
						var bodyData = req.WAFGetCacheBody()
 | 
				
			||||||
		data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
 | 
						if len(bodyData) == 0 {
 | 
				
			||||||
 | 
							data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return "", err, nil
 | 
								return "", err, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		req.BodyData = data
 | 
					
 | 
				
			||||||
		defer req.RestoreBody(data)
 | 
							bodyData = data
 | 
				
			||||||
 | 
							req.WAFSetCacheBody(data)
 | 
				
			||||||
 | 
							defer req.WAFRestoreBody(data)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TODO improve performance
 | 
						// TODO improve performance
 | 
				
			||||||
	var m interface{} = nil
 | 
						var m interface{} = nil
 | 
				
			||||||
	err := json.Unmarshal(req.BodyData, &m)
 | 
						err := json.Unmarshal(bodyData, &m)
 | 
				
			||||||
	if err != nil || m == nil {
 | 
						if err != nil || m == nil {
 | 
				
			||||||
		return "", nil, err
 | 
							return "", nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -37,7 +40,7 @@ func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param
 | 
				
			|||||||
	return "", nil, nil
 | 
						return "", nil, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestJSONArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestJSONArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -20,7 +20,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) {
 | 
				
			|||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
						//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	checkpoint := new(RequestJSONArgCheckpoint)
 | 
						checkpoint := new(RequestJSONArgCheckpoint)
 | 
				
			||||||
@@ -31,7 +31,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) {
 | 
				
			|||||||
	t.Log(checkpoint.RequestValue(req, "books", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "books", nil))
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "books.1", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "books.1", nil))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	body, err := ioutil.ReadAll(req.Body)
 | 
						body, err := ioutil.ReadAll(req.WAFRaw().Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -50,7 +50,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) {
 | 
				
			|||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
						//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	checkpoint := new(RequestJSONArgCheckpoint)
 | 
						checkpoint := new(RequestJSONArgCheckpoint)
 | 
				
			||||||
@@ -61,7 +61,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) {
 | 
				
			|||||||
	t.Log(checkpoint.RequestValue(req, "0.books", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "0.books", nil))
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	body, err := ioutil.ReadAll(req.Body)
 | 
						body, err := ioutil.ReadAll(req.WAFRaw().Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -80,7 +80,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) {
 | 
				
			|||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
						//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	checkpoint := new(RequestJSONArgCheckpoint)
 | 
						checkpoint := new(RequestJSONArgCheckpoint)
 | 
				
			||||||
@@ -91,7 +91,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) {
 | 
				
			|||||||
	t.Log(checkpoint.RequestValue(req, "0.books", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "0.books", nil))
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	body, err := ioutil.ReadAll(req.Body)
 | 
						body, err := ioutil.ReadAll(req.WAFRaw().Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,12 +9,12 @@ type RequestLengthCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = req.ContentLength
 | 
						value = req.WAFRaw().ContentLength
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,12 +9,12 @@ type RequestMethodCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestMethodCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestMethodCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = req.Method
 | 
						value = req.WAFRaw().Method
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestMethodCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestMethodCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,11 +9,11 @@ type RequestPathCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestPathCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestPathCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	return req.URL.Path, nil, nil
 | 
						return req.WAFRaw().URL.Path, nil, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestPathCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestPathCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,7 +12,7 @@ func TestRequestPathCheckpoint_RequestValue(t *testing.T) {
 | 
				
			|||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	checkpoint := new(RequestPathCheckpoint)
 | 
						checkpoint := new(RequestPathCheckpoint)
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "", nil))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,12 +9,12 @@ type RequestProtoCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestProtoCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestProtoCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = req.Proto
 | 
						value = req.WAFRaw().Proto
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestProtoCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestProtoCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,17 +10,17 @@ type RequestRawRemoteAddrCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	host, _, err := net.SplitHostPort(req.RemoteAddr)
 | 
						host, _, err := net.SplitHostPort(req.WAFRaw().RemoteAddr)
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		value = host
 | 
							value = host
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		value = req.RemoteAddr
 | 
							value = req.WAFRaw().RemoteAddr
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,12 +9,12 @@ type RequestRefererCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestRefererCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestRefererCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = req.Referer()
 | 
						value = req.WAFRaw().Referer()
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestRefererCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestRefererCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,56 +3,18 @@ package checkpoints
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
	"github.com/iwind/TeaGo/maps"
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
	"net"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type RequestRemoteAddrCheckpoint struct {
 | 
					type RequestRemoteAddrCheckpoint struct {
 | 
				
			||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	// X-Forwarded-For
 | 
						value = req.WAFRemoteIP()
 | 
				
			||||||
	forwardedFor := req.Header.Get("X-Forwarded-For")
 | 
					 | 
				
			||||||
	if len(forwardedFor) > 0 {
 | 
					 | 
				
			||||||
		commaIndex := strings.Index(forwardedFor, ",")
 | 
					 | 
				
			||||||
		if commaIndex > 0 {
 | 
					 | 
				
			||||||
			value = forwardedFor[:commaIndex]
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		value = forwardedFor
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Real-IP
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		realIP, ok := req.Header["X-Real-IP"]
 | 
					 | 
				
			||||||
		if ok && len(realIP) > 0 {
 | 
					 | 
				
			||||||
			value = realIP[0]
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Real-Ip
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		realIP, ok := req.Header["X-Real-Ip"]
 | 
					 | 
				
			||||||
		if ok && len(realIP) > 0 {
 | 
					 | 
				
			||||||
			value = realIP[0]
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Remote-Addr
 | 
					 | 
				
			||||||
	host, _, err := net.SplitHostPort(req.RemoteAddr)
 | 
					 | 
				
			||||||
	if err == nil {
 | 
					 | 
				
			||||||
		value = host
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		value = req.RemoteAddr
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,8 +11,8 @@ type RequestRemotePortCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestRemotePortCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	_, port, err := net.SplitHostPort(req.RemoteAddr)
 | 
						_, port, err := net.SplitHostPort(req.WAFRaw().RemoteAddr)
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		value = types.Int(port)
 | 
							value = types.Int(port)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
@@ -21,7 +21,7 @@ func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, par
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestRemotePortCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestRemotePortCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,8 +9,8 @@ type RequestRemoteUserCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestRemoteUserCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	username, _, ok := req.BasicAuth()
 | 
						username, _, ok := req.WAFRaw().BasicAuth()
 | 
				
			||||||
	if !ok {
 | 
						if !ok {
 | 
				
			||||||
		value = ""
 | 
							value = ""
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
@@ -19,7 +19,7 @@ func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, par
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestRemoteUserCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestRemoteUserCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,12 +9,12 @@ type RequestSchemeCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestSchemeCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestSchemeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = req.URL.Scheme
 | 
						value = req.WAFRaw().URL.Scheme
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestSchemeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestSchemeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,7 +12,7 @@ func TestRequestSchemeCheckpoint_RequestValue(t *testing.T) {
 | 
				
			|||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	checkpoint := new(RequestSchemeCheckpoint)
 | 
						checkpoint := new(RequestSchemeCheckpoint)
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "", nil))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,63 +11,65 @@ import (
 | 
				
			|||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ${requestUpload.arg}
 | 
					// RequestUploadCheckpoint ${requestUpload.arg}
 | 
				
			||||||
type RequestUploadCheckpoint struct {
 | 
					type RequestUploadCheckpoint struct {
 | 
				
			||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestUploadCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = ""
 | 
						value = ""
 | 
				
			||||||
	if param == "minSize" || param == "maxSize" {
 | 
						if param == "minSize" || param == "maxSize" {
 | 
				
			||||||
		value = 0
 | 
							value = 0
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if req.Method != http.MethodPost {
 | 
						if req.WAFRaw().Method != http.MethodPost {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if req.Body == nil {
 | 
						if req.WAFRaw().Body == nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if req.MultipartForm == nil {
 | 
						if req.WAFRaw().MultipartForm == nil {
 | 
				
			||||||
		if len(req.BodyData) == 0 {
 | 
							var bodyData = req.WAFGetCacheBody()
 | 
				
			||||||
			data, err := req.ReadBody(32 * 1024 * 1024)
 | 
							if len(bodyData) == 0 {
 | 
				
			||||||
 | 
								data, err := req.WAFReadBody(32 * 1024 * 1024)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				sysErr = err
 | 
									sysErr = err
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			req.BodyData = data
 | 
								bodyData = data
 | 
				
			||||||
			defer req.RestoreBody(data)
 | 
								req.WAFSetCacheBody(data)
 | 
				
			||||||
 | 
								defer req.WAFRestoreBody(data)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		oldBody := req.Body
 | 
							oldBody := req.WAFRaw().Body
 | 
				
			||||||
		req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyData))
 | 
							req.WAFRaw().Body = ioutil.NopCloser(bytes.NewBuffer(bodyData))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		err := req.ParseMultipartForm(32 * 1024 * 1024)
 | 
							err := req.WAFRaw().ParseMultipartForm(32 * 1024 * 1024)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 还原
 | 
							// 还原
 | 
				
			||||||
		req.Body = oldBody
 | 
							req.WAFRaw().Body = oldBody
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			userErr = err
 | 
								userErr = err
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if req.MultipartForm == nil {
 | 
							if req.WAFRaw().MultipartForm == nil {
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if param == "field" { // field
 | 
						if param == "field" { // field
 | 
				
			||||||
		fields := []string{}
 | 
							fields := []string{}
 | 
				
			||||||
		for field := range req.MultipartForm.File {
 | 
							for field := range req.WAFRaw().MultipartForm.File {
 | 
				
			||||||
			fields = append(fields, field)
 | 
								fields = append(fields, field)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		value = strings.Join(fields, ",")
 | 
							value = strings.Join(fields, ",")
 | 
				
			||||||
	} else if param == "minSize" { // minSize
 | 
						} else if param == "minSize" { // minSize
 | 
				
			||||||
		minSize := int64(0)
 | 
							minSize := int64(0)
 | 
				
			||||||
		for _, files := range req.MultipartForm.File {
 | 
							for _, files := range req.WAFRaw().MultipartForm.File {
 | 
				
			||||||
			for _, file := range files {
 | 
								for _, file := range files {
 | 
				
			||||||
				if minSize == 0 || minSize > file.Size {
 | 
									if minSize == 0 || minSize > file.Size {
 | 
				
			||||||
					minSize = file.Size
 | 
										minSize = file.Size
 | 
				
			||||||
@@ -77,7 +79,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
 | 
				
			|||||||
		value = minSize
 | 
							value = minSize
 | 
				
			||||||
	} else if param == "maxSize" { // maxSize
 | 
						} else if param == "maxSize" { // maxSize
 | 
				
			||||||
		maxSize := int64(0)
 | 
							maxSize := int64(0)
 | 
				
			||||||
		for _, files := range req.MultipartForm.File {
 | 
							for _, files := range req.WAFRaw().MultipartForm.File {
 | 
				
			||||||
			for _, file := range files {
 | 
								for _, file := range files {
 | 
				
			||||||
				if maxSize < file.Size {
 | 
									if maxSize < file.Size {
 | 
				
			||||||
					maxSize = file.Size
 | 
										maxSize = file.Size
 | 
				
			||||||
@@ -87,7 +89,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
 | 
				
			|||||||
		value = maxSize
 | 
							value = maxSize
 | 
				
			||||||
	} else if param == "name" { // name
 | 
						} else if param == "name" { // name
 | 
				
			||||||
		names := []string{}
 | 
							names := []string{}
 | 
				
			||||||
		for _, files := range req.MultipartForm.File {
 | 
							for _, files := range req.WAFRaw().MultipartForm.File {
 | 
				
			||||||
			for _, file := range files {
 | 
								for _, file := range files {
 | 
				
			||||||
				if !lists.ContainsString(names, file.Filename) {
 | 
									if !lists.ContainsString(names, file.Filename) {
 | 
				
			||||||
					names = append(names, file.Filename)
 | 
										names = append(names, file.Filename)
 | 
				
			||||||
@@ -97,7 +99,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
 | 
				
			|||||||
		value = strings.Join(names, ",")
 | 
							value = strings.Join(names, ",")
 | 
				
			||||||
	} else if param == "ext" { // ext
 | 
						} else if param == "ext" { // ext
 | 
				
			||||||
		extensions := []string{}
 | 
							extensions := []string{}
 | 
				
			||||||
		for _, files := range req.MultipartForm.File {
 | 
							for _, files := range req.WAFRaw().MultipartForm.File {
 | 
				
			||||||
			for _, file := range files {
 | 
								for _, file := range files {
 | 
				
			||||||
				if len(file.Filename) > 0 {
 | 
									if len(file.Filename) > 0 {
 | 
				
			||||||
					exit := strings.ToLower(filepath.Ext(file.Filename))
 | 
										exit := strings.ToLower(filepath.Ext(file.Filename))
 | 
				
			||||||
@@ -113,7 +115,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestUploadCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestUploadCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -63,8 +63,8 @@ func TestRequestUploadCheckpoint_RequestValue(t *testing.T) {
 | 
				
			|||||||
		t.Fatal()
 | 
							t.Fatal()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	req.Header.Add("Content-Type", writer.FormDataContentType())
 | 
						req.WAFRaw().Header.Add("Content-Type", writer.FormDataContentType())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	checkpoint := new(RequestUploadCheckpoint)
 | 
						checkpoint := new(RequestUploadCheckpoint)
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "field", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "field", nil))
 | 
				
			||||||
@@ -73,7 +73,7 @@ func TestRequestUploadCheckpoint_RequestValue(t *testing.T) {
 | 
				
			|||||||
	t.Log(checkpoint.RequestValue(req, "name", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "name", nil))
 | 
				
			||||||
	t.Log(checkpoint.RequestValue(req, "ext", nil))
 | 
						t.Log(checkpoint.RequestValue(req, "ext", nil))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	data, err := ioutil.ReadAll(req.Body)
 | 
						data, err := ioutil.ReadAll(req.WAFRaw().Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,16 +9,16 @@ type RequestURICheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestURICheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestURICheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if len(req.RequestURI) > 0 {
 | 
						if len(req.WAFRaw().RequestURI) > 0 {
 | 
				
			||||||
		value = req.RequestURI
 | 
							value = req.WAFRaw().RequestURI
 | 
				
			||||||
	} else if req.URL != nil {
 | 
						} else if req.WAFRaw().URL != nil {
 | 
				
			||||||
		value = req.URL.RequestURI()
 | 
							value = req.WAFRaw().URL.RequestURI()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestURICheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestURICheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -9,12 +9,12 @@ type RequestUserAgentCheckpoint struct {
 | 
				
			|||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestUserAgentCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestUserAgentCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = req.UserAgent()
 | 
						value = req.WAFRaw().UserAgent()
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RequestUserAgentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *RequestUserAgentCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -16,12 +16,12 @@ func (this *ResponseBodyCheckpoint) IsRequest() bool {
 | 
				
			|||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *ResponseBodyCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *ResponseBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = ""
 | 
						value = ""
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *ResponseBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *ResponseBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = ""
 | 
						value = ""
 | 
				
			||||||
	if resp != nil && resp.Body != nil {
 | 
						if resp != nil && resp.Body != nil {
 | 
				
			||||||
		if len(resp.BodyData) > 0 {
 | 
							if len(resp.BodyData) > 0 {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -14,12 +14,12 @@ func (this *ResponseBytesSentCheckpoint) IsRequest() bool {
 | 
				
			|||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *ResponseBytesSentCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *ResponseBytesSentCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = 0
 | 
						value = 0
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *ResponseBytesSentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *ResponseBytesSentCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = 0
 | 
						value = 0
 | 
				
			||||||
	if resp != nil {
 | 
						if resp != nil {
 | 
				
			||||||
		value = resp.ContentLength
 | 
							value = resp.ContentLength
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,12 +18,12 @@ func (this *ResponseGeneralHeaderLengthCheckpoint) IsComposed() bool {
 | 
				
			|||||||
	return true
 | 
						return true
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *ResponseGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *ResponseGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = false
 | 
						value = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	headers := options.GetSlice("headers")
 | 
						headers := options.GetSlice("headers")
 | 
				
			||||||
@@ -34,7 +34,7 @@ func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.R
 | 
				
			|||||||
	length := options.GetInt("length")
 | 
						length := options.GetInt("length")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, header := range headers {
 | 
						for _, header := range headers {
 | 
				
			||||||
		v := req.Header.Get(types.String(header))
 | 
							v := req.WAFRaw().Header.Get(types.String(header))
 | 
				
			||||||
		if len(v) > length {
 | 
							if len(v) > length {
 | 
				
			||||||
			value = true
 | 
								value = true
 | 
				
			||||||
			break
 | 
								break
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -14,12 +14,12 @@ func (this *ResponseHeaderCheckpoint) IsRequest() bool {
 | 
				
			|||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *ResponseHeaderCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *ResponseHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = ""
 | 
						value = ""
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *ResponseHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *ResponseHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if resp != nil && resp.Header != nil {
 | 
						if resp != nil && resp.Header != nil {
 | 
				
			||||||
		value = resp.Header.Get(param)
 | 
							value = resp.Header.Get(param)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,7 +5,7 @@ import (
 | 
				
			|||||||
	"github.com/iwind/TeaGo/maps"
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ${bytesSent}
 | 
					// ResponseStatusCheckpoint ${bytesSent}
 | 
				
			||||||
type ResponseStatusCheckpoint struct {
 | 
					type ResponseStatusCheckpoint struct {
 | 
				
			||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -14,12 +14,12 @@ func (this *ResponseStatusCheckpoint) IsRequest() bool {
 | 
				
			|||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *ResponseStatusCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *ResponseStatusCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	value = 0
 | 
						value = 0
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *ResponseStatusCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *ResponseStatusCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if resp != nil {
 | 
						if resp != nil {
 | 
				
			||||||
		value = resp.StatusCode
 | 
							value = resp.StatusCode
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,16 +5,16 @@ import (
 | 
				
			|||||||
	"github.com/iwind/TeaGo/maps"
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// just a sample checkpoint, copy and change it for your new checkpoint
 | 
					// SampleRequestCheckpoint just a sample checkpoint, copy and change it for your new checkpoint
 | 
				
			||||||
type SampleRequestCheckpoint struct {
 | 
					type SampleRequestCheckpoint struct {
 | 
				
			||||||
	Checkpoint
 | 
						Checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *SampleRequestCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *SampleRequestCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *SampleRequestCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
					func (this *SampleRequestCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
	if this.IsRequest() {
 | 
						if this.IsRequest() {
 | 
				
			||||||
		return this.RequestValue(req, param, options)
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,6 @@
 | 
				
			|||||||
package checkpoints
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// all check points list
 | 
					// AllCheckpoints all check points list
 | 
				
			||||||
var AllCheckpoints = []*CheckpointDefinition{
 | 
					var AllCheckpoints = []*CheckpointDefinition{
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		Name:        "通用请求Header长度限制",
 | 
							Name:        "通用请求Header长度限制",
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										52
									
								
								internal/waf/get302_validator.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								internal/waf/get302_validator.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var get302Validator = NewGet302Validator()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Get302Validator struct {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewGet302Validator() *Get302Validator {
 | 
				
			||||||
 | 
						return &Get302Validator{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Get302Validator) Run(request requests.Request, writer http.ResponseWriter) {
 | 
				
			||||||
 | 
						var info = request.WAFRaw().URL.Query().Get("info")
 | 
				
			||||||
 | 
						if len(info) == 0 {
 | 
				
			||||||
 | 
							writer.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
 | 
							_, _ = writer.Write([]byte("invalid request"))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						m, err := utils.SimpleDecryptMap(info)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							_, _ = writer.Write([]byte("invalid request"))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var timestamp = m.GetInt64("timestamp")
 | 
				
			||||||
 | 
						if time.Now().Unix()-timestamp > 5 { // 超过5秒认为失效
 | 
				
			||||||
 | 
							writer.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
 | 
							_, _ = writer.Write([]byte("invalid request"))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 加入白名单
 | 
				
			||||||
 | 
						life := m.GetInt64("life")
 | 
				
			||||||
 | 
						if life <= 0 {
 | 
				
			||||||
 | 
							life = 600 // 默认10分钟
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						setId := m.GetString("setId")
 | 
				
			||||||
 | 
						SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 返回原始URL
 | 
				
			||||||
 | 
						var url = m.GetString("url")
 | 
				
			||||||
 | 
						http.Redirect(writer, request.WAFRaw(), url, http.StatusFound)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										82
									
								
								internal/waf/ip_list.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								internal/waf/ip_list.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,82 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"sync/atomic"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var SharedIPWhiteList = NewIPList()
 | 
				
			||||||
 | 
					var SharedIPBlackLIst = NewIPList()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const IPTypeAll = "*"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// IPList IP列表管理
 | 
				
			||||||
 | 
					type IPList struct {
 | 
				
			||||||
 | 
						expireList *expires.List
 | 
				
			||||||
 | 
						ipMap      map[string]int64 // ip => id
 | 
				
			||||||
 | 
						idMap      map[int64]string // id => ip
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						id     int64
 | 
				
			||||||
 | 
						locker sync.RWMutex
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewIPList 获取新对象
 | 
				
			||||||
 | 
					func NewIPList() *IPList {
 | 
				
			||||||
 | 
						var list = &IPList{
 | 
				
			||||||
 | 
							ipMap: map[string]int64{},
 | 
				
			||||||
 | 
							idMap: map[int64]string{},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						e := expires.NewList()
 | 
				
			||||||
 | 
						list.expireList = e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							e.StartGC(func(itemId int64) {
 | 
				
			||||||
 | 
								list.remove(itemId)
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return list
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Add 添加IP
 | 
				
			||||||
 | 
					func (this *IPList) Add(ipType string, ip string, expiresAt int64) {
 | 
				
			||||||
 | 
						ip = ip + "@" + ipType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var id = this.nextId()
 | 
				
			||||||
 | 
						this.expireList.Add(id, expiresAt)
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						this.ipMap[ip] = id
 | 
				
			||||||
 | 
						this.idMap[id] = ip
 | 
				
			||||||
 | 
						this.locker.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Contains 判断是否有某个IP
 | 
				
			||||||
 | 
					func (this *IPList) Contains(ipType string, ip string) bool {
 | 
				
			||||||
 | 
						ip = ip + "@" + ipType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.locker.RLock()
 | 
				
			||||||
 | 
						defer this.locker.RUnlock()
 | 
				
			||||||
 | 
						_, ok := this.ipMap[ip]
 | 
				
			||||||
 | 
						return ok
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *IPList) remove(id int64) {
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						ip, ok := this.idMap[id]
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							ipId, ok := this.ipMap[ip]
 | 
				
			||||||
 | 
							if ok && ipId == id {
 | 
				
			||||||
 | 
								delete(this.ipMap, ip)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							delete(this.idMap, id)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.locker.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *IPList) nextId() int64 {
 | 
				
			||||||
 | 
						return atomic.AddInt64(&this.id, 1)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										67
									
								
								internal/waf/ip_list_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								internal/waf/ip_list_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,67 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/assert"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"runtime"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestNewIPList(t *testing.T) {
 | 
				
			||||||
 | 
						list := NewIPList()
 | 
				
			||||||
 | 
						list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix())
 | 
				
			||||||
 | 
						list.Add(IPTypeAll, "127.0.0.2", time.Now().Unix()+1)
 | 
				
			||||||
 | 
						list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix()+2)
 | 
				
			||||||
 | 
						list.Add(IPTypeAll, "127.0.0.3", time.Now().Unix()+3)
 | 
				
			||||||
 | 
						list.Add(IPTypeAll, "127.0.0.10", time.Now().Unix()+10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var ticker = time.NewTicker(1 * time.Second)
 | 
				
			||||||
 | 
						for range ticker.C {
 | 
				
			||||||
 | 
							t.Log("====")
 | 
				
			||||||
 | 
							logs.PrintAsJSON(list.ipMap, t)
 | 
				
			||||||
 | 
							logs.PrintAsJSON(list.idMap, t)
 | 
				
			||||||
 | 
							if len(list.idMap) == 0 {
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestIPList_Contains(t *testing.T) {
 | 
				
			||||||
 | 
						a := assert.NewAssertion(t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						list := NewIPList()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := 0; i < 1_0000; i++ {
 | 
				
			||||||
 | 
							list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						a.IsTrue(list.Contains(IPTypeAll, "192.168.1.100"))
 | 
				
			||||||
 | 
						a.IsFalse(list.Contains(IPTypeAll, "192.168.2.100"))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkIPList_Add(b *testing.B) {
 | 
				
			||||||
 | 
						runtime.GOMAXPROCS(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						list := NewIPList()
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
 | 
							list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						b.Log(len(list.ipMap))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkIPList_Has(b *testing.B) {
 | 
				
			||||||
 | 
						runtime.GOMAXPROCS(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						list := NewIPList()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := 0; i < 1_0000; i++ {
 | 
				
			||||||
 | 
							list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
 | 
							list.Contains(IPTypeAll, "192.168.1.100")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,154 +0,0 @@
 | 
				
			|||||||
package waf
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
 | 
					 | 
				
			||||||
	"github.com/iwind/TeaGo/lists"
 | 
					 | 
				
			||||||
	"github.com/iwind/TeaGo/types"
 | 
					 | 
				
			||||||
	stringutil "github.com/iwind/TeaGo/utils/string"
 | 
					 | 
				
			||||||
	"regexp"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type IPAction = string
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
var RegexpDigitNumber = regexp.MustCompile("^\\d+$")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const (
 | 
					 | 
				
			||||||
	IPActionAccept IPAction = "accept"
 | 
					 | 
				
			||||||
	IPActionReject IPAction = "reject"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// ip table
 | 
					 | 
				
			||||||
type IPTable struct {
 | 
					 | 
				
			||||||
	Id       string   `yaml:"id" json:"id"`
 | 
					 | 
				
			||||||
	On       bool     `yaml:"on" json:"on"`
 | 
					 | 
				
			||||||
	IP       string   `yaml:"ip" json:"ip"`             // single ip, cidr, ip range, TODO support *
 | 
					 | 
				
			||||||
	Port     string   `yaml:"port" json:"port"`         // single port, range, *
 | 
					 | 
				
			||||||
	Action   IPAction `yaml:"action" json:"action"`     // accept, reject
 | 
					 | 
				
			||||||
	TimeFrom int64    `yaml:"timeFrom" json:"timeFrom"` // from timestamp
 | 
					 | 
				
			||||||
	TimeTo   int64    `yaml:"timeTo" json:"timeTo"`     // zero means forever
 | 
					 | 
				
			||||||
	Remark   string   `yaml:"remark" json:"remark"`
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// port
 | 
					 | 
				
			||||||
	minPort int
 | 
					 | 
				
			||||||
	maxPort int
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	minPortWildcard bool
 | 
					 | 
				
			||||||
	maxPortWildcard bool
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	ports []int
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// ip
 | 
					 | 
				
			||||||
	ipRange *shared.IPRangeConfig
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func NewIPTable() *IPTable {
 | 
					 | 
				
			||||||
	return &IPTable{
 | 
					 | 
				
			||||||
		On: true,
 | 
					 | 
				
			||||||
		Id: stringutil.Rand(16),
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (this *IPTable) Init() error {
 | 
					 | 
				
			||||||
	// parse port
 | 
					 | 
				
			||||||
	if RegexpDigitNumber.MatchString(this.Port) {
 | 
					 | 
				
			||||||
		this.minPort = types.Int(this.Port)
 | 
					 | 
				
			||||||
		this.maxPort = types.Int(this.Port)
 | 
					 | 
				
			||||||
	} else if regexp.MustCompile(`[:-]`).MatchString(this.Port) {
 | 
					 | 
				
			||||||
		pieces := regexp.MustCompile(`[:-]`).Split(this.Port, 2)
 | 
					 | 
				
			||||||
		if pieces[0] == "*" {
 | 
					 | 
				
			||||||
			this.minPortWildcard = true
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			this.minPort = types.Int(pieces[0])
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if pieces[1] == "*" {
 | 
					 | 
				
			||||||
			this.maxPortWildcard = true
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			this.maxPort = types.Int(pieces[1])
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else if strings.Contains(this.Port, ",") {
 | 
					 | 
				
			||||||
		pieces := strings.Split(this.Port, ",")
 | 
					 | 
				
			||||||
		for _, piece := range pieces {
 | 
					 | 
				
			||||||
			piece = strings.TrimSpace(piece)
 | 
					 | 
				
			||||||
			if len(piece) > 0 {
 | 
					 | 
				
			||||||
				this.ports = append(this.ports, types.Int(piece))
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else if this.Port == "*" {
 | 
					 | 
				
			||||||
		this.minPortWildcard = true
 | 
					 | 
				
			||||||
		this.maxPortWildcard = true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// parse ip
 | 
					 | 
				
			||||||
	if len(this.IP) > 0 {
 | 
					 | 
				
			||||||
		ipRange, err := shared.ParseIPRange(this.IP)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return err
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		this.ipRange = ipRange
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// check ip
 | 
					 | 
				
			||||||
func (this *IPTable) Match(ip string, port int) (isMatched bool) {
 | 
					 | 
				
			||||||
	if !this.On {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	now := time.Now().Unix()
 | 
					 | 
				
			||||||
	if this.TimeFrom > 0 && now < this.TimeFrom {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if this.TimeTo > 0 && now > this.TimeTo {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !this.matchPort(port) {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if !this.matchIP(ip) {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return true
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (this *IPTable) matchPort(port int) bool {
 | 
					 | 
				
			||||||
	if port == 0 {
 | 
					 | 
				
			||||||
		return false
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if this.minPortWildcard {
 | 
					 | 
				
			||||||
		if this.maxPortWildcard {
 | 
					 | 
				
			||||||
			return true
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if this.maxPort >= port {
 | 
					 | 
				
			||||||
			return true
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if this.maxPortWildcard {
 | 
					 | 
				
			||||||
		if this.minPortWildcard {
 | 
					 | 
				
			||||||
			return true
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if this.minPort <= port {
 | 
					 | 
				
			||||||
			return true
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if (this.minPort > 0 || this.maxPort > 0) && this.minPort <= port && this.maxPort >= port {
 | 
					 | 
				
			||||||
		return true
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if len(this.ports) > 0 {
 | 
					 | 
				
			||||||
		return lists.ContainsInt(this.ports, port)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return false
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (this *IPTable) matchIP(ip string) bool {
 | 
					 | 
				
			||||||
	if this.ipRange == nil {
 | 
					 | 
				
			||||||
		return false
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return this.ipRange.Contains(ip)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@@ -1,142 +0,0 @@
 | 
				
			|||||||
package waf
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import (
 | 
					 | 
				
			||||||
	"github.com/iwind/TeaGo/assert"
 | 
					 | 
				
			||||||
	"testing"
 | 
					 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestIPTable_MatchIP(t *testing.T) {
 | 
					 | 
				
			||||||
	a := assert.NewAssertion(t)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		table := NewIPTable()
 | 
					 | 
				
			||||||
		err := table.Init()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatal(err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		a.IsFalse(table.Match("192.168.1.100", 8080))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		table := NewIPTable()
 | 
					 | 
				
			||||||
		table.IP = "*"
 | 
					 | 
				
			||||||
		table.Port = "8080"
 | 
					 | 
				
			||||||
		err := table.Init()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatal(err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		t.Log("port:", table.minPort, table.maxPort)
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
					 | 
				
			||||||
		a.IsFalse(table.Match("192.168.1.100", 8081))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		table := NewIPTable()
 | 
					 | 
				
			||||||
		table.IP = "*"
 | 
					 | 
				
			||||||
		table.Port = "8080-8082"
 | 
					 | 
				
			||||||
		err := table.Init()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatal(err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		t.Log("port:", table.minPort, table.maxPort)
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8081))
 | 
					 | 
				
			||||||
		a.IsFalse(table.Match("192.168.1.100", 8083))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		table := NewIPTable()
 | 
					 | 
				
			||||||
		table.IP = "*"
 | 
					 | 
				
			||||||
		table.Port = "*-8082"
 | 
					 | 
				
			||||||
		err := table.Init()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatal(err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		t.Log("port:", table.minPort, table.maxPort)
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8079))
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8081))
 | 
					 | 
				
			||||||
		a.IsFalse(table.Match("192.168.1.100", 8083))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		table := NewIPTable()
 | 
					 | 
				
			||||||
		table.IP = "*"
 | 
					 | 
				
			||||||
		table.Port = "8080-*"
 | 
					 | 
				
			||||||
		err := table.Init()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatal(err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		t.Log("port:", table.minPort, table.maxPort)
 | 
					 | 
				
			||||||
		a.IsFalse(table.Match("192.168.1.100", 8079))
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8081))
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8083))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		table := NewIPTable()
 | 
					 | 
				
			||||||
		table.IP = "*"
 | 
					 | 
				
			||||||
		table.Port = "*"
 | 
					 | 
				
			||||||
		err := table.Init()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatal(err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		t.Log("port:", table.minPort, table.maxPort)
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8079))
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8081))
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8083))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		table := NewIPTable()
 | 
					 | 
				
			||||||
		table.IP = "192.168.1.100"
 | 
					 | 
				
			||||||
		table.Port = "*"
 | 
					 | 
				
			||||||
		err := table.Init()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatal(err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		t.Log("port:", table.minPort, table.maxPort)
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		table := NewIPTable()
 | 
					 | 
				
			||||||
		table.IP = "192.168.1.99-192.168.1.101"
 | 
					 | 
				
			||||||
		table.Port = "*"
 | 
					 | 
				
			||||||
		err := table.Init()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatal(err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		t.Log("port:", table.minPort, table.maxPort)
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		table := NewIPTable()
 | 
					 | 
				
			||||||
		table.IP = "192.168.1.99/24"
 | 
					 | 
				
			||||||
		table.Port = "*"
 | 
					 | 
				
			||||||
		err := table.Init()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatal(err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		t.Log("ip:", table.ipRange)
 | 
					 | 
				
			||||||
		a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
					 | 
				
			||||||
		a.IsFalse(table.Match("192.168.2.100", 8080))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	{
 | 
					 | 
				
			||||||
		table := NewIPTable()
 | 
					 | 
				
			||||||
		table.IP = "192.168.1.99/24"
 | 
					 | 
				
			||||||
		table.TimeTo = time.Now().Unix() - 10
 | 
					 | 
				
			||||||
		table.Port = "*"
 | 
					 | 
				
			||||||
		err := table.Init()
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			t.Fatal(err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		a.IsFalse(table.Match("192.168.1.100", 8080))
 | 
					 | 
				
			||||||
		a.IsFalse(table.Match("192.168.2.100", 8080))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@@ -1,39 +1,28 @@
 | 
				
			|||||||
package requests
 | 
					package requests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
					 | 
				
			||||||
	"io"
 | 
					 | 
				
			||||||
	"io/ioutil"
 | 
					 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Request struct {
 | 
					type Request interface {
 | 
				
			||||||
	*http.Request
 | 
						// WAFRaw 原始请求
 | 
				
			||||||
	BodyData []byte
 | 
						WAFRaw() *http.Request
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewRequest(raw *http.Request) *Request {
 | 
						// WAFRemoteIP 客户端IP
 | 
				
			||||||
	return &Request{
 | 
						WAFRemoteIP() string
 | 
				
			||||||
		Request: raw,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *Request) Raw() *http.Request {
 | 
						// WAFGetCacheBody 获取缓存中的Body
 | 
				
			||||||
	return this.Request
 | 
						WAFGetCacheBody() []byte
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *Request) ReadBody(max int64) (data []byte, err error) {
 | 
						// WAFSetCacheBody 设置Body
 | 
				
			||||||
	if this.Request.ContentLength > 0 {
 | 
						WAFSetCacheBody(body []byte)
 | 
				
			||||||
		data, err = ioutil.ReadAll(io.LimitReader(this.Request.Body, max))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *Request) RestoreBody(data []byte) {
 | 
						// WAFReadBody 读取Body
 | 
				
			||||||
	if len(data) > 0 {
 | 
						WAFReadBody(max int64) (data []byte, err error)
 | 
				
			||||||
		rawReader := bytes.NewBuffer(data)
 | 
					
 | 
				
			||||||
		buf := make([]byte, 1024)
 | 
						// WAFRestoreBody 恢复Body
 | 
				
			||||||
		_, _ = io.CopyBuffer(rawReader, this.Request.Body, buf)
 | 
						WAFRestoreBody(data []byte)
 | 
				
			||||||
		this.Request.Body = ioutil.NopCloser(rawReader)
 | 
					
 | 
				
			||||||
	}
 | 
						// WAFServerId 服务ID
 | 
				
			||||||
 | 
						WAFServerId() int64
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										67
									
								
								internal/waf/requests/test_request.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								internal/waf/requests/test_request.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,67 @@
 | 
				
			|||||||
 | 
					// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package requests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type TestRequest struct {
 | 
				
			||||||
 | 
						req      *http.Request
 | 
				
			||||||
 | 
						BodyData []byte
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewTestRequest(raw *http.Request) *TestRequest {
 | 
				
			||||||
 | 
						return &TestRequest{
 | 
				
			||||||
 | 
							req: raw,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *TestRequest) WAFSetCacheBody(bodyData []byte) {
 | 
				
			||||||
 | 
						this.BodyData = bodyData
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *TestRequest) WAFGetCacheBody() []byte {
 | 
				
			||||||
 | 
						return this.BodyData
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *TestRequest) WAFRaw() *http.Request {
 | 
				
			||||||
 | 
						return this.req
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *TestRequest) WAFRemoteAddr() string {
 | 
				
			||||||
 | 
						return this.req.RemoteAddr
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *TestRequest) WAFRemoteIP() string {
 | 
				
			||||||
 | 
						host, _, err := net.SplitHostPort(this.req.RemoteAddr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return this.req.RemoteAddr
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							return host
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *TestRequest) WAFReadBody(max int64) (data []byte, err error) {
 | 
				
			||||||
 | 
						if this.req.ContentLength > 0 {
 | 
				
			||||||
 | 
							data, err = ioutil.ReadAll(io.LimitReader(this.req.Body, max))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *TestRequest) WAFRestoreBody(data []byte) {
 | 
				
			||||||
 | 
						if len(data) > 0 {
 | 
				
			||||||
 | 
							rawReader := bytes.NewBuffer(data)
 | 
				
			||||||
 | 
							buf := make([]byte, 1024)
 | 
				
			||||||
 | 
							_, _ = io.CopyBuffer(rawReader, this.req.Body, buf)
 | 
				
			||||||
 | 
							this.req.Body = ioutil.NopCloser(rawReader)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *TestRequest) WAFServerId() int64 {
 | 
				
			||||||
 | 
						return 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -183,7 +183,7 @@ func (this *Rule) Init() error {
 | 
				
			|||||||
	return err
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *Rule) MatchRequest(req *requests.Request) (b bool, err error) {
 | 
					func (this *Rule) MatchRequest(req requests.Request) (b bool, err error) {
 | 
				
			||||||
	if this.singleCheckpoint != nil {
 | 
						if this.singleCheckpoint != nil {
 | 
				
			||||||
		value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions)
 | 
							value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
@@ -233,7 +233,7 @@ func (this *Rule) MatchRequest(req *requests.Request) (b bool, err error) {
 | 
				
			|||||||
	return this.Test(value), nil
 | 
						return this.Test(value), nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *Rule) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, err error) {
 | 
					func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (b bool, err error) {
 | 
				
			||||||
	if this.singleCheckpoint != nil {
 | 
						if this.singleCheckpoint != nil {
 | 
				
			||||||
		// if is request param
 | 
							// if is request param
 | 
				
			||||||
		if this.singleCheckpoint.IsRequest() {
 | 
							if this.singleCheckpoint.IsRequest() {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -23,12 +23,12 @@ func NewRuleGroup() *RuleGroup {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RuleGroup) Init() error {
 | 
					func (this *RuleGroup) Init(waf *WAF) error {
 | 
				
			||||||
	this.hasRuleSets = len(this.RuleSets) > 0
 | 
						this.hasRuleSets = len(this.RuleSets) > 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if this.hasRuleSets {
 | 
						if this.hasRuleSets {
 | 
				
			||||||
		for _, set := range this.RuleSets {
 | 
							for _, set := range this.RuleSets {
 | 
				
			||||||
			err := set.Init()
 | 
								err := set.Init(waf)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -79,7 +79,7 @@ func (this *RuleGroup) RemoveRuleSet(id string) {
 | 
				
			|||||||
	this.RuleSets = result
 | 
						this.RuleSets = result
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RuleGroup) MatchRequest(req *requests.Request) (b bool, set *RuleSet, err error) {
 | 
					func (this *RuleGroup) MatchRequest(req requests.Request) (b bool, set *RuleSet, err error) {
 | 
				
			||||||
	if !this.hasRuleSets {
 | 
						if !this.hasRuleSets {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -98,7 +98,7 @@ func (this *RuleGroup) MatchRequest(req *requests.Request) (b bool, set *RuleSet
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RuleGroup) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, set *RuleSet, err error) {
 | 
					func (this *RuleGroup) MatchResponse(req requests.Request, resp *requests.Response) (b bool, set *RuleSet, err error) {
 | 
				
			||||||
	if !this.hasRuleSets {
 | 
						if !this.hasRuleSets {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,9 +1,13 @@
 | 
				
			|||||||
package waf
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/lists"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
	"github.com/iwind/TeaGo/maps"
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
	"github.com/iwind/TeaGo/utils/string"
 | 
						"github.com/iwind/TeaGo/utils/string"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type RuleConnector = string
 | 
					type RuleConnector = string
 | 
				
			||||||
@@ -14,16 +18,17 @@ const (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type RuleSet struct {
 | 
					type RuleSet struct {
 | 
				
			||||||
	Id          string        `yaml:"id" json:"id"`
 | 
						Id          string          `yaml:"id" json:"id"`
 | 
				
			||||||
	Code        string        `yaml:"code" json:"code"`
 | 
						Code        string          `yaml:"code" json:"code"`
 | 
				
			||||||
	IsOn        bool          `yaml:"isOn" json:"isOn"`
 | 
						IsOn        bool            `yaml:"isOn" json:"isOn"`
 | 
				
			||||||
	Name        string        `yaml:"name" json:"name"`
 | 
						Name        string          `yaml:"name" json:"name"`
 | 
				
			||||||
	Description string        `yaml:"description" json:"description"`
 | 
						Description string          `yaml:"description" json:"description"`
 | 
				
			||||||
	Rules       []*Rule       `yaml:"rules" json:"rules"`
 | 
						Rules       []*Rule         `yaml:"rules" json:"rules"`
 | 
				
			||||||
	Connector   RuleConnector `yaml:"connector" json:"connector"` // rules connector
 | 
						Connector   RuleConnector   `yaml:"connector" json:"connector"` // rules connector
 | 
				
			||||||
 | 
						Actions     []*ActionConfig `yaml:"actions" json:"actions"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	Action        ActionString `yaml:"action" json:"action"`
 | 
						actionCodes     []string
 | 
				
			||||||
	ActionOptions maps.Map     `yaml:"actionOptions" json:"actionOptions"` // TODO TO BE IMPLEMENTED
 | 
						actionInstances []ActionInterface
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	hasRules bool
 | 
						hasRules bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -35,7 +40,7 @@ func NewRuleSet() *RuleSet {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RuleSet) Init() error {
 | 
					func (this *RuleSet) Init(waf *WAF) error {
 | 
				
			||||||
	this.hasRules = len(this.Rules) > 0
 | 
						this.hasRules = len(this.Rules) > 0
 | 
				
			||||||
	if this.hasRules {
 | 
						if this.hasRules {
 | 
				
			||||||
		for _, rule := range this.Rules {
 | 
							for _, rule := range this.Rules {
 | 
				
			||||||
@@ -45,6 +50,31 @@ func (this *RuleSet) Init() error {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// action codes
 | 
				
			||||||
 | 
						var actionCodes = []string{}
 | 
				
			||||||
 | 
						for _, action := range this.Actions {
 | 
				
			||||||
 | 
							if !lists.ContainsString(actionCodes, action.Code) {
 | 
				
			||||||
 | 
								actionCodes = append(actionCodes, action.Code)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.actionCodes = actionCodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// action instances
 | 
				
			||||||
 | 
						this.actionInstances = []ActionInterface{}
 | 
				
			||||||
 | 
						for _, action := range this.Actions {
 | 
				
			||||||
 | 
							instance := FindActionInstance(action.Code, action.Options)
 | 
				
			||||||
 | 
							if instance == nil {
 | 
				
			||||||
 | 
								remotelogs.Error("WAF_RULE_SET", "can not find instance for action '"+action.Code+"'")
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								this.actionInstances = append(this.actionInstances, instance)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							err := instance.Init(waf)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								remotelogs.Error("WAF_RULE_SET", "init action '"+action.Code+"' failed: "+err.Error())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -52,7 +82,75 @@ func (this *RuleSet) AddRule(rule ...*Rule) {
 | 
				
			|||||||
	this.Rules = append(this.Rules, rule...)
 | 
						this.Rules = append(this.Rules, rule...)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RuleSet) MatchRequest(req *requests.Request) (b bool, err error) {
 | 
					// AddAction 添加动作
 | 
				
			||||||
 | 
					func (this *RuleSet) AddAction(code string, options maps.Map) {
 | 
				
			||||||
 | 
						if options == nil {
 | 
				
			||||||
 | 
							options = maps.Map{}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.Actions = append(this.Actions, &ActionConfig{
 | 
				
			||||||
 | 
							Code:    code,
 | 
				
			||||||
 | 
							Options: options,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// HasSpecialActions 除了Allow之外是否还有别的动作
 | 
				
			||||||
 | 
					func (this *RuleSet) HasSpecialActions() bool {
 | 
				
			||||||
 | 
						for _, action := range this.Actions {
 | 
				
			||||||
 | 
							if action.Code != ActionAllow {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// HasAttackActions 检查是否含有攻击防御动作
 | 
				
			||||||
 | 
					func (this *RuleSet) HasAttackActions() bool {
 | 
				
			||||||
 | 
						for _, action := range this.actionInstances {
 | 
				
			||||||
 | 
							if action.IsAttack() {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RuleSet) ActionCodes() []string {
 | 
				
			||||||
 | 
						return this.actionCodes
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Request, writer http.ResponseWriter) bool {
 | 
				
			||||||
 | 
						// 先执行allow
 | 
				
			||||||
 | 
						for _, instance := range this.actionInstances {
 | 
				
			||||||
 | 
							if !instance.WillChange() {
 | 
				
			||||||
 | 
								if waf.onActionCallback != nil {
 | 
				
			||||||
 | 
									goNext := waf.onActionCallback(instance)
 | 
				
			||||||
 | 
									if !goNext {
 | 
				
			||||||
 | 
										return false
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								logs.Printf("perform1: %#v", instance) // TODO
 | 
				
			||||||
 | 
								instance.Perform(waf, group, this, req, writer)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 再执行block|verify
 | 
				
			||||||
 | 
						for _, instance := range this.actionInstances {
 | 
				
			||||||
 | 
							// 只执行第一个可能改变请求的动作,其余的都会被忽略
 | 
				
			||||||
 | 
							if instance.WillChange() {
 | 
				
			||||||
 | 
								if waf.onActionCallback != nil {
 | 
				
			||||||
 | 
									goNext := waf.onActionCallback(instance)
 | 
				
			||||||
 | 
									if !goNext {
 | 
				
			||||||
 | 
										return false
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								logs.Printf("perform2: %#v", instance) // TODO
 | 
				
			||||||
 | 
								return instance.Perform(waf, group, this, req, writer)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RuleSet) MatchRequest(req requests.Request) (b bool, err error) {
 | 
				
			||||||
	if !this.hasRules {
 | 
						if !this.hasRules {
 | 
				
			||||||
		return false, nil
 | 
							return false, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -93,7 +191,7 @@ func (this *RuleSet) MatchRequest(req *requests.Request) (b bool, err error) {
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *RuleSet) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, err error) {
 | 
					func (this *RuleSet) MatchResponse(req requests.Request, resp *requests.Response) (b bool, err error) {
 | 
				
			||||||
	if !this.hasRules {
 | 
						if !this.hasRules {
 | 
				
			||||||
		return false, nil
 | 
							return false, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -28,7 +28,7 @@ func TestRuleSet_MatchRequest(t *testing.T) {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := set.Init()
 | 
						err := set.Init(nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -37,7 +37,7 @@ func TestRuleSet_MatchRequest(t *testing.T) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	t.Log(set.MatchRequest(req))
 | 
						t.Log(set.MatchRequest(req))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -60,7 +60,7 @@ func TestRuleSet_MatchRequest2(t *testing.T) {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := set.Init()
 | 
						err := set.Init(nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -69,7 +69,7 @@ func TestRuleSet_MatchRequest2(t *testing.T) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	a.IsTrue(set.MatchRequest(req))
 | 
						a.IsTrue(set.MatchRequest(req))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -102,7 +102,7 @@ func BenchmarkRuleSet_MatchRequest(b *testing.B) {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := set.Init()
 | 
						err := set.Init(nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		b.Fatal(err)
 | 
							b.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -111,7 +111,7 @@ func BenchmarkRuleSet_MatchRequest(b *testing.B) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		b.Fatal(err)
 | 
							b.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	for i := 0; i < b.N; i++ {
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
		_, _ = set.MatchRequest(req)
 | 
							_, _ = set.MatchRequest(req)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -132,7 +132,7 @@ func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := set.Init()
 | 
						err := set.Init(nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		b.Fatal(err)
 | 
							b.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -141,7 +141,7 @@ func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		b.Fatal(err)
 | 
							b.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	for i := 0; i < b.N; i++ {
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
		_, _ = set.MatchRequest(req)
 | 
							_, _ = set.MatchRequest(req)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -25,7 +25,7 @@ func TestRule_Init_Single(t *testing.T) {
 | 
				
			|||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	t.Log(rule.MatchRequest(req))
 | 
						t.Log(rule.MatchRequest(req))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -44,7 +44,7 @@ func TestRule_Init_Composite(t *testing.T) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
						req := requests.NewTestRequest(rawReq)
 | 
				
			||||||
	t.Log(rule.MatchRequest(req))
 | 
						t.Log(rule.MatchRequest(req))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -20,7 +20,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "Javascript事件"
 | 
								set.Name = "Javascript事件"
 | 
				
			||||||
			set.Code = "1001"
 | 
								set.Code = "1001"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${requestURI}",
 | 
									Param:             "${requestURI}",
 | 
				
			||||||
				Operator:          RuleOperatorMatch,
 | 
									Operator:          RuleOperatorMatch,
 | 
				
			||||||
@@ -36,7 +36,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "Javascript函数"
 | 
								set.Name = "Javascript函数"
 | 
				
			||||||
			set.Code = "1002"
 | 
								set.Code = "1002"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${requestURI}",
 | 
									Param:             "${requestURI}",
 | 
				
			||||||
				Operator:          RuleOperatorMatch,
 | 
									Operator:          RuleOperatorMatch,
 | 
				
			||||||
@@ -52,7 +52,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "HTML标签"
 | 
								set.Name = "HTML标签"
 | 
				
			||||||
			set.Code = "1003"
 | 
								set.Code = "1003"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${requestURI}",
 | 
									Param:             "${requestURI}",
 | 
				
			||||||
				Operator:          RuleOperatorMatch,
 | 
									Operator:          RuleOperatorMatch,
 | 
				
			||||||
@@ -80,7 +80,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "上传文件扩展名"
 | 
								set.Name = "上传文件扩展名"
 | 
				
			||||||
			set.Code = "2001"
 | 
								set.Code = "2001"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${requestUpload.ext}",
 | 
									Param:             "${requestUpload.ext}",
 | 
				
			||||||
				Operator:          RuleOperatorMatch,
 | 
									Operator:          RuleOperatorMatch,
 | 
				
			||||||
@@ -108,7 +108,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "Web Shell"
 | 
								set.Name = "Web Shell"
 | 
				
			||||||
			set.Code = "3001"
 | 
								set.Code = "3001"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${requestAll}",
 | 
									Param:             "${requestAll}",
 | 
				
			||||||
				Operator:          RuleOperatorMatch,
 | 
									Operator:          RuleOperatorMatch,
 | 
				
			||||||
@@ -135,7 +135,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "命令注入"
 | 
								set.Name = "命令注入"
 | 
				
			||||||
			set.Code = "4001"
 | 
								set.Code = "4001"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${requestURI}",
 | 
									Param:             "${requestURI}",
 | 
				
			||||||
				Operator:          RuleOperatorMatch,
 | 
									Operator:          RuleOperatorMatch,
 | 
				
			||||||
@@ -169,7 +169,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "路径穿越"
 | 
								set.Name = "路径穿越"
 | 
				
			||||||
			set.Code = "5001"
 | 
								set.Code = "5001"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${requestURI}",
 | 
									Param:             "${requestURI}",
 | 
				
			||||||
				Operator:          RuleOperatorMatch,
 | 
									Operator:          RuleOperatorMatch,
 | 
				
			||||||
@@ -197,7 +197,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "特殊目录"
 | 
								set.Name = "特殊目录"
 | 
				
			||||||
			set.Code = "6001"
 | 
								set.Code = "6001"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${requestPath}",
 | 
									Param:             "${requestPath}",
 | 
				
			||||||
				Operator:          RuleOperatorMatch,
 | 
									Operator:          RuleOperatorMatch,
 | 
				
			||||||
@@ -225,7 +225,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "Union SQL Injection"
 | 
								set.Name = "Union SQL Injection"
 | 
				
			||||||
			set.Code = "7001"
 | 
								set.Code = "7001"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${requestAll}",
 | 
									Param:             "${requestAll}",
 | 
				
			||||||
@@ -243,7 +243,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "SQL注释"
 | 
								set.Name = "SQL注释"
 | 
				
			||||||
			set.Code = "7002"
 | 
								set.Code = "7002"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${requestAll}",
 | 
									Param:             "${requestAll}",
 | 
				
			||||||
@@ -261,7 +261,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "SQL条件"
 | 
								set.Name = "SQL条件"
 | 
				
			||||||
			set.Code = "7003"
 | 
								set.Code = "7003"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${requestAll}",
 | 
									Param:             "${requestAll}",
 | 
				
			||||||
@@ -297,7 +297,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "SQL函数"
 | 
								set.Name = "SQL函数"
 | 
				
			||||||
			set.Code = "7004"
 | 
								set.Code = "7004"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${requestAll}",
 | 
									Param:             "${requestAll}",
 | 
				
			||||||
@@ -315,7 +315,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "SQL附加语句"
 | 
								set.Name = "SQL附加语句"
 | 
				
			||||||
			set.Code = "7005"
 | 
								set.Code = "7005"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${requestAll}",
 | 
									Param:             "${requestAll}",
 | 
				
			||||||
@@ -345,7 +345,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Name = "常见网络爬虫"
 | 
								set.Name = "常见网络爬虫"
 | 
				
			||||||
			set.Code = "20001"
 | 
								set.Code = "20001"
 | 
				
			||||||
			set.Connector = RuleConnectorOr
 | 
								set.Connector = RuleConnectorOr
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:             "${userAgent}",
 | 
									Param:             "${userAgent}",
 | 
				
			||||||
@@ -376,7 +376,7 @@ func Template() *WAF {
 | 
				
			|||||||
			set.Description = "限制单IP在一定时间内的请求数"
 | 
								set.Description = "限制单IP在一定时间内的请求数"
 | 
				
			||||||
			set.Code = "8001"
 | 
								set.Code = "8001"
 | 
				
			||||||
			set.Connector = RuleConnectorAnd
 | 
								set.Connector = RuleConnectorAnd
 | 
				
			||||||
			set.Action = ActionBlock
 | 
								set.AddAction(ActionBlock, nil)
 | 
				
			||||||
			set.AddRule(&Rule{
 | 
								set.AddRule(&Rule{
 | 
				
			||||||
				Param:    "${cc.requests}",
 | 
									Param:    "${cc.requests}",
 | 
				
			||||||
				Operator: RuleOperatorGt,
 | 
									Operator: RuleOperatorGt,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,7 @@ package waf
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
	"github.com/iwind/TeaGo/assert"
 | 
						"github.com/iwind/TeaGo/assert"
 | 
				
			||||||
	"github.com/iwind/TeaGo/lists"
 | 
						"github.com/iwind/TeaGo/lists"
 | 
				
			||||||
	"github.com/iwind/TeaGo/logs"
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
@@ -22,8 +23,8 @@ func Test_Template(t *testing.T) {
 | 
				
			|||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	template.OnAction(func(action ActionString) (goNext bool) {
 | 
						template.OnAction(func(action ActionInterface) (goNext bool) {
 | 
				
			||||||
		return action != ActionBlock
 | 
							return action.Code() != ActionBlock
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	testTemplate1001(a, t, template)
 | 
						testTemplate1001(a, t, template)
 | 
				
			||||||
@@ -40,7 +41,7 @@ func Test_Template(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func Test_Template2(t *testing.T) {
 | 
					func Test_Template2(t *testing.T) {
 | 
				
			||||||
	reader := bytes.NewReader([]byte(strings.Repeat("HELLO", 1024)))
 | 
						reader := bytes.NewReader([]byte(strings.Repeat("HELLO", 1024)))
 | 
				
			||||||
	req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=123", reader)
 | 
						req, err := http.NewRequest(http.MethodGet, "https://example.com/index.php?id=123", reader)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -52,7 +53,7 @@ func Test_Template2(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	now := time.Now()
 | 
						now := time.Now()
 | 
				
			||||||
	goNext, _, set, err := waf.MatchRequest(req, nil)
 | 
						goNext, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -80,7 +81,7 @@ func BenchmarkTemplate(b *testing.B) {
 | 
				
			|||||||
			b.Fatal(err)
 | 
								b.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		_, _, _, _ = waf.MatchRequest(req, nil)
 | 
							_, _, _, _ = waf.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -89,7 +90,7 @@ func testTemplate1001(a *assert.Assertion, t *testing.T, template *WAF) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	_, _, result, err := template.MatchRequest(req, nil)
 | 
						_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -104,7 +105,7 @@ func testTemplate1002(a *assert.Assertion, t *testing.T, template *WAF) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	_, _, result, err := template.MatchRequest(req, nil)
 | 
						_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -119,7 +120,7 @@ func testTemplate1003(a *assert.Assertion, t *testing.T, template *WAF) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	_, _, result, err := template.MatchRequest(req, nil)
 | 
						_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -185,7 +186,7 @@ func testTemplate2001(a *assert.Assertion, t *testing.T, template *WAF) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	req.Header.Add("Content-Type", writer.FormDataContentType())
 | 
						req.Header.Add("Content-Type", writer.FormDataContentType())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	_, _, result, err := template.MatchRequest(req, nil)
 | 
						_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -200,7 +201,7 @@ func testTemplate3001(a *assert.Assertion, t *testing.T, template *WAF) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	_, _, result, err := template.MatchRequest(req, nil)
 | 
						_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -215,7 +216,7 @@ func testTemplate4001(a *assert.Assertion, t *testing.T, template *WAF) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	_, _, result, err := template.MatchRequest(req, nil)
 | 
						_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -231,7 +232,7 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) {
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			t.Fatal(err)
 | 
								t.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		_, _, result, err := template.MatchRequest(req, nil)
 | 
							_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			t.Fatal(err)
 | 
								t.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -246,7 +247,7 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) {
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			t.Fatal(err)
 | 
								t.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		_, _, result, err := template.MatchRequest(req, nil)
 | 
							_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			t.Fatal(err)
 | 
								t.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -263,7 +264,7 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) {
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			t.Fatal(err)
 | 
								t.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		_, _, result, err := template.MatchRequest(req, nil)
 | 
							_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			t.Fatal(err)
 | 
								t.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -278,7 +279,7 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) {
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			t.Fatal(err)
 | 
								t.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		_, _, result, err := template.MatchRequest(req, nil)
 | 
							_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			t.Fatal(err)
 | 
								t.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -301,7 +302,7 @@ func testTemplate7001(a *assert.Assertion, t *testing.T, template *WAF) {
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			t.Fatal(err)
 | 
								t.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		_, _, result, err := template.MatchRequest(req, nil)
 | 
							_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			t.Fatal(err)
 | 
								t.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -338,7 +339,7 @@ func testTemplate20001(a *assert.Assertion, t *testing.T, template *WAF) {
 | 
				
			|||||||
			t.Fatal(err)
 | 
								t.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		req.Header.Set("User-Agent", bot)
 | 
							req.Header.Set("User-Agent", bot)
 | 
				
			||||||
		_, _, result, err := template.MatchRequest(req, nil)
 | 
							_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			t.Fatal(err)
 | 
								t.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -22,13 +22,11 @@ type WAF struct {
 | 
				
			|||||||
	Outbound       []*RuleGroup `yaml:"outbound" json:"outbound"`
 | 
						Outbound       []*RuleGroup `yaml:"outbound" json:"outbound"`
 | 
				
			||||||
	CreatedVersion string       `yaml:"createdVersion" json:"createdVersion"`
 | 
						CreatedVersion string       `yaml:"createdVersion" json:"createdVersion"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ActionBlock *BlockAction `yaml:"actionBlock" json:"actionBlock"` // action block config
 | 
						DefaultBlockAction *BlockAction
 | 
				
			||||||
 | 
					 | 
				
			||||||
	IPTables []*IPTable `yaml:"ipTables" json:"ipTables"` // IP table list
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	hasInboundRules  bool
 | 
						hasInboundRules  bool
 | 
				
			||||||
	hasOutboundRules bool
 | 
						hasOutboundRules bool
 | 
				
			||||||
	onActionCallback func(action ActionString) (goNext bool)
 | 
						onActionCallback func(action ActionInterface) (goNext bool)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	checkpointsMap map[string]checkpoints.CheckpointInterface // prefix => checkpoint
 | 
						checkpointsMap map[string]checkpoints.CheckpointInterface // prefix => checkpoint
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -87,7 +85,7 @@ func (this *WAF) Init() error {
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			err := group.Init()
 | 
								err := group.Init(this)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -103,7 +101,7 @@ func (this *WAF) Init() error {
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			err := group.Init()
 | 
								err := group.Init(this)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return err
 | 
									return err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -241,19 +239,24 @@ func (this *WAF) MoveOutboundRuleGroup(fromIndex int, toIndex int) {
 | 
				
			|||||||
	this.Outbound = result
 | 
						this.Outbound = result
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *WAF) MatchRequest(rawReq *http.Request, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) {
 | 
					func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) {
 | 
				
			||||||
	if !this.hasInboundRules {
 | 
						if !this.hasInboundRules {
 | 
				
			||||||
		return true, nil, nil, nil
 | 
							return true, nil, nil, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// validate captcha
 | 
						// validate captcha
 | 
				
			||||||
	if rawReq.URL.Path == "/WAFCAPTCHA" {
 | 
						var rawPath = req.WAFRaw().URL.Path
 | 
				
			||||||
 | 
						if rawPath == CaptchaPath {
 | 
				
			||||||
		captchaValidator.Run(req, writer)
 | 
							captchaValidator.Run(req, writer)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Get 302验证
 | 
				
			||||||
 | 
						if rawPath == Get302Path {
 | 
				
			||||||
 | 
							get302Validator.Run(req, writer)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// match rules
 | 
						// match rules
 | 
				
			||||||
	for _, group := range this.Inbound {
 | 
						for _, group := range this.Inbound {
 | 
				
			||||||
		if !group.IsOn {
 | 
							if !group.IsOn {
 | 
				
			||||||
@@ -264,31 +267,17 @@ func (this *WAF) MatchRequest(rawReq *http.Request, writer http.ResponseWriter)
 | 
				
			|||||||
			return true, nil, nil, err
 | 
								return true, nil, nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if b {
 | 
							if b {
 | 
				
			||||||
			if this.onActionCallback == nil {
 | 
								goNext := set.PerformActions(this, group, req, writer)
 | 
				
			||||||
				if set.Action == ActionBlock && this.ActionBlock != nil {
 | 
					 | 
				
			||||||
					return this.ActionBlock.Perform(this, req, writer), group, set, nil
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					actionObject := FindActionInstance(set.Action, set.ActionOptions)
 | 
					 | 
				
			||||||
					if actionObject == nil {
 | 
					 | 
				
			||||||
						return true, group, set, errors.New("no action called '" + set.Action + "'")
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
					goNext := actionObject.Perform(this, req, writer)
 | 
					 | 
				
			||||||
					return goNext, group, set, nil
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				goNext = this.onActionCallback(set.Action)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			return goNext, group, set, nil
 | 
								return goNext, group, set, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return true, nil, nil, nil
 | 
						return true, nil, nil, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *WAF) MatchResponse(rawReq *http.Request, rawResp *http.Response, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) {
 | 
					func (this *WAF) MatchResponse(req requests.Request, rawResp *http.Response, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) {
 | 
				
			||||||
	if !this.hasOutboundRules {
 | 
						if !this.hasOutboundRules {
 | 
				
			||||||
		return true, nil, nil, nil
 | 
							return true, nil, nil, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	req := requests.NewRequest(rawReq)
 | 
					 | 
				
			||||||
	resp := requests.NewResponse(rawResp)
 | 
						resp := requests.NewResponse(rawResp)
 | 
				
			||||||
	for _, group := range this.Outbound {
 | 
						for _, group := range this.Outbound {
 | 
				
			||||||
		if !group.IsOn {
 | 
							if !group.IsOn {
 | 
				
			||||||
@@ -299,27 +288,14 @@ func (this *WAF) MatchResponse(rawReq *http.Request, rawResp *http.Response, wri
 | 
				
			|||||||
			return true, nil, nil, err
 | 
								return true, nil, nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if b {
 | 
							if b {
 | 
				
			||||||
			if this.onActionCallback == nil {
 | 
								goNext := set.PerformActions(this, group, req, writer)
 | 
				
			||||||
				if set.Action == ActionBlock && this.ActionBlock != nil {
 | 
					 | 
				
			||||||
					return this.ActionBlock.Perform(this, req, writer), group, set, nil
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					actionObject := FindActionInstance(set.Action, set.ActionOptions)
 | 
					 | 
				
			||||||
					if actionObject == nil {
 | 
					 | 
				
			||||||
						return true, group, set, errors.New("no action called '" + set.Action + "'")
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
					goNext := actionObject.Perform(this, req, writer)
 | 
					 | 
				
			||||||
					return goNext, group, set, nil
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				goNext = this.onActionCallback(set.Action)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			return goNext, group, set, nil
 | 
								return goNext, group, set, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return true, nil, nil, nil
 | 
						return true, nil, nil, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// save to file path
 | 
					// Save save to file path
 | 
				
			||||||
func (this *WAF) Save(path string) error {
 | 
					func (this *WAF) Save(path string) error {
 | 
				
			||||||
	if len(path) == 0 {
 | 
						if len(path) == 0 {
 | 
				
			||||||
		return errors.New("path should not be empty")
 | 
							return errors.New("path should not be empty")
 | 
				
			||||||
@@ -378,7 +354,7 @@ func (this *WAF) CountOutboundRuleSets() int {
 | 
				
			|||||||
	return count
 | 
						return count
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *WAF) OnAction(onActionCallback func(action ActionString) (goNext bool)) {
 | 
					func (this *WAF) OnAction(onActionCallback func(action ActionInterface) (goNext bool)) {
 | 
				
			||||||
	this.onActionCallback = onActionCallback
 | 
						this.onActionCallback = onActionCallback
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -390,21 +366,21 @@ func (this *WAF) FindCheckpointInstance(prefix string) checkpoints.CheckpointInt
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// start
 | 
					// Start start
 | 
				
			||||||
func (this *WAF) Start() {
 | 
					func (this *WAF) Start() {
 | 
				
			||||||
	for _, checkpoint := range this.checkpointsMap {
 | 
						for _, checkpoint := range this.checkpointsMap {
 | 
				
			||||||
		checkpoint.Start()
 | 
							checkpoint.Start()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// call stop() when the waf was deleted
 | 
					// Stop call stop() when the waf was deleted
 | 
				
			||||||
func (this *WAF) Stop() {
 | 
					func (this *WAF) Stop() {
 | 
				
			||||||
	for _, checkpoint := range this.checkpointsMap {
 | 
						for _, checkpoint := range this.checkpointsMap {
 | 
				
			||||||
		checkpoint.Stop()
 | 
							checkpoint.Stop()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// merge with template
 | 
					// MergeTemplate merge with template
 | 
				
			||||||
func (this *WAF) MergeTemplate() (changedItems []string) {
 | 
					func (this *WAF) MergeTemplate() (changedItems []string) {
 | 
				
			||||||
	changedItems = []string{}
 | 
						changedItems = []string{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,7 @@
 | 
				
			|||||||
package waf
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
	"github.com/iwind/TeaGo/assert"
 | 
						"github.com/iwind/TeaGo/assert"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
@@ -24,7 +25,7 @@ func TestWAF_MatchRequest(t *testing.T) {
 | 
				
			|||||||
			Value:    "20",
 | 
								Value:    "20",
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	set.Action = ActionBlock
 | 
						set.AddAction(ActionBlock, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	group := NewRuleGroup()
 | 
						group := NewRuleGroup()
 | 
				
			||||||
	group.AddRuleSet(set)
 | 
						group.AddRuleSet(set)
 | 
				
			||||||
@@ -37,15 +38,15 @@ func TestWAF_MatchRequest(t *testing.T) {
 | 
				
			|||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	waf.OnAction(func(action ActionString) (goNext bool) {
 | 
						waf.OnAction(func(action ActionInterface) (goNext bool) {
 | 
				
			||||||
		return action != ActionBlock
 | 
							return action.Code() != ActionBlock
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	req, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil)
 | 
						req, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	goNext, _, set, err := waf.MatchRequest(req, nil)
 | 
						goNext, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user