diff --git a/internal/iplibrary/ip_list.go b/internal/iplibrary/ip_list.go
index 4f04582..be7f86b 100644
--- a/internal/iplibrary/ip_list.go
+++ b/internal/iplibrary/ip_list.go
@@ -6,7 +6,7 @@ import (
"sync"
)
-// IP名单
+// IPList IP名单
type IPList struct {
itemsMap map[int64]*IPItem // id => item
ipMap map[uint64][]int64 // ip => itemIds
@@ -96,7 +96,7 @@ func (this *IPList) Delete(itemId int64) {
this.isAll = len(this.ipMap[0]) > 0
}
-// 判断是否包含某个IP
+// Contains 判断是否包含某个IP
func (this *IPList) Contains(ip uint64) bool {
this.locker.RLock()
if this.isAll {
@@ -109,7 +109,7 @@ func (this *IPList) Contains(ip uint64) bool {
return ok
}
-// 是否包含一组IP
+// ContainsIPStrings 是否包含一组IP
func (this *IPList) ContainsIPStrings(ipStrings []string) (found bool, item *IPItem) {
if len(ipStrings) == 0 {
return
diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go
index 1b610a4..0ffae99 100644
--- a/internal/nodes/http_request.go
+++ b/internal/nodes/http_request.go
@@ -68,12 +68,15 @@ type HTTPRequest struct {
cacheKey string // 缓存使用的Key
isCached bool // 是否已经被缓存
isAttack bool // 是否是攻击请求
+ bodyData []byte // 读取的Body内容
// WAF相关
firewallPolicyId int64
firewallRuleGroupId int64
firewallRuleSetId int64
firewallRuleId int64
+ firewallActions []string
+ tags []string
logAttrs map[string]string
@@ -1197,5 +1200,10 @@ func (this *HTTPRequest) canIgnore(err error) bool {
return true
}
+ // HTTP内部错误
+ if strings.HasPrefix(err.Error(), "http:") || strings.HasPrefix(err.Error(), "http2:") {
+ return true
+ }
+
return false
}
diff --git a/internal/nodes/http_request_log.go b/internal/nodes/http_request_log.go
index 93ac503..df92eeb 100644
--- a/internal/nodes/http_request_log.go
+++ b/internal/nodes/http_request_log.go
@@ -128,6 +128,8 @@ func (this *HTTPRequest) log() {
FirewallRuleGroupId: this.firewallRuleGroupId,
FirewallRuleSetId: this.firewallRuleSetId,
FirewallRuleId: this.firewallRuleId,
+ FirewallActions: this.firewallActions,
+ Tags: this.tags,
Attrs: this.logAttrs,
}
diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go
index 559d2b1..c6100c5 100644
--- a/internal/nodes/http_request_waf.go
+++ b/internal/nodes/http_request_waf.go
@@ -1,6 +1,7 @@
package nodes
import (
+ "bytes"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
@@ -8,6 +9,8 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
+ "io"
+ "io/ioutil"
"net/http"
)
@@ -152,27 +155,36 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
if w == nil {
return
}
- goNext, ruleGroup, ruleSet, err := w.MatchRequest(this.RawReq, this.writer)
+
+ w.OnAction(func(action waf.ActionInterface) (goNext bool) {
+ switch action.Code() {
+ case waf.ActionTag:
+ this.tags = action.(*waf.TagAction).Tags
+ }
+ return true
+ })
+
+ goNext, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer)
if err != nil {
remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
return
}
if ruleSet != nil {
- if ruleSet.Action != waf.ActionAllow {
+ if ruleSet.HasSpecialActions() {
this.firewallPolicyId = firewallPolicy.Id
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
this.firewallRuleSetId = types.Int64(ruleSet.Id)
- if ruleSet.Action == waf.ActionBlock {
+ if ruleSet.HasAttackActions() {
this.isAttack = true
}
// 添加统计
- stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Action)
+ stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions)
}
- this.logAttrs["waf.action"] = ruleSet.Action
+ this.firewallActions = ruleSet.ActionCodes()
}
return !goNext, false
@@ -208,28 +220,79 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
return
}
- goNext, ruleGroup, ruleSet, err := w.MatchResponse(this.RawReq, resp, this.writer)
+ w.OnAction(func(action waf.ActionInterface) (goNext bool) {
+ switch action.Code() {
+ case waf.ActionTag:
+ this.tags = action.(*waf.TagAction).Tags
+ }
+ return true
+ })
+
+ goNext, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer)
if err != nil {
remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
return
}
if ruleSet != nil {
- if ruleSet.Action != waf.ActionAllow {
+ if ruleSet.HasSpecialActions() {
this.firewallPolicyId = firewallPolicy.Id
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
this.firewallRuleSetId = types.Int64(ruleSet.Id)
- if ruleSet.Action == waf.ActionBlock {
+ if ruleSet.HasAttackActions() {
this.isAttack = true
}
// 添加统计
- stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Action)
+ stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions)
}
- this.logAttrs["waf.action"] = ruleSet.Action
+ this.firewallActions = ruleSet.ActionCodes()
}
return !goNext
}
+
+// WAFRaw 原始请求
+func (this *HTTPRequest) WAFRaw() *http.Request {
+ return this.RawReq
+}
+
+// WAFRemoteIP 客户端IP
+func (this *HTTPRequest) WAFRemoteIP() string {
+ return this.requestRemoteAddr()
+}
+
+// WAFGetCacheBody 获取缓存中的Body
+func (this *HTTPRequest) WAFGetCacheBody() []byte {
+ return this.bodyData
+}
+
+// WAFSetCacheBody 设置Body
+func (this *HTTPRequest) WAFSetCacheBody(body []byte) {
+ this.bodyData = body
+}
+
+// WAFReadBody 读取Body
+func (this *HTTPRequest) WAFReadBody(max int64) (data []byte, err error) {
+ if this.RawReq.ContentLength > 0 {
+ data, err = ioutil.ReadAll(io.LimitReader(this.RawReq.Body, max))
+ }
+ return
+}
+
+// WAFRestoreBody 恢复Body
+func (this *HTTPRequest) WAFRestoreBody(data []byte) {
+ if len(data) > 0 {
+ rawReader := bytes.NewBuffer(data)
+ buf := make([]byte, 1024)
+ _, _ = io.CopyBuffer(rawReader, this.RawReq.Body, buf)
+ this.RawReq.Body = ioutil.NopCloser(rawReader)
+ }
+}
+
+// WAFServerId 服务ID
+func (this *HTTPRequest) WAFServerId() int64 {
+ return this.Server.Id
+}
diff --git a/internal/nodes/listener_base.go b/internal/nodes/listener_base.go
index 910cf7e..d06e32f 100644
--- a/internal/nodes/listener_base.go
+++ b/internal/nodes/listener_base.go
@@ -7,7 +7,7 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
- http2 "golang.org/x/net/http2"
+ "golang.org/x/net/http2"
"sync"
)
diff --git a/internal/nodes/traffic_listener.go b/internal/nodes/traffic_listener.go
index e934fbb..67d99d5 100644
--- a/internal/nodes/traffic_listener.go
+++ b/internal/nodes/traffic_listener.go
@@ -2,7 +2,10 @@
package nodes
-import "net"
+import (
+ "github.com/TeaOSLab/EdgeNode/internal/waf"
+ "net"
+)
// TrafficListener 用于统计流量的网络监听
type TrafficListener struct {
@@ -18,6 +21,17 @@ func (this *TrafficListener) Accept() (net.Conn, error) {
if err != nil {
return nil, err
}
+ // 是否在WAF名单中
+ ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
+ if err == nil {
+ if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, ip) && waf.SharedIPBlackLIst.Contains(waf.IPTypeAll, ip) {
+ go func() {
+ _ = conn.Close()
+ }()
+ return conn, nil
+ }
+ }
+
return NewTrafficConn(conn), nil
}
diff --git a/internal/nodes/waf_manager.go b/internal/nodes/waf_manager.go
index 91f542e..ac3ab19 100644
--- a/internal/nodes/waf_manager.go
+++ b/internal/nodes/waf_manager.go
@@ -11,20 +11,20 @@ import (
var sharedWAFManager = NewWAFManager()
-// WAF管理器
+// WAFManager WAF管理器
type WAFManager struct {
mapping map[int64]*waf.WAF // policyId => WAF
locker sync.RWMutex
}
-// 获取新对象
+// NewWAFManager 获取新对象
func NewWAFManager() *WAFManager {
return &WAFManager{
mapping: map[int64]*waf.WAF{},
}
}
-// 更新策略
+// UpdatePolicies 更新策略
func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallPolicy) {
this.locker.Lock()
defer this.locker.Unlock()
@@ -44,7 +44,7 @@ func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallP
this.mapping = m
}
-// 查找WAF
+// FindWAF 查找WAF
func (this *WAFManager) FindWAF(policyId int64) *waf.WAF {
this.locker.RLock()
w, _ := this.mapping[policyId]
@@ -78,14 +78,15 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
// rule sets
for _, set := range group.Sets {
s := &waf.RuleSet{
- Id: strconv.FormatInt(set.Id, 10),
- Code: set.Code,
- IsOn: set.IsOn,
- Name: set.Name,
- Description: set.Description,
- Connector: set.Connector,
- Action: set.Action,
- ActionOptions: set.ActionOptions,
+ Id: strconv.FormatInt(set.Id, 10),
+ Code: set.Code,
+ IsOn: set.IsOn,
+ Name: set.Name,
+ Description: set.Description,
+ Connector: set.Connector,
+ }
+ for _, a := range set.Actions {
+ s.AddAction(a.Code, a.Options)
}
// rules
@@ -132,14 +133,16 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
// rule sets
for _, set := range group.Sets {
s := &waf.RuleSet{
- Id: strconv.FormatInt(set.Id, 10),
- Code: set.Code,
- IsOn: set.IsOn,
- Name: set.Name,
- Description: set.Description,
- Connector: set.Connector,
- Action: set.Action,
- ActionOptions: set.ActionOptions,
+ Id: strconv.FormatInt(set.Id, 10),
+ Code: set.Code,
+ IsOn: set.IsOn,
+ Name: set.Name,
+ Description: set.Description,
+ Connector: set.Connector,
+ }
+
+ for _, a := range set.Actions {
+ s.AddAction(a.Code, a.Options)
}
// rules
@@ -164,10 +167,11 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
// action
if policy.BlockOptions != nil {
- w.ActionBlock = &waf.BlockAction{
+ w.DefaultBlockAction = &waf.BlockAction{
StatusCode: policy.BlockOptions.StatusCode,
Body: policy.BlockOptions.Body,
- URL: "",
+ URL: policy.BlockOptions.URL,
+ Timeout: policy.BlockOptions.Timeout,
}
}
diff --git a/internal/rpc/rpc_client.go b/internal/rpc/rpc_client.go
index 37f1a32..072262f 100644
--- a/internal/rpc/rpc_client.go
+++ b/internal/rpc/rpc_client.go
@@ -113,6 +113,10 @@ func (this *RPCClient) MetricStatRPC() pb.MetricStatServiceClient {
return pb.NewMetricStatServiceClient(this.pickConn())
}
+func (this *RPCClient) FirewallService() pb.FirewallServiceClient {
+ return pb.NewFirewallServiceClient(this.pickConn())
+}
+
// Context 节点上下文信息
func (this *RPCClient) Context() context.Context {
ctx := context.Background()
diff --git a/internal/stats/http_request_stat_manager.go b/internal/stats/http_request_stat_manager.go
index 1fc29db..c30b61e 100644
--- a/internal/stats/http_request_stat_manager.go
+++ b/internal/stats/http_request_stat_manager.go
@@ -8,6 +8,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/monitor"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
+ "github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
@@ -132,17 +133,19 @@ func (this *HTTPRequestStatManager) AddUserAgent(serverId int64, userAgent strin
}
// AddFirewallRuleGroupId 添加防火墙拦截动作
-func (this *HTTPRequestStatManager) AddFirewallRuleGroupId(serverId int64, firewallRuleGroupId int64, action string) {
+func (this *HTTPRequestStatManager) AddFirewallRuleGroupId(serverId int64, firewallRuleGroupId int64, actions []*waf.ActionConfig) {
if firewallRuleGroupId <= 0 {
return
}
- this.totalAttackRequests ++
+ this.totalAttackRequests++
- select {
- case this.firewallRuleGroupChan <- strconv.FormatInt(serverId, 10) + "@" + strconv.FormatInt(firewallRuleGroupId, 10) + "@" + action:
- default:
- // 超出容量我们就丢弃
+ for _, action := range actions {
+ select {
+ case this.firewallRuleGroupChan <- strconv.FormatInt(serverId, 10) + "@" + strconv.FormatInt(firewallRuleGroupId, 10) + "@" + action.Code:
+ default:
+ // 超出容量我们就丢弃
+ }
}
}
diff --git a/internal/utils/encrypt.go b/internal/utils/encrypt.go
new file mode 100644
index 0000000..2bf38e6
--- /dev/null
+++ b/internal/utils/encrypt.go
@@ -0,0 +1,159 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package utils
+
+import (
+ "bytes"
+ "crypto/aes"
+ "crypto/cipher"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
+ "github.com/TeaOSLab/EdgeNode/internal/events"
+ "github.com/iwind/TeaGo/logs"
+ "github.com/iwind/TeaGo/maps"
+ "github.com/iwind/TeaGo/rands"
+ stringutil "github.com/iwind/TeaGo/utils/string"
+)
+
+var (
+ simpleEncryptMagicKey = rands.HexString(32)
+)
+
+func init() {
+ events.On(events.EventReload, func() {
+ nodeConfig, _ := nodeconfigs.SharedNodeConfig()
+ if nodeConfig != nil {
+ simpleEncryptMagicKey = stringutil.Md5(nodeConfig.NodeId + "@" + nodeConfig.Secret)
+ }
+ })
+}
+
+// SimpleEncrypt 加密特殊信息
+func SimpleEncrypt(data []byte) []byte {
+ var method = &AES256CFBMethod{}
+ err := method.Init([]byte(simpleEncryptMagicKey), []byte(simpleEncryptMagicKey[:16]))
+ if err != nil {
+ logs.Println("[SimpleEncrypt]" + err.Error())
+ return data
+ }
+
+ dst, err := method.Encrypt(data)
+ if err != nil {
+ logs.Println("[SimpleEncrypt]" + err.Error())
+ return data
+ }
+ return dst
+}
+
+// SimpleDecrypt 解密特殊信息
+func SimpleDecrypt(data []byte) []byte {
+ var method = &AES256CFBMethod{}
+ err := method.Init([]byte(simpleEncryptMagicKey), []byte(simpleEncryptMagicKey[:16]))
+ if err != nil {
+ logs.Println("[MagicKeyEncode]" + err.Error())
+ return data
+ }
+
+ src, err := method.Decrypt(data)
+ if err != nil {
+ logs.Println("[MagicKeyEncode]" + err.Error())
+ return data
+ }
+ return src
+}
+
+func SimpleEncryptMap(m maps.Map) (base64String string, err error) {
+ mJSON, err := json.Marshal(m)
+ if err != nil {
+ return "", err
+ }
+ data := SimpleEncrypt(mJSON)
+ return base64.StdEncoding.EncodeToString(data), nil
+}
+
+func SimpleDecryptMap(base64String string) (maps.Map, error) {
+ data, err := base64.StdEncoding.DecodeString(base64String)
+ if err != nil {
+ return nil, err
+ }
+ mJSON := SimpleDecrypt(data)
+ var result = maps.Map{}
+ err = json.Unmarshal(mJSON, &result)
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+type AES256CFBMethod struct {
+ block cipher.Block
+ iv []byte
+}
+
+func (this *AES256CFBMethod) Init(key, iv []byte) error {
+ // 判断key是否为32长度
+ l := len(key)
+ if l > 32 {
+ key = key[:32]
+ } else if l < 32 {
+ key = append(key, bytes.Repeat([]byte{' '}, 32-l)...)
+ }
+
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ return err
+ }
+ this.block = block
+
+ // 判断iv长度
+ l2 := len(iv)
+ if l2 > aes.BlockSize {
+ iv = iv[:aes.BlockSize]
+ } else if l2 < aes.BlockSize {
+ iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
+ }
+ this.iv = iv
+
+ return nil
+}
+
+func (this *AES256CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
+ if len(src) == 0 {
+ return
+ }
+
+ defer func() {
+ r := recover()
+ if r != nil {
+ err = errors.New("encrypt failed")
+ }
+ }()
+
+ dst = make([]byte, len(src))
+
+ encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
+ encrypter.XORKeyStream(dst, src)
+
+ return
+}
+
+func (this *AES256CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
+ if len(dst) == 0 {
+ return
+ }
+
+ defer func() {
+ r := recover()
+ if r != nil {
+ err = errors.New("decrypt failed")
+ }
+ }()
+
+ src = make([]byte, len(dst))
+ decrypter := cipher.NewCFBDecrypter(this.block, this.iv)
+ decrypter.XORKeyStream(src, dst)
+
+ return
+}
diff --git a/internal/utils/encrypt_test.go b/internal/utils/encrypt_test.go
new file mode 100644
index 0000000..3a5d411
--- /dev/null
+++ b/internal/utils/encrypt_test.go
@@ -0,0 +1,52 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package utils
+
+import (
+ "github.com/iwind/TeaGo/maps"
+ "sync"
+ "testing"
+)
+
+func TestSimpleEncrypt(t *testing.T) {
+ var arr = []string{"Hello", "World", "People"}
+ for _, s := range arr {
+ var value = []byte(s)
+ encoded := SimpleEncrypt(value)
+ t.Log(encoded, string(encoded))
+ decoded := SimpleDecrypt(encoded)
+ t.Log(decoded, string(decoded))
+ }
+}
+
+func TestSimpleEncrypt_Concurrent(t *testing.T) {
+ wg := sync.WaitGroup{}
+ var arr = []string{"Hello", "World", "People"}
+ wg.Add(len(arr))
+ for _, s := range arr {
+ go func(s string) {
+ defer wg.Done()
+ t.Log(string(SimpleDecrypt(SimpleEncrypt([]byte(s)))))
+ }(s)
+ }
+ wg.Wait()
+}
+
+func TestSimpleEncryptMap(t *testing.T) {
+ var m = maps.Map{
+ "s": "Hello",
+ "i": 20,
+ "b": true,
+ }
+ encodedResult, err := SimpleEncryptMap(m)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log("result:", encodedResult)
+
+ decodedResult, err := SimpleDecryptMap(encodedResult)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Log(decodedResult)
+}
diff --git a/internal/utils/expires/list.go b/internal/utils/expires/list.go
index 2e83fe0..98fb485 100644
--- a/internal/utils/expires/list.go
+++ b/internal/utils/expires/list.go
@@ -12,6 +12,7 @@ type List struct {
itemsMap map[int64]int64 // itemId => timestamp
locker sync.Mutex
+ ticker *time.Ticker
}
func NewList() *List {
@@ -21,10 +22,7 @@ func NewList() *List {
}
}
-func (this *List) Add(itemId int64, expiredAt int64) {
- if expiredAt <= time.Now().Unix() {
- return
- }
+func (this *List) Add(itemId int64, expiresAt int64) {
this.locker.Lock()
defer this.locker.Unlock()
@@ -34,17 +32,17 @@ func (this *List) Add(itemId int64, expiredAt int64) {
this.removeItem(itemId)
}
- expireItemMap, ok := this.expireMap[expiredAt]
+ expireItemMap, ok := this.expireMap[expiresAt]
if ok {
expireItemMap[itemId] = true
} else {
expireItemMap = ItemMap{
itemId: true,
}
- this.expireMap[expiredAt] = expireItemMap
+ this.expireMap[expiresAt] = expireItemMap
}
- this.itemsMap[itemId] = expiredAt
+ this.itemsMap[itemId] = expiresAt
}
func (this *List) Remove(itemId int64) {
@@ -64,21 +62,22 @@ func (this *List) GC(timestamp int64, callback func(itemId int64)) {
}
func (this *List) StartGC(callback func(itemId int64)) {
- ticker := time.NewTicker(1 * time.Second)
+ this.ticker = time.NewTicker(1 * time.Second)
lastTimestamp := int64(0)
- for range ticker.C {
+ for range this.ticker.C {
timestamp := time.Now().Unix()
if lastTimestamp == 0 {
lastTimestamp = timestamp - 3600
}
- // 防止死循环
- if lastTimestamp > timestamp {
- continue
- }
-
- for i := lastTimestamp; i <= timestamp; i++ {
- this.GC(timestamp, callback)
+ if timestamp >= lastTimestamp {
+ for i := lastTimestamp; i <= timestamp; i++ {
+ this.GC(i, callback)
+ }
+ } else {
+ for i := timestamp; i <= lastTimestamp; i++ {
+ this.GC(i, callback)
+ }
}
// 这样做是为了防止系统时钟突变
diff --git a/internal/utils/expires/list_test.go b/internal/utils/expires/list_test.go
index c4b06d3..bca42e9 100644
--- a/internal/utils/expires/list_test.go
+++ b/internal/utils/expires/list_test.go
@@ -58,6 +58,10 @@ func TestList_Start_GC(t *testing.T) {
list.Add(2, time.Now().Unix()+1)
list.Add(3, time.Now().Unix()+2)
list.Add(4, time.Now().Unix()+5)
+ list.Add(5, time.Now().Unix()+5)
+ list.Add(6, time.Now().Unix()+6)
+ list.Add(7, time.Now().Unix()+6)
+ list.Add(8, time.Now().Unix()+6)
go func() {
list.StartGC(func(itemId int64) {
@@ -66,7 +70,7 @@ func TestList_Start_GC(t *testing.T) {
})
}()
- time.Sleep(10 * time.Second)
+ time.Sleep(20 * time.Second)
}
func TestList_ManyItems(t *testing.T) {
diff --git a/internal/utils/jsonutils/map.go b/internal/utils/jsonutils/map.go
new file mode 100644
index 0000000..4986f3e
--- /dev/null
+++ b/internal/utils/jsonutils/map.go
@@ -0,0 +1,35 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package jsonutils
+
+import (
+ "encoding/json"
+ "github.com/iwind/TeaGo/maps"
+)
+
+func MapToObject(m maps.Map, ptr interface{}) error {
+ if m == nil {
+ return nil
+ }
+ mJSON, err := json.Marshal(m)
+ if err != nil {
+ return err
+ }
+ return json.Unmarshal(mJSON, ptr)
+}
+
+func ObjectToMap(ptr interface{}) (maps.Map, error) {
+ if ptr == nil {
+ return maps.Map{}, nil
+ }
+ ptrJSON, err := json.Marshal(ptr)
+ if err != nil {
+ return nil, err
+ }
+ var result = maps.Map{}
+ err = json.Unmarshal(ptrJSON, &result)
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
diff --git a/internal/utils/jsonutils/map_test.go b/internal/utils/jsonutils/map_test.go
new file mode 100644
index 0000000..6bccfcf
--- /dev/null
+++ b/internal/utils/jsonutils/map_test.go
@@ -0,0 +1,46 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package jsonutils
+
+import (
+ "github.com/iwind/TeaGo/assert"
+ "github.com/iwind/TeaGo/maps"
+ "testing"
+)
+
+func TestMapToObject(t *testing.T) {
+ a := assert.NewAssertion(t)
+
+ type typeA struct {
+ B int `json:"b"`
+ C bool `json:"c"`
+ }
+
+ {
+ var obj = &typeA{B: 1, C: true}
+ m, err := ObjectToMap(obj)
+ if err != nil {
+ t.Fatal(err)
+ }
+ PrintT(m, t)
+ a.IsTrue(m.GetInt("b") == 1)
+ a.IsTrue(m.GetBool("c") == true)
+ }
+
+ {
+ var obj = &typeA{}
+ err := MapToObject(maps.Map{
+ "b": 1024,
+ "c": true,
+ }, obj)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if obj == nil {
+ t.Fatal("obj should not be nil")
+ }
+ a.IsTrue(obj.B == 1024)
+ a.IsTrue(obj.C == true)
+ PrintT(obj, t)
+ }
+}
diff --git a/internal/utils/jsonutils/utils.go b/internal/utils/jsonutils/utils.go
new file mode 100644
index 0000000..5c37dfd
--- /dev/null
+++ b/internal/utils/jsonutils/utils.go
@@ -0,0 +1,17 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package jsonutils
+
+import (
+ "encoding/json"
+ "testing"
+)
+
+func PrintT(obj interface{}, t *testing.T) {
+ data, err := json.MarshalIndent(obj, "", " ")
+ if err != nil {
+ t.Log(err)
+ } else {
+ t.Log(string(data))
+ }
+}
diff --git a/internal/waf/action_allow.go b/internal/waf/action_allow.go
index 35421ca..ea3b3a4 100644
--- a/internal/waf/action_allow.go
+++ b/internal/waf/action_allow.go
@@ -8,7 +8,23 @@ import (
type AllowAction struct {
}
-func (this *AllowAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
+func (this *AllowAction) Init(waf *WAF) error {
+ return nil
+}
+
+func (this *AllowAction) Code() string {
+ return ActionAllow
+}
+
+func (this *AllowAction) IsAttack() bool {
+ return false
+}
+
+func (this *AllowAction) WillChange() bool {
+ return false
+}
+
+func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
// do nothing
return true
}
diff --git a/internal/waf/action_base.go b/internal/waf/action_base.go
new file mode 100644
index 0000000..e0e6bec
--- /dev/null
+++ b/internal/waf/action_base.go
@@ -0,0 +1,21 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package waf
+
+import "net/http"
+
+type BaseAction struct {
+}
+
+// CloseConn 关闭连接
+func (this *BaseAction) CloseConn(writer http.ResponseWriter) error {
+ // 断开连接
+ hijack, ok := writer.(http.Hijacker)
+ if ok {
+ conn, _, err := hijack.Hijack()
+ if err == nil {
+ return conn.Close()
+ }
+ }
+ return nil
+}
diff --git a/internal/waf/action_block.go b/internal/waf/action_block.go
index e6ba70f..4b91fe4 100644
--- a/internal/waf/action_block.go
+++ b/internal/waf/action_block.go
@@ -23,12 +23,48 @@ type BlockAction struct {
StatusCode int `yaml:"statusCode" json:"statusCode"`
Body string `yaml:"body" json:"body"` // supports HTML
URL string `yaml:"url" json:"url"`
+ Timeout int32 `yaml:"timeout" json:"timeout"`
}
-func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
+func (this *BlockAction) Init(waf *WAF) error {
+ if waf.DefaultBlockAction != nil {
+ if this.StatusCode <= 0 {
+ this.StatusCode = waf.DefaultBlockAction.StatusCode
+ }
+ if len(this.Body) == 0 {
+ this.Body = waf.DefaultBlockAction.Body
+ }
+ if len(this.URL) == 0 {
+ this.URL = waf.DefaultBlockAction.URL
+ }
+ if this.Timeout <= 0 {
+ this.Timeout = waf.DefaultBlockAction.Timeout
+ }
+ }
+ return nil
+}
+
+func (this *BlockAction) Code() string {
+ return ActionBlock
+}
+
+func (this *BlockAction) IsAttack() bool {
+ return true
+}
+
+func (this *BlockAction) WillChange() bool {
+ return true
+}
+
+func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
+ if this.Timeout > 0 {
+ // 加入到黑名单
+ SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), time.Now().Unix()+int64(this.Timeout))
+ }
+
if writer != nil {
- // if status code eq 444, we close the connection
- if this.StatusCode == 444 {
+ // close the connection
+ defer func() {
hijack, ok := writer.(http.Hijacker)
if ok {
conn, _, _ := hijack.Hijack()
@@ -37,7 +73,7 @@ func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer htt
return
}
}
- }
+ }()
// output response
if this.StatusCode > 0 {
diff --git a/internal/waf/action_captcha.go b/internal/waf/action_captcha.go
index 120869d..4db432e 100644
--- a/internal/waf/action_captcha.go
+++ b/internal/waf/action_captcha.go
@@ -1,11 +1,14 @@
package waf
import (
+ "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
+ "github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
- "github.com/iwind/TeaGo/types"
+ "github.com/iwind/TeaGo/maps"
stringutil "github.com/iwind/TeaGo/utils/string"
"net/http"
"net/url"
+ "strings"
"time"
)
@@ -13,27 +16,63 @@ var captchaSalt = stringutil.Rand(32)
const (
CaptchaSeconds = 600 // 10 minutes
+ CaptchaPath = "/WAF/VERIFY/CAPTCHA"
)
type CaptchaAction struct {
+ Life int32 `yaml:"life" json:"life"`
+ Language string `yaml:"language" json:"language"` // 语言,zh-CN, en-US ...
+ AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单
}
-func (this *CaptchaAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
- // TEAWEB_CAPTCHA:
- cookie, err := request.Cookie("TEAWEB_WAF_CAPTCHA")
- if err == nil && cookie != nil && len(cookie.Value) > 32 {
- m := cookie.Value[:32]
- timestamp := cookie.Value[32:]
- if stringutil.Md5(captchaSalt+timestamp) == m && time.Now().Unix() < types.Int64(timestamp) { // verify md5
- return true
+func (this *CaptchaAction) Init(waf *WAF) error {
+ return nil
+}
+
+func (this *CaptchaAction) Code() string {
+ return ActionCaptcha
+}
+
+func (this *CaptchaAction) IsAttack() bool {
+ return false
+}
+
+func (this *CaptchaAction) WillChange() bool {
+ return true
+}
+
+func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
+ // 是否在白名单中
+ if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
+ return true
+ }
+
+ refURL := request.WAFRaw().URL.String()
+
+ // 覆盖配置
+ if strings.HasPrefix(refURL, CaptchaPath) {
+ info := request.WAFRaw().URL.Query().Get("info")
+ if len(info) > 0 {
+ m, err := utils.SimpleDecryptMap(info)
+ if err == nil && m != nil {
+ refURL = m.GetString("url")
+ }
}
}
- refURL := request.URL.String()
- if len(request.Referer()) > 0 {
- refURL = request.Referer()
+ var captchaConfig = maps.Map{
+ "action": this,
+ "timestamp": time.Now().Unix(),
+ "url": refURL,
+ "setId": set.Id,
}
- http.Redirect(writer, request.Raw(), "/WAFCAPTCHA?url="+url.QueryEscape(refURL), http.StatusTemporaryRedirect)
+ info, err := utils.SimpleEncryptMap(captchaConfig)
+ if err != nil {
+ remotelogs.Error("WAF_CAPTCHA_ACTION", "encode captcha config failed: "+err.Error())
+ return true
+ }
+
+ http.Redirect(writer, request.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
return false
}
diff --git a/internal/waf/action_category.go b/internal/waf/action_category.go
new file mode 100644
index 0000000..f5bf9c2
--- /dev/null
+++ b/internal/waf/action_category.go
@@ -0,0 +1,13 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package waf
+
+import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
+
+type ActionCategory = string
+
+const (
+ ActionCategoryAllow ActionCategory = firewallconfigs.HTTPFirewallActionCategoryAllow
+ ActionCategoryBlock ActionCategory = firewallconfigs.HTTPFirewallActionCategoryBlock
+ ActionCategoryVerify ActionCategory = firewallconfigs.HTTPFirewallActionCategoryVerify
+)
diff --git a/internal/waf/action_config.go b/internal/waf/action_config.go
new file mode 100644
index 0000000..5cae9cf
--- /dev/null
+++ b/internal/waf/action_config.go
@@ -0,0 +1,10 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package waf
+
+import "github.com/iwind/TeaGo/maps"
+
+type ActionConfig struct {
+ Code string `yaml:"code" json:"code"`
+ Options maps.Map `yaml:"options" json:"options"`
+}
diff --git a/internal/waf/action_definition.go b/internal/waf/action_definition.go
index e268742..aff5bc4 100644
--- a/internal/waf/action_definition.go
+++ b/internal/waf/action_definition.go
@@ -2,11 +2,12 @@ package waf
import "reflect"
-// action definition
+// ActionDefinition action definition
type ActionDefinition struct {
Name string
Code ActionString
Description string
+ Category string // category: block, verify, allow
Instance ActionInterface
Type reflect.Type
}
diff --git a/internal/waf/action_get_302.go b/internal/waf/action_get_302.go
new file mode 100644
index 0000000..304d310
--- /dev/null
+++ b/internal/waf/action_get_302.go
@@ -0,0 +1,71 @@
+package waf
+
+import (
+ "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
+ "github.com/TeaOSLab/EdgeNode/internal/utils"
+ "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
+ "github.com/iwind/TeaGo/maps"
+ "net/http"
+ "net/url"
+ "time"
+)
+
+const (
+ Get302Path = "/WAF/VERIFY/GET"
+)
+
+// Get302Action
+// 原理: origin url --> 302 verify url --> origin url
+// TODO 将来支持meta refresh验证
+type Get302Action struct {
+ BaseAction
+
+ Life int32 `yaml:"life" json:"life"`
+}
+
+func (this *Get302Action) Init(waf *WAF) error {
+ return nil
+}
+
+func (this *Get302Action) Code() string {
+ return ActionGet302
+}
+
+func (this *Get302Action) IsAttack() bool {
+ return false
+}
+
+func (this *Get302Action) WillChange() bool {
+ return true
+}
+
+func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
+ // 仅限于Get
+ if request.WAFRaw().Method != http.MethodGet {
+ return true
+ }
+
+ // 是否已经在白名单中
+ if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
+ return true
+ }
+
+ var m = maps.Map{
+ "url": request.WAFRaw().URL.String(),
+ "timestamp": time.Now().Unix(),
+ "life": this.Life,
+ "setId": set.Id,
+ }
+ info, err := utils.SimpleEncryptMap(m)
+ if err != nil {
+ remotelogs.Error("WAF_GET_302_ACTION", "encode info failed: "+err.Error())
+ return true
+ }
+
+ http.Redirect(writer, request.WAFRaw(), Get302Path+"?info="+url.QueryEscape(info), http.StatusFound)
+
+ // 关闭连接
+ _ = this.CloseConn(writer)
+
+ return true
+}
diff --git a/internal/waf/action_go_group.go b/internal/waf/action_go_group.go
index 446bd0a..85f2f64 100644
--- a/internal/waf/action_go_group.go
+++ b/internal/waf/action_go_group.go
@@ -10,13 +10,29 @@ type GoGroupAction struct {
GroupId string `yaml:"groupId" json:"groupId"`
}
-func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
- group := waf.FindRuleGroup(this.GroupId)
- if group == nil || !group.IsOn {
+func (this *GoGroupAction) Init(waf *WAF) error {
+ return nil
+}
+
+func (this *GoGroupAction) Code() string {
+ return ActionGoGroup
+}
+
+func (this *GoGroupAction) IsAttack() bool {
+ return false
+}
+
+func (this *GoGroupAction) WillChange() bool {
+ return true
+}
+
+func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
+ nextGroup := waf.FindRuleGroup(this.GroupId)
+ if nextGroup == nil || !nextGroup.IsOn {
return true
}
- b, set, err := group.MatchRequest(request)
+ b, nextSet, err := nextGroup.MatchRequest(request)
if err != nil {
logs.Error(err)
return true
@@ -26,9 +42,5 @@ func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer h
return true
}
- actionObject := FindActionInstance(set.Action, set.ActionOptions)
- if actionObject == nil {
- return true
- }
- return actionObject.Perform(waf, request, writer)
+ return nextSet.PerformActions(waf, nextGroup, request, writer)
}
diff --git a/internal/waf/action_go_set.go b/internal/waf/action_go_set.go
index ad8b049..eadfd03 100644
--- a/internal/waf/action_go_set.go
+++ b/internal/waf/action_go_set.go
@@ -11,17 +11,33 @@ type GoSetAction struct {
SetId string `yaml:"setId" json:"setId"`
}
-func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
- group := waf.FindRuleGroup(this.GroupId)
- if group == nil || !group.IsOn {
+func (this *GoSetAction) Init(waf *WAF) error {
+ return nil
+}
+
+func (this *GoSetAction) Code() string {
+ return ActionGoSet
+}
+
+func (this *GoSetAction) IsAttack() bool {
+ return false
+}
+
+func (this *GoSetAction) WillChange() bool {
+ return true
+}
+
+func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
+ nextGroup := waf.FindRuleGroup(this.GroupId)
+ if nextGroup == nil || !nextGroup.IsOn {
return true
}
- set := group.FindRuleSet(this.SetId)
- if set == nil || !set.IsOn {
+ nextSet := nextGroup.FindRuleSet(this.SetId)
+ if nextSet == nil || !nextSet.IsOn {
return true
}
- b, err := set.MatchRequest(request)
+ b, err := nextSet.MatchRequest(request)
if err != nil {
logs.Error(err)
return true
@@ -29,9 +45,5 @@ func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer htt
if !b {
return true
}
- actionObject := FindActionInstance(set.Action, set.ActionOptions)
- if actionObject == nil {
- return true
- }
- return actionObject.Perform(waf, request, writer)
+ return nextSet.PerformActions(waf, nextGroup, request, writer)
}
diff --git a/internal/waf/action_interface.go b/internal/waf/action_interface.go
new file mode 100644
index 0000000..256b58e
--- /dev/null
+++ b/internal/waf/action_interface.go
@@ -0,0 +1,25 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package waf
+
+import (
+ "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
+ "net/http"
+)
+
+type ActionInterface interface {
+ // Init 初始化
+ Init(waf *WAF) error
+
+ // Code 代号
+ Code() string
+
+ // IsAttack 是否为拦截攻击动作
+ IsAttack() bool
+
+ // WillChange determine if the action will change the request
+ WillChange() bool
+
+ // Perform perform the action
+ Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool)
+}
diff --git a/internal/waf/action_log.go b/internal/waf/action_log.go
index 8b8efcd..74c85ac 100644
--- a/internal/waf/action_log.go
+++ b/internal/waf/action_log.go
@@ -8,6 +8,22 @@ import (
type LogAction struct {
}
-func (this *LogAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
+func (this *LogAction) Init(waf *WAF) error {
+ return nil
+}
+
+func (this *LogAction) Code() string {
+ return ActionLog
+}
+
+func (this *LogAction) IsAttack() bool {
+ return false
+}
+
+func (this *LogAction) WillChange() bool {
+ return false
+}
+
+func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
return true
}
diff --git a/internal/waf/action_notify.go b/internal/waf/action_notify.go
new file mode 100644
index 0000000..1df7d4e
--- /dev/null
+++ b/internal/waf/action_notify.go
@@ -0,0 +1,86 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package waf
+
+import (
+ "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
+ "github.com/TeaOSLab/EdgeNode/internal/events"
+ "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
+ "github.com/TeaOSLab/EdgeNode/internal/rpc"
+ "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
+ "github.com/iwind/TeaGo/types"
+ "net/http"
+ "time"
+)
+
+type notifyTask struct {
+ ServerId int64
+ HttpFirewallPolicyId int64
+ HttpFirewallRuleGroupId int64
+ HttpFirewallRuleSetId int64
+ CreatedAt int64
+}
+
+var notifyChan = make(chan *notifyTask, 128)
+
+func init() {
+ events.On(events.EventLoaded, func() {
+ go func() {
+ rpcClient, err := rpc.SharedRPC()
+ if err != nil {
+ remotelogs.Error("WAF_NOTIFY_ACTION", "create rpc client failed: "+err.Error())
+ return
+ }
+
+ for task := range notifyChan {
+ _, err = rpcClient.FirewallService().NotifyHTTPFirewallEvent(rpcClient.Context(), &pb.NotifyHTTPFirewallEventRequest{
+ ServerId: task.ServerId,
+ HttpFirewallPolicyId: task.HttpFirewallPolicyId,
+ HttpFirewallRuleGroupId: task.HttpFirewallRuleGroupId,
+ HttpFirewallRuleSetId: task.HttpFirewallRuleSetId,
+ CreatedAt: task.CreatedAt,
+ })
+ if err != nil {
+ remotelogs.Error("WAF_NOTIFY_ACTION", "notify failed: "+err.Error())
+ }
+ }
+ }()
+ })
+}
+
+type NotifyAction struct {
+}
+
+func (this *NotifyAction) Init(waf *WAF) error {
+ return nil
+}
+
+func (this *NotifyAction) Code() string {
+ return ActionNotify
+}
+
+func (this *NotifyAction) IsAttack() bool {
+ return false
+}
+
+// WillChange determine if the action will change the request
+func (this *NotifyAction) WillChange() bool {
+ return false
+}
+
+// Perform perform the action
+func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
+ select {
+ case notifyChan <- ¬ifyTask{
+ ServerId: request.WAFServerId(),
+ HttpFirewallPolicyId: types.Int64(waf.Id),
+ HttpFirewallRuleGroupId: types.Int64(group.Id),
+ HttpFirewallRuleSetId: types.Int64(set.Id),
+ CreatedAt: time.Now().Unix(),
+ }:
+ default:
+
+ }
+
+ return true
+}
diff --git a/internal/waf/action_post_307.go b/internal/waf/action_post_307.go
new file mode 100644
index 0000000..22a4dc1
--- /dev/null
+++ b/internal/waf/action_post_307.go
@@ -0,0 +1,88 @@
+package waf
+
+import (
+ "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
+ "github.com/TeaOSLab/EdgeNode/internal/utils"
+ "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
+ "github.com/iwind/TeaGo/maps"
+ "net/http"
+ "time"
+)
+
+type Post307Action struct {
+ Life int32 `yaml:"life" json:"life"`
+
+ BaseAction
+}
+
+func (this *Post307Action) Init(waf *WAF) error {
+ return nil
+}
+
+func (this *Post307Action) Code() string {
+ return ActionPost307
+}
+
+func (this *Post307Action) IsAttack() bool {
+ return false
+}
+
+func (this *Post307Action) WillChange() bool {
+ return true
+}
+
+func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
+ var cookieName = "WAF_VALIDATOR_ID"
+
+ // 仅限于POST
+ if request.WAFRaw().Method != http.MethodPost {
+ return true
+ }
+
+ // 是否已经在白名单中
+ if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
+ return true
+ }
+
+ // 判断是否有Cookie
+ cookie, err := request.WAFRaw().Cookie(cookieName)
+ if err == nil && cookie != nil {
+ m, err := utils.SimpleDecryptMap(cookie.Value)
+ if err == nil && m.GetString("remoteIP") == request.WAFRemoteIP() && time.Now().Unix() < m.GetInt64("timestamp")+10 {
+ var life = m.GetInt64("life")
+ if life <= 0 {
+ life = 600 // 默认10分钟
+ }
+ var setId = m.GetString("setId")
+ SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life)
+ return true
+ }
+ }
+
+ var m = maps.Map{
+ "timestamp": time.Now().Unix(),
+ "life": this.Life,
+ "setId": set.Id,
+ "remoteIP": request.WAFRemoteIP(),
+ }
+ info, err := utils.SimpleEncryptMap(m)
+ if err != nil {
+ remotelogs.Error("WAF_POST_302_ACTION", "encode info failed: "+err.Error())
+ return true
+ }
+
+ // 设置Cookie
+ http.SetCookie(writer, &http.Cookie{
+ Name: cookieName,
+ Path: "/",
+ MaxAge: 10,
+ Value: info,
+ })
+
+ http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect)
+
+ // 关闭连接
+ _ = this.CloseConn(writer)
+
+ return true
+}
diff --git a/internal/waf/action_record_ip.go b/internal/waf/action_record_ip.go
new file mode 100644
index 0000000..8a34906
--- /dev/null
+++ b/internal/waf/action_record_ip.go
@@ -0,0 +1,120 @@
+package waf
+
+import (
+ "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
+ "github.com/TeaOSLab/EdgeNode/internal/events"
+ "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
+ "github.com/TeaOSLab/EdgeNode/internal/rpc"
+ "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
+ "net/http"
+ "strings"
+ "time"
+)
+
+type recordIPTask struct {
+ ip string
+ listId int64
+ expiredAt int64
+ level string
+}
+
+var recordIPTaskChan = make(chan *recordIPTask, 1024)
+
+func init() {
+ events.On(events.EventLoaded, func() {
+ go func() {
+ rpcClient, err := rpc.SharedRPC()
+ if err != nil {
+ remotelogs.Error("WAF_RECORD_IP_ACTION", "create rpc client failed: "+err.Error())
+ return
+ }
+
+ for task := range recordIPTaskChan {
+ ipType := "ipv4"
+ if strings.Contains(task.ip, ":") {
+ ipType = "ipv6"
+ }
+ _, err = rpcClient.IPItemRPC().CreateIPItem(rpcClient.Context(), &pb.CreateIPItemRequest{
+ IpListId: task.listId,
+ IpFrom: task.ip,
+ IpTo: "",
+ ExpiredAt: task.expiredAt,
+ Reason: "触发WAF规则自动加入",
+ Type: ipType,
+ EventLevel: task.level,
+ })
+ if err != nil {
+ remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
+ }
+ }
+ }()
+ })
+}
+
+type RecordIPAction struct {
+ BaseAction
+
+ Type string `yaml:"type" json:"type"`
+ IPListId int64 `yaml:"ipListId" json:"ipListId"`
+ Level string `yaml:"level" json:"level"`
+ Timeout int32 `yaml:"timeout" json:"timeout"`
+}
+
+func (this *RecordIPAction) Init(waf *WAF) error {
+ return nil
+}
+
+func (this *RecordIPAction) Code() string {
+ return ActionRecordIP
+}
+
+func (this *RecordIPAction) IsAttack() bool {
+ return this.Type == "black"
+}
+
+func (this *RecordIPAction) WillChange() bool {
+ return this.Type == "black"
+}
+
+func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
+ // 是否在本地白名单中
+ if SharedIPWhiteList.Contains("set:"+set.Id, set.Id) {
+ return true
+ }
+
+ // 先加入本地的黑名单
+ timeout := this.Timeout
+ if timeout <= 0 {
+ timeout = 86400 // 1天
+ }
+ expiredAt := time.Now().Unix() + int64(timeout)
+
+ if this.Type == "black" {
+ _ = this.CloseConn(writer)
+
+ SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), expiredAt)
+ } else {
+ // 加入本地白名单
+ timeout := this.Timeout
+ if timeout <= 0 {
+ timeout = 86400 // 1天
+ }
+ SharedIPWhiteList.Add("set:"+set.Id, request.WAFRemoteIP(), expiredAt)
+ }
+
+ // 上报
+ if this.IPListId > 0 {
+ select {
+ case recordIPTaskChan <- &recordIPTask{
+ ip: request.WAFRemoteIP(),
+ listId: this.IPListId,
+ expiredAt: expiredAt,
+ level: this.Level,
+ }:
+ default:
+
+ }
+ }
+
+ return this.Type != "black"
+}
diff --git a/internal/waf/action_tag.go b/internal/waf/action_tag.go
new file mode 100644
index 0000000..b39794f
--- /dev/null
+++ b/internal/waf/action_tag.go
@@ -0,0 +1,30 @@
+package waf
+
+import (
+ "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
+ "net/http"
+)
+
+type TagAction struct {
+ Tags []string `yaml:"tags" json:"tags"`
+}
+
+func (this *TagAction) Init(waf *WAF) error {
+ return nil
+}
+
+func (this *TagAction) Code() string {
+ return ActionTag
+}
+
+func (this *TagAction) IsAttack() bool {
+ return false
+}
+
+func (this *TagAction) WillChange() bool {
+ return false
+}
+
+func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
+ return true
+}
diff --git a/internal/waf/action_type.go b/internal/waf/action_type.go
deleted file mode 100644
index 221226d..0000000
--- a/internal/waf/action_type.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package waf
-
-import (
- "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
- "net/http"
-)
-
-type ActionString = string
-
-const (
- ActionLog = "log" // allow and log
- ActionBlock = "block" // block
- ActionCaptcha = "captcha" // block and show captcha
- ActionAllow = "allow" // allow
- ActionGoGroup = "go_group" // go to next rule group
- ActionGoSet = "go_set" // go to next rule set
-)
-
-type ActionInterface interface {
- Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool)
-}
diff --git a/internal/waf/action_types.go b/internal/waf/action_types.go
new file mode 100644
index 0000000..3a9d512
--- /dev/null
+++ b/internal/waf/action_types.go
@@ -0,0 +1,88 @@
+package waf
+
+import "reflect"
+
+type ActionString = string
+
+const (
+ ActionLog ActionString = "log" // allow and log
+ ActionBlock ActionString = "block" // block
+ ActionCaptcha ActionString = "captcha" // block and show captcha
+ ActionNotify ActionString = "notify" // 告警
+ ActionGet302 ActionString = "get_302" // 针对GET的302重定向认证
+ ActionPost307 ActionString = "post_307" // 针对POST的307重定向认证
+ ActionRecordIP ActionString = "record_ip" // 记录IP
+ ActionTag ActionString = "tag" // 标签
+ ActionAllow ActionString = "allow" // allow
+ ActionGoGroup ActionString = "go_group" // go to next rule group
+ ActionGoSet ActionString = "go_set" // go to next rule set
+)
+
+var AllActions = []*ActionDefinition{
+ {
+ Name: "阻止",
+ Code: ActionBlock,
+ Instance: new(BlockAction),
+ Type: reflect.TypeOf(new(BlockAction)).Elem(),
+ },
+ {
+ Name: "允许通过",
+ Code: ActionAllow,
+ Instance: new(AllowAction),
+ Type: reflect.TypeOf(new(AllowAction)).Elem(),
+ },
+ {
+ Name: "允许并记录日志",
+ Code: ActionLog,
+ Instance: new(LogAction),
+ Type: reflect.TypeOf(new(LogAction)).Elem(),
+ },
+ {
+ Name: "Captcha验证码",
+ Code: ActionCaptcha,
+ Instance: new(CaptchaAction),
+ Type: reflect.TypeOf(new(CaptchaAction)).Elem(),
+ },
+ {
+ Name: "告警",
+ Code: ActionNotify,
+ Instance: new(NotifyAction),
+ Type: reflect.TypeOf(new(NotifyAction)).Elem(),
+ },
+ {
+ Name: "GET 302",
+ Code: ActionGet302,
+ Instance: new(Get302Action),
+ Type: reflect.TypeOf(new(Get302Action)).Elem(),
+ },
+ {
+ Name: "POST 307",
+ Code: ActionPost307,
+ Instance: new(Post307Action),
+ Type: reflect.TypeOf(new(Post307Action)).Elem(),
+ },
+ {
+ Name: "记录IP",
+ Code: ActionRecordIP,
+ Instance: new(RecordIPAction),
+ Type: reflect.TypeOf(new(RecordIPAction)).Elem(),
+ },
+ {
+ Name: "标签",
+ Code: ActionTag,
+ Instance: new(TagAction),
+ Type: reflect.TypeOf(new(TagAction)).Elem(),
+ },
+ {
+ Name: "跳到下一个规则分组",
+ Code: ActionGoGroup,
+ Instance: new(GoGroupAction),
+ Type: reflect.TypeOf(new(GoGroupAction)).Elem(),
+ },
+ {
+ Name: "跳到下一个规则集",
+ Code: ActionGoSet,
+ Instance: new(GoSetAction),
+ Type: reflect.TypeOf(new(GoSetAction)).Elem(),
+ },
+}
diff --git a/internal/waf/action_utils.go b/internal/waf/action_utils.go
index d2178e9..39f1259 100644
--- a/internal/waf/action_utils.go
+++ b/internal/waf/action_utils.go
@@ -1,45 +1,12 @@
package waf
import (
+ "encoding/json"
+ "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/maps"
"reflect"
)
-var AllActions = []*ActionDefinition{
- {
- Name: "阻止",
- Code: ActionBlock,
- Instance: new(BlockAction),
- },
- {
- Name: "允许通过",
- Code: ActionAllow,
- Instance: new(AllowAction),
- },
- {
- Name: "允许并记录日志",
- Code: ActionLog,
- Instance: new(LogAction),
- },
- {
- Name: "Captcha验证码",
- Code: ActionCaptcha,
- Instance: new(CaptchaAction),
- },
- {
- Name: "跳到下一个规则分组",
- Code: ActionGoGroup,
- Instance: new(GoGroupAction),
- Type: reflect.TypeOf(new(GoGroupAction)).Elem(),
- },
- {
- Name: "跳到下一个规则集",
- Code: ActionGoSet,
- Instance: new(GoSetAction),
- Type: reflect.TypeOf(new(GoSetAction)).Elem(),
- },
-}
-
func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
for _, def := range AllActions {
if def.Code == action {
@@ -49,15 +16,13 @@ func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
instance := ptrValue.Interface().(ActionInterface)
if len(options) > 0 {
- count := def.Type.NumField()
- for i := 0; i < count; i++ {
- field := def.Type.Field(i)
- tag, ok := field.Tag.Lookup("yaml")
- if ok {
- v, ok := options[tag]
- if ok && reflect.TypeOf(v) == field.Type {
- ptrValue.Elem().FieldByName(field.Name).Set(reflect.ValueOf(v))
- }
+ optionsJSON, err := json.Marshal(options)
+ if err != nil {
+ remotelogs.Error("WAF_FindActionInstance", "encode options to json failed: "+err.Error())
+ } else {
+ err = json.Unmarshal(optionsJSON, instance)
+ if err != nil {
+ remotelogs.Error("WAF_FindActionInstance", "decode options from json failed: "+err.Error())
}
}
}
diff --git a/internal/waf/action_utils_test.go b/internal/waf/action_utils_test.go
index e219f55..735fe32 100644
--- a/internal/waf/action_utils_test.go
+++ b/internal/waf/action_utils_test.go
@@ -2,6 +2,7 @@ package waf
import (
"github.com/iwind/TeaGo/assert"
+ "github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"runtime"
"testing"
@@ -16,11 +17,20 @@ func TestFindActionInstance(t *testing.T) {
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
- t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b",}))
+ t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b"}))
a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil))
}
+func TestFindActionInstance_Options(t *testing.T) {
+ //t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
+ //t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
+ //logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{}), t)
+ logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{
+ "timeout": 3600,
+ }), t)
+}
+
func BenchmarkFindActionInstance(b *testing.B) {
runtime.GOMAXPROCS(1)
for i := 0; i < b.N; i++ {
diff --git a/internal/waf/captcha_validator.go b/internal/waf/captcha_validator.go
index e945e27..954bbdf 100644
--- a/internal/waf/captcha_validator.go
+++ b/internal/waf/captcha_validator.go
@@ -3,29 +3,64 @@ package waf
import (
"bytes"
"encoding/base64"
- "fmt"
+ "github.com/TeaOSLab/EdgeNode/internal/utils"
+ "github.com/TeaOSLab/EdgeNode/internal/utils/jsonutils"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/dchest/captcha"
"github.com/iwind/TeaGo/logs"
- stringutil "github.com/iwind/TeaGo/utils/string"
+ "github.com/iwind/TeaGo/types"
"net/http"
+ "strconv"
+ "strings"
"time"
)
-var captchaValidator = &CaptchaValidator{}
+var captchaValidator = NewCaptchaValidator()
type CaptchaValidator struct {
}
-func (this *CaptchaValidator) Run(request *requests.Request, writer http.ResponseWriter) {
- if request.Method == http.MethodPost && len(request.FormValue("TEAWEB_WAF_CAPTCHA_ID")) > 0 {
- this.validate(request, writer)
+func NewCaptchaValidator() *CaptchaValidator {
+ return &CaptchaValidator{}
+}
+
+func (this *CaptchaValidator) Run(request requests.Request, writer http.ResponseWriter) {
+ var info = request.WAFRaw().URL.Query().Get("info")
+ if len(info) == 0 {
+ writer.WriteHeader(http.StatusBadRequest)
+ _, _ = writer.Write([]byte("invalid request"))
+ return
+ }
+ m, err := utils.SimpleDecryptMap(info)
+ if err != nil {
+ _, _ = writer.Write([]byte("invalid request"))
+ return
+ }
+
+ timestamp := m.GetInt64("timestamp")
+ if timestamp < time.Now().Unix()-600 { // 10分钟之后信息过期
+ http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
+ return
+ }
+
+ var actionConfig = &CaptchaAction{}
+ err = jsonutils.MapToObject(m.GetMap("action"), actionConfig)
+ if err != nil {
+ http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
+ return
+ }
+
+ var setId = m.GetInt64("setId")
+ var originURL = m.GetString("url")
+
+ if request.WAFRaw().Method == http.MethodPost && len(request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 {
+ this.validate(actionConfig, setId, originURL, request, writer)
} else {
- this.show(request, writer)
+ this.show(actionConfig, request, writer)
}
}
-func (this *CaptchaValidator) show(request *requests.Request, writer http.ResponseWriter) {
+func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests.Request, writer http.ResponseWriter) {
// show captcha
captchaId := captcha.NewLen(6)
buf := bytes.NewBuffer([]byte{})
@@ -35,48 +70,86 @@ func (this *CaptchaValidator) show(request *requests.Request, writer http.Respon
return
}
+ var lang = actionConfig.Language
+ if len(lang) == 0 {
+ acceptLanguage := request.WAFRaw().Header.Get("Accept-Language")
+ if len(acceptLanguage) > 0 {
+ langIndex := strings.Index(acceptLanguage, ",")
+ if langIndex > 0 {
+ lang = acceptLanguage[:langIndex]
+ }
+ }
+ }
+ if len(lang) == 0 {
+ lang = "en-US"
+ }
+
+ var msgTitle = ""
+ var msgPrompt = ""
+ var msgButtonTitle = ""
+
+ switch lang {
+ case "en-US":
+ msgTitle = "Verify Yourself"
+ msgPrompt = "Input verify code above:"
+ msgButtonTitle = "Verify Yourself"
+ case "zh-CN":
+ msgTitle = "身份验证"
+ msgPrompt = "请输入上面的验证码"
+ msgButtonTitle = "提交验证"
+ default:
+ msgTitle = "Verify Yourself"
+ msgPrompt = "Input verify code above:"
+ msgButtonTitle = "Verify Yourself"
+ }
+
+ writer.Header().Set("Content-Type", "text/html; charset=utf-8")
_, _ = writer.Write([]byte(`
- Verify Yourself
+ ` + msgTitle + `
+
`))
}
-func (this *CaptchaValidator) validate(request *requests.Request, writer http.ResponseWriter) (allow bool) {
- captchaId := request.FormValue("TEAWEB_WAF_CAPTCHA_ID")
+func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, setId int64, originURL string, request requests.Request, writer http.ResponseWriter) (allow bool) {
+ captchaId := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")
if len(captchaId) > 0 {
- captchaCode := request.FormValue("TEAWEB_WAF_CAPTCHA_CODE")
+ captchaCode := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_CODE")
if captcha.VerifyString(captchaId, captchaCode) {
- // set cookie
- timestamp := fmt.Sprintf("%d", time.Now().Unix()+CaptchaSeconds)
- m := stringutil.Md5(captchaSalt + timestamp)
- http.SetCookie(writer, &http.Cookie{
- Name: "TEAWEB_WAF_CAPTCHA",
- Value: m + timestamp,
- MaxAge: CaptchaSeconds, // TODO 这个时间可以设置
- Path: "/", // all of dirs
- })
+ var life = CaptchaSeconds
+ if actionConfig.Life > 0 {
+ life = types.Int(actionConfig.Life)
+ }
- rawURL := request.URL.Query().Get("url")
- http.Redirect(writer, request.Raw(), rawURL, http.StatusSeeOther)
+ // 加入到白名单
+ SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), request.WAFRemoteIP(), time.Now().Unix()+int64(life)) // TODO
+
+ http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther)
return false
} else {
- http.Redirect(writer, request.Raw(), request.URL.String(), http.StatusSeeOther)
+ http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusSeeOther)
}
}
diff --git a/internal/waf/checkpoints/cc.go b/internal/waf/checkpoints/cc.go
index 8de7ef9..de6c3c9 100644
--- a/internal/waf/checkpoints/cc.go
+++ b/internal/waf/checkpoints/cc.go
@@ -5,14 +5,12 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
- "net"
"regexp"
- "strings"
"sync"
"time"
)
-// ${cc.arg}
+// CCCheckpoint ${cc.arg}
// TODO implement more traffic rules
type CCCheckpoint struct {
Checkpoint
@@ -32,7 +30,7 @@ func (this *CCCheckpoint) Start() {
this.cache = ttlcache.NewCache()
}
-func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *CCCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = 0
if this.cache == nil {
@@ -66,12 +64,12 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
var key = ""
switch userType {
case "ip":
- key = this.ip(req)
+ key = req.WAFRemoteIP()
case "cookie":
if len(userField) == 0 {
- key = this.ip(req)
+ key = req.WAFRemoteIP()
} else {
- cookie, _ := req.Cookie(userField)
+ cookie, _ := req.WAFRaw().Cookie(userField)
if cookie != nil {
v := cookie.Value
if userIndex > 0 && len(v) > userIndex {
@@ -82,9 +80,9 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
}
case "get":
if len(userField) == 0 {
- key = this.ip(req)
+ key = req.WAFRemoteIP()
} else {
- v := req.URL.Query().Get(userField)
+ v := req.WAFRaw().URL.Query().Get(userField)
if userIndex > 0 && len(v) > userIndex {
v = v[userIndex:]
}
@@ -92,9 +90,9 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
}
case "post":
if len(userField) == 0 {
- key = this.ip(req)
+ key = req.WAFRemoteIP()
} else {
- v := req.PostFormValue(userField)
+ v := req.WAFRaw().PostFormValue(userField)
if userIndex > 0 && len(v) > userIndex {
v = v[userIndex:]
}
@@ -102,19 +100,19 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
}
case "header":
if len(userField) == 0 {
- key = this.ip(req)
+ key = req.WAFRemoteIP()
} else {
- v := req.Header.Get(userField)
+ v := req.WAFRaw().Header.Get(userField)
if userIndex > 0 && len(v) > userIndex {
v = v[userIndex:]
}
key = "USER@" + userType + "@" + userField + "@" + v
}
default:
- key = this.ip(req)
+ key = req.WAFRemoteIP()
}
if len(key) == 0 {
- key = this.ip(req)
+ key = req.WAFRemoteIP()
}
value = this.cache.IncreaseInt64(key, int64(1), time.Now().Unix()+period)
}
@@ -122,7 +120,7 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
return
}
-func (this *CCCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *CCCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
@@ -210,38 +208,3 @@ func (this *CCCheckpoint) Stop() {
this.cache = nil
}
}
-
-func (this *CCCheckpoint) ip(req *requests.Request) string {
- // X-Forwarded-For
- forwardedFor := req.Header.Get("X-Forwarded-For")
- if len(forwardedFor) > 0 {
- commaIndex := strings.Index(forwardedFor, ",")
- if commaIndex > 0 {
- return forwardedFor[:commaIndex]
- }
- return forwardedFor
- }
-
- // Real-IP
- {
- realIP, ok := req.Header["X-Real-IP"]
- if ok && len(realIP) > 0 {
- return realIP[0]
- }
- }
-
- // Real-Ip
- {
- realIP, ok := req.Header["X-Real-Ip"]
- if ok && len(realIP) > 0 {
- return realIP[0]
- }
- }
-
- // Remote-Addr
- host, _, err := net.SplitHostPort(req.RemoteAddr)
- if err == nil {
- return host
- }
- return req.RemoteAddr
-}
diff --git a/internal/waf/checkpoints/cc_test.go b/internal/waf/checkpoints/cc_test.go
index 6245798..249b477 100644
--- a/internal/waf/checkpoints/cc_test.go
+++ b/internal/waf/checkpoints/cc_test.go
@@ -2,6 +2,7 @@ package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
+ "github.com/iwind/TeaGo/maps"
"net/http"
"testing"
)
@@ -12,31 +13,31 @@ func TestCCCheckpoint_RequestValue(t *testing.T) {
t.Fatal(err)
}
- req := requests.NewRequest(raw)
- req.RemoteAddr = "127.0.0.1"
+ req := requests.NewTestRequest(raw)
+ req.WAFRaw().RemoteAddr = "127.0.0.1"
checkpoint := new(CCCheckpoint)
checkpoint.Init()
checkpoint.Start()
- options := map[string]string{
+ options := maps.Map{
"period": "5",
}
t.Log(checkpoint.RequestValue(req, "requests", options))
t.Log(checkpoint.RequestValue(req, "requests", options))
- req.RemoteAddr = "127.0.0.2"
+ req.WAFRaw().RemoteAddr = "127.0.0.2"
t.Log(checkpoint.RequestValue(req, "requests", options))
- req.RemoteAddr = "127.0.0.1"
+ req.WAFRaw().RemoteAddr = "127.0.0.1"
t.Log(checkpoint.RequestValue(req, "requests", options))
- req.RemoteAddr = "127.0.0.2"
+ req.WAFRaw().RemoteAddr = "127.0.0.2"
t.Log(checkpoint.RequestValue(req, "requests", options))
- req.RemoteAddr = "127.0.0.2"
+ req.WAFRaw().RemoteAddr = "127.0.0.2"
t.Log(checkpoint.RequestValue(req, "requests", options))
- req.RemoteAddr = "127.0.0.2"
+ req.WAFRaw().RemoteAddr = "127.0.0.2"
t.Log(checkpoint.RequestValue(req, "requests", options))
}
diff --git a/internal/waf/checkpoints/checkpoint_interface.go b/internal/waf/checkpoints/checkpoint_interface.go
index 0a8ac8d..532ae62 100644
--- a/internal/waf/checkpoints/checkpoint_interface.go
+++ b/internal/waf/checkpoints/checkpoint_interface.go
@@ -5,32 +5,32 @@ import (
"github.com/iwind/TeaGo/maps"
)
-// Check Point
+// CheckpointInterface Check Point
type CheckpointInterface interface {
- // initialize
+ // Init initialize
Init()
- // is request?
+ // IsRequest is request?
IsRequest() bool
- // is composed?
+ // IsComposed is composed?
IsComposed() bool
- // get request value
- RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
+ // RequestValue get request value
+ RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
- // get response value
- ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
+ // ResponseValue get response value
+ ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
- // param option list
+ // ParamOptions param option list
ParamOptions() *ParamOptions
- // options
+ // Options options
Options() []OptionInterface
- // start
+ // Start start
Start()
- // stop
+ // Stop stop
Stop()
}
diff --git a/internal/waf/checkpoints/request_all.go b/internal/waf/checkpoints/request_all.go
index 64664a6..30a5f98 100644
--- a/internal/waf/checkpoints/request_all.go
+++ b/internal/waf/checkpoints/request_all.go
@@ -5,32 +5,34 @@ import (
"github.com/iwind/TeaGo/maps"
)
-// ${requestAll}
+// RequestAllCheckpoint ${requestAll}
type RequestAllCheckpoint struct {
Checkpoint
}
-func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
valueBytes := []byte{}
- if len(req.RequestURI) > 0 {
- valueBytes = append(valueBytes, req.RequestURI...)
- } else if req.URL != nil {
- valueBytes = append(valueBytes, req.URL.RequestURI()...)
+ if len(req.WAFRaw().RequestURI) > 0 {
+ valueBytes = append(valueBytes, req.WAFRaw().RequestURI...)
+ } else if req.WAFRaw().URL != nil {
+ valueBytes = append(valueBytes, req.WAFRaw().URL.RequestURI()...)
}
- if req.Body != nil {
+ if req.WAFRaw().Body != nil {
valueBytes = append(valueBytes, ' ')
- if len(req.BodyData) == 0 {
- data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
+ var bodyData = req.WAFGetCacheBody()
+ if len(bodyData) == 0 {
+ data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
if err != nil {
return "", err, nil
}
- req.BodyData = data
- req.RestoreBody(data)
+ bodyData = data
+ req.WAFSetCacheBody(data)
+ req.WAFRestoreBody(data)
}
- valueBytes = append(valueBytes, req.BodyData...)
+ valueBytes = append(valueBytes, bodyData...)
}
value = valueBytes
@@ -38,7 +40,7 @@ func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param stri
return
}
-func (this *RequestAllCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestAllCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = ""
if this.IsRequest() {
return this.RequestValue(req, param, options)
diff --git a/internal/waf/checkpoints/request_all_test.go b/internal/waf/checkpoints/request_all_test.go
index d8a12a0..5ee81d5 100644
--- a/internal/waf/checkpoints/request_all_test.go
+++ b/internal/waf/checkpoints/request_all_test.go
@@ -18,7 +18,7 @@ func TestRequestAllCheckpoint_RequestValue(t *testing.T) {
}
checkpoint := new(RequestAllCheckpoint)
- v, sysErr, userErr := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
+ v, sysErr, userErr := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
if sysErr != nil {
t.Fatal(sysErr)
}
@@ -42,7 +42,7 @@ func TestRequestAllCheckpoint_RequestValue_Max(t *testing.T) {
}
checkpoint := new(RequestBodyCheckpoint)
- value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
+ value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
if err != nil {
t.Fatal(err)
}
@@ -65,6 +65,6 @@ func BenchmarkRequestAllCheckpoint_RequestValue(b *testing.B) {
checkpoint := new(RequestAllCheckpoint)
for i := 0; i < b.N; i++ {
- _, _, _ = checkpoint.RequestValue(requests.NewRequest(req), "", nil)
+ _, _, _ = checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
}
}
diff --git a/internal/waf/checkpoints/request_arg.go b/internal/waf/checkpoints/request_arg.go
index 813fc51..a9c51a5 100644
--- a/internal/waf/checkpoints/request_arg.go
+++ b/internal/waf/checkpoints/request_arg.go
@@ -9,11 +9,11 @@ type RequestArgCheckpoint struct {
Checkpoint
}
-func (this *RequestArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- return req.URL.Query().Get(param), nil, nil
+func (this *RequestArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ return req.WAFRaw().URL.Query().Get(param), nil, nil
}
-func (this *RequestArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_arg_test.go b/internal/waf/checkpoints/request_arg_test.go
index a7cdaf3..6ac84f6 100644
--- a/internal/waf/checkpoints/request_arg_test.go
+++ b/internal/waf/checkpoints/request_arg_test.go
@@ -12,7 +12,7 @@ func TestArgParam_RequestValue(t *testing.T) {
t.Fatal(err)
}
- req := requests.NewRequest(rawReq)
+ req := requests.NewTestRequest(rawReq)
checkpoint := new(RequestArgCheckpoint)
t.Log(checkpoint.RequestValue(req, "name", nil))
diff --git a/internal/waf/checkpoints/request_args.go b/internal/waf/checkpoints/request_args.go
index 9a3883c..a83dc3f 100644
--- a/internal/waf/checkpoints/request_args.go
+++ b/internal/waf/checkpoints/request_args.go
@@ -9,12 +9,12 @@ type RequestArgsCheckpoint struct {
Checkpoint
}
-func (this *RequestArgsCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- value = req.URL.RawQuery
+func (this *RequestArgsCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ value = req.WAFRaw().URL.RawQuery
return
}
-func (this *RequestArgsCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestArgsCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_body.go b/internal/waf/checkpoints/request_body.go
index d6e54ca..a04d0ff 100644
--- a/internal/waf/checkpoints/request_body.go
+++ b/internal/waf/checkpoints/request_body.go
@@ -5,31 +5,33 @@ import (
"github.com/iwind/TeaGo/maps"
)
-// ${requestBody}
+// RequestBodyCheckpoint ${requestBody}
type RequestBodyCheckpoint struct {
Checkpoint
}
-func (this *RequestBodyCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- if req.Body == nil {
+func (this *RequestBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ if req.WAFRaw().Body == nil {
value = ""
return
}
- if len(req.BodyData) == 0 {
- data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
+ var bodyData = req.WAFGetCacheBody()
+ if len(bodyData) == 0 {
+ data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
if err != nil {
return "", err, nil
}
- req.BodyData = data
- req.RestoreBody(data)
+ bodyData = data
+ req.WAFSetCacheBody(data)
+ req.WAFRestoreBody(data)
}
- return req.BodyData, nil, nil
+ return bodyData, nil, nil
}
-func (this *RequestBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_body_test.go b/internal/waf/checkpoints/request_body_test.go
index b1c982d..8bdb0d2 100644
--- a/internal/waf/checkpoints/request_body_test.go
+++ b/internal/waf/checkpoints/request_body_test.go
@@ -11,19 +11,20 @@ import (
)
func TestRequestBodyCheckpoint_RequestValue(t *testing.T) {
- req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456")))
+ rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456")))
if err != nil {
t.Fatal(err)
}
-
+ var req = requests.NewTestRequest(rawReq)
checkpoint := new(RequestBodyCheckpoint)
- t.Log(checkpoint.RequestValue(requests.NewRequest(req), "", nil))
+ t.Log(checkpoint.RequestValue(req, "", nil))
- body, err := ioutil.ReadAll(req.Body)
+ body, err := ioutil.ReadAll(rawReq.Body)
if err != nil {
t.Fatal(err)
}
t.Log(string(body))
+ t.Log(string(req.WAFGetCacheBody()))
}
func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
@@ -33,7 +34,7 @@ func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
}
checkpoint := new(RequestBodyCheckpoint)
- value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
+ value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/waf/checkpoints/request_content_type.go b/internal/waf/checkpoints/request_content_type.go
index 6a11132..6ff04cd 100644
--- a/internal/waf/checkpoints/request_content_type.go
+++ b/internal/waf/checkpoints/request_content_type.go
@@ -9,12 +9,12 @@ type RequestContentTypeCheckpoint struct {
Checkpoint
}
-func (this *RequestContentTypeCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- value = req.Header.Get("Content-Type")
+func (this *RequestContentTypeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ value = req.WAFRaw().Header.Get("Content-Type")
return
}
-func (this *RequestContentTypeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestContentTypeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_cookie.go b/internal/waf/checkpoints/request_cookie.go
index eb1ba91..33fd968 100644
--- a/internal/waf/checkpoints/request_cookie.go
+++ b/internal/waf/checkpoints/request_cookie.go
@@ -9,8 +9,8 @@ type RequestCookieCheckpoint struct {
Checkpoint
}
-func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- cookie, err := req.Cookie(param)
+func (this *RequestCookieCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ cookie, err := req.WAFRaw().Cookie(param)
if err != nil {
value = ""
return
@@ -20,7 +20,7 @@ func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param s
return
}
-func (this *RequestCookieCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestCookieCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_cookies.go b/internal/waf/checkpoints/request_cookies.go
index f284788..a9f1035 100644
--- a/internal/waf/checkpoints/request_cookies.go
+++ b/internal/waf/checkpoints/request_cookies.go
@@ -11,16 +11,16 @@ type RequestCookiesCheckpoint struct {
Checkpoint
}
-func (this *RequestCookiesCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestCookiesCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
var cookies = []string{}
- for _, cookie := range req.Cookies() {
+ for _, cookie := range req.WAFRaw().Cookies() {
cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value))
}
value = strings.Join(cookies, "&")
return
}
-func (this *RequestCookiesCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestCookiesCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_form_arg.go b/internal/waf/checkpoints/request_form_arg.go
index 92fa0ab..1862784 100644
--- a/internal/waf/checkpoints/request_form_arg.go
+++ b/internal/waf/checkpoints/request_form_arg.go
@@ -6,33 +6,35 @@ import (
"net/url"
)
-// ${requestForm.arg}
+// RequestFormArgCheckpoint ${requestForm.arg}
type RequestFormArgCheckpoint struct {
Checkpoint
}
-func (this *RequestFormArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- if req.Body == nil {
+func (this *RequestFormArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ if req.WAFRaw().Body == nil {
value = ""
return
}
- if len(req.BodyData) == 0 {
- data, err := req.ReadBody(32 * 1024 * 1024) // read 32m bytes
+ var bodyData = req.WAFGetCacheBody()
+ if len(bodyData) == 0 {
+ data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
if err != nil {
return "", err, nil
}
- req.BodyData = data
- req.RestoreBody(data)
+ bodyData = data
+ req.WAFSetCacheBody(data)
+ req.WAFRestoreBody(data)
}
// TODO improve performance
- values, _ := url.ParseQuery(string(req.BodyData))
+ values, _ := url.ParseQuery(string(bodyData))
return values.Get(param), nil, nil
}
-func (this *RequestFormArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestFormArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_form_arg_test.go b/internal/waf/checkpoints/request_form_arg_test.go
index 01c0396..5da0624 100644
--- a/internal/waf/checkpoints/request_form_arg_test.go
+++ b/internal/waf/checkpoints/request_form_arg_test.go
@@ -15,8 +15,8 @@ func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
t.Fatal(err)
}
- req := requests.NewRequest(rawReq)
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req := requests.NewTestRequest(rawReq)
+ req.WAFRaw().Header.Set("Content-Type", "application/x-www-form-urlencoded")
checkpoint := new(RequestFormArgCheckpoint)
t.Log(checkpoint.RequestValue(req, "name", nil))
@@ -24,7 +24,7 @@ func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
t.Log(checkpoint.RequestValue(req, "Hello", nil))
t.Log(checkpoint.RequestValue(req, "encoded", nil))
- body, err := ioutil.ReadAll(req.Body)
+ body, err := ioutil.ReadAll(req.WAFRaw().Body)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/waf/checkpoints/request_general_header_length.go b/internal/waf/checkpoints/request_general_header_length.go
index 4f2a430..50e8251 100644
--- a/internal/waf/checkpoints/request_general_header_length.go
+++ b/internal/waf/checkpoints/request_general_header_length.go
@@ -14,7 +14,7 @@ func (this *RequestGeneralHeaderLengthCheckpoint) IsComposed() bool {
return true
}
-func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = false
headers := options.GetSlice("headers")
@@ -25,7 +25,7 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Req
length := options.GetInt("length")
for _, header := range headers {
- v := req.Header.Get(types.String(header))
+ v := req.WAFRaw().Header.Get(types.String(header))
if len(v) > length {
value = true
break
@@ -35,6 +35,6 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Req
return
}
-func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return
}
diff --git a/internal/waf/checkpoints/request_header.go b/internal/waf/checkpoints/request_header.go
index 029def4..8b206d0 100644
--- a/internal/waf/checkpoints/request_header.go
+++ b/internal/waf/checkpoints/request_header.go
@@ -10,8 +10,8 @@ type RequestHeaderCheckpoint struct {
Checkpoint
}
-func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- v, found := req.Header[param]
+func (this *RequestHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ v, found := req.WAFRaw().Header[param]
if !found {
value = ""
return
@@ -20,7 +20,7 @@ func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param s
return
}
-func (this *RequestHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_headers.go b/internal/waf/checkpoints/request_headers.go
index c5ef280..0fdb225 100644
--- a/internal/waf/checkpoints/request_headers.go
+++ b/internal/waf/checkpoints/request_headers.go
@@ -11,9 +11,9 @@ type RequestHeadersCheckpoint struct {
Checkpoint
}
-func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestHeadersCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
var headers = []string{}
- for k, v := range req.Header {
+ for k, v := range req.WAFRaw().Header {
for _, subV := range v {
headers = append(headers, k+": "+subV)
}
@@ -23,7 +23,7 @@ func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param
return
}
-func (this *RequestHeadersCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestHeadersCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_host.go b/internal/waf/checkpoints/request_host.go
index 60174d0..105f4a7 100644
--- a/internal/waf/checkpoints/request_host.go
+++ b/internal/waf/checkpoints/request_host.go
@@ -9,12 +9,12 @@ type RequestHostCheckpoint struct {
Checkpoint
}
-func (this *RequestHostCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- value = req.Host
+func (this *RequestHostCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ value = req.WAFRaw().Host
return
}
-func (this *RequestHostCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestHostCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_host_test.go b/internal/waf/checkpoints/request_host_test.go
index fc1b449..b9274a7 100644
--- a/internal/waf/checkpoints/request_host_test.go
+++ b/internal/waf/checkpoints/request_host_test.go
@@ -12,8 +12,8 @@ func TestRequestHostCheckpoint_RequestValue(t *testing.T) {
t.Fatal(err)
}
- req := requests.NewRequest(rawReq)
- req.Header.Set("Host", "cloud.teaos.cn")
+ req := requests.NewTestRequest(rawReq)
+ req.WAFRaw().Header.Set("Host", "cloud.teaos.cn")
checkpoint := new(RequestHostCheckpoint)
t.Log(checkpoint.RequestValue(req, "", nil))
diff --git a/internal/waf/checkpoints/request_json_arg.go b/internal/waf/checkpoints/request_json_arg.go
index 1a6414e..341db2d 100644
--- a/internal/waf/checkpoints/request_json_arg.go
+++ b/internal/waf/checkpoints/request_json_arg.go
@@ -8,24 +8,27 @@ import (
"strings"
)
-// ${requestJSON.arg}
+// RequestJSONArgCheckpoint ${requestJSON.arg}
type RequestJSONArgCheckpoint struct {
Checkpoint
}
-func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- if len(req.BodyData) == 0 {
- data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
+func (this *RequestJSONArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ var bodyData = req.WAFGetCacheBody()
+ if len(bodyData) == 0 {
+ data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
if err != nil {
return "", err, nil
}
- req.BodyData = data
- defer req.RestoreBody(data)
+
+ bodyData = data
+ req.WAFSetCacheBody(data)
+ defer req.WAFRestoreBody(data)
}
// TODO improve performance
var m interface{} = nil
- err := json.Unmarshal(req.BodyData, &m)
+ err := json.Unmarshal(bodyData, &m)
if err != nil || m == nil {
return "", nil, err
}
@@ -37,7 +40,7 @@ func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param
return "", nil, nil
}
-func (this *RequestJSONArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestJSONArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_json_arg_test.go b/internal/waf/checkpoints/request_json_arg_test.go
index 00708be..63fae0b 100644
--- a/internal/waf/checkpoints/request_json_arg_test.go
+++ b/internal/waf/checkpoints/request_json_arg_test.go
@@ -20,7 +20,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) {
t.Fatal(err)
}
- req := requests.NewRequest(rawReq)
+ req := requests.NewTestRequest(rawReq)
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
checkpoint := new(RequestJSONArgCheckpoint)
@@ -31,7 +31,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) {
t.Log(checkpoint.RequestValue(req, "books", nil))
t.Log(checkpoint.RequestValue(req, "books.1", nil))
- body, err := ioutil.ReadAll(req.Body)
+ body, err := ioutil.ReadAll(req.WAFRaw().Body)
if err != nil {
t.Fatal(err)
}
@@ -50,7 +50,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) {
t.Fatal(err)
}
- req := requests.NewRequest(rawReq)
+ req := requests.NewTestRequest(rawReq)
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
checkpoint := new(RequestJSONArgCheckpoint)
@@ -61,7 +61,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) {
t.Log(checkpoint.RequestValue(req, "0.books", nil))
t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
- body, err := ioutil.ReadAll(req.Body)
+ body, err := ioutil.ReadAll(req.WAFRaw().Body)
if err != nil {
t.Fatal(err)
}
@@ -80,7 +80,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) {
t.Fatal(err)
}
- req := requests.NewRequest(rawReq)
+ req := requests.NewTestRequest(rawReq)
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
checkpoint := new(RequestJSONArgCheckpoint)
@@ -91,7 +91,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) {
t.Log(checkpoint.RequestValue(req, "0.books", nil))
t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
- body, err := ioutil.ReadAll(req.Body)
+ body, err := ioutil.ReadAll(req.WAFRaw().Body)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/waf/checkpoints/request_length.go b/internal/waf/checkpoints/request_length.go
index 9a09556..e26a18b 100644
--- a/internal/waf/checkpoints/request_length.go
+++ b/internal/waf/checkpoints/request_length.go
@@ -9,12 +9,12 @@ type RequestLengthCheckpoint struct {
Checkpoint
}
-func (this *RequestLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- value = req.ContentLength
+func (this *RequestLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ value = req.WAFRaw().ContentLength
return
}
-func (this *RequestLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_method.go b/internal/waf/checkpoints/request_method.go
index b27deb0..3b85fc0 100644
--- a/internal/waf/checkpoints/request_method.go
+++ b/internal/waf/checkpoints/request_method.go
@@ -9,12 +9,12 @@ type RequestMethodCheckpoint struct {
Checkpoint
}
-func (this *RequestMethodCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- value = req.Method
+func (this *RequestMethodCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ value = req.WAFRaw().Method
return
}
-func (this *RequestMethodCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestMethodCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_path.go b/internal/waf/checkpoints/request_path.go
index 7c934de..5e757bb 100644
--- a/internal/waf/checkpoints/request_path.go
+++ b/internal/waf/checkpoints/request_path.go
@@ -9,11 +9,11 @@ type RequestPathCheckpoint struct {
Checkpoint
}
-func (this *RequestPathCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- return req.URL.Path, nil, nil
+func (this *RequestPathCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ return req.WAFRaw().URL.Path, nil, nil
}
-func (this *RequestPathCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestPathCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_path_test.go b/internal/waf/checkpoints/request_path_test.go
index e100602..88f47cb 100644
--- a/internal/waf/checkpoints/request_path_test.go
+++ b/internal/waf/checkpoints/request_path_test.go
@@ -12,7 +12,7 @@ func TestRequestPathCheckpoint_RequestValue(t *testing.T) {
t.Fatal(err)
}
- req := requests.NewRequest(rawReq)
+ req := requests.NewTestRequest(rawReq)
checkpoint := new(RequestPathCheckpoint)
t.Log(checkpoint.RequestValue(req, "", nil))
}
diff --git a/internal/waf/checkpoints/request_proto.go b/internal/waf/checkpoints/request_proto.go
index f3cd372..235b2db 100644
--- a/internal/waf/checkpoints/request_proto.go
+++ b/internal/waf/checkpoints/request_proto.go
@@ -9,12 +9,12 @@ type RequestProtoCheckpoint struct {
Checkpoint
}
-func (this *RequestProtoCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- value = req.Proto
+func (this *RequestProtoCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ value = req.WAFRaw().Proto
return
}
-func (this *RequestProtoCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestProtoCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_raw_remote_addr.go b/internal/waf/checkpoints/request_raw_remote_addr.go
index 9da7f11..7886c44 100644
--- a/internal/waf/checkpoints/request_raw_remote_addr.go
+++ b/internal/waf/checkpoints/request_raw_remote_addr.go
@@ -10,17 +10,17 @@ type RequestRawRemoteAddrCheckpoint struct {
Checkpoint
}
-func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- host, _, err := net.SplitHostPort(req.RemoteAddr)
+func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ host, _, err := net.SplitHostPort(req.WAFRaw().RemoteAddr)
if err == nil {
value = host
} else {
- value = req.RemoteAddr
+ value = req.WAFRaw().RemoteAddr
}
return
}
-func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_referer.go b/internal/waf/checkpoints/request_referer.go
index 0160579..775c084 100644
--- a/internal/waf/checkpoints/request_referer.go
+++ b/internal/waf/checkpoints/request_referer.go
@@ -9,12 +9,12 @@ type RequestRefererCheckpoint struct {
Checkpoint
}
-func (this *RequestRefererCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- value = req.Referer()
+func (this *RequestRefererCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ value = req.WAFRaw().Referer()
return
}
-func (this *RequestRefererCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestRefererCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_remote_addr.go b/internal/waf/checkpoints/request_remote_addr.go
index b80e42b..dc26a10 100644
--- a/internal/waf/checkpoints/request_remote_addr.go
+++ b/internal/waf/checkpoints/request_remote_addr.go
@@ -3,56 +3,18 @@ package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
- "net"
- "strings"
)
type RequestRemoteAddrCheckpoint struct {
Checkpoint
}
-func (this *RequestRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- // X-Forwarded-For
- forwardedFor := req.Header.Get("X-Forwarded-For")
- if len(forwardedFor) > 0 {
- commaIndex := strings.Index(forwardedFor, ",")
- if commaIndex > 0 {
- value = forwardedFor[:commaIndex]
- return
- }
- value = forwardedFor
- return
- }
-
- // Real-IP
- {
- realIP, ok := req.Header["X-Real-IP"]
- if ok && len(realIP) > 0 {
- value = realIP[0]
- return
- }
- }
-
- // Real-Ip
- {
- realIP, ok := req.Header["X-Real-Ip"]
- if ok && len(realIP) > 0 {
- value = realIP[0]
- return
- }
- }
-
- // Remote-Addr
- host, _, err := net.SplitHostPort(req.RemoteAddr)
- if err == nil {
- value = host
- } else {
- value = req.RemoteAddr
- }
+func (this *RequestRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ value = req.WAFRemoteIP()
return
}
-func (this *RequestRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_remote_port.go b/internal/waf/checkpoints/request_remote_port.go
index 5f65c7e..f5aa158 100644
--- a/internal/waf/checkpoints/request_remote_port.go
+++ b/internal/waf/checkpoints/request_remote_port.go
@@ -11,8 +11,8 @@ type RequestRemotePortCheckpoint struct {
Checkpoint
}
-func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- _, port, err := net.SplitHostPort(req.RemoteAddr)
+func (this *RequestRemotePortCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ _, port, err := net.SplitHostPort(req.WAFRaw().RemoteAddr)
if err == nil {
value = types.Int(port)
} else {
@@ -21,7 +21,7 @@ func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, par
return
}
-func (this *RequestRemotePortCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestRemotePortCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_remote_user.go b/internal/waf/checkpoints/request_remote_user.go
index 1c27cc8..a2d1e20 100644
--- a/internal/waf/checkpoints/request_remote_user.go
+++ b/internal/waf/checkpoints/request_remote_user.go
@@ -9,8 +9,8 @@ type RequestRemoteUserCheckpoint struct {
Checkpoint
}
-func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- username, _, ok := req.BasicAuth()
+func (this *RequestRemoteUserCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ username, _, ok := req.WAFRaw().BasicAuth()
if !ok {
value = ""
return
@@ -19,7 +19,7 @@ func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, par
return
}
-func (this *RequestRemoteUserCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestRemoteUserCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_scheme.go b/internal/waf/checkpoints/request_scheme.go
index 11e27e1..05f98c6 100644
--- a/internal/waf/checkpoints/request_scheme.go
+++ b/internal/waf/checkpoints/request_scheme.go
@@ -9,12 +9,12 @@ type RequestSchemeCheckpoint struct {
Checkpoint
}
-func (this *RequestSchemeCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- value = req.URL.Scheme
+func (this *RequestSchemeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ value = req.WAFRaw().URL.Scheme
return
}
-func (this *RequestSchemeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestSchemeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_scheme_test.go b/internal/waf/checkpoints/request_scheme_test.go
index 461cf23..8738a3d 100644
--- a/internal/waf/checkpoints/request_scheme_test.go
+++ b/internal/waf/checkpoints/request_scheme_test.go
@@ -12,7 +12,7 @@ func TestRequestSchemeCheckpoint_RequestValue(t *testing.T) {
t.Fatal(err)
}
- req := requests.NewRequest(rawReq)
+ req := requests.NewTestRequest(rawReq)
checkpoint := new(RequestSchemeCheckpoint)
t.Log(checkpoint.RequestValue(req, "", nil))
}
diff --git a/internal/waf/checkpoints/request_upload.go b/internal/waf/checkpoints/request_upload.go
index d76656d..ed43dcf 100644
--- a/internal/waf/checkpoints/request_upload.go
+++ b/internal/waf/checkpoints/request_upload.go
@@ -11,63 +11,65 @@ import (
"strings"
)
-// ${requestUpload.arg}
+// RequestUploadCheckpoint ${requestUpload.arg}
type RequestUploadCheckpoint struct {
Checkpoint
}
-func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestUploadCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = ""
if param == "minSize" || param == "maxSize" {
value = 0
}
- if req.Method != http.MethodPost {
+ if req.WAFRaw().Method != http.MethodPost {
return
}
- if req.Body == nil {
+ if req.WAFRaw().Body == nil {
return
}
- if req.MultipartForm == nil {
- if len(req.BodyData) == 0 {
- data, err := req.ReadBody(32 * 1024 * 1024)
+ if req.WAFRaw().MultipartForm == nil {
+ var bodyData = req.WAFGetCacheBody()
+ if len(bodyData) == 0 {
+ data, err := req.WAFReadBody(32 * 1024 * 1024)
if err != nil {
sysErr = err
return
}
- req.BodyData = data
- defer req.RestoreBody(data)
+ bodyData = data
+ req.WAFSetCacheBody(data)
+ defer req.WAFRestoreBody(data)
}
- oldBody := req.Body
- req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyData))
+ oldBody := req.WAFRaw().Body
+ req.WAFRaw().Body = ioutil.NopCloser(bytes.NewBuffer(bodyData))
- err := req.ParseMultipartForm(32 * 1024 * 1024)
+ err := req.WAFRaw().ParseMultipartForm(32 * 1024 * 1024)
// 还原
- req.Body = oldBody
+ req.WAFRaw().Body = oldBody
if err != nil {
userErr = err
return
}
- if req.MultipartForm == nil {
+ if req.WAFRaw().MultipartForm == nil {
return
}
}
if param == "field" { // field
fields := []string{}
- for field := range req.MultipartForm.File {
+ for field := range req.WAFRaw().MultipartForm.File {
fields = append(fields, field)
}
value = strings.Join(fields, ",")
} else if param == "minSize" { // minSize
minSize := int64(0)
- for _, files := range req.MultipartForm.File {
+ for _, files := range req.WAFRaw().MultipartForm.File {
for _, file := range files {
if minSize == 0 || minSize > file.Size {
minSize = file.Size
@@ -77,7 +79,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
value = minSize
} else if param == "maxSize" { // maxSize
maxSize := int64(0)
- for _, files := range req.MultipartForm.File {
+ for _, files := range req.WAFRaw().MultipartForm.File {
for _, file := range files {
if maxSize < file.Size {
maxSize = file.Size
@@ -87,7 +89,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
value = maxSize
} else if param == "name" { // name
names := []string{}
- for _, files := range req.MultipartForm.File {
+ for _, files := range req.WAFRaw().MultipartForm.File {
for _, file := range files {
if !lists.ContainsString(names, file.Filename) {
names = append(names, file.Filename)
@@ -97,7 +99,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
value = strings.Join(names, ",")
} else if param == "ext" { // ext
extensions := []string{}
- for _, files := range req.MultipartForm.File {
+ for _, files := range req.WAFRaw().MultipartForm.File {
for _, file := range files {
if len(file.Filename) > 0 {
exit := strings.ToLower(filepath.Ext(file.Filename))
@@ -113,7 +115,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
return
}
-func (this *RequestUploadCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestUploadCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_upload_test.go b/internal/waf/checkpoints/request_upload_test.go
index cc1ab64..0bc03a5 100644
--- a/internal/waf/checkpoints/request_upload_test.go
+++ b/internal/waf/checkpoints/request_upload_test.go
@@ -63,8 +63,8 @@ func TestRequestUploadCheckpoint_RequestValue(t *testing.T) {
t.Fatal()
}
- req := requests.NewRequest(rawReq)
- req.Header.Add("Content-Type", writer.FormDataContentType())
+ req := requests.NewTestRequest(rawReq)
+ req.WAFRaw().Header.Add("Content-Type", writer.FormDataContentType())
checkpoint := new(RequestUploadCheckpoint)
t.Log(checkpoint.RequestValue(req, "field", nil))
@@ -73,7 +73,7 @@ func TestRequestUploadCheckpoint_RequestValue(t *testing.T) {
t.Log(checkpoint.RequestValue(req, "name", nil))
t.Log(checkpoint.RequestValue(req, "ext", nil))
- data, err := ioutil.ReadAll(req.Body)
+ data, err := ioutil.ReadAll(req.WAFRaw().Body)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/waf/checkpoints/request_uri.go b/internal/waf/checkpoints/request_uri.go
index d927baf..bfe72fd 100644
--- a/internal/waf/checkpoints/request_uri.go
+++ b/internal/waf/checkpoints/request_uri.go
@@ -9,16 +9,16 @@ type RequestURICheckpoint struct {
Checkpoint
}
-func (this *RequestURICheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- if len(req.RequestURI) > 0 {
- value = req.RequestURI
- } else if req.URL != nil {
- value = req.URL.RequestURI()
+func (this *RequestURICheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ if len(req.WAFRaw().RequestURI) > 0 {
+ value = req.WAFRaw().RequestURI
+ } else if req.WAFRaw().URL != nil {
+ value = req.WAFRaw().URL.RequestURI()
}
return
}
-func (this *RequestURICheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestURICheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/request_user_agent.go b/internal/waf/checkpoints/request_user_agent.go
index 407fe50..a9c1bec 100644
--- a/internal/waf/checkpoints/request_user_agent.go
+++ b/internal/waf/checkpoints/request_user_agent.go
@@ -9,12 +9,12 @@ type RequestUserAgentCheckpoint struct {
Checkpoint
}
-func (this *RequestUserAgentCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
- value = req.UserAgent()
+func (this *RequestUserAgentCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+ value = req.WAFRaw().UserAgent()
return
}
-func (this *RequestUserAgentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *RequestUserAgentCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/response_body.go b/internal/waf/checkpoints/response_body.go
index 4e48dac..a39fc67 100644
--- a/internal/waf/checkpoints/response_body.go
+++ b/internal/waf/checkpoints/response_body.go
@@ -16,12 +16,12 @@ func (this *ResponseBodyCheckpoint) IsRequest() bool {
return false
}
-func (this *ResponseBodyCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *ResponseBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = ""
return
}
-func (this *ResponseBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *ResponseBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = ""
if resp != nil && resp.Body != nil {
if len(resp.BodyData) > 0 {
diff --git a/internal/waf/checkpoints/response_bytes_sent.go b/internal/waf/checkpoints/response_bytes_sent.go
index 9461f97..75a719a 100644
--- a/internal/waf/checkpoints/response_bytes_sent.go
+++ b/internal/waf/checkpoints/response_bytes_sent.go
@@ -14,12 +14,12 @@ func (this *ResponseBytesSentCheckpoint) IsRequest() bool {
return false
}
-func (this *ResponseBytesSentCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *ResponseBytesSentCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = 0
return
}
-func (this *ResponseBytesSentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *ResponseBytesSentCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = 0
if resp != nil {
value = resp.ContentLength
diff --git a/internal/waf/checkpoints/response_general_header_length.go b/internal/waf/checkpoints/response_general_header_length.go
index f1ef6ff..00404fd 100644
--- a/internal/waf/checkpoints/response_general_header_length.go
+++ b/internal/waf/checkpoints/response_general_header_length.go
@@ -18,12 +18,12 @@ func (this *ResponseGeneralHeaderLengthCheckpoint) IsComposed() bool {
return true
}
-func (this *ResponseGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *ResponseGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return
}
-func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = false
headers := options.GetSlice("headers")
@@ -34,7 +34,7 @@ func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.R
length := options.GetInt("length")
for _, header := range headers {
- v := req.Header.Get(types.String(header))
+ v := req.WAFRaw().Header.Get(types.String(header))
if len(v) > length {
value = true
break
diff --git a/internal/waf/checkpoints/response_header.go b/internal/waf/checkpoints/response_header.go
index 5d23df5..839e657 100644
--- a/internal/waf/checkpoints/response_header.go
+++ b/internal/waf/checkpoints/response_header.go
@@ -14,12 +14,12 @@ func (this *ResponseHeaderCheckpoint) IsRequest() bool {
return false
}
-func (this *ResponseHeaderCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *ResponseHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = ""
return
}
-func (this *ResponseHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *ResponseHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if resp != nil && resp.Header != nil {
value = resp.Header.Get(param)
} else {
diff --git a/internal/waf/checkpoints/response_status.go b/internal/waf/checkpoints/response_status.go
index fb63a79..eb9a9bd 100644
--- a/internal/waf/checkpoints/response_status.go
+++ b/internal/waf/checkpoints/response_status.go
@@ -5,7 +5,7 @@ import (
"github.com/iwind/TeaGo/maps"
)
-// ${bytesSent}
+// ResponseStatusCheckpoint ${bytesSent}
type ResponseStatusCheckpoint struct {
Checkpoint
}
@@ -14,12 +14,12 @@ func (this *ResponseStatusCheckpoint) IsRequest() bool {
return false
}
-func (this *ResponseStatusCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *ResponseStatusCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = 0
return
}
-func (this *ResponseStatusCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *ResponseStatusCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if resp != nil {
value = resp.StatusCode
}
diff --git a/internal/waf/checkpoints/sample_request.go b/internal/waf/checkpoints/sample_request.go
index e542eda..1aa1197 100644
--- a/internal/waf/checkpoints/sample_request.go
+++ b/internal/waf/checkpoints/sample_request.go
@@ -5,16 +5,16 @@ import (
"github.com/iwind/TeaGo/maps"
)
-// just a sample checkpoint, copy and change it for your new checkpoint
+// SampleRequestCheckpoint just a sample checkpoint, copy and change it for your new checkpoint
type SampleRequestCheckpoint struct {
Checkpoint
}
-func (this *SampleRequestCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *SampleRequestCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return
}
-func (this *SampleRequestCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
+func (this *SampleRequestCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
diff --git a/internal/waf/checkpoints/utils.go b/internal/waf/checkpoints/utils.go
index e3393d7..6d79e13 100644
--- a/internal/waf/checkpoints/utils.go
+++ b/internal/waf/checkpoints/utils.go
@@ -1,6 +1,6 @@
package checkpoints
-// all check points list
+// AllCheckpoints all check points list
var AllCheckpoints = []*CheckpointDefinition{
{
Name: "通用请求Header长度限制",
diff --git a/internal/waf/get302_validator.go b/internal/waf/get302_validator.go
new file mode 100644
index 0000000..fbb562b
--- /dev/null
+++ b/internal/waf/get302_validator.go
@@ -0,0 +1,52 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package waf
+
+import (
+ "github.com/TeaOSLab/EdgeNode/internal/utils"
+ "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
+ "net/http"
+ "time"
+)
+
+var get302Validator = NewGet302Validator()
+
+type Get302Validator struct {
+}
+
+func NewGet302Validator() *Get302Validator {
+ return &Get302Validator{}
+}
+
+func (this *Get302Validator) Run(request requests.Request, writer http.ResponseWriter) {
+ var info = request.WAFRaw().URL.Query().Get("info")
+ if len(info) == 0 {
+ writer.WriteHeader(http.StatusBadRequest)
+ _, _ = writer.Write([]byte("invalid request"))
+ return
+ }
+ m, err := utils.SimpleDecryptMap(info)
+ if err != nil {
+ _, _ = writer.Write([]byte("invalid request"))
+ return
+ }
+
+ var timestamp = m.GetInt64("timestamp")
+ if time.Now().Unix()-timestamp > 5 { // 超过5秒认为失效
+ writer.WriteHeader(http.StatusBadRequest)
+ _, _ = writer.Write([]byte("invalid request"))
+ return
+ }
+
+ // 加入白名单
+ life := m.GetInt64("life")
+ if life <= 0 {
+ life = 600 // 默认10分钟
+ }
+ setId := m.GetString("setId")
+ SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life)
+
+ // 返回原始URL
+ var url = m.GetString("url")
+ http.Redirect(writer, request.WAFRaw(), url, http.StatusFound)
+}
diff --git a/internal/waf/ip_list.go b/internal/waf/ip_list.go
new file mode 100644
index 0000000..5c53624
--- /dev/null
+++ b/internal/waf/ip_list.go
@@ -0,0 +1,82 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package waf
+
+import (
+ "github.com/TeaOSLab/EdgeNode/internal/utils/expires"
+ "sync"
+ "sync/atomic"
+)
+
+var SharedIPWhiteList = NewIPList()
+var SharedIPBlackLIst = NewIPList()
+
+const IPTypeAll = "*"
+
+// IPList IP列表管理
+type IPList struct {
+ expireList *expires.List
+ ipMap map[string]int64 // ip => id
+ idMap map[int64]string // id => ip
+
+ id int64
+ locker sync.RWMutex
+}
+
+// NewIPList 获取新对象
+func NewIPList() *IPList {
+ var list = &IPList{
+ ipMap: map[string]int64{},
+ idMap: map[int64]string{},
+ }
+
+ e := expires.NewList()
+ list.expireList = e
+
+ go func() {
+ e.StartGC(func(itemId int64) {
+ list.remove(itemId)
+ })
+ }()
+
+ return list
+}
+
+// Add 添加IP
+func (this *IPList) Add(ipType string, ip string, expiresAt int64) {
+ ip = ip + "@" + ipType
+
+ var id = this.nextId()
+ this.expireList.Add(id, expiresAt)
+ this.locker.Lock()
+ this.ipMap[ip] = id
+ this.idMap[id] = ip
+ this.locker.Unlock()
+}
+
+// Contains 判断是否有某个IP
+func (this *IPList) Contains(ipType string, ip string) bool {
+ ip = ip + "@" + ipType
+
+ this.locker.RLock()
+ defer this.locker.RUnlock()
+ _, ok := this.ipMap[ip]
+ return ok
+}
+
+func (this *IPList) remove(id int64) {
+ this.locker.Lock()
+ ip, ok := this.idMap[id]
+ if ok {
+ ipId, ok := this.ipMap[ip]
+ if ok && ipId == id {
+ delete(this.ipMap, ip)
+ }
+ delete(this.idMap, id)
+ }
+ this.locker.Unlock()
+}
+
+func (this *IPList) nextId() int64 {
+ return atomic.AddInt64(&this.id, 1)
+}
diff --git a/internal/waf/ip_list_test.go b/internal/waf/ip_list_test.go
new file mode 100644
index 0000000..c175f43
--- /dev/null
+++ b/internal/waf/ip_list_test.go
@@ -0,0 +1,67 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package waf
+
+import (
+ "github.com/iwind/TeaGo/assert"
+ "github.com/iwind/TeaGo/logs"
+ "runtime"
+ "strconv"
+ "testing"
+ "time"
+)
+
+func TestNewIPList(t *testing.T) {
+ list := NewIPList()
+ list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix())
+ list.Add(IPTypeAll, "127.0.0.2", time.Now().Unix()+1)
+ list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix()+2)
+ list.Add(IPTypeAll, "127.0.0.3", time.Now().Unix()+3)
+ list.Add(IPTypeAll, "127.0.0.10", time.Now().Unix()+10)
+
+ var ticker = time.NewTicker(1 * time.Second)
+ for range ticker.C {
+ t.Log("====")
+ logs.PrintAsJSON(list.ipMap, t)
+ logs.PrintAsJSON(list.idMap, t)
+ if len(list.idMap) == 0 {
+ break
+ }
+ }
+}
+
+func TestIPList_Contains(t *testing.T) {
+ a := assert.NewAssertion(t)
+
+ list := NewIPList()
+
+ for i := 0; i < 1_0000; i++ {
+ list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
+ }
+ a.IsTrue(list.Contains(IPTypeAll, "192.168.1.100"))
+ a.IsFalse(list.Contains(IPTypeAll, "192.168.2.100"))
+}
+
+func BenchmarkIPList_Add(b *testing.B) {
+ runtime.GOMAXPROCS(1)
+
+ list := NewIPList()
+ for i := 0; i < b.N; i++ {
+ list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
+ }
+ b.Log(len(list.ipMap))
+}
+
+func BenchmarkIPList_Has(b *testing.B) {
+ runtime.GOMAXPROCS(1)
+
+ list := NewIPList()
+
+ for i := 0; i < 1_0000; i++ {
+ list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
+ }
+
+ for i := 0; i < b.N; i++ {
+ list.Contains(IPTypeAll, "192.168.1.100")
+ }
+}
diff --git a/internal/waf/ip_table.go b/internal/waf/ip_table.go
deleted file mode 100644
index 7367b84..0000000
--- a/internal/waf/ip_table.go
+++ /dev/null
@@ -1,154 +0,0 @@
-package waf
-
-import (
- "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
- "github.com/iwind/TeaGo/lists"
- "github.com/iwind/TeaGo/types"
- stringutil "github.com/iwind/TeaGo/utils/string"
- "regexp"
- "strings"
- "time"
-)
-
-type IPAction = string
-
-var RegexpDigitNumber = regexp.MustCompile("^\\d+$")
-
-const (
- IPActionAccept IPAction = "accept"
- IPActionReject IPAction = "reject"
-)
-
-// ip table
-type IPTable struct {
- Id string `yaml:"id" json:"id"`
- On bool `yaml:"on" json:"on"`
- IP string `yaml:"ip" json:"ip"` // single ip, cidr, ip range, TODO support *
- Port string `yaml:"port" json:"port"` // single port, range, *
- Action IPAction `yaml:"action" json:"action"` // accept, reject
- TimeFrom int64 `yaml:"timeFrom" json:"timeFrom"` // from timestamp
- TimeTo int64 `yaml:"timeTo" json:"timeTo"` // zero means forever
- Remark string `yaml:"remark" json:"remark"`
-
- // port
- minPort int
- maxPort int
-
- minPortWildcard bool
- maxPortWildcard bool
-
- ports []int
-
- // ip
- ipRange *shared.IPRangeConfig
-}
-
-func NewIPTable() *IPTable {
- return &IPTable{
- On: true,
- Id: stringutil.Rand(16),
- }
-}
-
-func (this *IPTable) Init() error {
- // parse port
- if RegexpDigitNumber.MatchString(this.Port) {
- this.minPort = types.Int(this.Port)
- this.maxPort = types.Int(this.Port)
- } else if regexp.MustCompile(`[:-]`).MatchString(this.Port) {
- pieces := regexp.MustCompile(`[:-]`).Split(this.Port, 2)
- if pieces[0] == "*" {
- this.minPortWildcard = true
- } else {
- this.minPort = types.Int(pieces[0])
- }
- if pieces[1] == "*" {
- this.maxPortWildcard = true
- } else {
- this.maxPort = types.Int(pieces[1])
- }
- } else if strings.Contains(this.Port, ",") {
- pieces := strings.Split(this.Port, ",")
- for _, piece := range pieces {
- piece = strings.TrimSpace(piece)
- if len(piece) > 0 {
- this.ports = append(this.ports, types.Int(piece))
- }
- }
- } else if this.Port == "*" {
- this.minPortWildcard = true
- this.maxPortWildcard = true
- }
-
- // parse ip
- if len(this.IP) > 0 {
- ipRange, err := shared.ParseIPRange(this.IP)
- if err != nil {
- return err
- }
- this.ipRange = ipRange
- }
-
- return nil
-}
-
-// check ip
-func (this *IPTable) Match(ip string, port int) (isMatched bool) {
- if !this.On {
- return
- }
-
- now := time.Now().Unix()
- if this.TimeFrom > 0 && now < this.TimeFrom {
- return
- }
- if this.TimeTo > 0 && now > this.TimeTo {
- return
- }
-
- if !this.matchPort(port) {
- return
- }
-
- if !this.matchIP(ip) {
- return
- }
-
- return true
-}
-
-func (this *IPTable) matchPort(port int) bool {
- if port == 0 {
- return false
- }
- if this.minPortWildcard {
- if this.maxPortWildcard {
- return true
- }
- if this.maxPort >= port {
- return true
- }
- }
- if this.maxPortWildcard {
- if this.minPortWildcard {
- return true
- }
- if this.minPort <= port {
- return true
- }
- }
- if (this.minPort > 0 || this.maxPort > 0) && this.minPort <= port && this.maxPort >= port {
- return true
- }
- if len(this.ports) > 0 {
- return lists.ContainsInt(this.ports, port)
- }
- return false
-}
-
-func (this *IPTable) matchIP(ip string) bool {
- if this.ipRange == nil {
- return false
- }
- return this.ipRange.Contains(ip)
-}
diff --git a/internal/waf/ip_table_test.go b/internal/waf/ip_table_test.go
deleted file mode 100644
index 5fbd3e8..0000000
--- a/internal/waf/ip_table_test.go
+++ /dev/null
@@ -1,142 +0,0 @@
-package waf
-
-import (
- "github.com/iwind/TeaGo/assert"
- "testing"
- "time"
-)
-
-func TestIPTable_MatchIP(t *testing.T) {
- a := assert.NewAssertion(t)
-
- {
- table := NewIPTable()
- err := table.Init()
- if err != nil {
- t.Fatal(err)
- }
- a.IsFalse(table.Match("192.168.1.100", 8080))
- }
-
- {
- table := NewIPTable()
- table.IP = "*"
- table.Port = "8080"
- err := table.Init()
- if err != nil {
- t.Fatal(err)
- }
- t.Log("port:", table.minPort, table.maxPort)
- a.IsTrue(table.Match("192.168.1.100", 8080))
- a.IsFalse(table.Match("192.168.1.100", 8081))
- }
-
- {
- table := NewIPTable()
- table.IP = "*"
- table.Port = "8080-8082"
- err := table.Init()
- if err != nil {
- t.Fatal(err)
- }
- t.Log("port:", table.minPort, table.maxPort)
- a.IsTrue(table.Match("192.168.1.100", 8080))
- a.IsTrue(table.Match("192.168.1.100", 8081))
- a.IsFalse(table.Match("192.168.1.100", 8083))
- }
-
- {
- table := NewIPTable()
- table.IP = "*"
- table.Port = "*-8082"
- err := table.Init()
- if err != nil {
- t.Fatal(err)
- }
- t.Log("port:", table.minPort, table.maxPort)
- a.IsTrue(table.Match("192.168.1.100", 8079))
- a.IsTrue(table.Match("192.168.1.100", 8080))
- a.IsTrue(table.Match("192.168.1.100", 8081))
- a.IsFalse(table.Match("192.168.1.100", 8083))
- }
-
- {
- table := NewIPTable()
- table.IP = "*"
- table.Port = "8080-*"
- err := table.Init()
- if err != nil {
- t.Fatal(err)
- }
- t.Log("port:", table.minPort, table.maxPort)
- a.IsFalse(table.Match("192.168.1.100", 8079))
- a.IsTrue(table.Match("192.168.1.100", 8080))
- a.IsTrue(table.Match("192.168.1.100", 8081))
- a.IsTrue(table.Match("192.168.1.100", 8083))
- }
-
- {
- table := NewIPTable()
- table.IP = "*"
- table.Port = "*"
- err := table.Init()
- if err != nil {
- t.Fatal(err)
- }
- t.Log("port:", table.minPort, table.maxPort)
- a.IsTrue(table.Match("192.168.1.100", 8079))
- a.IsTrue(table.Match("192.168.1.100", 8080))
- a.IsTrue(table.Match("192.168.1.100", 8081))
- a.IsTrue(table.Match("192.168.1.100", 8083))
- }
-
- {
- table := NewIPTable()
- table.IP = "192.168.1.100"
- table.Port = "*"
- err := table.Init()
- if err != nil {
- t.Fatal(err)
- }
- t.Log("port:", table.minPort, table.maxPort)
- a.IsTrue(table.Match("192.168.1.100", 8080))
- }
-
- {
- table := NewIPTable()
- table.IP = "192.168.1.99-192.168.1.101"
- table.Port = "*"
- err := table.Init()
- if err != nil {
- t.Fatal(err)
- }
- t.Log("port:", table.minPort, table.maxPort)
- a.IsTrue(table.Match("192.168.1.100", 8080))
- }
-
- {
- table := NewIPTable()
- table.IP = "192.168.1.99/24"
- table.Port = "*"
- err := table.Init()
- if err != nil {
- t.Fatal(err)
- }
- t.Log("ip:", table.ipRange)
- a.IsTrue(table.Match("192.168.1.100", 8080))
- a.IsFalse(table.Match("192.168.2.100", 8080))
- }
-
- {
- table := NewIPTable()
- table.IP = "192.168.1.99/24"
- table.TimeTo = time.Now().Unix() - 10
- table.Port = "*"
- err := table.Init()
- if err != nil {
- t.Fatal(err)
- }
- a.IsFalse(table.Match("192.168.1.100", 8080))
- a.IsFalse(table.Match("192.168.2.100", 8080))
- }
-}
diff --git a/internal/waf/requests/request.go b/internal/waf/requests/request.go
index 28c29b9..b2f35a9 100644
--- a/internal/waf/requests/request.go
+++ b/internal/waf/requests/request.go
@@ -1,39 +1,28 @@
package requests
import (
- "bytes"
- "io"
- "io/ioutil"
"net/http"
)
-type Request struct {
- *http.Request
- BodyData []byte
-}
+type Request interface {
+ // WAFRaw 原始请求
+ WAFRaw() *http.Request
-func NewRequest(raw *http.Request) *Request {
- return &Request{
- Request: raw,
- }
-}
+ // WAFRemoteIP 客户端IP
+ WAFRemoteIP() string
-func (this *Request) Raw() *http.Request {
- return this.Request
-}
+ // WAFGetCacheBody 获取缓存中的Body
+ WAFGetCacheBody() []byte
-func (this *Request) ReadBody(max int64) (data []byte, err error) {
- if this.Request.ContentLength > 0 {
- data, err = ioutil.ReadAll(io.LimitReader(this.Request.Body, max))
- }
- return
-}
+ // WAFSetCacheBody 设置Body
+ WAFSetCacheBody(body []byte)
-func (this *Request) RestoreBody(data []byte) {
- if len(data) > 0 {
- rawReader := bytes.NewBuffer(data)
- buf := make([]byte, 1024)
- _, _ = io.CopyBuffer(rawReader, this.Request.Body, buf)
- this.Request.Body = ioutil.NopCloser(rawReader)
- }
+ // WAFReadBody 读取Body
+ WAFReadBody(max int64) (data []byte, err error)
+
+ // WAFRestoreBody 恢复Body
+ WAFRestoreBody(data []byte)
+
+ // WAFServerId 服务ID
+ WAFServerId() int64
}
diff --git a/internal/waf/requests/test_request.go b/internal/waf/requests/test_request.go
new file mode 100644
index 0000000..4682d37
--- /dev/null
+++ b/internal/waf/requests/test_request.go
@@ -0,0 +1,67 @@
+// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
+
+package requests
+
+import (
+ "bytes"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/http"
+)
+
+type TestRequest struct {
+ req *http.Request
+ BodyData []byte
+}
+
+func NewTestRequest(raw *http.Request) *TestRequest {
+ return &TestRequest{
+ req: raw,
+ }
+}
+
+func (this *TestRequest) WAFSetCacheBody(bodyData []byte) {
+ this.BodyData = bodyData
+}
+
+func (this *TestRequest) WAFGetCacheBody() []byte {
+ return this.BodyData
+}
+
+func (this *TestRequest) WAFRaw() *http.Request {
+ return this.req
+}
+
+func (this *TestRequest) WAFRemoteAddr() string {
+ return this.req.RemoteAddr
+}
+
+func (this *TestRequest) WAFRemoteIP() string {
+ host, _, err := net.SplitHostPort(this.req.RemoteAddr)
+ if err != nil {
+ return this.req.RemoteAddr
+ } else {
+ return host
+ }
+}
+
+func (this *TestRequest) WAFReadBody(max int64) (data []byte, err error) {
+ if this.req.ContentLength > 0 {
+ data, err = ioutil.ReadAll(io.LimitReader(this.req.Body, max))
+ }
+ return
+}
+
+func (this *TestRequest) WAFRestoreBody(data []byte) {
+ if len(data) > 0 {
+ rawReader := bytes.NewBuffer(data)
+ buf := make([]byte, 1024)
+ _, _ = io.CopyBuffer(rawReader, this.req.Body, buf)
+ this.req.Body = ioutil.NopCloser(rawReader)
+ }
+}
+
+func (this *TestRequest) WAFServerId() int64 {
+ return 0
+}
diff --git a/internal/waf/rule.go b/internal/waf/rule.go
index 7455df1..a9ad080 100644
--- a/internal/waf/rule.go
+++ b/internal/waf/rule.go
@@ -183,7 +183,7 @@ func (this *Rule) Init() error {
return err
}
-func (this *Rule) MatchRequest(req *requests.Request) (b bool, err error) {
+func (this *Rule) MatchRequest(req requests.Request) (b bool, err error) {
if this.singleCheckpoint != nil {
value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions)
if err != nil {
@@ -233,7 +233,7 @@ func (this *Rule) MatchRequest(req *requests.Request) (b bool, err error) {
return this.Test(value), nil
}
-func (this *Rule) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, err error) {
+func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (b bool, err error) {
if this.singleCheckpoint != nil {
// if is request param
if this.singleCheckpoint.IsRequest() {
diff --git a/internal/waf/rule_group.go b/internal/waf/rule_group.go
index 1c7bea5..19577c4 100644
--- a/internal/waf/rule_group.go
+++ b/internal/waf/rule_group.go
@@ -23,12 +23,12 @@ func NewRuleGroup() *RuleGroup {
}
}
-func (this *RuleGroup) Init() error {
+func (this *RuleGroup) Init(waf *WAF) error {
this.hasRuleSets = len(this.RuleSets) > 0
if this.hasRuleSets {
for _, set := range this.RuleSets {
- err := set.Init()
+ err := set.Init(waf)
if err != nil {
return err
}
@@ -79,7 +79,7 @@ func (this *RuleGroup) RemoveRuleSet(id string) {
this.RuleSets = result
}
-func (this *RuleGroup) MatchRequest(req *requests.Request) (b bool, set *RuleSet, err error) {
+func (this *RuleGroup) MatchRequest(req requests.Request) (b bool, set *RuleSet, err error) {
if !this.hasRuleSets {
return
}
@@ -98,7 +98,7 @@ func (this *RuleGroup) MatchRequest(req *requests.Request) (b bool, set *RuleSet
return
}
-func (this *RuleGroup) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, set *RuleSet, err error) {
+func (this *RuleGroup) MatchResponse(req requests.Request, resp *requests.Response) (b bool, set *RuleSet, err error) {
if !this.hasRuleSets {
return
}
diff --git a/internal/waf/rule_set.go b/internal/waf/rule_set.go
index 431a504..38b6e15 100644
--- a/internal/waf/rule_set.go
+++ b/internal/waf/rule_set.go
@@ -1,9 +1,13 @@
package waf
import (
+ "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
+ "github.com/iwind/TeaGo/lists"
+ "github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/utils/string"
+ "net/http"
)
type RuleConnector = string
@@ -14,16 +18,17 @@ const (
)
type RuleSet struct {
- Id string `yaml:"id" json:"id"`
- Code string `yaml:"code" json:"code"`
- IsOn bool `yaml:"isOn" json:"isOn"`
- Name string `yaml:"name" json:"name"`
- Description string `yaml:"description" json:"description"`
- Rules []*Rule `yaml:"rules" json:"rules"`
- Connector RuleConnector `yaml:"connector" json:"connector"` // rules connector
+ Id string `yaml:"id" json:"id"`
+ Code string `yaml:"code" json:"code"`
+ IsOn bool `yaml:"isOn" json:"isOn"`
+ Name string `yaml:"name" json:"name"`
+ Description string `yaml:"description" json:"description"`
+ Rules []*Rule `yaml:"rules" json:"rules"`
+ Connector RuleConnector `yaml:"connector" json:"connector"` // rules connector
+ Actions []*ActionConfig `yaml:"actions" json:"actions"`
- Action ActionString `yaml:"action" json:"action"`
- ActionOptions maps.Map `yaml:"actionOptions" json:"actionOptions"` // TODO TO BE IMPLEMENTED
+ actionCodes []string
+ actionInstances []ActionInterface
hasRules bool
}
@@ -35,7 +40,7 @@ func NewRuleSet() *RuleSet {
}
}
-func (this *RuleSet) Init() error {
+func (this *RuleSet) Init(waf *WAF) error {
this.hasRules = len(this.Rules) > 0
if this.hasRules {
for _, rule := range this.Rules {
@@ -45,6 +50,31 @@ func (this *RuleSet) Init() error {
}
}
}
+
+ // action codes
+ var actionCodes = []string{}
+ for _, action := range this.Actions {
+ if !lists.ContainsString(actionCodes, action.Code) {
+ actionCodes = append(actionCodes, action.Code)
+ }
+ }
+ this.actionCodes = actionCodes
+
+ // action instances
+ this.actionInstances = []ActionInterface{}
+ for _, action := range this.Actions {
+ instance := FindActionInstance(action.Code, action.Options)
+ if instance == nil {
+ remotelogs.Error("WAF_RULE_SET", "can not find instance for action '"+action.Code+"'")
+ } else {
+ this.actionInstances = append(this.actionInstances, instance)
+ }
+ err := instance.Init(waf)
+ if err != nil {
+ remotelogs.Error("WAF_RULE_SET", "init action '"+action.Code+"' failed: "+err.Error())
+ }
+ }
+
return nil
}
@@ -52,7 +82,75 @@ func (this *RuleSet) AddRule(rule ...*Rule) {
this.Rules = append(this.Rules, rule...)
}
-func (this *RuleSet) MatchRequest(req *requests.Request) (b bool, err error) {
+// AddAction 添加动作
+func (this *RuleSet) AddAction(code string, options maps.Map) {
+ if options == nil {
+ options = maps.Map{}
+ }
+ this.Actions = append(this.Actions, &ActionConfig{
+ Code: code,
+ Options: options,
+ })
+}
+
+// HasSpecialActions 除了Allow之外是否还有别的动作
+func (this *RuleSet) HasSpecialActions() bool {
+ for _, action := range this.Actions {
+ if action.Code != ActionAllow {
+ return true
+ }
+ }
+ return false
+}
+
+// HasAttackActions 检查是否含有攻击防御动作
+func (this *RuleSet) HasAttackActions() bool {
+ for _, action := range this.actionInstances {
+ if action.IsAttack() {
+ return true
+ }
+ }
+ return false
+}
+
+func (this *RuleSet) ActionCodes() []string {
+ return this.actionCodes
+}
+
+func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Request, writer http.ResponseWriter) bool {
+ // 先执行allow
+ for _, instance := range this.actionInstances {
+ if !instance.WillChange() {
+ if waf.onActionCallback != nil {
+ goNext := waf.onActionCallback(instance)
+ if !goNext {
+ return false
+ }
+ }
+ logs.Printf("perform1: %#v", instance) // TODO
+ instance.Perform(waf, group, this, req, writer)
+ }
+ }
+
+ // 再执行block|verify
+ for _, instance := range this.actionInstances {
+ // 只执行第一个可能改变请求的动作,其余的都会被忽略
+ if instance.WillChange() {
+ if waf.onActionCallback != nil {
+ goNext := waf.onActionCallback(instance)
+ if !goNext {
+ return false
+ }
+ }
+ logs.Printf("perform2: %#v", instance) // TODO
+ return instance.Perform(waf, group, this, req, writer)
+ }
+ }
+
+ return true
+}
+
+func (this *RuleSet) MatchRequest(req requests.Request) (b bool, err error) {
if !this.hasRules {
return false, nil
}
@@ -93,7 +191,7 @@ func (this *RuleSet) MatchRequest(req *requests.Request) (b bool, err error) {
return
}
-func (this *RuleSet) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, err error) {
+func (this *RuleSet) MatchResponse(req requests.Request, resp *requests.Response) (b bool, err error) {
if !this.hasRules {
return false, nil
}
diff --git a/internal/waf/rule_set_test.go b/internal/waf/rule_set_test.go
index 9d39377..5317643 100644
--- a/internal/waf/rule_set_test.go
+++ b/internal/waf/rule_set_test.go
@@ -28,7 +28,7 @@ func TestRuleSet_MatchRequest(t *testing.T) {
},
}
- err := set.Init()
+ err := set.Init(nil)
if err != nil {
t.Fatal(err)
}
@@ -37,7 +37,7 @@ func TestRuleSet_MatchRequest(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- req := requests.NewRequest(rawReq)
+ req := requests.NewTestRequest(rawReq)
t.Log(set.MatchRequest(req))
}
@@ -60,7 +60,7 @@ func TestRuleSet_MatchRequest2(t *testing.T) {
},
}
- err := set.Init()
+ err := set.Init(nil)
if err != nil {
t.Fatal(err)
}
@@ -69,7 +69,7 @@ func TestRuleSet_MatchRequest2(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- req := requests.NewRequest(rawReq)
+ req := requests.NewTestRequest(rawReq)
a.IsTrue(set.MatchRequest(req))
}
@@ -102,7 +102,7 @@ func BenchmarkRuleSet_MatchRequest(b *testing.B) {
},
}
- err := set.Init()
+ err := set.Init(nil)
if err != nil {
b.Fatal(err)
}
@@ -111,7 +111,7 @@ func BenchmarkRuleSet_MatchRequest(b *testing.B) {
if err != nil {
b.Fatal(err)
}
- req := requests.NewRequest(rawReq)
+ req := requests.NewTestRequest(rawReq)
for i := 0; i < b.N; i++ {
_, _ = set.MatchRequest(req)
}
@@ -132,7 +132,7 @@ func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) {
},
}
- err := set.Init()
+ err := set.Init(nil)
if err != nil {
b.Fatal(err)
}
@@ -141,7 +141,7 @@ func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) {
if err != nil {
b.Fatal(err)
}
- req := requests.NewRequest(rawReq)
+ req := requests.NewTestRequest(rawReq)
for i := 0; i < b.N; i++ {
_, _ = set.MatchRequest(req)
}
diff --git a/internal/waf/rule_test.go b/internal/waf/rule_test.go
index 6a7731c..e9597b3 100644
--- a/internal/waf/rule_test.go
+++ b/internal/waf/rule_test.go
@@ -25,7 +25,7 @@ func TestRule_Init_Single(t *testing.T) {
t.Fatal(err)
}
- req := requests.NewRequest(rawReq)
+ req := requests.NewTestRequest(rawReq)
t.Log(rule.MatchRequest(req))
}
@@ -44,7 +44,7 @@ func TestRule_Init_Composite(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- req := requests.NewRequest(rawReq)
+ req := requests.NewTestRequest(rawReq)
t.Log(rule.MatchRequest(req))
}
diff --git a/internal/waf/template.go b/internal/waf/template.go
index 83ffe83..5bd06c3 100644
--- a/internal/waf/template.go
+++ b/internal/waf/template.go
@@ -20,7 +20,7 @@ func Template() *WAF {
set.Name = "Javascript事件"
set.Code = "1001"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
@@ -36,7 +36,7 @@ func Template() *WAF {
set.Name = "Javascript函数"
set.Code = "1002"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
@@ -52,7 +52,7 @@ func Template() *WAF {
set.Name = "HTML标签"
set.Code = "1003"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
@@ -80,7 +80,7 @@ func Template() *WAF {
set.Name = "上传文件扩展名"
set.Code = "2001"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestUpload.ext}",
Operator: RuleOperatorMatch,
@@ -108,7 +108,7 @@ func Template() *WAF {
set.Name = "Web Shell"
set.Code = "3001"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
@@ -135,7 +135,7 @@ func Template() *WAF {
set.Name = "命令注入"
set.Code = "4001"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
@@ -169,7 +169,7 @@ func Template() *WAF {
set.Name = "路径穿越"
set.Code = "5001"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
@@ -197,7 +197,7 @@ func Template() *WAF {
set.Name = "特殊目录"
set.Code = "6001"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestPath}",
Operator: RuleOperatorMatch,
@@ -225,7 +225,7 @@ func Template() *WAF {
set.Name = "Union SQL Injection"
set.Code = "7001"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
@@ -243,7 +243,7 @@ func Template() *WAF {
set.Name = "SQL注释"
set.Code = "7002"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
@@ -261,7 +261,7 @@ func Template() *WAF {
set.Name = "SQL条件"
set.Code = "7003"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
@@ -297,7 +297,7 @@ func Template() *WAF {
set.Name = "SQL函数"
set.Code = "7004"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
@@ -315,7 +315,7 @@ func Template() *WAF {
set.Name = "SQL附加语句"
set.Code = "7005"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
@@ -345,7 +345,7 @@ func Template() *WAF {
set.Name = "常见网络爬虫"
set.Code = "20001"
set.Connector = RuleConnectorOr
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${userAgent}",
@@ -376,7 +376,7 @@ func Template() *WAF {
set.Description = "限制单IP在一定时间内的请求数"
set.Code = "8001"
set.Connector = RuleConnectorAnd
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${cc.requests}",
Operator: RuleOperatorGt,
diff --git a/internal/waf/template_test.go b/internal/waf/template_test.go
index c2074b6..e10d691 100644
--- a/internal/waf/template_test.go
+++ b/internal/waf/template_test.go
@@ -2,6 +2,7 @@ package waf
import (
"bytes"
+ "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/logs"
@@ -22,8 +23,8 @@ func Test_Template(t *testing.T) {
t.Fatal(err)
}
- template.OnAction(func(action ActionString) (goNext bool) {
- return action != ActionBlock
+ template.OnAction(func(action ActionInterface) (goNext bool) {
+ return action.Code() != ActionBlock
})
testTemplate1001(a, t, template)
@@ -40,7 +41,7 @@ func Test_Template(t *testing.T) {
func Test_Template2(t *testing.T) {
reader := bytes.NewReader([]byte(strings.Repeat("HELLO", 1024)))
- req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=123", reader)
+ req, err := http.NewRequest(http.MethodGet, "https://example.com/index.php?id=123", reader)
if err != nil {
t.Fatal(err)
}
@@ -52,7 +53,7 @@ func Test_Template2(t *testing.T) {
}
now := time.Now()
- goNext, _, set, err := waf.MatchRequest(req, nil)
+ goNext, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}
@@ -80,7 +81,7 @@ func BenchmarkTemplate(b *testing.B) {
b.Fatal(err)
}
- _, _, _, _ = waf.MatchRequest(req, nil)
+ _, _, _, _ = waf.MatchRequest(requests.NewTestRequest(req), nil)
}
}
@@ -89,7 +90,7 @@ func testTemplate1001(a *assert.Assertion, t *testing.T, template *WAF) {
if err != nil {
t.Fatal(err)
}
- _, _, result, err := template.MatchRequest(req, nil)
+ _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}
@@ -104,7 +105,7 @@ func testTemplate1002(a *assert.Assertion, t *testing.T, template *WAF) {
if err != nil {
t.Fatal(err)
}
- _, _, result, err := template.MatchRequest(req, nil)
+ _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}
@@ -119,7 +120,7 @@ func testTemplate1003(a *assert.Assertion, t *testing.T, template *WAF) {
if err != nil {
t.Fatal(err)
}
- _, _, result, err := template.MatchRequest(req, nil)
+ _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}
@@ -185,7 +186,7 @@ func testTemplate2001(a *assert.Assertion, t *testing.T, template *WAF) {
req.Header.Add("Content-Type", writer.FormDataContentType())
- _, _, result, err := template.MatchRequest(req, nil)
+ _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}
@@ -200,7 +201,7 @@ func testTemplate3001(a *assert.Assertion, t *testing.T, template *WAF) {
if err != nil {
t.Fatal(err)
}
- _, _, result, err := template.MatchRequest(req, nil)
+ _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}
@@ -215,7 +216,7 @@ func testTemplate4001(a *assert.Assertion, t *testing.T, template *WAF) {
if err != nil {
t.Fatal(err)
}
- _, _, result, err := template.MatchRequest(req, nil)
+ _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}
@@ -231,7 +232,7 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) {
if err != nil {
t.Fatal(err)
}
- _, _, result, err := template.MatchRequest(req, nil)
+ _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}
@@ -246,7 +247,7 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) {
if err != nil {
t.Fatal(err)
}
- _, _, result, err := template.MatchRequest(req, nil)
+ _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}
@@ -263,7 +264,7 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) {
if err != nil {
t.Fatal(err)
}
- _, _, result, err := template.MatchRequest(req, nil)
+ _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}
@@ -278,7 +279,7 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) {
if err != nil {
t.Fatal(err)
}
- _, _, result, err := template.MatchRequest(req, nil)
+ _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}
@@ -301,7 +302,7 @@ func testTemplate7001(a *assert.Assertion, t *testing.T, template *WAF) {
if err != nil {
t.Fatal(err)
}
- _, _, result, err := template.MatchRequest(req, nil)
+ _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}
@@ -338,7 +339,7 @@ func testTemplate20001(a *assert.Assertion, t *testing.T, template *WAF) {
t.Fatal(err)
}
req.Header.Set("User-Agent", bot)
- _, _, result, err := template.MatchRequest(req, nil)
+ _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}
diff --git a/internal/waf/waf.go b/internal/waf/waf.go
index 7d6fc8f..08c762c 100644
--- a/internal/waf/waf.go
+++ b/internal/waf/waf.go
@@ -22,13 +22,11 @@ type WAF struct {
Outbound []*RuleGroup `yaml:"outbound" json:"outbound"`
CreatedVersion string `yaml:"createdVersion" json:"createdVersion"`
- ActionBlock *BlockAction `yaml:"actionBlock" json:"actionBlock"` // action block config
-
- IPTables []*IPTable `yaml:"ipTables" json:"ipTables"` // IP table list
+ DefaultBlockAction *BlockAction
hasInboundRules bool
hasOutboundRules bool
- onActionCallback func(action ActionString) (goNext bool)
+ onActionCallback func(action ActionInterface) (goNext bool)
checkpointsMap map[string]checkpoints.CheckpointInterface // prefix => checkpoint
}
@@ -87,7 +85,7 @@ func (this *WAF) Init() error {
}
}
- err := group.Init()
+ err := group.Init(this)
if err != nil {
return err
}
@@ -103,7 +101,7 @@ func (this *WAF) Init() error {
}
}
- err := group.Init()
+ err := group.Init(this)
if err != nil {
return err
}
@@ -241,19 +239,24 @@ func (this *WAF) MoveOutboundRuleGroup(fromIndex int, toIndex int) {
this.Outbound = result
}
-func (this *WAF) MatchRequest(rawReq *http.Request, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) {
+func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) {
if !this.hasInboundRules {
return true, nil, nil, nil
}
- req := requests.NewRequest(rawReq)
-
// validate captcha
- if rawReq.URL.Path == "/WAFCAPTCHA" {
+ var rawPath = req.WAFRaw().URL.Path
+ if rawPath == CaptchaPath {
captchaValidator.Run(req, writer)
return
}
+ // Get 302验证
+ if rawPath == Get302Path {
+ get302Validator.Run(req, writer)
+ return
+ }
+
// match rules
for _, group := range this.Inbound {
if !group.IsOn {
@@ -264,31 +267,17 @@ func (this *WAF) MatchRequest(rawReq *http.Request, writer http.ResponseWriter)
return true, nil, nil, err
}
if b {
- if this.onActionCallback == nil {
- if set.Action == ActionBlock && this.ActionBlock != nil {
- return this.ActionBlock.Perform(this, req, writer), group, set, nil
- } else {
- actionObject := FindActionInstance(set.Action, set.ActionOptions)
- if actionObject == nil {
- return true, group, set, errors.New("no action called '" + set.Action + "'")
- }
- goNext := actionObject.Perform(this, req, writer)
- return goNext, group, set, nil
- }
- } else {
- goNext = this.onActionCallback(set.Action)
- }
+ goNext := set.PerformActions(this, group, req, writer)
return goNext, group, set, nil
}
}
return true, nil, nil, nil
}
-func (this *WAF) MatchResponse(rawReq *http.Request, rawResp *http.Response, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) {
+func (this *WAF) MatchResponse(req requests.Request, rawResp *http.Response, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) {
if !this.hasOutboundRules {
return true, nil, nil, nil
}
- req := requests.NewRequest(rawReq)
resp := requests.NewResponse(rawResp)
for _, group := range this.Outbound {
if !group.IsOn {
@@ -299,27 +288,14 @@ func (this *WAF) MatchResponse(rawReq *http.Request, rawResp *http.Response, wri
return true, nil, nil, err
}
if b {
- if this.onActionCallback == nil {
- if set.Action == ActionBlock && this.ActionBlock != nil {
- return this.ActionBlock.Perform(this, req, writer), group, set, nil
- } else {
- actionObject := FindActionInstance(set.Action, set.ActionOptions)
- if actionObject == nil {
- return true, group, set, errors.New("no action called '" + set.Action + "'")
- }
- goNext := actionObject.Perform(this, req, writer)
- return goNext, group, set, nil
- }
- } else {
- goNext = this.onActionCallback(set.Action)
- }
+ goNext := set.PerformActions(this, group, req, writer)
return goNext, group, set, nil
}
}
return true, nil, nil, nil
}
-// save to file path
+// Save save to file path
func (this *WAF) Save(path string) error {
if len(path) == 0 {
return errors.New("path should not be empty")
@@ -378,7 +354,7 @@ func (this *WAF) CountOutboundRuleSets() int {
return count
}
-func (this *WAF) OnAction(onActionCallback func(action ActionString) (goNext bool)) {
+func (this *WAF) OnAction(onActionCallback func(action ActionInterface) (goNext bool)) {
this.onActionCallback = onActionCallback
}
@@ -390,21 +366,21 @@ func (this *WAF) FindCheckpointInstance(prefix string) checkpoints.CheckpointInt
return nil
}
-// start
+// Start start
func (this *WAF) Start() {
for _, checkpoint := range this.checkpointsMap {
checkpoint.Start()
}
}
-// call stop() when the waf was deleted
+// Stop call stop() when the waf was deleted
func (this *WAF) Stop() {
for _, checkpoint := range this.checkpointsMap {
checkpoint.Stop()
}
}
-// merge with template
+// MergeTemplate merge with template
func (this *WAF) MergeTemplate() (changedItems []string) {
changedItems = []string{}
diff --git a/internal/waf/waf_test.go b/internal/waf/waf_test.go
index acca3d1..5395eb8 100644
--- a/internal/waf/waf_test.go
+++ b/internal/waf/waf_test.go
@@ -1,6 +1,7 @@
package waf
import (
+ "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/assert"
"net/http"
"testing"
@@ -24,7 +25,7 @@ func TestWAF_MatchRequest(t *testing.T) {
Value: "20",
},
}
- set.Action = ActionBlock
+ set.AddAction(ActionBlock, nil)
group := NewRuleGroup()
group.AddRuleSet(set)
@@ -37,15 +38,15 @@ func TestWAF_MatchRequest(t *testing.T) {
t.Fatal(err)
}
- waf.OnAction(func(action ActionString) (goNext bool) {
- return action != ActionBlock
+ waf.OnAction(func(action ActionInterface) (goNext bool) {
+ return action.Code() != ActionBlock
})
req, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil)
if err != nil {
t.Fatal(err)
}
- goNext, _, set, err := waf.MatchRequest(req, nil)
+ goNext, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil)
if err != nil {
t.Fatal(err)
}