mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-03 06:40:25 +08:00
WAF增加多个动作
This commit is contained in:
@@ -6,7 +6,7 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// IP名单
|
||||
// IPList IP名单
|
||||
type IPList struct {
|
||||
itemsMap map[int64]*IPItem // id => item
|
||||
ipMap map[uint64][]int64 // ip => itemIds
|
||||
@@ -96,7 +96,7 @@ func (this *IPList) Delete(itemId int64) {
|
||||
this.isAll = len(this.ipMap[0]) > 0
|
||||
}
|
||||
|
||||
// 判断是否包含某个IP
|
||||
// Contains 判断是否包含某个IP
|
||||
func (this *IPList) Contains(ip uint64) bool {
|
||||
this.locker.RLock()
|
||||
if this.isAll {
|
||||
@@ -109,7 +109,7 @@ func (this *IPList) Contains(ip uint64) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// 是否包含一组IP
|
||||
// ContainsIPStrings 是否包含一组IP
|
||||
func (this *IPList) ContainsIPStrings(ipStrings []string) (found bool, item *IPItem) {
|
||||
if len(ipStrings) == 0 {
|
||||
return
|
||||
|
||||
@@ -68,12 +68,15 @@ type HTTPRequest struct {
|
||||
cacheKey string // 缓存使用的Key
|
||||
isCached bool // 是否已经被缓存
|
||||
isAttack bool // 是否是攻击请求
|
||||
bodyData []byte // 读取的Body内容
|
||||
|
||||
// WAF相关
|
||||
firewallPolicyId int64
|
||||
firewallRuleGroupId int64
|
||||
firewallRuleSetId int64
|
||||
firewallRuleId int64
|
||||
firewallActions []string
|
||||
tags []string
|
||||
|
||||
logAttrs map[string]string
|
||||
|
||||
@@ -1197,5 +1200,10 @@ func (this *HTTPRequest) canIgnore(err error) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// HTTP内部错误
|
||||
if strings.HasPrefix(err.Error(), "http:") || strings.HasPrefix(err.Error(), "http2:") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -128,6 +128,8 @@ func (this *HTTPRequest) log() {
|
||||
FirewallRuleGroupId: this.firewallRuleGroupId,
|
||||
FirewallRuleSetId: this.firewallRuleSetId,
|
||||
FirewallRuleId: this.firewallRuleId,
|
||||
FirewallActions: this.firewallActions,
|
||||
Tags: this.tags,
|
||||
|
||||
Attrs: this.logAttrs,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
@@ -8,6 +9,8 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -152,27 +155,36 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
goNext, ruleGroup, ruleSet, err := w.MatchRequest(this.RawReq, this.writer)
|
||||
|
||||
w.OnAction(func(action waf.ActionInterface) (goNext bool) {
|
||||
switch action.Code() {
|
||||
case waf.ActionTag:
|
||||
this.tags = action.(*waf.TagAction).Tags
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
goNext, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if ruleSet != nil {
|
||||
if ruleSet.Action != waf.ActionAllow {
|
||||
if ruleSet.HasSpecialActions() {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
|
||||
this.firewallRuleSetId = types.Int64(ruleSet.Id)
|
||||
|
||||
if ruleSet.Action == waf.ActionBlock {
|
||||
if ruleSet.HasAttackActions() {
|
||||
this.isAttack = true
|
||||
}
|
||||
|
||||
// 添加统计
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Action)
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions)
|
||||
}
|
||||
|
||||
this.logAttrs["waf.action"] = ruleSet.Action
|
||||
this.firewallActions = ruleSet.ActionCodes()
|
||||
}
|
||||
|
||||
return !goNext, false
|
||||
@@ -208,28 +220,79 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
|
||||
return
|
||||
}
|
||||
|
||||
goNext, ruleGroup, ruleSet, err := w.MatchResponse(this.RawReq, resp, this.writer)
|
||||
w.OnAction(func(action waf.ActionInterface) (goNext bool) {
|
||||
switch action.Code() {
|
||||
case waf.ActionTag:
|
||||
this.tags = action.(*waf.TagAction).Tags
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
goNext, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer)
|
||||
if err != nil {
|
||||
remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if ruleSet != nil {
|
||||
if ruleSet.Action != waf.ActionAllow {
|
||||
if ruleSet.HasSpecialActions() {
|
||||
this.firewallPolicyId = firewallPolicy.Id
|
||||
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
|
||||
this.firewallRuleSetId = types.Int64(ruleSet.Id)
|
||||
|
||||
if ruleSet.Action == waf.ActionBlock {
|
||||
if ruleSet.HasAttackActions() {
|
||||
this.isAttack = true
|
||||
}
|
||||
|
||||
// 添加统计
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Action)
|
||||
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions)
|
||||
}
|
||||
|
||||
this.logAttrs["waf.action"] = ruleSet.Action
|
||||
this.firewallActions = ruleSet.ActionCodes()
|
||||
}
|
||||
|
||||
return !goNext
|
||||
}
|
||||
|
||||
// WAFRaw 原始请求
|
||||
func (this *HTTPRequest) WAFRaw() *http.Request {
|
||||
return this.RawReq
|
||||
}
|
||||
|
||||
// WAFRemoteIP 客户端IP
|
||||
func (this *HTTPRequest) WAFRemoteIP() string {
|
||||
return this.requestRemoteAddr()
|
||||
}
|
||||
|
||||
// WAFGetCacheBody 获取缓存中的Body
|
||||
func (this *HTTPRequest) WAFGetCacheBody() []byte {
|
||||
return this.bodyData
|
||||
}
|
||||
|
||||
// WAFSetCacheBody 设置Body
|
||||
func (this *HTTPRequest) WAFSetCacheBody(body []byte) {
|
||||
this.bodyData = body
|
||||
}
|
||||
|
||||
// WAFReadBody 读取Body
|
||||
func (this *HTTPRequest) WAFReadBody(max int64) (data []byte, err error) {
|
||||
if this.RawReq.ContentLength > 0 {
|
||||
data, err = ioutil.ReadAll(io.LimitReader(this.RawReq.Body, max))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// WAFRestoreBody 恢复Body
|
||||
func (this *HTTPRequest) WAFRestoreBody(data []byte) {
|
||||
if len(data) > 0 {
|
||||
rawReader := bytes.NewBuffer(data)
|
||||
buf := make([]byte, 1024)
|
||||
_, _ = io.CopyBuffer(rawReader, this.RawReq.Body, buf)
|
||||
this.RawReq.Body = ioutil.NopCloser(rawReader)
|
||||
}
|
||||
}
|
||||
|
||||
// WAFServerId 服务ID
|
||||
func (this *HTTPRequest) WAFServerId() int64 {
|
||||
return this.Server.Id
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
http2 "golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2"
|
||||
"sync"
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
package nodes
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"net"
|
||||
)
|
||||
|
||||
// TrafficListener 用于统计流量的网络监听
|
||||
type TrafficListener struct {
|
||||
@@ -18,6 +21,17 @@ func (this *TrafficListener) Accept() (net.Conn, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 是否在WAF名单中
|
||||
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if err == nil {
|
||||
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, ip) && waf.SharedIPBlackLIst.Contains(waf.IPTypeAll, ip) {
|
||||
go func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
return NewTrafficConn(conn), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,20 +11,20 @@ import (
|
||||
|
||||
var sharedWAFManager = NewWAFManager()
|
||||
|
||||
// WAF管理器
|
||||
// WAFManager WAF管理器
|
||||
type WAFManager struct {
|
||||
mapping map[int64]*waf.WAF // policyId => WAF
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
// 获取新对象
|
||||
// NewWAFManager 获取新对象
|
||||
func NewWAFManager() *WAFManager {
|
||||
return &WAFManager{
|
||||
mapping: map[int64]*waf.WAF{},
|
||||
}
|
||||
}
|
||||
|
||||
// 更新策略
|
||||
// UpdatePolicies 更新策略
|
||||
func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallPolicy) {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
@@ -44,7 +44,7 @@ func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallP
|
||||
this.mapping = m
|
||||
}
|
||||
|
||||
// 查找WAF
|
||||
// FindWAF 查找WAF
|
||||
func (this *WAFManager) FindWAF(policyId int64) *waf.WAF {
|
||||
this.locker.RLock()
|
||||
w, _ := this.mapping[policyId]
|
||||
@@ -78,14 +78,15 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
|
||||
// rule sets
|
||||
for _, set := range group.Sets {
|
||||
s := &waf.RuleSet{
|
||||
Id: strconv.FormatInt(set.Id, 10),
|
||||
Code: set.Code,
|
||||
IsOn: set.IsOn,
|
||||
Name: set.Name,
|
||||
Description: set.Description,
|
||||
Connector: set.Connector,
|
||||
Action: set.Action,
|
||||
ActionOptions: set.ActionOptions,
|
||||
Id: strconv.FormatInt(set.Id, 10),
|
||||
Code: set.Code,
|
||||
IsOn: set.IsOn,
|
||||
Name: set.Name,
|
||||
Description: set.Description,
|
||||
Connector: set.Connector,
|
||||
}
|
||||
for _, a := range set.Actions {
|
||||
s.AddAction(a.Code, a.Options)
|
||||
}
|
||||
|
||||
// rules
|
||||
@@ -132,14 +133,16 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
|
||||
// rule sets
|
||||
for _, set := range group.Sets {
|
||||
s := &waf.RuleSet{
|
||||
Id: strconv.FormatInt(set.Id, 10),
|
||||
Code: set.Code,
|
||||
IsOn: set.IsOn,
|
||||
Name: set.Name,
|
||||
Description: set.Description,
|
||||
Connector: set.Connector,
|
||||
Action: set.Action,
|
||||
ActionOptions: set.ActionOptions,
|
||||
Id: strconv.FormatInt(set.Id, 10),
|
||||
Code: set.Code,
|
||||
IsOn: set.IsOn,
|
||||
Name: set.Name,
|
||||
Description: set.Description,
|
||||
Connector: set.Connector,
|
||||
}
|
||||
|
||||
for _, a := range set.Actions {
|
||||
s.AddAction(a.Code, a.Options)
|
||||
}
|
||||
|
||||
// rules
|
||||
@@ -164,10 +167,11 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
|
||||
|
||||
// action
|
||||
if policy.BlockOptions != nil {
|
||||
w.ActionBlock = &waf.BlockAction{
|
||||
w.DefaultBlockAction = &waf.BlockAction{
|
||||
StatusCode: policy.BlockOptions.StatusCode,
|
||||
Body: policy.BlockOptions.Body,
|
||||
URL: "",
|
||||
URL: policy.BlockOptions.URL,
|
||||
Timeout: policy.BlockOptions.Timeout,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -113,6 +113,10 @@ func (this *RPCClient) MetricStatRPC() pb.MetricStatServiceClient {
|
||||
return pb.NewMetricStatServiceClient(this.pickConn())
|
||||
}
|
||||
|
||||
func (this *RPCClient) FirewallService() pb.FirewallServiceClient {
|
||||
return pb.NewFirewallServiceClient(this.pickConn())
|
||||
}
|
||||
|
||||
// Context 节点上下文信息
|
||||
func (this *RPCClient) Context() context.Context {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/monitor"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
@@ -132,17 +133,19 @@ func (this *HTTPRequestStatManager) AddUserAgent(serverId int64, userAgent strin
|
||||
}
|
||||
|
||||
// AddFirewallRuleGroupId 添加防火墙拦截动作
|
||||
func (this *HTTPRequestStatManager) AddFirewallRuleGroupId(serverId int64, firewallRuleGroupId int64, action string) {
|
||||
func (this *HTTPRequestStatManager) AddFirewallRuleGroupId(serverId int64, firewallRuleGroupId int64, actions []*waf.ActionConfig) {
|
||||
if firewallRuleGroupId <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
this.totalAttackRequests ++
|
||||
this.totalAttackRequests++
|
||||
|
||||
select {
|
||||
case this.firewallRuleGroupChan <- strconv.FormatInt(serverId, 10) + "@" + strconv.FormatInt(firewallRuleGroupId, 10) + "@" + action:
|
||||
default:
|
||||
// 超出容量我们就丢弃
|
||||
for _, action := range actions {
|
||||
select {
|
||||
case this.firewallRuleGroupChan <- strconv.FormatInt(serverId, 10) + "@" + strconv.FormatInt(firewallRuleGroupId, 10) + "@" + action.Code:
|
||||
default:
|
||||
// 超出容量我们就丢弃
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
159
internal/utils/encrypt.go
Normal file
159
internal/utils/encrypt.go
Normal file
@@ -0,0 +1,159 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/rands"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
)
|
||||
|
||||
var (
|
||||
simpleEncryptMagicKey = rands.HexString(32)
|
||||
)
|
||||
|
||||
func init() {
|
||||
events.On(events.EventReload, func() {
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil {
|
||||
simpleEncryptMagicKey = stringutil.Md5(nodeConfig.NodeId + "@" + nodeConfig.Secret)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SimpleEncrypt 加密特殊信息
|
||||
func SimpleEncrypt(data []byte) []byte {
|
||||
var method = &AES256CFBMethod{}
|
||||
err := method.Init([]byte(simpleEncryptMagicKey), []byte(simpleEncryptMagicKey[:16]))
|
||||
if err != nil {
|
||||
logs.Println("[SimpleEncrypt]" + err.Error())
|
||||
return data
|
||||
}
|
||||
|
||||
dst, err := method.Encrypt(data)
|
||||
if err != nil {
|
||||
logs.Println("[SimpleEncrypt]" + err.Error())
|
||||
return data
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// SimpleDecrypt 解密特殊信息
|
||||
func SimpleDecrypt(data []byte) []byte {
|
||||
var method = &AES256CFBMethod{}
|
||||
err := method.Init([]byte(simpleEncryptMagicKey), []byte(simpleEncryptMagicKey[:16]))
|
||||
if err != nil {
|
||||
logs.Println("[MagicKeyEncode]" + err.Error())
|
||||
return data
|
||||
}
|
||||
|
||||
src, err := method.Decrypt(data)
|
||||
if err != nil {
|
||||
logs.Println("[MagicKeyEncode]" + err.Error())
|
||||
return data
|
||||
}
|
||||
return src
|
||||
}
|
||||
|
||||
func SimpleEncryptMap(m maps.Map) (base64String string, err error) {
|
||||
mJSON, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data := SimpleEncrypt(mJSON)
|
||||
return base64.StdEncoding.EncodeToString(data), nil
|
||||
}
|
||||
|
||||
func SimpleDecryptMap(base64String string) (maps.Map, error) {
|
||||
data, err := base64.StdEncoding.DecodeString(base64String)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mJSON := SimpleDecrypt(data)
|
||||
var result = maps.Map{}
|
||||
err = json.Unmarshal(mJSON, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type AES256CFBMethod struct {
|
||||
block cipher.Block
|
||||
iv []byte
|
||||
}
|
||||
|
||||
func (this *AES256CFBMethod) Init(key, iv []byte) error {
|
||||
// 判断key是否为32长度
|
||||
l := len(key)
|
||||
if l > 32 {
|
||||
key = key[:32]
|
||||
} else if l < 32 {
|
||||
key = append(key, bytes.Repeat([]byte{' '}, 32-l)...)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.block = block
|
||||
|
||||
// 判断iv长度
|
||||
l2 := len(iv)
|
||||
if l2 > aes.BlockSize {
|
||||
iv = iv[:aes.BlockSize]
|
||||
} else if l2 < aes.BlockSize {
|
||||
iv = append(iv, bytes.Repeat([]byte{' '}, aes.BlockSize-l2)...)
|
||||
}
|
||||
this.iv = iv
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *AES256CFBMethod) Encrypt(src []byte) (dst []byte, err error) {
|
||||
if len(src) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r != nil {
|
||||
err = errors.New("encrypt failed")
|
||||
}
|
||||
}()
|
||||
|
||||
dst = make([]byte, len(src))
|
||||
|
||||
encrypter := cipher.NewCFBEncrypter(this.block, this.iv)
|
||||
encrypter.XORKeyStream(dst, src)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *AES256CFBMethod) Decrypt(dst []byte) (src []byte, err error) {
|
||||
if len(dst) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r != nil {
|
||||
err = errors.New("decrypt failed")
|
||||
}
|
||||
}()
|
||||
|
||||
src = make([]byte, len(dst))
|
||||
decrypter := cipher.NewCFBDecrypter(this.block, this.iv)
|
||||
decrypter.XORKeyStream(src, dst)
|
||||
|
||||
return
|
||||
}
|
||||
52
internal/utils/encrypt_test.go
Normal file
52
internal/utils/encrypt_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSimpleEncrypt(t *testing.T) {
|
||||
var arr = []string{"Hello", "World", "People"}
|
||||
for _, s := range arr {
|
||||
var value = []byte(s)
|
||||
encoded := SimpleEncrypt(value)
|
||||
t.Log(encoded, string(encoded))
|
||||
decoded := SimpleDecrypt(encoded)
|
||||
t.Log(decoded, string(decoded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleEncrypt_Concurrent(t *testing.T) {
|
||||
wg := sync.WaitGroup{}
|
||||
var arr = []string{"Hello", "World", "People"}
|
||||
wg.Add(len(arr))
|
||||
for _, s := range arr {
|
||||
go func(s string) {
|
||||
defer wg.Done()
|
||||
t.Log(string(SimpleDecrypt(SimpleEncrypt([]byte(s)))))
|
||||
}(s)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSimpleEncryptMap(t *testing.T) {
|
||||
var m = maps.Map{
|
||||
"s": "Hello",
|
||||
"i": 20,
|
||||
"b": true,
|
||||
}
|
||||
encodedResult, err := SimpleEncryptMap(m)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("result:", encodedResult)
|
||||
|
||||
decodedResult, err := SimpleDecryptMap(encodedResult)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(decodedResult)
|
||||
}
|
||||
@@ -12,6 +12,7 @@ type List struct {
|
||||
itemsMap map[int64]int64 // itemId => timestamp
|
||||
|
||||
locker sync.Mutex
|
||||
ticker *time.Ticker
|
||||
}
|
||||
|
||||
func NewList() *List {
|
||||
@@ -21,10 +22,7 @@ func NewList() *List {
|
||||
}
|
||||
}
|
||||
|
||||
func (this *List) Add(itemId int64, expiredAt int64) {
|
||||
if expiredAt <= time.Now().Unix() {
|
||||
return
|
||||
}
|
||||
func (this *List) Add(itemId int64, expiresAt int64) {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
@@ -34,17 +32,17 @@ func (this *List) Add(itemId int64, expiredAt int64) {
|
||||
this.removeItem(itemId)
|
||||
}
|
||||
|
||||
expireItemMap, ok := this.expireMap[expiredAt]
|
||||
expireItemMap, ok := this.expireMap[expiresAt]
|
||||
if ok {
|
||||
expireItemMap[itemId] = true
|
||||
} else {
|
||||
expireItemMap = ItemMap{
|
||||
itemId: true,
|
||||
}
|
||||
this.expireMap[expiredAt] = expireItemMap
|
||||
this.expireMap[expiresAt] = expireItemMap
|
||||
}
|
||||
|
||||
this.itemsMap[itemId] = expiredAt
|
||||
this.itemsMap[itemId] = expiresAt
|
||||
}
|
||||
|
||||
func (this *List) Remove(itemId int64) {
|
||||
@@ -64,21 +62,22 @@ func (this *List) GC(timestamp int64, callback func(itemId int64)) {
|
||||
}
|
||||
|
||||
func (this *List) StartGC(callback func(itemId int64)) {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
this.ticker = time.NewTicker(1 * time.Second)
|
||||
lastTimestamp := int64(0)
|
||||
for range ticker.C {
|
||||
for range this.ticker.C {
|
||||
timestamp := time.Now().Unix()
|
||||
if lastTimestamp == 0 {
|
||||
lastTimestamp = timestamp - 3600
|
||||
}
|
||||
|
||||
// 防止死循环
|
||||
if lastTimestamp > timestamp {
|
||||
continue
|
||||
}
|
||||
|
||||
for i := lastTimestamp; i <= timestamp; i++ {
|
||||
this.GC(timestamp, callback)
|
||||
if timestamp >= lastTimestamp {
|
||||
for i := lastTimestamp; i <= timestamp; i++ {
|
||||
this.GC(i, callback)
|
||||
}
|
||||
} else {
|
||||
for i := timestamp; i <= lastTimestamp; i++ {
|
||||
this.GC(i, callback)
|
||||
}
|
||||
}
|
||||
|
||||
// 这样做是为了防止系统时钟突变
|
||||
|
||||
@@ -58,6 +58,10 @@ func TestList_Start_GC(t *testing.T) {
|
||||
list.Add(2, time.Now().Unix()+1)
|
||||
list.Add(3, time.Now().Unix()+2)
|
||||
list.Add(4, time.Now().Unix()+5)
|
||||
list.Add(5, time.Now().Unix()+5)
|
||||
list.Add(6, time.Now().Unix()+6)
|
||||
list.Add(7, time.Now().Unix()+6)
|
||||
list.Add(8, time.Now().Unix()+6)
|
||||
|
||||
go func() {
|
||||
list.StartGC(func(itemId int64) {
|
||||
@@ -66,7 +70,7 @@ func TestList_Start_GC(t *testing.T) {
|
||||
})
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
time.Sleep(20 * time.Second)
|
||||
}
|
||||
|
||||
func TestList_ManyItems(t *testing.T) {
|
||||
|
||||
35
internal/utils/jsonutils/map.go
Normal file
35
internal/utils/jsonutils/map.go
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package jsonutils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
)
|
||||
|
||||
func MapToObject(m maps.Map, ptr interface{}) error {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
mJSON, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(mJSON, ptr)
|
||||
}
|
||||
|
||||
func ObjectToMap(ptr interface{}) (maps.Map, error) {
|
||||
if ptr == nil {
|
||||
return maps.Map{}, nil
|
||||
}
|
||||
ptrJSON, err := json.Marshal(ptr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result = maps.Map{}
|
||||
err = json.Unmarshal(ptrJSON, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
46
internal/utils/jsonutils/map_test.go
Normal file
46
internal/utils/jsonutils/map_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package jsonutils
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMapToObject(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
|
||||
type typeA struct {
|
||||
B int `json:"b"`
|
||||
C bool `json:"c"`
|
||||
}
|
||||
|
||||
{
|
||||
var obj = &typeA{B: 1, C: true}
|
||||
m, err := ObjectToMap(obj)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
PrintT(m, t)
|
||||
a.IsTrue(m.GetInt("b") == 1)
|
||||
a.IsTrue(m.GetBool("c") == true)
|
||||
}
|
||||
|
||||
{
|
||||
var obj = &typeA{}
|
||||
err := MapToObject(maps.Map{
|
||||
"b": 1024,
|
||||
"c": true,
|
||||
}, obj)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if obj == nil {
|
||||
t.Fatal("obj should not be nil")
|
||||
}
|
||||
a.IsTrue(obj.B == 1024)
|
||||
a.IsTrue(obj.C == true)
|
||||
PrintT(obj, t)
|
||||
}
|
||||
}
|
||||
17
internal/utils/jsonutils/utils.go
Normal file
17
internal/utils/jsonutils/utils.go
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package jsonutils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func PrintT(obj interface{}, t *testing.T) {
|
||||
data, err := json.MarshalIndent(obj, "", " ")
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
} else {
|
||||
t.Log(string(data))
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,23 @@ import (
|
||||
type AllowAction struct {
|
||||
}
|
||||
|
||||
func (this *AllowAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
func (this *AllowAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *AllowAction) Code() string {
|
||||
return ActionAllow
|
||||
}
|
||||
|
||||
func (this *AllowAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *AllowAction) WillChange() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
// do nothing
|
||||
return true
|
||||
}
|
||||
|
||||
21
internal/waf/action_base.go
Normal file
21
internal/waf/action_base.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package waf
|
||||
|
||||
import "net/http"
|
||||
|
||||
type BaseAction struct {
|
||||
}
|
||||
|
||||
// CloseConn 关闭连接
|
||||
func (this *BaseAction) CloseConn(writer http.ResponseWriter) error {
|
||||
// 断开连接
|
||||
hijack, ok := writer.(http.Hijacker)
|
||||
if ok {
|
||||
conn, _, err := hijack.Hijack()
|
||||
if err == nil {
|
||||
return conn.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -23,12 +23,48 @@ type BlockAction struct {
|
||||
StatusCode int `yaml:"statusCode" json:"statusCode"`
|
||||
Body string `yaml:"body" json:"body"` // supports HTML
|
||||
URL string `yaml:"url" json:"url"`
|
||||
Timeout int32 `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
func (this *BlockAction) Init(waf *WAF) error {
|
||||
if waf.DefaultBlockAction != nil {
|
||||
if this.StatusCode <= 0 {
|
||||
this.StatusCode = waf.DefaultBlockAction.StatusCode
|
||||
}
|
||||
if len(this.Body) == 0 {
|
||||
this.Body = waf.DefaultBlockAction.Body
|
||||
}
|
||||
if len(this.URL) == 0 {
|
||||
this.URL = waf.DefaultBlockAction.URL
|
||||
}
|
||||
if this.Timeout <= 0 {
|
||||
this.Timeout = waf.DefaultBlockAction.Timeout
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *BlockAction) Code() string {
|
||||
return ActionBlock
|
||||
}
|
||||
|
||||
func (this *BlockAction) IsAttack() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *BlockAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
if this.Timeout > 0 {
|
||||
// 加入到黑名单
|
||||
SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), time.Now().Unix()+int64(this.Timeout))
|
||||
}
|
||||
|
||||
if writer != nil {
|
||||
// if status code eq 444, we close the connection
|
||||
if this.StatusCode == 444 {
|
||||
// close the connection
|
||||
defer func() {
|
||||
hijack, ok := writer.(http.Hijacker)
|
||||
if ok {
|
||||
conn, _, _ := hijack.Hijack()
|
||||
@@ -37,7 +73,7 @@ func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer htt
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// output response
|
||||
if this.StatusCode > 0 {
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -13,27 +16,63 @@ var captchaSalt = stringutil.Rand(32)
|
||||
|
||||
const (
|
||||
CaptchaSeconds = 600 // 10 minutes
|
||||
CaptchaPath = "/WAF/VERIFY/CAPTCHA"
|
||||
)
|
||||
|
||||
type CaptchaAction struct {
|
||||
Life int32 `yaml:"life" json:"life"`
|
||||
Language string `yaml:"language" json:"language"` // 语言,zh-CN, en-US ...
|
||||
AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单
|
||||
}
|
||||
|
||||
func (this *CaptchaAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
// TEAWEB_CAPTCHA:
|
||||
cookie, err := request.Cookie("TEAWEB_WAF_CAPTCHA")
|
||||
if err == nil && cookie != nil && len(cookie.Value) > 32 {
|
||||
m := cookie.Value[:32]
|
||||
timestamp := cookie.Value[32:]
|
||||
if stringutil.Md5(captchaSalt+timestamp) == m && time.Now().Unix() < types.Int64(timestamp) { // verify md5
|
||||
return true
|
||||
func (this *CaptchaAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *CaptchaAction) Code() string {
|
||||
return ActionCaptcha
|
||||
}
|
||||
|
||||
func (this *CaptchaAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *CaptchaAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
// 是否在白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
|
||||
return true
|
||||
}
|
||||
|
||||
refURL := request.WAFRaw().URL.String()
|
||||
|
||||
// 覆盖配置
|
||||
if strings.HasPrefix(refURL, CaptchaPath) {
|
||||
info := request.WAFRaw().URL.Query().Get("info")
|
||||
if len(info) > 0 {
|
||||
m, err := utils.SimpleDecryptMap(info)
|
||||
if err == nil && m != nil {
|
||||
refURL = m.GetString("url")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
refURL := request.URL.String()
|
||||
if len(request.Referer()) > 0 {
|
||||
refURL = request.Referer()
|
||||
var captchaConfig = maps.Map{
|
||||
"action": this,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"url": refURL,
|
||||
"setId": set.Id,
|
||||
}
|
||||
http.Redirect(writer, request.Raw(), "/WAFCAPTCHA?url="+url.QueryEscape(refURL), http.StatusTemporaryRedirect)
|
||||
info, err := utils.SimpleEncryptMap(captchaConfig)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_CAPTCHA_ACTION", "encode captcha config failed: "+err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
http.Redirect(writer, request.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
13
internal/waf/action_category.go
Normal file
13
internal/waf/action_category.go
Normal file
@@ -0,0 +1,13 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package waf
|
||||
|
||||
import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
|
||||
type ActionCategory = string
|
||||
|
||||
const (
|
||||
ActionCategoryAllow ActionCategory = firewallconfigs.HTTPFirewallActionCategoryAllow
|
||||
ActionCategoryBlock ActionCategory = firewallconfigs.HTTPFirewallActionCategoryBlock
|
||||
ActionCategoryVerify ActionCategory = firewallconfigs.HTTPFirewallActionCategoryVerify
|
||||
)
|
||||
10
internal/waf/action_config.go
Normal file
10
internal/waf/action_config.go
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package waf
|
||||
|
||||
import "github.com/iwind/TeaGo/maps"
|
||||
|
||||
type ActionConfig struct {
|
||||
Code string `yaml:"code" json:"code"`
|
||||
Options maps.Map `yaml:"options" json:"options"`
|
||||
}
|
||||
@@ -2,11 +2,12 @@ package waf
|
||||
|
||||
import "reflect"
|
||||
|
||||
// action definition
|
||||
// ActionDefinition action definition
|
||||
type ActionDefinition struct {
|
||||
Name string
|
||||
Code ActionString
|
||||
Description string
|
||||
Category string // category: block, verify, allow
|
||||
Instance ActionInterface
|
||||
Type reflect.Type
|
||||
}
|
||||
|
||||
71
internal/waf/action_get_302.go
Normal file
71
internal/waf/action_get_302.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
Get302Path = "/WAF/VERIFY/GET"
|
||||
)
|
||||
|
||||
// Get302Action
|
||||
// 原理: origin url --> 302 verify url --> origin url
|
||||
// TODO 将来支持meta refresh验证
|
||||
type Get302Action struct {
|
||||
BaseAction
|
||||
|
||||
Life int32 `yaml:"life" json:"life"`
|
||||
}
|
||||
|
||||
func (this *Get302Action) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Get302Action) Code() string {
|
||||
return ActionGet302
|
||||
}
|
||||
|
||||
func (this *Get302Action) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *Get302Action) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
// 仅限于Get
|
||||
if request.WAFRaw().Method != http.MethodGet {
|
||||
return true
|
||||
}
|
||||
|
||||
// 是否已经在白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
|
||||
return true
|
||||
}
|
||||
|
||||
var m = maps.Map{
|
||||
"url": request.WAFRaw().URL.String(),
|
||||
"timestamp": time.Now().Unix(),
|
||||
"life": this.Life,
|
||||
"setId": set.Id,
|
||||
}
|
||||
info, err := utils.SimpleEncryptMap(m)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_GET_302_ACTION", "encode info failed: "+err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
http.Redirect(writer, request.WAFRaw(), Get302Path+"?info="+url.QueryEscape(info), http.StatusFound)
|
||||
|
||||
// 关闭连接
|
||||
_ = this.CloseConn(writer)
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -10,13 +10,29 @@ type GoGroupAction struct {
|
||||
GroupId string `yaml:"groupId" json:"groupId"`
|
||||
}
|
||||
|
||||
func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
group := waf.FindRuleGroup(this.GroupId)
|
||||
if group == nil || !group.IsOn {
|
||||
func (this *GoGroupAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *GoGroupAction) Code() string {
|
||||
return ActionGoGroup
|
||||
}
|
||||
|
||||
func (this *GoGroupAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *GoGroupAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
nextGroup := waf.FindRuleGroup(this.GroupId)
|
||||
if nextGroup == nil || !nextGroup.IsOn {
|
||||
return true
|
||||
}
|
||||
|
||||
b, set, err := group.MatchRequest(request)
|
||||
b, nextSet, err := nextGroup.MatchRequest(request)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return true
|
||||
@@ -26,9 +42,5 @@ func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer h
|
||||
return true
|
||||
}
|
||||
|
||||
actionObject := FindActionInstance(set.Action, set.ActionOptions)
|
||||
if actionObject == nil {
|
||||
return true
|
||||
}
|
||||
return actionObject.Perform(waf, request, writer)
|
||||
return nextSet.PerformActions(waf, nextGroup, request, writer)
|
||||
}
|
||||
|
||||
@@ -11,17 +11,33 @@ type GoSetAction struct {
|
||||
SetId string `yaml:"setId" json:"setId"`
|
||||
}
|
||||
|
||||
func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
group := waf.FindRuleGroup(this.GroupId)
|
||||
if group == nil || !group.IsOn {
|
||||
func (this *GoSetAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *GoSetAction) Code() string {
|
||||
return ActionGoSet
|
||||
}
|
||||
|
||||
func (this *GoSetAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *GoSetAction) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
nextGroup := waf.FindRuleGroup(this.GroupId)
|
||||
if nextGroup == nil || !nextGroup.IsOn {
|
||||
return true
|
||||
}
|
||||
set := group.FindRuleSet(this.SetId)
|
||||
if set == nil || !set.IsOn {
|
||||
nextSet := nextGroup.FindRuleSet(this.SetId)
|
||||
if nextSet == nil || !nextSet.IsOn {
|
||||
return true
|
||||
}
|
||||
|
||||
b, err := set.MatchRequest(request)
|
||||
b, err := nextSet.MatchRequest(request)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return true
|
||||
@@ -29,9 +45,5 @@ func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer htt
|
||||
if !b {
|
||||
return true
|
||||
}
|
||||
actionObject := FindActionInstance(set.Action, set.ActionOptions)
|
||||
if actionObject == nil {
|
||||
return true
|
||||
}
|
||||
return actionObject.Perform(waf, request, writer)
|
||||
return nextSet.PerformActions(waf, nextGroup, request, writer)
|
||||
}
|
||||
|
||||
25
internal/waf/action_interface.go
Normal file
25
internal/waf/action_interface.go
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ActionInterface interface {
|
||||
// Init 初始化
|
||||
Init(waf *WAF) error
|
||||
|
||||
// Code 代号
|
||||
Code() string
|
||||
|
||||
// IsAttack 是否为拦截攻击动作
|
||||
IsAttack() bool
|
||||
|
||||
// WillChange determine if the action will change the request
|
||||
WillChange() bool
|
||||
|
||||
// Perform perform the action
|
||||
Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool)
|
||||
}
|
||||
@@ -8,6 +8,22 @@ import (
|
||||
type LogAction struct {
|
||||
}
|
||||
|
||||
func (this *LogAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
func (this *LogAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *LogAction) Code() string {
|
||||
return ActionLog
|
||||
}
|
||||
|
||||
func (this *LogAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *LogAction) WillChange() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
86
internal/waf/action_notify.go
Normal file
86
internal/waf/action_notify.go
Normal file
@@ -0,0 +1,86 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type notifyTask struct {
|
||||
ServerId int64
|
||||
HttpFirewallPolicyId int64
|
||||
HttpFirewallRuleGroupId int64
|
||||
HttpFirewallRuleSetId int64
|
||||
CreatedAt int64
|
||||
}
|
||||
|
||||
var notifyChan = make(chan *notifyTask, 128)
|
||||
|
||||
func init() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
go func() {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_NOTIFY_ACTION", "create rpc client failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
for task := range notifyChan {
|
||||
_, err = rpcClient.FirewallService().NotifyHTTPFirewallEvent(rpcClient.Context(), &pb.NotifyHTTPFirewallEventRequest{
|
||||
ServerId: task.ServerId,
|
||||
HttpFirewallPolicyId: task.HttpFirewallPolicyId,
|
||||
HttpFirewallRuleGroupId: task.HttpFirewallRuleGroupId,
|
||||
HttpFirewallRuleSetId: task.HttpFirewallRuleSetId,
|
||||
CreatedAt: task.CreatedAt,
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_NOTIFY_ACTION", "notify failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
type NotifyAction struct {
|
||||
}
|
||||
|
||||
func (this *NotifyAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *NotifyAction) Code() string {
|
||||
return ActionNotify
|
||||
}
|
||||
|
||||
func (this *NotifyAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// WillChange determine if the action will change the request
|
||||
func (this *NotifyAction) WillChange() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Perform perform the action
|
||||
func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
select {
|
||||
case notifyChan <- ¬ifyTask{
|
||||
ServerId: request.WAFServerId(),
|
||||
HttpFirewallPolicyId: types.Int64(waf.Id),
|
||||
HttpFirewallRuleGroupId: types.Int64(group.Id),
|
||||
HttpFirewallRuleSetId: types.Int64(set.Id),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
}:
|
||||
default:
|
||||
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
88
internal/waf/action_post_307.go
Normal file
88
internal/waf/action_post_307.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Post307Action struct {
|
||||
Life int32 `yaml:"life" json:"life"`
|
||||
|
||||
BaseAction
|
||||
}
|
||||
|
||||
func (this *Post307Action) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Post307Action) Code() string {
|
||||
return ActionPost307
|
||||
}
|
||||
|
||||
func (this *Post307Action) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *Post307Action) WillChange() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
var cookieName = "WAF_VALIDATOR_ID"
|
||||
|
||||
// 仅限于POST
|
||||
if request.WAFRaw().Method != http.MethodPost {
|
||||
return true
|
||||
}
|
||||
|
||||
// 是否已经在白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 判断是否有Cookie
|
||||
cookie, err := request.WAFRaw().Cookie(cookieName)
|
||||
if err == nil && cookie != nil {
|
||||
m, err := utils.SimpleDecryptMap(cookie.Value)
|
||||
if err == nil && m.GetString("remoteIP") == request.WAFRemoteIP() && time.Now().Unix() < m.GetInt64("timestamp")+10 {
|
||||
var life = m.GetInt64("life")
|
||||
if life <= 0 {
|
||||
life = 600 // 默认10分钟
|
||||
}
|
||||
var setId = m.GetString("setId")
|
||||
SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
var m = maps.Map{
|
||||
"timestamp": time.Now().Unix(),
|
||||
"life": this.Life,
|
||||
"setId": set.Id,
|
||||
"remoteIP": request.WAFRemoteIP(),
|
||||
}
|
||||
info, err := utils.SimpleEncryptMap(m)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_POST_302_ACTION", "encode info failed: "+err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
// 设置Cookie
|
||||
http.SetCookie(writer, &http.Cookie{
|
||||
Name: cookieName,
|
||||
Path: "/",
|
||||
MaxAge: 10,
|
||||
Value: info,
|
||||
})
|
||||
|
||||
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect)
|
||||
|
||||
// 关闭连接
|
||||
_ = this.CloseConn(writer)
|
||||
|
||||
return true
|
||||
}
|
||||
120
internal/waf/action_record_ip.go
Normal file
120
internal/waf/action_record_ip.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type recordIPTask struct {
|
||||
ip string
|
||||
listId int64
|
||||
expiredAt int64
|
||||
level string
|
||||
}
|
||||
|
||||
var recordIPTaskChan = make(chan *recordIPTask, 1024)
|
||||
|
||||
func init() {
|
||||
events.On(events.EventLoaded, func() {
|
||||
go func() {
|
||||
rpcClient, err := rpc.SharedRPC()
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_RECORD_IP_ACTION", "create rpc client failed: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
for task := range recordIPTaskChan {
|
||||
ipType := "ipv4"
|
||||
if strings.Contains(task.ip, ":") {
|
||||
ipType = "ipv6"
|
||||
}
|
||||
_, err = rpcClient.IPItemRPC().CreateIPItem(rpcClient.Context(), &pb.CreateIPItemRequest{
|
||||
IpListId: task.listId,
|
||||
IpFrom: task.ip,
|
||||
IpTo: "",
|
||||
ExpiredAt: task.expiredAt,
|
||||
Reason: "触发WAF规则自动加入",
|
||||
Type: ipType,
|
||||
EventLevel: task.level,
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
type RecordIPAction struct {
|
||||
BaseAction
|
||||
|
||||
Type string `yaml:"type" json:"type"`
|
||||
IPListId int64 `yaml:"ipListId" json:"ipListId"`
|
||||
Level string `yaml:"level" json:"level"`
|
||||
Timeout int32 `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
func (this *RecordIPAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *RecordIPAction) Code() string {
|
||||
return ActionRecordIP
|
||||
}
|
||||
|
||||
func (this *RecordIPAction) IsAttack() bool {
|
||||
return this.Type == "black"
|
||||
}
|
||||
|
||||
func (this *RecordIPAction) WillChange() bool {
|
||||
return this.Type == "black"
|
||||
}
|
||||
|
||||
func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
// 是否在本地白名单中
|
||||
if SharedIPWhiteList.Contains("set:"+set.Id, set.Id) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 先加入本地的黑名单
|
||||
timeout := this.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = 86400 // 1天
|
||||
}
|
||||
expiredAt := time.Now().Unix() + int64(timeout)
|
||||
|
||||
if this.Type == "black" {
|
||||
_ = this.CloseConn(writer)
|
||||
|
||||
SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), expiredAt)
|
||||
} else {
|
||||
// 加入本地白名单
|
||||
timeout := this.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = 86400 // 1天
|
||||
}
|
||||
SharedIPWhiteList.Add("set:"+set.Id, request.WAFRemoteIP(), expiredAt)
|
||||
}
|
||||
|
||||
// 上报
|
||||
if this.IPListId > 0 {
|
||||
select {
|
||||
case recordIPTaskChan <- &recordIPTask{
|
||||
ip: request.WAFRemoteIP(),
|
||||
listId: this.IPListId,
|
||||
expiredAt: expiredAt,
|
||||
level: this.Level,
|
||||
}:
|
||||
default:
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return this.Type != "black"
|
||||
}
|
||||
30
internal/waf/action_tag.go
Normal file
30
internal/waf/action_tag.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type TagAction struct {
|
||||
Tags []string `yaml:"tags" json:"tags"`
|
||||
}
|
||||
|
||||
func (this *TagAction) Init(waf *WAF) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *TagAction) Code() string {
|
||||
return ActionTag
|
||||
}
|
||||
|
||||
func (this *TagAction) IsAttack() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *TagAction) WillChange() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
return true
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ActionString = string
|
||||
|
||||
const (
|
||||
ActionLog = "log" // allow and log
|
||||
ActionBlock = "block" // block
|
||||
ActionCaptcha = "captcha" // block and show captcha
|
||||
ActionAllow = "allow" // allow
|
||||
ActionGoGroup = "go_group" // go to next rule group
|
||||
ActionGoSet = "go_set" // go to next rule set
|
||||
)
|
||||
|
||||
type ActionInterface interface {
|
||||
Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool)
|
||||
}
|
||||
88
internal/waf/action_types.go
Normal file
88
internal/waf/action_types.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package waf
|
||||
|
||||
import "reflect"
|
||||
|
||||
type ActionString = string
|
||||
|
||||
const (
|
||||
ActionLog ActionString = "log" // allow and log
|
||||
ActionBlock ActionString = "block" // block
|
||||
ActionCaptcha ActionString = "captcha" // block and show captcha
|
||||
ActionNotify ActionString = "notify" // 告警
|
||||
ActionGet302 ActionString = "get_302" // 针对GET的302重定向认证
|
||||
ActionPost307 ActionString = "post_307" // 针对POST的307重定向认证
|
||||
ActionRecordIP ActionString = "record_ip" // 记录IP
|
||||
ActionTag ActionString = "tag" // 标签
|
||||
ActionAllow ActionString = "allow" // allow
|
||||
ActionGoGroup ActionString = "go_group" // go to next rule group
|
||||
ActionGoSet ActionString = "go_set" // go to next rule set
|
||||
)
|
||||
|
||||
var AllActions = []*ActionDefinition{
|
||||
{
|
||||
Name: "阻止",
|
||||
Code: ActionBlock,
|
||||
Instance: new(BlockAction),
|
||||
Type: reflect.TypeOf(new(BlockAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "允许通过",
|
||||
Code: ActionAllow,
|
||||
Instance: new(AllowAction),
|
||||
Type: reflect.TypeOf(new(AllowAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "允许并记录日志",
|
||||
Code: ActionLog,
|
||||
Instance: new(LogAction),
|
||||
Type: reflect.TypeOf(new(LogAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "Captcha验证码",
|
||||
Code: ActionCaptcha,
|
||||
Instance: new(CaptchaAction),
|
||||
Type: reflect.TypeOf(new(CaptchaAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "告警",
|
||||
Code: ActionNotify,
|
||||
Instance: new(NotifyAction),
|
||||
Type: reflect.TypeOf(new(NotifyAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "GET 302",
|
||||
Code: ActionGet302,
|
||||
Instance: new(Get302Action),
|
||||
Type: reflect.TypeOf(new(Get302Action)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "POST 307",
|
||||
Code: ActionPost307,
|
||||
Instance: new(Post307Action),
|
||||
Type: reflect.TypeOf(new(Post307Action)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "记录IP",
|
||||
Code: ActionRecordIP,
|
||||
Instance: new(RecordIPAction),
|
||||
Type: reflect.TypeOf(new(RecordIPAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "标签",
|
||||
Code: ActionTag,
|
||||
Instance: new(TagAction),
|
||||
Type: reflect.TypeOf(new(TagAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "跳到下一个规则分组",
|
||||
Code: ActionGoGroup,
|
||||
Instance: new(GoGroupAction),
|
||||
Type: reflect.TypeOf(new(GoGroupAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "跳到下一个规则集",
|
||||
Code: ActionGoSet,
|
||||
Instance: new(GoSetAction),
|
||||
Type: reflect.TypeOf(new(GoSetAction)).Elem(),
|
||||
},
|
||||
}
|
||||
@@ -1,45 +1,12 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
var AllActions = []*ActionDefinition{
|
||||
{
|
||||
Name: "阻止",
|
||||
Code: ActionBlock,
|
||||
Instance: new(BlockAction),
|
||||
},
|
||||
{
|
||||
Name: "允许通过",
|
||||
Code: ActionAllow,
|
||||
Instance: new(AllowAction),
|
||||
},
|
||||
{
|
||||
Name: "允许并记录日志",
|
||||
Code: ActionLog,
|
||||
Instance: new(LogAction),
|
||||
},
|
||||
{
|
||||
Name: "Captcha验证码",
|
||||
Code: ActionCaptcha,
|
||||
Instance: new(CaptchaAction),
|
||||
},
|
||||
{
|
||||
Name: "跳到下一个规则分组",
|
||||
Code: ActionGoGroup,
|
||||
Instance: new(GoGroupAction),
|
||||
Type: reflect.TypeOf(new(GoGroupAction)).Elem(),
|
||||
},
|
||||
{
|
||||
Name: "跳到下一个规则集",
|
||||
Code: ActionGoSet,
|
||||
Instance: new(GoSetAction),
|
||||
Type: reflect.TypeOf(new(GoSetAction)).Elem(),
|
||||
},
|
||||
}
|
||||
|
||||
func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
|
||||
for _, def := range AllActions {
|
||||
if def.Code == action {
|
||||
@@ -49,15 +16,13 @@ func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
|
||||
instance := ptrValue.Interface().(ActionInterface)
|
||||
|
||||
if len(options) > 0 {
|
||||
count := def.Type.NumField()
|
||||
for i := 0; i < count; i++ {
|
||||
field := def.Type.Field(i)
|
||||
tag, ok := field.Tag.Lookup("yaml")
|
||||
if ok {
|
||||
v, ok := options[tag]
|
||||
if ok && reflect.TypeOf(v) == field.Type {
|
||||
ptrValue.Elem().FieldByName(field.Name).Set(reflect.ValueOf(v))
|
||||
}
|
||||
optionsJSON, err := json.Marshal(options)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_FindActionInstance", "encode options to json failed: "+err.Error())
|
||||
} else {
|
||||
err = json.Unmarshal(optionsJSON, instance)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_FindActionInstance", "decode options from json failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package waf
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"runtime"
|
||||
"testing"
|
||||
@@ -16,11 +17,20 @@ func TestFindActionInstance(t *testing.T) {
|
||||
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
|
||||
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
|
||||
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
|
||||
t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b",}))
|
||||
t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b"}))
|
||||
|
||||
a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil))
|
||||
}
|
||||
|
||||
func TestFindActionInstance_Options(t *testing.T) {
|
||||
//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
|
||||
//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
|
||||
//logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{}), t)
|
||||
logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{
|
||||
"timeout": 3600,
|
||||
}), t)
|
||||
}
|
||||
|
||||
func BenchmarkFindActionInstance(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
||||
@@ -3,29 +3,64 @@ package waf
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/jsonutils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/dchest/captcha"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var captchaValidator = &CaptchaValidator{}
|
||||
var captchaValidator = NewCaptchaValidator()
|
||||
|
||||
type CaptchaValidator struct {
|
||||
}
|
||||
|
||||
func (this *CaptchaValidator) Run(request *requests.Request, writer http.ResponseWriter) {
|
||||
if request.Method == http.MethodPost && len(request.FormValue("TEAWEB_WAF_CAPTCHA_ID")) > 0 {
|
||||
this.validate(request, writer)
|
||||
func NewCaptchaValidator() *CaptchaValidator {
|
||||
return &CaptchaValidator{}
|
||||
}
|
||||
|
||||
func (this *CaptchaValidator) Run(request requests.Request, writer http.ResponseWriter) {
|
||||
var info = request.WAFRaw().URL.Query().Get("info")
|
||||
if len(info) == 0 {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = writer.Write([]byte("invalid request"))
|
||||
return
|
||||
}
|
||||
m, err := utils.SimpleDecryptMap(info)
|
||||
if err != nil {
|
||||
_, _ = writer.Write([]byte("invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
timestamp := m.GetInt64("timestamp")
|
||||
if timestamp < time.Now().Unix()-600 { // 10分钟之后信息过期
|
||||
http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
var actionConfig = &CaptchaAction{}
|
||||
err = jsonutils.MapToObject(m.GetMap("action"), actionConfig)
|
||||
if err != nil {
|
||||
http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
var setId = m.GetInt64("setId")
|
||||
var originURL = m.GetString("url")
|
||||
|
||||
if request.WAFRaw().Method == http.MethodPost && len(request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 {
|
||||
this.validate(actionConfig, setId, originURL, request, writer)
|
||||
} else {
|
||||
this.show(request, writer)
|
||||
this.show(actionConfig, request, writer)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *CaptchaValidator) show(request *requests.Request, writer http.ResponseWriter) {
|
||||
func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests.Request, writer http.ResponseWriter) {
|
||||
// show captcha
|
||||
captchaId := captcha.NewLen(6)
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
@@ -35,48 +70,86 @@ func (this *CaptchaValidator) show(request *requests.Request, writer http.Respon
|
||||
return
|
||||
}
|
||||
|
||||
var lang = actionConfig.Language
|
||||
if len(lang) == 0 {
|
||||
acceptLanguage := request.WAFRaw().Header.Get("Accept-Language")
|
||||
if len(acceptLanguage) > 0 {
|
||||
langIndex := strings.Index(acceptLanguage, ",")
|
||||
if langIndex > 0 {
|
||||
lang = acceptLanguage[:langIndex]
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(lang) == 0 {
|
||||
lang = "en-US"
|
||||
}
|
||||
|
||||
var msgTitle = ""
|
||||
var msgPrompt = ""
|
||||
var msgButtonTitle = ""
|
||||
|
||||
switch lang {
|
||||
case "en-US":
|
||||
msgTitle = "Verify Yourself"
|
||||
msgPrompt = "Input verify code above:"
|
||||
msgButtonTitle = "Verify Yourself"
|
||||
case "zh-CN":
|
||||
msgTitle = "身份验证"
|
||||
msgPrompt = "请输入上面的验证码"
|
||||
msgButtonTitle = "提交验证"
|
||||
default:
|
||||
msgTitle = "Verify Yourself"
|
||||
msgPrompt = "Input verify code above:"
|
||||
msgButtonTitle = "Verify Yourself"
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
_, _ = writer.Write([]byte(`<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Verify Yourself</title>
|
||||
<title>` + msgTitle + `</title>
|
||||
<script type="text/javascript">
|
||||
if (window.addEventListener != null) {
|
||||
window.addEventListener("load", function () {
|
||||
document.getElementById("GOEDGE_WAF_CAPTCHA_CODE").focus()
|
||||
})
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<form method="POST">
|
||||
<input type="hidden" name="TEAWEB_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
|
||||
<input type="hidden" name="GOEDGE_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
|
||||
<img src="data:image/png;base64, ` + base64.StdEncoding.EncodeToString(buf.Bytes()) + `"/>` + `
|
||||
<div>
|
||||
<p>Input verify code above:</p>
|
||||
<input type="text" name="TEAWEB_WAF_CAPTCHA_CODE" maxlength="6" size="18" autocomplete="off" z-index="1" style="font-size:16px;line-height:24px; letter-spacing: 15px; padding-left: 4px"/>
|
||||
<p>` + msgPrompt + `</p>
|
||||
<input type="text" name="GOEDGE_WAF_CAPTCHA_CODE" id="GOEDGE_WAF_CAPTCHA_CODE" maxlength="6" autocomplete="off" z-index="1" style="font-size:16px;line-height:24px; letter-spacing: 15px; padding-left: 4px; width: 160px"/>
|
||||
</div>
|
||||
<div>
|
||||
<button type="submit" onclick="window.location = '/webhook'" style="line-height:24px;margin-top:10px">Verify Yourself</button>
|
||||
<button type="submit" style="line-height:24px;margin-top:10px">` + msgButtonTitle + `</button>
|
||||
</div>
|
||||
</form>
|
||||
</body>
|
||||
</html>`))
|
||||
}
|
||||
|
||||
func (this *CaptchaValidator) validate(request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
captchaId := request.FormValue("TEAWEB_WAF_CAPTCHA_ID")
|
||||
func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, setId int64, originURL string, request requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||
captchaId := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")
|
||||
if len(captchaId) > 0 {
|
||||
captchaCode := request.FormValue("TEAWEB_WAF_CAPTCHA_CODE")
|
||||
captchaCode := request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_CODE")
|
||||
if captcha.VerifyString(captchaId, captchaCode) {
|
||||
// set cookie
|
||||
timestamp := fmt.Sprintf("%d", time.Now().Unix()+CaptchaSeconds)
|
||||
m := stringutil.Md5(captchaSalt + timestamp)
|
||||
http.SetCookie(writer, &http.Cookie{
|
||||
Name: "TEAWEB_WAF_CAPTCHA",
|
||||
Value: m + timestamp,
|
||||
MaxAge: CaptchaSeconds, // TODO 这个时间可以设置
|
||||
Path: "/", // all of dirs
|
||||
})
|
||||
var life = CaptchaSeconds
|
||||
if actionConfig.Life > 0 {
|
||||
life = types.Int(actionConfig.Life)
|
||||
}
|
||||
|
||||
rawURL := request.URL.Query().Get("url")
|
||||
http.Redirect(writer, request.Raw(), rawURL, http.StatusSeeOther)
|
||||
// 加入到白名单
|
||||
SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), request.WAFRemoteIP(), time.Now().Unix()+int64(life)) // TODO
|
||||
|
||||
http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther)
|
||||
|
||||
return false
|
||||
} else {
|
||||
http.Redirect(writer, request.Raw(), request.URL.String(), http.StatusSeeOther)
|
||||
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusSeeOther)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,14 +5,12 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ${cc.arg}
|
||||
// CCCheckpoint ${cc.arg}
|
||||
// TODO implement more traffic rules
|
||||
type CCCheckpoint struct {
|
||||
Checkpoint
|
||||
@@ -32,7 +30,7 @@ func (this *CCCheckpoint) Start() {
|
||||
this.cache = ttlcache.NewCache()
|
||||
}
|
||||
|
||||
func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *CCCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = 0
|
||||
|
||||
if this.cache == nil {
|
||||
@@ -66,12 +64,12 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
|
||||
var key = ""
|
||||
switch userType {
|
||||
case "ip":
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
case "cookie":
|
||||
if len(userField) == 0 {
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
} else {
|
||||
cookie, _ := req.Cookie(userField)
|
||||
cookie, _ := req.WAFRaw().Cookie(userField)
|
||||
if cookie != nil {
|
||||
v := cookie.Value
|
||||
if userIndex > 0 && len(v) > userIndex {
|
||||
@@ -82,9 +80,9 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
|
||||
}
|
||||
case "get":
|
||||
if len(userField) == 0 {
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
} else {
|
||||
v := req.URL.Query().Get(userField)
|
||||
v := req.WAFRaw().URL.Query().Get(userField)
|
||||
if userIndex > 0 && len(v) > userIndex {
|
||||
v = v[userIndex:]
|
||||
}
|
||||
@@ -92,9 +90,9 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
|
||||
}
|
||||
case "post":
|
||||
if len(userField) == 0 {
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
} else {
|
||||
v := req.PostFormValue(userField)
|
||||
v := req.WAFRaw().PostFormValue(userField)
|
||||
if userIndex > 0 && len(v) > userIndex {
|
||||
v = v[userIndex:]
|
||||
}
|
||||
@@ -102,19 +100,19 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
|
||||
}
|
||||
case "header":
|
||||
if len(userField) == 0 {
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
} else {
|
||||
v := req.Header.Get(userField)
|
||||
v := req.WAFRaw().Header.Get(userField)
|
||||
if userIndex > 0 && len(v) > userIndex {
|
||||
v = v[userIndex:]
|
||||
}
|
||||
key = "USER@" + userType + "@" + userField + "@" + v
|
||||
}
|
||||
default:
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
}
|
||||
if len(key) == 0 {
|
||||
key = this.ip(req)
|
||||
key = req.WAFRemoteIP()
|
||||
}
|
||||
value = this.cache.IncreaseInt64(key, int64(1), time.Now().Unix()+period)
|
||||
}
|
||||
@@ -122,7 +120,7 @@ func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, opti
|
||||
return
|
||||
}
|
||||
|
||||
func (this *CCCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *CCCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
@@ -210,38 +208,3 @@ func (this *CCCheckpoint) Stop() {
|
||||
this.cache = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (this *CCCheckpoint) ip(req *requests.Request) string {
|
||||
// X-Forwarded-For
|
||||
forwardedFor := req.Header.Get("X-Forwarded-For")
|
||||
if len(forwardedFor) > 0 {
|
||||
commaIndex := strings.Index(forwardedFor, ",")
|
||||
if commaIndex > 0 {
|
||||
return forwardedFor[:commaIndex]
|
||||
}
|
||||
return forwardedFor
|
||||
}
|
||||
|
||||
// Real-IP
|
||||
{
|
||||
realIP, ok := req.Header["X-Real-IP"]
|
||||
if ok && len(realIP) > 0 {
|
||||
return realIP[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Real-Ip
|
||||
{
|
||||
realIP, ok := req.Header["X-Real-Ip"]
|
||||
if ok && len(realIP) > 0 {
|
||||
return realIP[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Remote-Addr
|
||||
host, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
return host
|
||||
}
|
||||
return req.RemoteAddr
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package checkpoints
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
@@ -12,31 +13,31 @@ func TestCCCheckpoint_RequestValue(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := requests.NewRequest(raw)
|
||||
req.RemoteAddr = "127.0.0.1"
|
||||
req := requests.NewTestRequest(raw)
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.1"
|
||||
|
||||
checkpoint := new(CCCheckpoint)
|
||||
checkpoint.Init()
|
||||
checkpoint.Start()
|
||||
|
||||
options := map[string]string{
|
||||
options := maps.Map{
|
||||
"period": "5",
|
||||
}
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
|
||||
req.RemoteAddr = "127.0.0.2"
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.2"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
|
||||
req.RemoteAddr = "127.0.0.1"
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.1"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
|
||||
req.RemoteAddr = "127.0.0.2"
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.2"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
|
||||
req.RemoteAddr = "127.0.0.2"
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.2"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
|
||||
req.RemoteAddr = "127.0.0.2"
|
||||
req.WAFRaw().RemoteAddr = "127.0.0.2"
|
||||
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||
}
|
||||
|
||||
@@ -5,32 +5,32 @@ import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
)
|
||||
|
||||
// Check Point
|
||||
// CheckpointInterface Check Point
|
||||
type CheckpointInterface interface {
|
||||
// initialize
|
||||
// Init initialize
|
||||
Init()
|
||||
|
||||
// is request?
|
||||
// IsRequest is request?
|
||||
IsRequest() bool
|
||||
|
||||
// is composed?
|
||||
// IsComposed is composed?
|
||||
IsComposed() bool
|
||||
|
||||
// get request value
|
||||
RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
|
||||
// RequestValue get request value
|
||||
RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
|
||||
|
||||
// get response value
|
||||
ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
|
||||
// ResponseValue get response value
|
||||
ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error)
|
||||
|
||||
// param option list
|
||||
// ParamOptions param option list
|
||||
ParamOptions() *ParamOptions
|
||||
|
||||
// options
|
||||
// Options options
|
||||
Options() []OptionInterface
|
||||
|
||||
// start
|
||||
// Start start
|
||||
Start()
|
||||
|
||||
// stop
|
||||
// Stop stop
|
||||
Stop()
|
||||
}
|
||||
|
||||
@@ -5,32 +5,34 @@ import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
)
|
||||
|
||||
// ${requestAll}
|
||||
// RequestAllCheckpoint ${requestAll}
|
||||
type RequestAllCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
valueBytes := []byte{}
|
||||
if len(req.RequestURI) > 0 {
|
||||
valueBytes = append(valueBytes, req.RequestURI...)
|
||||
} else if req.URL != nil {
|
||||
valueBytes = append(valueBytes, req.URL.RequestURI()...)
|
||||
if len(req.WAFRaw().RequestURI) > 0 {
|
||||
valueBytes = append(valueBytes, req.WAFRaw().RequestURI...)
|
||||
} else if req.WAFRaw().URL != nil {
|
||||
valueBytes = append(valueBytes, req.WAFRaw().URL.RequestURI()...)
|
||||
}
|
||||
|
||||
if req.Body != nil {
|
||||
if req.WAFRaw().Body != nil {
|
||||
valueBytes = append(valueBytes, ' ')
|
||||
|
||||
if len(req.BodyData) == 0 {
|
||||
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||
var bodyData = req.WAFGetCacheBody()
|
||||
if len(bodyData) == 0 {
|
||||
data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||
if err != nil {
|
||||
return "", err, nil
|
||||
}
|
||||
|
||||
req.BodyData = data
|
||||
req.RestoreBody(data)
|
||||
bodyData = data
|
||||
req.WAFSetCacheBody(data)
|
||||
req.WAFRestoreBody(data)
|
||||
}
|
||||
valueBytes = append(valueBytes, req.BodyData...)
|
||||
valueBytes = append(valueBytes, bodyData...)
|
||||
}
|
||||
|
||||
value = valueBytes
|
||||
@@ -38,7 +40,7 @@ func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param stri
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestAllCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestAllCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = ""
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestRequestAllCheckpoint_RequestValue(t *testing.T) {
|
||||
}
|
||||
|
||||
checkpoint := new(RequestAllCheckpoint)
|
||||
v, sysErr, userErr := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
|
||||
v, sysErr, userErr := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
|
||||
if sysErr != nil {
|
||||
t.Fatal(sysErr)
|
||||
}
|
||||
@@ -42,7 +42,7 @@ func TestRequestAllCheckpoint_RequestValue_Max(t *testing.T) {
|
||||
}
|
||||
|
||||
checkpoint := new(RequestBodyCheckpoint)
|
||||
value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
|
||||
value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -65,6 +65,6 @@ func BenchmarkRequestAllCheckpoint_RequestValue(b *testing.B) {
|
||||
|
||||
checkpoint := new(RequestAllCheckpoint)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = checkpoint.RequestValue(requests.NewRequest(req), "", nil)
|
||||
_, _, _ = checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,11 +9,11 @@ type RequestArgCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
return req.URL.Query().Get(param), nil, nil
|
||||
func (this *RequestArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
return req.WAFRaw().URL.Query().Get(param), nil, nil
|
||||
}
|
||||
|
||||
func (this *RequestArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ func TestArgParam_RequestValue(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := requests.NewRequest(rawReq)
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
|
||||
checkpoint := new(RequestArgCheckpoint)
|
||||
t.Log(checkpoint.RequestValue(req, "name", nil))
|
||||
|
||||
@@ -9,12 +9,12 @@ type RequestArgsCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestArgsCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.URL.RawQuery
|
||||
func (this *RequestArgsCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().URL.RawQuery
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestArgsCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestArgsCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -5,31 +5,33 @@ import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
)
|
||||
|
||||
// ${requestBody}
|
||||
// RequestBodyCheckpoint ${requestBody}
|
||||
type RequestBodyCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestBodyCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if req.Body == nil {
|
||||
func (this *RequestBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if req.WAFRaw().Body == nil {
|
||||
value = ""
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.BodyData) == 0 {
|
||||
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||
var bodyData = req.WAFGetCacheBody()
|
||||
if len(bodyData) == 0 {
|
||||
data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||
if err != nil {
|
||||
return "", err, nil
|
||||
}
|
||||
|
||||
req.BodyData = data
|
||||
req.RestoreBody(data)
|
||||
bodyData = data
|
||||
req.WAFSetCacheBody(data)
|
||||
req.WAFRestoreBody(data)
|
||||
}
|
||||
|
||||
return req.BodyData, nil, nil
|
||||
return bodyData, nil, nil
|
||||
}
|
||||
|
||||
func (this *RequestBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -11,19 +11,20 @@ import (
|
||||
)
|
||||
|
||||
func TestRequestBodyCheckpoint_RequestValue(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456")))
|
||||
rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456")))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var req = requests.NewTestRequest(rawReq)
|
||||
checkpoint := new(RequestBodyCheckpoint)
|
||||
t.Log(checkpoint.RequestValue(requests.NewRequest(req), "", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "", nil))
|
||||
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
body, err := ioutil.ReadAll(rawReq.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(string(body))
|
||||
t.Log(string(req.WAFGetCacheBody()))
|
||||
}
|
||||
|
||||
func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
|
||||
@@ -33,7 +34,7 @@ func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
|
||||
}
|
||||
|
||||
checkpoint := new(RequestBodyCheckpoint)
|
||||
value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
|
||||
value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -9,12 +9,12 @@ type RequestContentTypeCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestContentTypeCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.Header.Get("Content-Type")
|
||||
func (this *RequestContentTypeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().Header.Get("Content-Type")
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestContentTypeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestContentTypeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ type RequestCookieCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
cookie, err := req.Cookie(param)
|
||||
func (this *RequestCookieCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
cookie, err := req.WAFRaw().Cookie(param)
|
||||
if err != nil {
|
||||
value = ""
|
||||
return
|
||||
@@ -20,7 +20,7 @@ func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param s
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestCookieCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestCookieCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -11,16 +11,16 @@ type RequestCookiesCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestCookiesCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestCookiesCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
var cookies = []string{}
|
||||
for _, cookie := range req.Cookies() {
|
||||
for _, cookie := range req.WAFRaw().Cookies() {
|
||||
cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value))
|
||||
}
|
||||
value = strings.Join(cookies, "&")
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestCookiesCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestCookiesCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -6,33 +6,35 @@ import (
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// ${requestForm.arg}
|
||||
// RequestFormArgCheckpoint ${requestForm.arg}
|
||||
type RequestFormArgCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestFormArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if req.Body == nil {
|
||||
func (this *RequestFormArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if req.WAFRaw().Body == nil {
|
||||
value = ""
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.BodyData) == 0 {
|
||||
data, err := req.ReadBody(32 * 1024 * 1024) // read 32m bytes
|
||||
var bodyData = req.WAFGetCacheBody()
|
||||
if len(bodyData) == 0 {
|
||||
data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||
if err != nil {
|
||||
return "", err, nil
|
||||
}
|
||||
|
||||
req.BodyData = data
|
||||
req.RestoreBody(data)
|
||||
bodyData = data
|
||||
req.WAFSetCacheBody(data)
|
||||
req.WAFRestoreBody(data)
|
||||
}
|
||||
|
||||
// TODO improve performance
|
||||
values, _ := url.ParseQuery(string(req.BodyData))
|
||||
values, _ := url.ParseQuery(string(bodyData))
|
||||
return values.Get(param), nil, nil
|
||||
}
|
||||
|
||||
func (this *RequestFormArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestFormArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -15,8 +15,8 @@ func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := requests.NewRequest(rawReq)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
req.WAFRaw().Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
checkpoint := new(RequestFormArgCheckpoint)
|
||||
t.Log(checkpoint.RequestValue(req, "name", nil))
|
||||
@@ -24,7 +24,7 @@ func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
|
||||
t.Log(checkpoint.RequestValue(req, "Hello", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "encoded", nil))
|
||||
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
body, err := ioutil.ReadAll(req.WAFRaw().Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ func (this *RequestGeneralHeaderLengthCheckpoint) IsComposed() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = false
|
||||
|
||||
headers := options.GetSlice("headers")
|
||||
@@ -25,7 +25,7 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Req
|
||||
length := options.GetInt("length")
|
||||
|
||||
for _, header := range headers {
|
||||
v := req.Header.Get(types.String(header))
|
||||
v := req.WAFRaw().Header.Get(types.String(header))
|
||||
if len(v) > length {
|
||||
value = true
|
||||
break
|
||||
@@ -35,6 +35,6 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Req
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ type RequestHeaderCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
v, found := req.Header[param]
|
||||
func (this *RequestHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
v, found := req.WAFRaw().Header[param]
|
||||
if !found {
|
||||
value = ""
|
||||
return
|
||||
@@ -20,7 +20,7 @@ func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param s
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -11,9 +11,9 @@ type RequestHeadersCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestHeadersCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
var headers = []string{}
|
||||
for k, v := range req.Header {
|
||||
for k, v := range req.WAFRaw().Header {
|
||||
for _, subV := range v {
|
||||
headers = append(headers, k+": "+subV)
|
||||
}
|
||||
@@ -23,7 +23,7 @@ func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestHeadersCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestHeadersCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -9,12 +9,12 @@ type RequestHostCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestHostCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.Host
|
||||
func (this *RequestHostCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().Host
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestHostCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestHostCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ func TestRequestHostCheckpoint_RequestValue(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := requests.NewRequest(rawReq)
|
||||
req.Header.Set("Host", "cloud.teaos.cn")
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
req.WAFRaw().Header.Set("Host", "cloud.teaos.cn")
|
||||
|
||||
checkpoint := new(RequestHostCheckpoint)
|
||||
t.Log(checkpoint.RequestValue(req, "", nil))
|
||||
|
||||
@@ -8,24 +8,27 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ${requestJSON.arg}
|
||||
// RequestJSONArgCheckpoint ${requestJSON.arg}
|
||||
type RequestJSONArgCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if len(req.BodyData) == 0 {
|
||||
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||
func (this *RequestJSONArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
var bodyData = req.WAFGetCacheBody()
|
||||
if len(bodyData) == 0 {
|
||||
data, err := req.WAFReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||
if err != nil {
|
||||
return "", err, nil
|
||||
}
|
||||
req.BodyData = data
|
||||
defer req.RestoreBody(data)
|
||||
|
||||
bodyData = data
|
||||
req.WAFSetCacheBody(data)
|
||||
defer req.WAFRestoreBody(data)
|
||||
}
|
||||
|
||||
// TODO improve performance
|
||||
var m interface{} = nil
|
||||
err := json.Unmarshal(req.BodyData, &m)
|
||||
err := json.Unmarshal(bodyData, &m)
|
||||
if err != nil || m == nil {
|
||||
return "", nil, err
|
||||
}
|
||||
@@ -37,7 +40,7 @@ func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func (this *RequestJSONArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestJSONArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := requests.NewRequest(rawReq)
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
checkpoint := new(RequestJSONArgCheckpoint)
|
||||
@@ -31,7 +31,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) {
|
||||
t.Log(checkpoint.RequestValue(req, "books", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "books.1", nil))
|
||||
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
body, err := ioutil.ReadAll(req.WAFRaw().Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -50,7 +50,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := requests.NewRequest(rawReq)
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
checkpoint := new(RequestJSONArgCheckpoint)
|
||||
@@ -61,7 +61,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) {
|
||||
t.Log(checkpoint.RequestValue(req, "0.books", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
|
||||
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
body, err := ioutil.ReadAll(req.WAFRaw().Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := requests.NewRequest(rawReq)
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
checkpoint := new(RequestJSONArgCheckpoint)
|
||||
@@ -91,7 +91,7 @@ func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) {
|
||||
t.Log(checkpoint.RequestValue(req, "0.books", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
|
||||
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
body, err := ioutil.ReadAll(req.WAFRaw().Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -9,12 +9,12 @@ type RequestLengthCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.ContentLength
|
||||
func (this *RequestLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().ContentLength
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -9,12 +9,12 @@ type RequestMethodCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestMethodCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.Method
|
||||
func (this *RequestMethodCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().Method
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestMethodCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestMethodCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -9,11 +9,11 @@ type RequestPathCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestPathCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
return req.URL.Path, nil, nil
|
||||
func (this *RequestPathCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
return req.WAFRaw().URL.Path, nil, nil
|
||||
}
|
||||
|
||||
func (this *RequestPathCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestPathCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ func TestRequestPathCheckpoint_RequestValue(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := requests.NewRequest(rawReq)
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
checkpoint := new(RequestPathCheckpoint)
|
||||
t.Log(checkpoint.RequestValue(req, "", nil))
|
||||
}
|
||||
|
||||
@@ -9,12 +9,12 @@ type RequestProtoCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestProtoCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.Proto
|
||||
func (this *RequestProtoCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().Proto
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestProtoCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestProtoCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -10,17 +10,17 @@ type RequestRawRemoteAddrCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
host, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
host, _, err := net.SplitHostPort(req.WAFRaw().RemoteAddr)
|
||||
if err == nil {
|
||||
value = host
|
||||
} else {
|
||||
value = req.RemoteAddr
|
||||
value = req.WAFRaw().RemoteAddr
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -9,12 +9,12 @@ type RequestRefererCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestRefererCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.Referer()
|
||||
func (this *RequestRefererCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().Referer()
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestRefererCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestRefererCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -3,56 +3,18 @@ package checkpoints
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type RequestRemoteAddrCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
// X-Forwarded-For
|
||||
forwardedFor := req.Header.Get("X-Forwarded-For")
|
||||
if len(forwardedFor) > 0 {
|
||||
commaIndex := strings.Index(forwardedFor, ",")
|
||||
if commaIndex > 0 {
|
||||
value = forwardedFor[:commaIndex]
|
||||
return
|
||||
}
|
||||
value = forwardedFor
|
||||
return
|
||||
}
|
||||
|
||||
// Real-IP
|
||||
{
|
||||
realIP, ok := req.Header["X-Real-IP"]
|
||||
if ok && len(realIP) > 0 {
|
||||
value = realIP[0]
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Real-Ip
|
||||
{
|
||||
realIP, ok := req.Header["X-Real-Ip"]
|
||||
if ok && len(realIP) > 0 {
|
||||
value = realIP[0]
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Remote-Addr
|
||||
host, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
value = host
|
||||
} else {
|
||||
value = req.RemoteAddr
|
||||
}
|
||||
func (this *RequestRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.WAFRemoteIP()
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -11,8 +11,8 @@ type RequestRemotePortCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
_, port, err := net.SplitHostPort(req.RemoteAddr)
|
||||
func (this *RequestRemotePortCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
_, port, err := net.SplitHostPort(req.WAFRaw().RemoteAddr)
|
||||
if err == nil {
|
||||
value = types.Int(port)
|
||||
} else {
|
||||
@@ -21,7 +21,7 @@ func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, par
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestRemotePortCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestRemotePortCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ type RequestRemoteUserCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
username, _, ok := req.BasicAuth()
|
||||
func (this *RequestRemoteUserCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
username, _, ok := req.WAFRaw().BasicAuth()
|
||||
if !ok {
|
||||
value = ""
|
||||
return
|
||||
@@ -19,7 +19,7 @@ func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, par
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestRemoteUserCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestRemoteUserCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -9,12 +9,12 @@ type RequestSchemeCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestSchemeCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.URL.Scheme
|
||||
func (this *RequestSchemeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().URL.Scheme
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestSchemeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestSchemeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ func TestRequestSchemeCheckpoint_RequestValue(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := requests.NewRequest(rawReq)
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
checkpoint := new(RequestSchemeCheckpoint)
|
||||
t.Log(checkpoint.RequestValue(req, "", nil))
|
||||
}
|
||||
|
||||
@@ -11,63 +11,65 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ${requestUpload.arg}
|
||||
// RequestUploadCheckpoint ${requestUpload.arg}
|
||||
type RequestUploadCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestUploadCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = ""
|
||||
if param == "minSize" || param == "maxSize" {
|
||||
value = 0
|
||||
}
|
||||
|
||||
if req.Method != http.MethodPost {
|
||||
if req.WAFRaw().Method != http.MethodPost {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Body == nil {
|
||||
if req.WAFRaw().Body == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if req.MultipartForm == nil {
|
||||
if len(req.BodyData) == 0 {
|
||||
data, err := req.ReadBody(32 * 1024 * 1024)
|
||||
if req.WAFRaw().MultipartForm == nil {
|
||||
var bodyData = req.WAFGetCacheBody()
|
||||
if len(bodyData) == 0 {
|
||||
data, err := req.WAFReadBody(32 * 1024 * 1024)
|
||||
if err != nil {
|
||||
sysErr = err
|
||||
return
|
||||
}
|
||||
|
||||
req.BodyData = data
|
||||
defer req.RestoreBody(data)
|
||||
bodyData = data
|
||||
req.WAFSetCacheBody(data)
|
||||
defer req.WAFRestoreBody(data)
|
||||
}
|
||||
oldBody := req.Body
|
||||
req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyData))
|
||||
oldBody := req.WAFRaw().Body
|
||||
req.WAFRaw().Body = ioutil.NopCloser(bytes.NewBuffer(bodyData))
|
||||
|
||||
err := req.ParseMultipartForm(32 * 1024 * 1024)
|
||||
err := req.WAFRaw().ParseMultipartForm(32 * 1024 * 1024)
|
||||
|
||||
// 还原
|
||||
req.Body = oldBody
|
||||
req.WAFRaw().Body = oldBody
|
||||
|
||||
if err != nil {
|
||||
userErr = err
|
||||
return
|
||||
}
|
||||
|
||||
if req.MultipartForm == nil {
|
||||
if req.WAFRaw().MultipartForm == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if param == "field" { // field
|
||||
fields := []string{}
|
||||
for field := range req.MultipartForm.File {
|
||||
for field := range req.WAFRaw().MultipartForm.File {
|
||||
fields = append(fields, field)
|
||||
}
|
||||
value = strings.Join(fields, ",")
|
||||
} else if param == "minSize" { // minSize
|
||||
minSize := int64(0)
|
||||
for _, files := range req.MultipartForm.File {
|
||||
for _, files := range req.WAFRaw().MultipartForm.File {
|
||||
for _, file := range files {
|
||||
if minSize == 0 || minSize > file.Size {
|
||||
minSize = file.Size
|
||||
@@ -77,7 +79,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
|
||||
value = minSize
|
||||
} else if param == "maxSize" { // maxSize
|
||||
maxSize := int64(0)
|
||||
for _, files := range req.MultipartForm.File {
|
||||
for _, files := range req.WAFRaw().MultipartForm.File {
|
||||
for _, file := range files {
|
||||
if maxSize < file.Size {
|
||||
maxSize = file.Size
|
||||
@@ -87,7 +89,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
|
||||
value = maxSize
|
||||
} else if param == "name" { // name
|
||||
names := []string{}
|
||||
for _, files := range req.MultipartForm.File {
|
||||
for _, files := range req.WAFRaw().MultipartForm.File {
|
||||
for _, file := range files {
|
||||
if !lists.ContainsString(names, file.Filename) {
|
||||
names = append(names, file.Filename)
|
||||
@@ -97,7 +99,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
|
||||
value = strings.Join(names, ",")
|
||||
} else if param == "ext" { // ext
|
||||
extensions := []string{}
|
||||
for _, files := range req.MultipartForm.File {
|
||||
for _, files := range req.WAFRaw().MultipartForm.File {
|
||||
for _, file := range files {
|
||||
if len(file.Filename) > 0 {
|
||||
exit := strings.ToLower(filepath.Ext(file.Filename))
|
||||
@@ -113,7 +115,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param s
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestUploadCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestUploadCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -63,8 +63,8 @@ func TestRequestUploadCheckpoint_RequestValue(t *testing.T) {
|
||||
t.Fatal()
|
||||
}
|
||||
|
||||
req := requests.NewRequest(rawReq)
|
||||
req.Header.Add("Content-Type", writer.FormDataContentType())
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
req.WAFRaw().Header.Add("Content-Type", writer.FormDataContentType())
|
||||
|
||||
checkpoint := new(RequestUploadCheckpoint)
|
||||
t.Log(checkpoint.RequestValue(req, "field", nil))
|
||||
@@ -73,7 +73,7 @@ func TestRequestUploadCheckpoint_RequestValue(t *testing.T) {
|
||||
t.Log(checkpoint.RequestValue(req, "name", nil))
|
||||
t.Log(checkpoint.RequestValue(req, "ext", nil))
|
||||
|
||||
data, err := ioutil.ReadAll(req.Body)
|
||||
data, err := ioutil.ReadAll(req.WAFRaw().Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -9,16 +9,16 @@ type RequestURICheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestURICheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if len(req.RequestURI) > 0 {
|
||||
value = req.RequestURI
|
||||
} else if req.URL != nil {
|
||||
value = req.URL.RequestURI()
|
||||
func (this *RequestURICheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if len(req.WAFRaw().RequestURI) > 0 {
|
||||
value = req.WAFRaw().RequestURI
|
||||
} else if req.WAFRaw().URL != nil {
|
||||
value = req.WAFRaw().URL.RequestURI()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestURICheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestURICheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -9,12 +9,12 @@ type RequestUserAgentCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *RequestUserAgentCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.UserAgent()
|
||||
func (this *RequestUserAgentCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = req.WAFRaw().UserAgent()
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RequestUserAgentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *RequestUserAgentCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -16,12 +16,12 @@ func (this *ResponseBodyCheckpoint) IsRequest() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *ResponseBodyCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *ResponseBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = ""
|
||||
return
|
||||
}
|
||||
|
||||
func (this *ResponseBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *ResponseBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = ""
|
||||
if resp != nil && resp.Body != nil {
|
||||
if len(resp.BodyData) > 0 {
|
||||
|
||||
@@ -14,12 +14,12 @@ func (this *ResponseBytesSentCheckpoint) IsRequest() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *ResponseBytesSentCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *ResponseBytesSentCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = 0
|
||||
return
|
||||
}
|
||||
|
||||
func (this *ResponseBytesSentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *ResponseBytesSentCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = 0
|
||||
if resp != nil {
|
||||
value = resp.ContentLength
|
||||
|
||||
@@ -18,12 +18,12 @@ func (this *ResponseGeneralHeaderLengthCheckpoint) IsComposed() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *ResponseGeneralHeaderLengthCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *ResponseGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = false
|
||||
|
||||
headers := options.GetSlice("headers")
|
||||
@@ -34,7 +34,7 @@ func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req *requests.R
|
||||
length := options.GetInt("length")
|
||||
|
||||
for _, header := range headers {
|
||||
v := req.Header.Get(types.String(header))
|
||||
v := req.WAFRaw().Header.Get(types.String(header))
|
||||
if len(v) > length {
|
||||
value = true
|
||||
break
|
||||
|
||||
@@ -14,12 +14,12 @@ func (this *ResponseHeaderCheckpoint) IsRequest() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *ResponseHeaderCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *ResponseHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = ""
|
||||
return
|
||||
}
|
||||
|
||||
func (this *ResponseHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *ResponseHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if resp != nil && resp.Header != nil {
|
||||
value = resp.Header.Get(param)
|
||||
} else {
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
)
|
||||
|
||||
// ${bytesSent}
|
||||
// ResponseStatusCheckpoint ${bytesSent}
|
||||
type ResponseStatusCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
@@ -14,12 +14,12 @@ func (this *ResponseStatusCheckpoint) IsRequest() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *ResponseStatusCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *ResponseStatusCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
value = 0
|
||||
return
|
||||
}
|
||||
|
||||
func (this *ResponseStatusCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *ResponseStatusCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if resp != nil {
|
||||
value = resp.StatusCode
|
||||
}
|
||||
|
||||
@@ -5,16 +5,16 @@ import (
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
)
|
||||
|
||||
// just a sample checkpoint, copy and change it for your new checkpoint
|
||||
// SampleRequestCheckpoint just a sample checkpoint, copy and change it for your new checkpoint
|
||||
type SampleRequestCheckpoint struct {
|
||||
Checkpoint
|
||||
}
|
||||
|
||||
func (this *SampleRequestCheckpoint) RequestValue(req *requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *SampleRequestCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (this *SampleRequestCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
func (this *SampleRequestCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
|
||||
if this.IsRequest() {
|
||||
return this.RequestValue(req, param, options)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package checkpoints
|
||||
|
||||
// all check points list
|
||||
// AllCheckpoints all check points list
|
||||
var AllCheckpoints = []*CheckpointDefinition{
|
||||
{
|
||||
Name: "通用请求Header长度限制",
|
||||
|
||||
52
internal/waf/get302_validator.go
Normal file
52
internal/waf/get302_validator.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var get302Validator = NewGet302Validator()
|
||||
|
||||
type Get302Validator struct {
|
||||
}
|
||||
|
||||
func NewGet302Validator() *Get302Validator {
|
||||
return &Get302Validator{}
|
||||
}
|
||||
|
||||
func (this *Get302Validator) Run(request requests.Request, writer http.ResponseWriter) {
|
||||
var info = request.WAFRaw().URL.Query().Get("info")
|
||||
if len(info) == 0 {
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = writer.Write([]byte("invalid request"))
|
||||
return
|
||||
}
|
||||
m, err := utils.SimpleDecryptMap(info)
|
||||
if err != nil {
|
||||
_, _ = writer.Write([]byte("invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
var timestamp = m.GetInt64("timestamp")
|
||||
if time.Now().Unix()-timestamp > 5 { // 超过5秒认为失效
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = writer.Write([]byte("invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
// 加入白名单
|
||||
life := m.GetInt64("life")
|
||||
if life <= 0 {
|
||||
life = 600 // 默认10分钟
|
||||
}
|
||||
setId := m.GetString("setId")
|
||||
SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life)
|
||||
|
||||
// 返回原始URL
|
||||
var url = m.GetString("url")
|
||||
http.Redirect(writer, request.WAFRaw(), url, http.StatusFound)
|
||||
}
|
||||
82
internal/waf/ip_list.go
Normal file
82
internal/waf/ip_list.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var SharedIPWhiteList = NewIPList()
|
||||
var SharedIPBlackLIst = NewIPList()
|
||||
|
||||
const IPTypeAll = "*"
|
||||
|
||||
// IPList IP列表管理
|
||||
type IPList struct {
|
||||
expireList *expires.List
|
||||
ipMap map[string]int64 // ip => id
|
||||
idMap map[int64]string // id => ip
|
||||
|
||||
id int64
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
// NewIPList 获取新对象
|
||||
func NewIPList() *IPList {
|
||||
var list = &IPList{
|
||||
ipMap: map[string]int64{},
|
||||
idMap: map[int64]string{},
|
||||
}
|
||||
|
||||
e := expires.NewList()
|
||||
list.expireList = e
|
||||
|
||||
go func() {
|
||||
e.StartGC(func(itemId int64) {
|
||||
list.remove(itemId)
|
||||
})
|
||||
}()
|
||||
|
||||
return list
|
||||
}
|
||||
|
||||
// Add 添加IP
|
||||
func (this *IPList) Add(ipType string, ip string, expiresAt int64) {
|
||||
ip = ip + "@" + ipType
|
||||
|
||||
var id = this.nextId()
|
||||
this.expireList.Add(id, expiresAt)
|
||||
this.locker.Lock()
|
||||
this.ipMap[ip] = id
|
||||
this.idMap[id] = ip
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// Contains 判断是否有某个IP
|
||||
func (this *IPList) Contains(ipType string, ip string) bool {
|
||||
ip = ip + "@" + ipType
|
||||
|
||||
this.locker.RLock()
|
||||
defer this.locker.RUnlock()
|
||||
_, ok := this.ipMap[ip]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (this *IPList) remove(id int64) {
|
||||
this.locker.Lock()
|
||||
ip, ok := this.idMap[id]
|
||||
if ok {
|
||||
ipId, ok := this.ipMap[ip]
|
||||
if ok && ipId == id {
|
||||
delete(this.ipMap, ip)
|
||||
}
|
||||
delete(this.idMap, id)
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
func (this *IPList) nextId() int64 {
|
||||
return atomic.AddInt64(&this.id, 1)
|
||||
}
|
||||
67
internal/waf/ip_list_test.go
Normal file
67
internal/waf/ip_list_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewIPList(t *testing.T) {
|
||||
list := NewIPList()
|
||||
list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix())
|
||||
list.Add(IPTypeAll, "127.0.0.2", time.Now().Unix()+1)
|
||||
list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix()+2)
|
||||
list.Add(IPTypeAll, "127.0.0.3", time.Now().Unix()+3)
|
||||
list.Add(IPTypeAll, "127.0.0.10", time.Now().Unix()+10)
|
||||
|
||||
var ticker = time.NewTicker(1 * time.Second)
|
||||
for range ticker.C {
|
||||
t.Log("====")
|
||||
logs.PrintAsJSON(list.ipMap, t)
|
||||
logs.PrintAsJSON(list.idMap, t)
|
||||
if len(list.idMap) == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPList_Contains(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
|
||||
list := NewIPList()
|
||||
|
||||
for i := 0; i < 1_0000; i++ {
|
||||
list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
|
||||
}
|
||||
a.IsTrue(list.Contains(IPTypeAll, "192.168.1.100"))
|
||||
a.IsFalse(list.Contains(IPTypeAll, "192.168.2.100"))
|
||||
}
|
||||
|
||||
func BenchmarkIPList_Add(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
list := NewIPList()
|
||||
for i := 0; i < b.N; i++ {
|
||||
list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
|
||||
}
|
||||
b.Log(len(list.ipMap))
|
||||
}
|
||||
|
||||
func BenchmarkIPList_Has(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
list := NewIPList()
|
||||
|
||||
for i := 0; i < 1_0000; i++ {
|
||||
list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
list.Contains(IPTypeAll, "192.168.1.100")
|
||||
}
|
||||
}
|
||||
@@ -1,154 +0,0 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type IPAction = string
|
||||
|
||||
var RegexpDigitNumber = regexp.MustCompile("^\\d+$")
|
||||
|
||||
const (
|
||||
IPActionAccept IPAction = "accept"
|
||||
IPActionReject IPAction = "reject"
|
||||
)
|
||||
|
||||
// ip table
|
||||
type IPTable struct {
|
||||
Id string `yaml:"id" json:"id"`
|
||||
On bool `yaml:"on" json:"on"`
|
||||
IP string `yaml:"ip" json:"ip"` // single ip, cidr, ip range, TODO support *
|
||||
Port string `yaml:"port" json:"port"` // single port, range, *
|
||||
Action IPAction `yaml:"action" json:"action"` // accept, reject
|
||||
TimeFrom int64 `yaml:"timeFrom" json:"timeFrom"` // from timestamp
|
||||
TimeTo int64 `yaml:"timeTo" json:"timeTo"` // zero means forever
|
||||
Remark string `yaml:"remark" json:"remark"`
|
||||
|
||||
// port
|
||||
minPort int
|
||||
maxPort int
|
||||
|
||||
minPortWildcard bool
|
||||
maxPortWildcard bool
|
||||
|
||||
ports []int
|
||||
|
||||
// ip
|
||||
ipRange *shared.IPRangeConfig
|
||||
}
|
||||
|
||||
func NewIPTable() *IPTable {
|
||||
return &IPTable{
|
||||
On: true,
|
||||
Id: stringutil.Rand(16),
|
||||
}
|
||||
}
|
||||
|
||||
func (this *IPTable) Init() error {
|
||||
// parse port
|
||||
if RegexpDigitNumber.MatchString(this.Port) {
|
||||
this.minPort = types.Int(this.Port)
|
||||
this.maxPort = types.Int(this.Port)
|
||||
} else if regexp.MustCompile(`[:-]`).MatchString(this.Port) {
|
||||
pieces := regexp.MustCompile(`[:-]`).Split(this.Port, 2)
|
||||
if pieces[0] == "*" {
|
||||
this.minPortWildcard = true
|
||||
} else {
|
||||
this.minPort = types.Int(pieces[0])
|
||||
}
|
||||
if pieces[1] == "*" {
|
||||
this.maxPortWildcard = true
|
||||
} else {
|
||||
this.maxPort = types.Int(pieces[1])
|
||||
}
|
||||
} else if strings.Contains(this.Port, ",") {
|
||||
pieces := strings.Split(this.Port, ",")
|
||||
for _, piece := range pieces {
|
||||
piece = strings.TrimSpace(piece)
|
||||
if len(piece) > 0 {
|
||||
this.ports = append(this.ports, types.Int(piece))
|
||||
}
|
||||
}
|
||||
} else if this.Port == "*" {
|
||||
this.minPortWildcard = true
|
||||
this.maxPortWildcard = true
|
||||
}
|
||||
|
||||
// parse ip
|
||||
if len(this.IP) > 0 {
|
||||
ipRange, err := shared.ParseIPRange(this.IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
this.ipRange = ipRange
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// check ip
|
||||
func (this *IPTable) Match(ip string, port int) (isMatched bool) {
|
||||
if !this.On {
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
if this.TimeFrom > 0 && now < this.TimeFrom {
|
||||
return
|
||||
}
|
||||
if this.TimeTo > 0 && now > this.TimeTo {
|
||||
return
|
||||
}
|
||||
|
||||
if !this.matchPort(port) {
|
||||
return
|
||||
}
|
||||
|
||||
if !this.matchIP(ip) {
|
||||
return
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *IPTable) matchPort(port int) bool {
|
||||
if port == 0 {
|
||||
return false
|
||||
}
|
||||
if this.minPortWildcard {
|
||||
if this.maxPortWildcard {
|
||||
return true
|
||||
}
|
||||
if this.maxPort >= port {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if this.maxPortWildcard {
|
||||
if this.minPortWildcard {
|
||||
return true
|
||||
}
|
||||
if this.minPort <= port {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if (this.minPort > 0 || this.maxPort > 0) && this.minPort <= port && this.maxPort >= port {
|
||||
return true
|
||||
}
|
||||
if len(this.ports) > 0 {
|
||||
return lists.ContainsInt(this.ports, port)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *IPTable) matchIP(ip string) bool {
|
||||
if this.ipRange == nil {
|
||||
return false
|
||||
}
|
||||
return this.ipRange.Contains(ip)
|
||||
}
|
||||
@@ -1,142 +0,0 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIPTable_MatchIP(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
|
||||
{
|
||||
table := NewIPTable()
|
||||
err := table.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a.IsFalse(table.Match("192.168.1.100", 8080))
|
||||
}
|
||||
|
||||
{
|
||||
table := NewIPTable()
|
||||
table.IP = "*"
|
||||
table.Port = "8080"
|
||||
err := table.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("port:", table.minPort, table.maxPort)
|
||||
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||
a.IsFalse(table.Match("192.168.1.100", 8081))
|
||||
}
|
||||
|
||||
{
|
||||
table := NewIPTable()
|
||||
table.IP = "*"
|
||||
table.Port = "8080-8082"
|
||||
err := table.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("port:", table.minPort, table.maxPort)
|
||||
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||
a.IsTrue(table.Match("192.168.1.100", 8081))
|
||||
a.IsFalse(table.Match("192.168.1.100", 8083))
|
||||
}
|
||||
|
||||
{
|
||||
table := NewIPTable()
|
||||
table.IP = "*"
|
||||
table.Port = "*-8082"
|
||||
err := table.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("port:", table.minPort, table.maxPort)
|
||||
a.IsTrue(table.Match("192.168.1.100", 8079))
|
||||
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||
a.IsTrue(table.Match("192.168.1.100", 8081))
|
||||
a.IsFalse(table.Match("192.168.1.100", 8083))
|
||||
}
|
||||
|
||||
{
|
||||
table := NewIPTable()
|
||||
table.IP = "*"
|
||||
table.Port = "8080-*"
|
||||
err := table.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("port:", table.minPort, table.maxPort)
|
||||
a.IsFalse(table.Match("192.168.1.100", 8079))
|
||||
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||
a.IsTrue(table.Match("192.168.1.100", 8081))
|
||||
a.IsTrue(table.Match("192.168.1.100", 8083))
|
||||
}
|
||||
|
||||
{
|
||||
table := NewIPTable()
|
||||
table.IP = "*"
|
||||
table.Port = "*"
|
||||
err := table.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("port:", table.minPort, table.maxPort)
|
||||
a.IsTrue(table.Match("192.168.1.100", 8079))
|
||||
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||
a.IsTrue(table.Match("192.168.1.100", 8081))
|
||||
a.IsTrue(table.Match("192.168.1.100", 8083))
|
||||
}
|
||||
|
||||
{
|
||||
table := NewIPTable()
|
||||
table.IP = "192.168.1.100"
|
||||
table.Port = "*"
|
||||
err := table.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("port:", table.minPort, table.maxPort)
|
||||
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||
}
|
||||
|
||||
{
|
||||
table := NewIPTable()
|
||||
table.IP = "192.168.1.99-192.168.1.101"
|
||||
table.Port = "*"
|
||||
err := table.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("port:", table.minPort, table.maxPort)
|
||||
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||
}
|
||||
|
||||
{
|
||||
table := NewIPTable()
|
||||
table.IP = "192.168.1.99/24"
|
||||
table.Port = "*"
|
||||
err := table.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ip:", table.ipRange)
|
||||
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||
a.IsFalse(table.Match("192.168.2.100", 8080))
|
||||
}
|
||||
|
||||
{
|
||||
table := NewIPTable()
|
||||
table.IP = "192.168.1.99/24"
|
||||
table.TimeTo = time.Now().Unix() - 10
|
||||
table.Port = "*"
|
||||
err := table.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a.IsFalse(table.Match("192.168.1.100", 8080))
|
||||
a.IsFalse(table.Match("192.168.2.100", 8080))
|
||||
}
|
||||
}
|
||||
@@ -1,39 +1,28 @@
|
||||
package requests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
*http.Request
|
||||
BodyData []byte
|
||||
}
|
||||
type Request interface {
|
||||
// WAFRaw 原始请求
|
||||
WAFRaw() *http.Request
|
||||
|
||||
func NewRequest(raw *http.Request) *Request {
|
||||
return &Request{
|
||||
Request: raw,
|
||||
}
|
||||
}
|
||||
// WAFRemoteIP 客户端IP
|
||||
WAFRemoteIP() string
|
||||
|
||||
func (this *Request) Raw() *http.Request {
|
||||
return this.Request
|
||||
}
|
||||
// WAFGetCacheBody 获取缓存中的Body
|
||||
WAFGetCacheBody() []byte
|
||||
|
||||
func (this *Request) ReadBody(max int64) (data []byte, err error) {
|
||||
if this.Request.ContentLength > 0 {
|
||||
data, err = ioutil.ReadAll(io.LimitReader(this.Request.Body, max))
|
||||
}
|
||||
return
|
||||
}
|
||||
// WAFSetCacheBody 设置Body
|
||||
WAFSetCacheBody(body []byte)
|
||||
|
||||
func (this *Request) RestoreBody(data []byte) {
|
||||
if len(data) > 0 {
|
||||
rawReader := bytes.NewBuffer(data)
|
||||
buf := make([]byte, 1024)
|
||||
_, _ = io.CopyBuffer(rawReader, this.Request.Body, buf)
|
||||
this.Request.Body = ioutil.NopCloser(rawReader)
|
||||
}
|
||||
// WAFReadBody 读取Body
|
||||
WAFReadBody(max int64) (data []byte, err error)
|
||||
|
||||
// WAFRestoreBody 恢复Body
|
||||
WAFRestoreBody(data []byte)
|
||||
|
||||
// WAFServerId 服务ID
|
||||
WAFServerId() int64
|
||||
}
|
||||
|
||||
67
internal/waf/requests/test_request.go
Normal file
67
internal/waf/requests/test_request.go
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package requests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type TestRequest struct {
|
||||
req *http.Request
|
||||
BodyData []byte
|
||||
}
|
||||
|
||||
func NewTestRequest(raw *http.Request) *TestRequest {
|
||||
return &TestRequest{
|
||||
req: raw,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *TestRequest) WAFSetCacheBody(bodyData []byte) {
|
||||
this.BodyData = bodyData
|
||||
}
|
||||
|
||||
func (this *TestRequest) WAFGetCacheBody() []byte {
|
||||
return this.BodyData
|
||||
}
|
||||
|
||||
func (this *TestRequest) WAFRaw() *http.Request {
|
||||
return this.req
|
||||
}
|
||||
|
||||
func (this *TestRequest) WAFRemoteAddr() string {
|
||||
return this.req.RemoteAddr
|
||||
}
|
||||
|
||||
func (this *TestRequest) WAFRemoteIP() string {
|
||||
host, _, err := net.SplitHostPort(this.req.RemoteAddr)
|
||||
if err != nil {
|
||||
return this.req.RemoteAddr
|
||||
} else {
|
||||
return host
|
||||
}
|
||||
}
|
||||
|
||||
func (this *TestRequest) WAFReadBody(max int64) (data []byte, err error) {
|
||||
if this.req.ContentLength > 0 {
|
||||
data, err = ioutil.ReadAll(io.LimitReader(this.req.Body, max))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (this *TestRequest) WAFRestoreBody(data []byte) {
|
||||
if len(data) > 0 {
|
||||
rawReader := bytes.NewBuffer(data)
|
||||
buf := make([]byte, 1024)
|
||||
_, _ = io.CopyBuffer(rawReader, this.req.Body, buf)
|
||||
this.req.Body = ioutil.NopCloser(rawReader)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *TestRequest) WAFServerId() int64 {
|
||||
return 0
|
||||
}
|
||||
@@ -183,7 +183,7 @@ func (this *Rule) Init() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *Rule) MatchRequest(req *requests.Request) (b bool, err error) {
|
||||
func (this *Rule) MatchRequest(req requests.Request) (b bool, err error) {
|
||||
if this.singleCheckpoint != nil {
|
||||
value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions)
|
||||
if err != nil {
|
||||
@@ -233,7 +233,7 @@ func (this *Rule) MatchRequest(req *requests.Request) (b bool, err error) {
|
||||
return this.Test(value), nil
|
||||
}
|
||||
|
||||
func (this *Rule) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, err error) {
|
||||
func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (b bool, err error) {
|
||||
if this.singleCheckpoint != nil {
|
||||
// if is request param
|
||||
if this.singleCheckpoint.IsRequest() {
|
||||
|
||||
@@ -23,12 +23,12 @@ func NewRuleGroup() *RuleGroup {
|
||||
}
|
||||
}
|
||||
|
||||
func (this *RuleGroup) Init() error {
|
||||
func (this *RuleGroup) Init(waf *WAF) error {
|
||||
this.hasRuleSets = len(this.RuleSets) > 0
|
||||
|
||||
if this.hasRuleSets {
|
||||
for _, set := range this.RuleSets {
|
||||
err := set.Init()
|
||||
err := set.Init(waf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -79,7 +79,7 @@ func (this *RuleGroup) RemoveRuleSet(id string) {
|
||||
this.RuleSets = result
|
||||
}
|
||||
|
||||
func (this *RuleGroup) MatchRequest(req *requests.Request) (b bool, set *RuleSet, err error) {
|
||||
func (this *RuleGroup) MatchRequest(req requests.Request) (b bool, set *RuleSet, err error) {
|
||||
if !this.hasRuleSets {
|
||||
return
|
||||
}
|
||||
@@ -98,7 +98,7 @@ func (this *RuleGroup) MatchRequest(req *requests.Request) (b bool, set *RuleSet
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RuleGroup) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, set *RuleSet, err error) {
|
||||
func (this *RuleGroup) MatchResponse(req requests.Request, resp *requests.Response) (b bool, set *RuleSet, err error) {
|
||||
if !this.hasRuleSets {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/utils/string"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type RuleConnector = string
|
||||
@@ -14,16 +18,17 @@ const (
|
||||
)
|
||||
|
||||
type RuleSet struct {
|
||||
Id string `yaml:"id" json:"id"`
|
||||
Code string `yaml:"code" json:"code"`
|
||||
IsOn bool `yaml:"isOn" json:"isOn"`
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Description string `yaml:"description" json:"description"`
|
||||
Rules []*Rule `yaml:"rules" json:"rules"`
|
||||
Connector RuleConnector `yaml:"connector" json:"connector"` // rules connector
|
||||
Id string `yaml:"id" json:"id"`
|
||||
Code string `yaml:"code" json:"code"`
|
||||
IsOn bool `yaml:"isOn" json:"isOn"`
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Description string `yaml:"description" json:"description"`
|
||||
Rules []*Rule `yaml:"rules" json:"rules"`
|
||||
Connector RuleConnector `yaml:"connector" json:"connector"` // rules connector
|
||||
Actions []*ActionConfig `yaml:"actions" json:"actions"`
|
||||
|
||||
Action ActionString `yaml:"action" json:"action"`
|
||||
ActionOptions maps.Map `yaml:"actionOptions" json:"actionOptions"` // TODO TO BE IMPLEMENTED
|
||||
actionCodes []string
|
||||
actionInstances []ActionInterface
|
||||
|
||||
hasRules bool
|
||||
}
|
||||
@@ -35,7 +40,7 @@ func NewRuleSet() *RuleSet {
|
||||
}
|
||||
}
|
||||
|
||||
func (this *RuleSet) Init() error {
|
||||
func (this *RuleSet) Init(waf *WAF) error {
|
||||
this.hasRules = len(this.Rules) > 0
|
||||
if this.hasRules {
|
||||
for _, rule := range this.Rules {
|
||||
@@ -45,6 +50,31 @@ func (this *RuleSet) Init() error {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// action codes
|
||||
var actionCodes = []string{}
|
||||
for _, action := range this.Actions {
|
||||
if !lists.ContainsString(actionCodes, action.Code) {
|
||||
actionCodes = append(actionCodes, action.Code)
|
||||
}
|
||||
}
|
||||
this.actionCodes = actionCodes
|
||||
|
||||
// action instances
|
||||
this.actionInstances = []ActionInterface{}
|
||||
for _, action := range this.Actions {
|
||||
instance := FindActionInstance(action.Code, action.Options)
|
||||
if instance == nil {
|
||||
remotelogs.Error("WAF_RULE_SET", "can not find instance for action '"+action.Code+"'")
|
||||
} else {
|
||||
this.actionInstances = append(this.actionInstances, instance)
|
||||
}
|
||||
err := instance.Init(waf)
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_RULE_SET", "init action '"+action.Code+"' failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -52,7 +82,75 @@ func (this *RuleSet) AddRule(rule ...*Rule) {
|
||||
this.Rules = append(this.Rules, rule...)
|
||||
}
|
||||
|
||||
func (this *RuleSet) MatchRequest(req *requests.Request) (b bool, err error) {
|
||||
// AddAction 添加动作
|
||||
func (this *RuleSet) AddAction(code string, options maps.Map) {
|
||||
if options == nil {
|
||||
options = maps.Map{}
|
||||
}
|
||||
this.Actions = append(this.Actions, &ActionConfig{
|
||||
Code: code,
|
||||
Options: options,
|
||||
})
|
||||
}
|
||||
|
||||
// HasSpecialActions 除了Allow之外是否还有别的动作
|
||||
func (this *RuleSet) HasSpecialActions() bool {
|
||||
for _, action := range this.Actions {
|
||||
if action.Code != ActionAllow {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasAttackActions 检查是否含有攻击防御动作
|
||||
func (this *RuleSet) HasAttackActions() bool {
|
||||
for _, action := range this.actionInstances {
|
||||
if action.IsAttack() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *RuleSet) ActionCodes() []string {
|
||||
return this.actionCodes
|
||||
}
|
||||
|
||||
func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Request, writer http.ResponseWriter) bool {
|
||||
// 先执行allow
|
||||
for _, instance := range this.actionInstances {
|
||||
if !instance.WillChange() {
|
||||
if waf.onActionCallback != nil {
|
||||
goNext := waf.onActionCallback(instance)
|
||||
if !goNext {
|
||||
return false
|
||||
}
|
||||
}
|
||||
logs.Printf("perform1: %#v", instance) // TODO
|
||||
instance.Perform(waf, group, this, req, writer)
|
||||
}
|
||||
}
|
||||
|
||||
// 再执行block|verify
|
||||
for _, instance := range this.actionInstances {
|
||||
// 只执行第一个可能改变请求的动作,其余的都会被忽略
|
||||
if instance.WillChange() {
|
||||
if waf.onActionCallback != nil {
|
||||
goNext := waf.onActionCallback(instance)
|
||||
if !goNext {
|
||||
return false
|
||||
}
|
||||
}
|
||||
logs.Printf("perform2: %#v", instance) // TODO
|
||||
return instance.Perform(waf, group, this, req, writer)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (this *RuleSet) MatchRequest(req requests.Request) (b bool, err error) {
|
||||
if !this.hasRules {
|
||||
return false, nil
|
||||
}
|
||||
@@ -93,7 +191,7 @@ func (this *RuleSet) MatchRequest(req *requests.Request) (b bool, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (this *RuleSet) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, err error) {
|
||||
func (this *RuleSet) MatchResponse(req requests.Request, resp *requests.Response) (b bool, err error) {
|
||||
if !this.hasRules {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ func TestRuleSet_MatchRequest(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := set.Init()
|
||||
err := set.Init(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -37,7 +37,7 @@ func TestRuleSet_MatchRequest(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := requests.NewRequest(rawReq)
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
t.Log(set.MatchRequest(req))
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ func TestRuleSet_MatchRequest2(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := set.Init()
|
||||
err := set.Init(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -69,7 +69,7 @@ func TestRuleSet_MatchRequest2(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := requests.NewRequest(rawReq)
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
a.IsTrue(set.MatchRequest(req))
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ func BenchmarkRuleSet_MatchRequest(b *testing.B) {
|
||||
},
|
||||
}
|
||||
|
||||
err := set.Init()
|
||||
err := set.Init(nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@@ -111,7 +111,7 @@ func BenchmarkRuleSet_MatchRequest(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
req := requests.NewRequest(rawReq)
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = set.MatchRequest(req)
|
||||
}
|
||||
@@ -132,7 +132,7 @@ func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) {
|
||||
},
|
||||
}
|
||||
|
||||
err := set.Init()
|
||||
err := set.Init(nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@@ -141,7 +141,7 @@ func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) {
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
req := requests.NewRequest(rawReq)
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = set.MatchRequest(req)
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func TestRule_Init_Single(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := requests.NewRequest(rawReq)
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
t.Log(rule.MatchRequest(req))
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func TestRule_Init_Composite(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := requests.NewRequest(rawReq)
|
||||
req := requests.NewTestRequest(rawReq)
|
||||
t.Log(rule.MatchRequest(req))
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ func Template() *WAF {
|
||||
set.Name = "Javascript事件"
|
||||
set.Code = "1001"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
set.AddRule(&Rule{
|
||||
Param: "${requestURI}",
|
||||
Operator: RuleOperatorMatch,
|
||||
@@ -36,7 +36,7 @@ func Template() *WAF {
|
||||
set.Name = "Javascript函数"
|
||||
set.Code = "1002"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
set.AddRule(&Rule{
|
||||
Param: "${requestURI}",
|
||||
Operator: RuleOperatorMatch,
|
||||
@@ -52,7 +52,7 @@ func Template() *WAF {
|
||||
set.Name = "HTML标签"
|
||||
set.Code = "1003"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
set.AddRule(&Rule{
|
||||
Param: "${requestURI}",
|
||||
Operator: RuleOperatorMatch,
|
||||
@@ -80,7 +80,7 @@ func Template() *WAF {
|
||||
set.Name = "上传文件扩展名"
|
||||
set.Code = "2001"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
set.AddRule(&Rule{
|
||||
Param: "${requestUpload.ext}",
|
||||
Operator: RuleOperatorMatch,
|
||||
@@ -108,7 +108,7 @@ func Template() *WAF {
|
||||
set.Name = "Web Shell"
|
||||
set.Code = "3001"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
set.AddRule(&Rule{
|
||||
Param: "${requestAll}",
|
||||
Operator: RuleOperatorMatch,
|
||||
@@ -135,7 +135,7 @@ func Template() *WAF {
|
||||
set.Name = "命令注入"
|
||||
set.Code = "4001"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
set.AddRule(&Rule{
|
||||
Param: "${requestURI}",
|
||||
Operator: RuleOperatorMatch,
|
||||
@@ -169,7 +169,7 @@ func Template() *WAF {
|
||||
set.Name = "路径穿越"
|
||||
set.Code = "5001"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
set.AddRule(&Rule{
|
||||
Param: "${requestURI}",
|
||||
Operator: RuleOperatorMatch,
|
||||
@@ -197,7 +197,7 @@ func Template() *WAF {
|
||||
set.Name = "特殊目录"
|
||||
set.Code = "6001"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
set.AddRule(&Rule{
|
||||
Param: "${requestPath}",
|
||||
Operator: RuleOperatorMatch,
|
||||
@@ -225,7 +225,7 @@ func Template() *WAF {
|
||||
set.Name = "Union SQL Injection"
|
||||
set.Code = "7001"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
|
||||
set.AddRule(&Rule{
|
||||
Param: "${requestAll}",
|
||||
@@ -243,7 +243,7 @@ func Template() *WAF {
|
||||
set.Name = "SQL注释"
|
||||
set.Code = "7002"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
|
||||
set.AddRule(&Rule{
|
||||
Param: "${requestAll}",
|
||||
@@ -261,7 +261,7 @@ func Template() *WAF {
|
||||
set.Name = "SQL条件"
|
||||
set.Code = "7003"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
|
||||
set.AddRule(&Rule{
|
||||
Param: "${requestAll}",
|
||||
@@ -297,7 +297,7 @@ func Template() *WAF {
|
||||
set.Name = "SQL函数"
|
||||
set.Code = "7004"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
|
||||
set.AddRule(&Rule{
|
||||
Param: "${requestAll}",
|
||||
@@ -315,7 +315,7 @@ func Template() *WAF {
|
||||
set.Name = "SQL附加语句"
|
||||
set.Code = "7005"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
|
||||
set.AddRule(&Rule{
|
||||
Param: "${requestAll}",
|
||||
@@ -345,7 +345,7 @@ func Template() *WAF {
|
||||
set.Name = "常见网络爬虫"
|
||||
set.Code = "20001"
|
||||
set.Connector = RuleConnectorOr
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
|
||||
set.AddRule(&Rule{
|
||||
Param: "${userAgent}",
|
||||
@@ -376,7 +376,7 @@ func Template() *WAF {
|
||||
set.Description = "限制单IP在一定时间内的请求数"
|
||||
set.Code = "8001"
|
||||
set.Connector = RuleConnectorAnd
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
set.AddRule(&Rule{
|
||||
Param: "${cc.requests}",
|
||||
Operator: RuleOperatorGt,
|
||||
|
||||
@@ -2,6 +2,7 @@ package waf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
@@ -22,8 +23,8 @@ func Test_Template(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
template.OnAction(func(action ActionString) (goNext bool) {
|
||||
return action != ActionBlock
|
||||
template.OnAction(func(action ActionInterface) (goNext bool) {
|
||||
return action.Code() != ActionBlock
|
||||
})
|
||||
|
||||
testTemplate1001(a, t, template)
|
||||
@@ -40,7 +41,7 @@ func Test_Template(t *testing.T) {
|
||||
|
||||
func Test_Template2(t *testing.T) {
|
||||
reader := bytes.NewReader([]byte(strings.Repeat("HELLO", 1024)))
|
||||
req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=123", reader)
|
||||
req, err := http.NewRequest(http.MethodGet, "https://example.com/index.php?id=123", reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -52,7 +53,7 @@ func Test_Template2(t *testing.T) {
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
goNext, _, set, err := waf.MatchRequest(req, nil)
|
||||
goNext, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -80,7 +81,7 @@ func BenchmarkTemplate(b *testing.B) {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
_, _, _, _ = waf.MatchRequest(req, nil)
|
||||
_, _, _, _ = waf.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +90,7 @@ func testTemplate1001(a *assert.Assertion, t *testing.T, template *WAF) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, result, err := template.MatchRequest(req, nil)
|
||||
_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -104,7 +105,7 @@ func testTemplate1002(a *assert.Assertion, t *testing.T, template *WAF) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, result, err := template.MatchRequest(req, nil)
|
||||
_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -119,7 +120,7 @@ func testTemplate1003(a *assert.Assertion, t *testing.T, template *WAF) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, result, err := template.MatchRequest(req, nil)
|
||||
_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -185,7 +186,7 @@ func testTemplate2001(a *assert.Assertion, t *testing.T, template *WAF) {
|
||||
|
||||
req.Header.Add("Content-Type", writer.FormDataContentType())
|
||||
|
||||
_, _, result, err := template.MatchRequest(req, nil)
|
||||
_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -200,7 +201,7 @@ func testTemplate3001(a *assert.Assertion, t *testing.T, template *WAF) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, result, err := template.MatchRequest(req, nil)
|
||||
_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -215,7 +216,7 @@ func testTemplate4001(a *assert.Assertion, t *testing.T, template *WAF) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, result, err := template.MatchRequest(req, nil)
|
||||
_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -231,7 +232,7 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, result, err := template.MatchRequest(req, nil)
|
||||
_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -246,7 +247,7 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, result, err := template.MatchRequest(req, nil)
|
||||
_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -263,7 +264,7 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, result, err := template.MatchRequest(req, nil)
|
||||
_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -278,7 +279,7 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, result, err := template.MatchRequest(req, nil)
|
||||
_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -301,7 +302,7 @@ func testTemplate7001(a *assert.Assertion, t *testing.T, template *WAF) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, _, result, err := template.MatchRequest(req, nil)
|
||||
_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -338,7 +339,7 @@ func testTemplate20001(a *assert.Assertion, t *testing.T, template *WAF) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("User-Agent", bot)
|
||||
_, _, result, err := template.MatchRequest(req, nil)
|
||||
_, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -22,13 +22,11 @@ type WAF struct {
|
||||
Outbound []*RuleGroup `yaml:"outbound" json:"outbound"`
|
||||
CreatedVersion string `yaml:"createdVersion" json:"createdVersion"`
|
||||
|
||||
ActionBlock *BlockAction `yaml:"actionBlock" json:"actionBlock"` // action block config
|
||||
|
||||
IPTables []*IPTable `yaml:"ipTables" json:"ipTables"` // IP table list
|
||||
DefaultBlockAction *BlockAction
|
||||
|
||||
hasInboundRules bool
|
||||
hasOutboundRules bool
|
||||
onActionCallback func(action ActionString) (goNext bool)
|
||||
onActionCallback func(action ActionInterface) (goNext bool)
|
||||
|
||||
checkpointsMap map[string]checkpoints.CheckpointInterface // prefix => checkpoint
|
||||
}
|
||||
@@ -87,7 +85,7 @@ func (this *WAF) Init() error {
|
||||
}
|
||||
}
|
||||
|
||||
err := group.Init()
|
||||
err := group.Init(this)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -103,7 +101,7 @@ func (this *WAF) Init() error {
|
||||
}
|
||||
}
|
||||
|
||||
err := group.Init()
|
||||
err := group.Init(this)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -241,19 +239,24 @@ func (this *WAF) MoveOutboundRuleGroup(fromIndex int, toIndex int) {
|
||||
this.Outbound = result
|
||||
}
|
||||
|
||||
func (this *WAF) MatchRequest(rawReq *http.Request, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) {
|
||||
func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) {
|
||||
if !this.hasInboundRules {
|
||||
return true, nil, nil, nil
|
||||
}
|
||||
|
||||
req := requests.NewRequest(rawReq)
|
||||
|
||||
// validate captcha
|
||||
if rawReq.URL.Path == "/WAFCAPTCHA" {
|
||||
var rawPath = req.WAFRaw().URL.Path
|
||||
if rawPath == CaptchaPath {
|
||||
captchaValidator.Run(req, writer)
|
||||
return
|
||||
}
|
||||
|
||||
// Get 302验证
|
||||
if rawPath == Get302Path {
|
||||
get302Validator.Run(req, writer)
|
||||
return
|
||||
}
|
||||
|
||||
// match rules
|
||||
for _, group := range this.Inbound {
|
||||
if !group.IsOn {
|
||||
@@ -264,31 +267,17 @@ func (this *WAF) MatchRequest(rawReq *http.Request, writer http.ResponseWriter)
|
||||
return true, nil, nil, err
|
||||
}
|
||||
if b {
|
||||
if this.onActionCallback == nil {
|
||||
if set.Action == ActionBlock && this.ActionBlock != nil {
|
||||
return this.ActionBlock.Perform(this, req, writer), group, set, nil
|
||||
} else {
|
||||
actionObject := FindActionInstance(set.Action, set.ActionOptions)
|
||||
if actionObject == nil {
|
||||
return true, group, set, errors.New("no action called '" + set.Action + "'")
|
||||
}
|
||||
goNext := actionObject.Perform(this, req, writer)
|
||||
return goNext, group, set, nil
|
||||
}
|
||||
} else {
|
||||
goNext = this.onActionCallback(set.Action)
|
||||
}
|
||||
goNext := set.PerformActions(this, group, req, writer)
|
||||
return goNext, group, set, nil
|
||||
}
|
||||
}
|
||||
return true, nil, nil, nil
|
||||
}
|
||||
|
||||
func (this *WAF) MatchResponse(rawReq *http.Request, rawResp *http.Response, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) {
|
||||
func (this *WAF) MatchResponse(req requests.Request, rawResp *http.Response, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) {
|
||||
if !this.hasOutboundRules {
|
||||
return true, nil, nil, nil
|
||||
}
|
||||
req := requests.NewRequest(rawReq)
|
||||
resp := requests.NewResponse(rawResp)
|
||||
for _, group := range this.Outbound {
|
||||
if !group.IsOn {
|
||||
@@ -299,27 +288,14 @@ func (this *WAF) MatchResponse(rawReq *http.Request, rawResp *http.Response, wri
|
||||
return true, nil, nil, err
|
||||
}
|
||||
if b {
|
||||
if this.onActionCallback == nil {
|
||||
if set.Action == ActionBlock && this.ActionBlock != nil {
|
||||
return this.ActionBlock.Perform(this, req, writer), group, set, nil
|
||||
} else {
|
||||
actionObject := FindActionInstance(set.Action, set.ActionOptions)
|
||||
if actionObject == nil {
|
||||
return true, group, set, errors.New("no action called '" + set.Action + "'")
|
||||
}
|
||||
goNext := actionObject.Perform(this, req, writer)
|
||||
return goNext, group, set, nil
|
||||
}
|
||||
} else {
|
||||
goNext = this.onActionCallback(set.Action)
|
||||
}
|
||||
goNext := set.PerformActions(this, group, req, writer)
|
||||
return goNext, group, set, nil
|
||||
}
|
||||
}
|
||||
return true, nil, nil, nil
|
||||
}
|
||||
|
||||
// save to file path
|
||||
// Save save to file path
|
||||
func (this *WAF) Save(path string) error {
|
||||
if len(path) == 0 {
|
||||
return errors.New("path should not be empty")
|
||||
@@ -378,7 +354,7 @@ func (this *WAF) CountOutboundRuleSets() int {
|
||||
return count
|
||||
}
|
||||
|
||||
func (this *WAF) OnAction(onActionCallback func(action ActionString) (goNext bool)) {
|
||||
func (this *WAF) OnAction(onActionCallback func(action ActionInterface) (goNext bool)) {
|
||||
this.onActionCallback = onActionCallback
|
||||
}
|
||||
|
||||
@@ -390,21 +366,21 @@ func (this *WAF) FindCheckpointInstance(prefix string) checkpoints.CheckpointInt
|
||||
return nil
|
||||
}
|
||||
|
||||
// start
|
||||
// Start start
|
||||
func (this *WAF) Start() {
|
||||
for _, checkpoint := range this.checkpointsMap {
|
||||
checkpoint.Start()
|
||||
}
|
||||
}
|
||||
|
||||
// call stop() when the waf was deleted
|
||||
// Stop call stop() when the waf was deleted
|
||||
func (this *WAF) Stop() {
|
||||
for _, checkpoint := range this.checkpointsMap {
|
||||
checkpoint.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// merge with template
|
||||
// MergeTemplate merge with template
|
||||
func (this *WAF) MergeTemplate() (changedItems []string) {
|
||||
changedItems = []string{}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package waf
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"net/http"
|
||||
"testing"
|
||||
@@ -24,7 +25,7 @@ func TestWAF_MatchRequest(t *testing.T) {
|
||||
Value: "20",
|
||||
},
|
||||
}
|
||||
set.Action = ActionBlock
|
||||
set.AddAction(ActionBlock, nil)
|
||||
|
||||
group := NewRuleGroup()
|
||||
group.AddRuleSet(set)
|
||||
@@ -37,15 +38,15 @@ func TestWAF_MatchRequest(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
waf.OnAction(func(action ActionString) (goNext bool) {
|
||||
return action != ActionBlock
|
||||
waf.OnAction(func(action ActionInterface) (goNext bool) {
|
||||
return action.Code() != ActionBlock
|
||||
})
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
goNext, _, set, err := waf.MatchRequest(req, nil)
|
||||
goNext, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user