diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index f651c12..dad26e7 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -46,7 +46,7 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) { } // 检查是否在临时黑名单中 - if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeService, this.ReqServer.Id, remoteAddr) || waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteAddr) { + if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeServer, this.ReqServer.Id, remoteAddr) || waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteAddr) { this.disableLog = true this.Close() diff --git a/internal/utils/encrypt.go b/internal/utils/encrypt.go index 5abcb00..8f923f4 100644 --- a/internal/utils/encrypt.go +++ b/internal/utils/encrypt.go @@ -74,7 +74,7 @@ func SimpleEncryptMap(m maps.Map) (base64String string, err error) { if err != nil { return "", err } - data := SimpleEncrypt(mJSON) + var data = SimpleEncrypt(mJSON) return base64.StdEncoding.EncodeToString(data), nil } @@ -83,7 +83,7 @@ func SimpleDecryptMap(base64String string) (maps.Map, error) { if err != nil { return nil, err } - mJSON := SimpleDecrypt(data) + var mJSON = SimpleDecrypt(data) var result = maps.Map{} err = json.Unmarshal(mJSON, &result) if err != nil { @@ -92,6 +92,25 @@ func SimpleDecryptMap(base64String string) (maps.Map, error) { return result, nil } +func SimpleEncryptObject(ptr any) (string, error) { + mJSON, err := json.Marshal(ptr) + if err != nil { + return "", err + } + var data = SimpleEncrypt(mJSON) + return base64.StdEncoding.EncodeToString(data), nil +} + +func SimpleDecryptObjet(base64String string, ptr any) error { + data, err := base64.StdEncoding.DecodeString(base64String) + if err != nil { + return err + } + var mJSON = SimpleDecrypt(data) + err = json.Unmarshal(mJSON, ptr) + return err +} + type AES256CFBMethod struct { block cipher.Block iv []byte @@ -99,7 +118,7 @@ type AES256CFBMethod struct { func (this *AES256CFBMethod) Init(key, iv []byte) error { // 判断key是否为32长度 - l := len(key) + var l = len(key) if l > 32 { key = key[:32] } else if l < 32 { @@ -113,7 +132,7 @@ func (this *AES256CFBMethod) Init(key, iv []byte) error { this.block = block // 判断iv长度 - l2 := len(iv) + var l2 = len(iv) if l2 > aes.BlockSize { iv = iv[:aes.BlockSize] } else if l2 < aes.BlockSize { @@ -130,7 +149,7 @@ func (this *AES256CFBMethod) Encrypt(src []byte) (dst []byte, err error) { } defer func() { - r := recover() + var r = recover() if r != nil { err = errors.New("encrypt failed") } @@ -138,7 +157,7 @@ func (this *AES256CFBMethod) Encrypt(src []byte) (dst []byte, err error) { dst = make([]byte, len(src)) - encrypter := cipher.NewCFBEncrypter(this.block, this.iv) + var encrypter = cipher.NewCFBEncrypter(this.block, this.iv) encrypter.XORKeyStream(dst, src) return @@ -157,7 +176,7 @@ func (this *AES256CFBMethod) Decrypt(dst []byte) (src []byte, err error) { }() src = make([]byte, len(dst)) - decrypter := cipher.NewCFBDecrypter(this.block, this.iv) + var decrypter = cipher.NewCFBDecrypter(this.block, this.iv) decrypter.XORKeyStream(src, dst) return diff --git a/internal/utils/encrypt_test.go b/internal/utils/encrypt_test.go index 3a5d411..2bf89c9 100644 --- a/internal/utils/encrypt_test.go +++ b/internal/utils/encrypt_test.go @@ -1,32 +1,60 @@ // Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. -package utils +package utils_test import ( + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/maps" "sync" "testing" ) func TestSimpleEncrypt(t *testing.T) { + var a = assert.NewAssertion(t) + var arr = []string{"Hello", "World", "People"} for _, s := range arr { var value = []byte(s) - encoded := SimpleEncrypt(value) + var encoded = utils.SimpleEncrypt(value) t.Log(encoded, string(encoded)) - decoded := SimpleDecrypt(encoded) + var decoded = utils.SimpleDecrypt(encoded) t.Log(decoded, string(decoded)) + a.IsTrue(s == string(decoded)) } } +func TestSimpleEncryptObject(t *testing.T) { + var a = assert.NewAssertion(t) + + type Obj struct { + Name string `json:"name"` + Age int `json:"age"` + } + + encoded, err := utils.SimpleEncryptObject(&Obj{Name: "lily", Age: 20}) + if err != nil { + t.Fatal(err) + } + + var obj = &Obj{} + err = utils.SimpleDecryptObjet(encoded, obj) + if err != nil { + t.Fatal(err) + } + t.Logf("%#v", obj) + a.IsTrue(obj.Name == "lily") + a.IsTrue(obj.Age == 20) +} + func TestSimpleEncrypt_Concurrent(t *testing.T) { - wg := sync.WaitGroup{} + var wg = sync.WaitGroup{} var arr = []string{"Hello", "World", "People"} wg.Add(len(arr)) for _, s := range arr { go func(s string) { defer wg.Done() - t.Log(string(SimpleDecrypt(SimpleEncrypt([]byte(s))))) + t.Log(string(utils.SimpleDecrypt(utils.SimpleEncrypt([]byte(s))))) }(s) } wg.Wait() @@ -38,13 +66,13 @@ func TestSimpleEncryptMap(t *testing.T) { "i": 20, "b": true, } - encodedResult, err := SimpleEncryptMap(m) + encodedResult, err := utils.SimpleEncryptMap(m) if err != nil { t.Fatal(err) } t.Log("result:", encodedResult) - decodedResult, err := SimpleDecryptMap(encodedResult) + decodedResult, err := utils.SimpleDecryptMap(encodedResult) if err != nil { t.Fatal(err) } diff --git a/internal/waf/action_block.go b/internal/waf/action_block.go index 89fa1eb..5390431 100644 --- a/internal/waf/action_block.go +++ b/internal/waf/action_block.go @@ -1,6 +1,7 @@ package waf import ( + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/waf/requests" @@ -28,6 +29,8 @@ type BlockAction struct { Timeout int32 `yaml:"timeout" json:"timeout"` TimeoutMax int32 `yaml:"timeoutMax" json:"timeoutMax"` Scope string `yaml:"scope" json:"scope"` + + FailBlockScopeAll bool `yaml:"failBlockScopeAll" json:"failBlockScopeAll"` } func (this *BlockAction) Init(waf *WAF) error { @@ -45,7 +48,10 @@ func (this *BlockAction) Init(waf *WAF) error { this.Timeout = waf.DefaultBlockAction.Timeout this.TimeoutMax = waf.DefaultBlockAction.TimeoutMax // 只有没有填写封锁时长的时候才会使用默认的封锁时长最大值 } + + this.FailBlockScopeAll = waf.DefaultBlockAction.FailBlockScopeAll } + return nil } @@ -74,7 +80,7 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque timeout = timeout + int32(rands.Int64()%int64(timeoutMax-timeout+1)) } - SharedIPBlackList.RecordIP(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(timeout), waf.Id, waf.UseLocalFirewall, group.Id, set.Id, "") + SharedIPBlackList.RecordIP(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(timeout), waf.Id, waf.UseLocalFirewall && (this.FailBlockScopeAll || this.Scope == firewallconfigs.FirewallScopeGlobal), group.Id, set.Id, "") if writer != nil { // close the connection diff --git a/internal/waf/action_captcha.go b/internal/waf/action_captcha.go index a581f8e..cdc5e9a 100644 --- a/internal/waf/action_captcha.go +++ b/internal/waf/action_captcha.go @@ -6,7 +6,6 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/waf/requests" wafutils "github.com/TeaOSLab/EdgeNode/internal/waf/utils" - "github.com/iwind/TeaGo/maps" "net/http" "net/url" "strings" @@ -135,24 +134,32 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req // 覆盖配置 if strings.HasPrefix(refURL, CaptchaPath) { - info := req.WAFRaw().URL.Query().Get("info") + var info = req.WAFRaw().URL.Query().Get("info") if len(info) > 0 { - m, err := utils.SimpleDecryptMap(info) - if err == nil && m != nil { - refURL = m.GetString("url") + var oldArg = &InfoArg{} + decodeErr := oldArg.Decode(info) + if decodeErr == nil && oldArg.IsValid() { + refURL = oldArg.URL + } else { + // 兼容老版本 + m, err := utils.SimpleDecryptMap(info) + if err == nil && m != nil { + refURL = m.GetString("url") + } } } } - var captchaConfig = maps.Map{ - "actionId": this.ActionId(), - "timestamp": time.Now().Unix(), - "url": refURL, - "policyId": waf.Id, - "groupId": group.Id, - "setId": set.Id, + var captchaConfig = &InfoArg{ + ActionId: this.ActionId(), + Timestamp: time.Now().Unix(), + URL: refURL, + PolicyId: waf.Id, + GroupId: group.Id, + SetId: set.Id, + UseLocalFirewall: waf.UseLocalFirewall && (this.FailBlockScopeAll || this.Scope == firewallconfigs.AllowScopeGlobal), } - info, err := utils.SimpleEncryptMap(captchaConfig) + info, err := utils.SimpleEncryptObject(captchaConfig) if err != nil { remotelogs.Error("WAF_CAPTCHA_ACTION", "encode captcha config failed: "+err.Error()) return PerformResult{ @@ -161,11 +168,11 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req } // 占用一次失败次数 - CaptchaIncreaseFails(req, this, waf.Id, group.Id, set.Id, CaptchaPageCodeInit) + CaptchaIncreaseFails(req, this, waf.Id, group.Id, set.Id, CaptchaPageCodeInit, waf.UseLocalFirewall && (this.FailBlockScopeAll || this.Scope == firewallconfigs.FirewallScopeGlobal)) req.DisableStat() req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect) - http.Redirect(writer, req.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect) + http.Redirect(writer, req.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info)+"&from="+url.QueryEscape(refURL), http.StatusTemporaryRedirect) return PerformResult{} } diff --git a/internal/waf/action_get_302.go b/internal/waf/action_get_302.go index 40884ec..fc2c657 100644 --- a/internal/waf/action_get_302.go +++ b/internal/waf/action_get_302.go @@ -4,7 +4,6 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/waf/requests" - "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" "net/http" "net/url" @@ -56,16 +55,17 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ } } - var m = maps.Map{ - "url": request.WAFRaw().URL.String(), - "timestamp": time.Now().Unix(), - "life": this.Life, - "scope": this.Scope, - "policyId": waf.Id, - "groupId": group.Id, - "setId": set.Id, + var m = InfoArg{ + URL: request.WAFRaw().URL.String(), + Timestamp: time.Now().Unix(), + Life: this.Life, + Scope: this.Scope, + PolicyId: waf.Id, + GroupId: group.Id, + SetId: set.Id, + UseLocalFirewall: false, } - info, err := utils.SimpleEncryptMap(m) + info, err := utils.SimpleEncryptObject(m) if err != nil { remotelogs.Error("WAF_GET_302_ACTION", "encode info failed: "+err.Error()) return PerformResult{ diff --git a/internal/waf/action_js_cookie.go b/internal/waf/action_js_cookie.go index b7a1305..3b34537 100644 --- a/internal/waf/action_js_cookie.go +++ b/internal/waf/action_js_cookie.go @@ -22,10 +22,32 @@ type JSCookieAction struct { MaxFails int `yaml:"maxFails" json:"maxFails"` // 最大失败次数 FailBlockTimeout int `yaml:"failBlockTimeout" json:"failBlockTimeout"` // 失败拦截时间 Scope string `yaml:"scope" json:"scope"` + + FailBlockScopeAll bool `yaml:"failBlockScopeAll" json:"failBlockScopeAll"` } func (this *JSCookieAction) Init(waf *WAF) error { - this.Scope = firewallconfigs.FirewallScopeGlobal + + if waf.DefaultJSCookieAction != nil { + if this.Life <= 0 { + this.Life = waf.DefaultJSCookieAction.Life + } + if this.MaxFails <= 0 { + this.MaxFails = waf.DefaultJSCookieAction.MaxFails + } + if this.FailBlockTimeout <= 0 { + this.FailBlockTimeout = waf.DefaultJSCookieAction.FailBlockTimeout + } + if len(this.Scope) == 0 { + this.Scope = waf.DefaultJSCookieAction.Scope + } + + this.FailBlockScopeAll = waf.DefaultJSCookieAction.FailBlockScopeAll + } + + if len(this.Scope) == 0 { + this.Scope = firewallconfigs.FirewallScopeGlobal + } return nil } @@ -107,19 +129,19 @@ window.location.reload(); _, _ = writer.Write([]byte(respHTML)) // 记录失败次数 - this.increaseFails(req, waf.Id, group.Id, set.Id) + this.increaseFails(req, waf.Id, group.Id, set.Id, waf.UseLocalFirewall && (this.FailBlockScopeAll || this.Scope == firewallconfigs.FirewallScopeGlobal)) return PerformResult{} } -func (this *JSCookieAction) increaseFails(req requests.Request, policyId int64, groupId int64, setId int64) (goNext bool) { +func (this *JSCookieAction) increaseFails(req requests.Request, policyId int64, groupId int64, setId int64, useLocalFirewall bool) (goNext bool) { var maxFails = this.MaxFails var failBlockTimeout = this.FailBlockTimeout if maxFails <= 0 { maxFails = 10 // 默认10次 - } else if maxFails <= 3 { - maxFails = 3 // 不能小于3,防止意外刷新出现 + } else if maxFails <= 5 { + maxFails = 5 // 不能小于3,防止意外刷新出现 } if failBlockTimeout <= 0 { failBlockTimeout = 1800 // 默认1800s @@ -129,7 +151,7 @@ func (this *JSCookieAction) increaseFails(req requests.Request, policyId int64, var countFails = counters.SharedCounter.IncreaseKey(key, 300) if int(countFails) >= maxFails { - SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, true, groupId, setId, "JS_COOKIE验证连续失败超过"+types.String(maxFails)+"次") + SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeServer, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, useLocalFirewall, groupId, setId, "JS_COOKIE验证连续失败超过"+types.String(maxFails)+"次") return false } diff --git a/internal/waf/action_post_307.go b/internal/waf/action_post_307.go index d8e67d0..b60057e 100644 --- a/internal/waf/action_post_307.go +++ b/internal/waf/action_post_307.go @@ -4,7 +4,6 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/waf/requests" - "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" "io" "net/http" @@ -35,7 +34,7 @@ func (this *Post307Action) WillChange() bool { } func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult { - var cookieName = "WAF_VALIDATOR_ID" + const cookieName = "WAF_VALIDATOR_ID" // 仅限于POST if request.WAFRaw().Method != http.MethodPost { @@ -52,32 +51,64 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req } // 判断是否有Cookie - cookie, err := request.WAFRaw().Cookie(cookieName) - if err == nil && cookie != nil { - m, err := utils.SimpleDecryptMap(cookie.Value) - if err == nil && m.GetString("remoteIP") == request.WAFRemoteIP() && time.Now().Unix() < m.GetInt64("timestamp")+10 { - var life = m.GetInt64("life") + cookie, cookieErr := request.WAFRaw().Cookie(cookieName) + if cookieErr == nil && cookie != nil { + var remoteIP string + var life int64 + var setId int64 + var policyId int64 + var groupId int64 + var timestamp int64 + + var infoArg = &InfoArg{} + var success bool + decodeErr := infoArg.Decode(cookie.Value) + if decodeErr == nil && infoArg.IsValid() { + success = true + + remoteIP = infoArg.RemoteIP + life = int64(infoArg.Life) + setId = infoArg.SetId + policyId = infoArg.PolicyId + groupId = infoArg.GroupId + timestamp = infoArg.Timestamp + } else { + // 兼容老版本 + m, decodeMapErr := utils.SimpleDecryptMap(cookie.Value) + if decodeMapErr == nil { + success = true + + remoteIP = m.GetString("remoteIP") + timestamp = m.GetInt64("timestamp") + life = m.GetInt64("life") + setId = m.GetInt64("setId") + groupId = m.GetInt64("groupId") + policyId = m.GetInt64("policyId") + } + } + + if success && remoteIP == request.WAFRemoteIP() && time.Now().Unix() < timestamp+10 { if life <= 0 { life = 600 // 默认10分钟 } - var setId = types.String(m.GetInt64("setId")) - SharedIPWhiteList.RecordIP("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), false, m.GetInt64("groupId"), m.GetInt64("setId"), "") + SharedIPWhiteList.RecordIP("set:"+types.String(setId), this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, policyId, false, groupId, setId, "") return PerformResult{ ContinueRequest: true, } } } - var m = maps.Map{ - "timestamp": time.Now().Unix(), - "life": this.Life, - "scope": this.Scope, - "policyId": waf.Id, - "groupId": group.Id, - "setId": set.Id, - "remoteIP": request.WAFRemoteIP(), + var m = &InfoArg{ + Timestamp: time.Now().Unix(), + Life: this.Life, + Scope: this.Scope, + PolicyId: waf.Id, + GroupId: group.Id, + SetId: set.Id, + RemoteIP: request.WAFRemoteIP(), + UseLocalFirewall: false, } - info, err := utils.SimpleEncryptMap(m) + info, err := utils.SimpleEncryptObject(m) if err != nil { remotelogs.Error("WAF_POST_307_ACTION", "encode info failed: "+err.Error()) return PerformResult{ diff --git a/internal/waf/action_record_ip.go b/internal/waf/action_record_ip.go index 573df36..c8e6692 100644 --- a/internal/waf/action_record_ip.go +++ b/internal/waf/action_record_ip.go @@ -178,7 +178,7 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re // 上报 if ipListId > 0 && ipListIsAvailable { var serverId int64 - if this.Scope == firewallconfigs.FirewallScopeService { + if this.Scope == firewallconfigs.FirewallScopeServer { serverId = request.WAFServerId() } diff --git a/internal/waf/captcha_counter.go b/internal/waf/captcha_counter.go index abaf389..db9f503 100644 --- a/internal/waf/captcha_counter.go +++ b/internal/waf/captcha_counter.go @@ -20,7 +20,7 @@ const ( ) // CaptchaIncreaseFails 增加Captcha失败次数,以便后续操作 -func CaptchaIncreaseFails(req requests.Request, actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, pageCode CaptchaPageCode) (goNext bool) { +func CaptchaIncreaseFails(req requests.Request, actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, pageCode CaptchaPageCode, useLocalFirewall bool) (goNext bool) { var maxFails = actionConfig.MaxFails var failBlockTimeout = actionConfig.FailBlockTimeout if maxFails > 0 && failBlockTimeout > 0 { @@ -29,7 +29,7 @@ func CaptchaIncreaseFails(req requests.Request, actionConfig *CaptchaAction, pol } var countFails = counters.SharedCounter.IncreaseKey(CaptchaCacheKey(req, pageCode), 300) if int(countFails) >= maxFails { - SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, true, groupId, setId, "CAPTCHA验证连续失败超过"+types.String(maxFails)+"次") + SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeServer, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, useLocalFirewall, groupId, setId, "CAPTCHA验证连续失败超过"+types.String(maxFails)+"次") return false } } diff --git a/internal/waf/captcha_validator.go b/internal/waf/captcha_validator.go index 253107d..e0775c8 100644 --- a/internal/waf/captcha_validator.go +++ b/internal/waf/captcha_validator.go @@ -38,34 +38,80 @@ func NewCaptchaValidator() *CaptchaValidator { } func (this *CaptchaValidator) Run(req requests.Request, writer http.ResponseWriter, defaultCaptchaType firewallconfigs.ServerCaptchaType) { + var realURL string + var urlObj = req.WAFRaw().URL + if urlObj != nil { + realURL = urlObj.Query().Get("from") + } + var info = req.WAFRaw().URL.Query().Get("info") if len(info) == 0 { - req.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest) - writer.WriteHeader(http.StatusBadRequest) - _, _ = writer.Write([]byte("invalid request")) - return - } - m, err := utils.SimpleDecryptMap(info) - if err != nil { - req.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest) - writer.WriteHeader(http.StatusBadRequest) - _, _ = writer.Write([]byte("invalid request")) + if len(realURL) > 0 { + req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect) + http.Redirect(writer, req.WAFRaw(), realURL, http.StatusTemporaryRedirect) + } else { + req.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest) + writer.WriteHeader(http.StatusBadRequest) + _, _ = writer.Write([]byte("invalid request (001)")) + } return } - var timestamp = m.GetInt64("timestamp") - if timestamp < time.Now().Unix()-600 { // 10分钟之后信息过期 + var success bool + var actionId int64 + var setId int64 + var originURL string + var policyId int64 + var groupId int64 + var useLocalFirewall bool + var timestamp int64 + + var infoArg = &InfoArg{} + decodeErr := infoArg.Decode(info) + if decodeErr == nil && infoArg.IsValid() { + success = true + + actionId = infoArg.ActionId + setId = infoArg.SetId + originURL = infoArg.URL + policyId = infoArg.PolicyId + groupId = infoArg.GroupId + useLocalFirewall = infoArg.UseLocalFirewall + timestamp = infoArg.Timestamp + } else { + // 兼容老版本 + m, decodeMapErr := utils.SimpleDecryptMap(info) + if decodeMapErr == nil { + success = true + + actionId = m.GetInt64("actionId") + setId = m.GetInt64("setId") + originURL = m.GetString("url") + policyId = m.GetInt64("policyId") + groupId = m.GetInt64("groupId") + useLocalFirewall = m.GetBool("useLocalFirewall") + timestamp = m.GetInt64("timestamp") + } + } + + if !success { + if len(realURL) > 0 { + req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect) + http.Redirect(writer, req.WAFRaw(), realURL, http.StatusTemporaryRedirect) + } else { + req.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest) + writer.WriteHeader(http.StatusBadRequest) + _, _ = writer.Write([]byte("invalid request (005)")) + } + return + } + + if timestamp < fasttime.Now().Unix()-600 { // 10分钟之后信息过期 req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect) - http.Redirect(writer, req.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect) + http.Redirect(writer, req.WAFRaw(), originURL, http.StatusTemporaryRedirect) return } - var actionId = m.GetInt64("actionId") - var setId = m.GetInt64("setId") - var originURL = m.GetString("url") - var policyId = m.GetInt64("policyId") - var groupId = m.GetInt64("groupId") - var waf = SharedWAFManager.FindWAF(policyId) if waf == nil { req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect) @@ -102,23 +148,23 @@ func (this *CaptchaValidator) Run(req requests.Request, writer http.ResponseWrit if req.WAFRaw().Method == http.MethodPost && len(req.WAFRaw().FormValue(captchaIdName)) > 0 { switch captchaType { case firewallconfigs.CaptchaTypeOneClick: - this.validateOneClickForm(captchaActionConfig, policyId, groupId, setId, originURL, req, writer) + this.validateOneClickForm(captchaActionConfig, policyId, groupId, setId, originURL, req, writer, useLocalFirewall) case firewallconfigs.CaptchaTypeSlide: - this.validateSlideForm(captchaActionConfig, policyId, groupId, setId, originURL, req, writer) + this.validateSlideForm(captchaActionConfig, policyId, groupId, setId, originURL, req, writer, useLocalFirewall) case firewallconfigs.CaptchaTypeGeeTest: - this.validateGeeTestForm(captchaActionConfig, policyId, groupId, setId, originURL, req, writer) + this.validateGeeTestForm(captchaActionConfig, policyId, groupId, setId, originURL, req, writer, useLocalFirewall) default: - this.validateVerifyCodeForm(captchaActionConfig, policyId, groupId, setId, originURL, req, writer) + this.validateVerifyCodeForm(captchaActionConfig, policyId, groupId, setId, originURL, req, writer, useLocalFirewall) } } else { var captchaId = req.WAFRaw().URL.Query().Get(captchaIdName) if len(captchaId) > 0 { // 增加计数 - CaptchaIncreaseFails(req, captchaActionConfig, policyId, groupId, setId, CaptchaPageCodeImage) + CaptchaIncreaseFails(req, captchaActionConfig, policyId, groupId, setId, CaptchaPageCodeImage, useLocalFirewall) this.showImage(captchaActionConfig, req, writer, captchaType) } else { // 增加计数 - CaptchaIncreaseFails(req, captchaActionConfig, policyId, groupId, setId, CaptchaPageCodeShow) + CaptchaIncreaseFails(req, captchaActionConfig, policyId, groupId, setId, CaptchaPageCodeShow, useLocalFirewall) this.show(captchaActionConfig, setId, originURL, req, writer, captchaType) } } @@ -310,7 +356,7 @@ func (this *CaptchaValidator) showVerifyImage(actionConfig *CaptchaAction, req r } } -func (this *CaptchaValidator) validateVerifyCodeForm(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, req requests.Request, writer http.ResponseWriter) (allow bool) { +func (this *CaptchaValidator) validateVerifyCodeForm(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, req requests.Request, writer http.ResponseWriter, useLocalFirewall bool) (allow bool) { var captchaId = req.WAFRaw().FormValue(captchaIdName) if len(captchaId) > 0 { var captchaCode = req.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_CODE") @@ -332,7 +378,7 @@ func (this *CaptchaValidator) validateVerifyCodeForm(actionConfig *CaptchaAction return false } else { // 增加计数 - if !CaptchaIncreaseFails(req, actionConfig, policyId, groupId, setId, CaptchaPageCodeSubmit) { + if !CaptchaIncreaseFails(req, actionConfig, policyId, groupId, setId, CaptchaPageCodeSubmit, useLocalFirewall) { return false } @@ -459,7 +505,7 @@ func (this *CaptchaValidator) showOneClickForm(actionConfig *CaptchaAction, req _, _ = writer.Write([]byte(msgHTML)) } -func (this *CaptchaValidator) validateOneClickForm(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, req requests.Request, writer http.ResponseWriter) (allow bool) { +func (this *CaptchaValidator) validateOneClickForm(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, req requests.Request, writer http.ResponseWriter, useLocalFirewall bool) (allow bool) { var captchaId = req.WAFRaw().FormValue(captchaIdName) var nonce = req.WAFRaw().FormValue("nonce") if len(captchaId) > 0 { @@ -486,7 +532,7 @@ func (this *CaptchaValidator) validateOneClickForm(actionConfig *CaptchaAction, } } else { // 增加计数 - if !CaptchaIncreaseFails(req, actionConfig, policyId, groupId, setId, CaptchaPageCodeSubmit) { + if !CaptchaIncreaseFails(req, actionConfig, policyId, groupId, setId, CaptchaPageCodeSubmit, useLocalFirewall) { return false } @@ -658,7 +704,7 @@ func (this *CaptchaValidator) showSlideForm(actionConfig *CaptchaAction, req req _, _ = writer.Write([]byte(msgHTML)) } -func (this *CaptchaValidator) validateSlideForm(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, req requests.Request, writer http.ResponseWriter) (allow bool) { +func (this *CaptchaValidator) validateSlideForm(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, req requests.Request, writer http.ResponseWriter, useLocalFirewall bool) (allow bool) { var captchaId = req.WAFRaw().FormValue(captchaIdName) var nonce = req.WAFRaw().FormValue("nonce") if len(captchaId) > 0 { @@ -685,7 +731,7 @@ func (this *CaptchaValidator) validateSlideForm(actionConfig *CaptchaAction, pol } } else { // 增加计数 - if !CaptchaIncreaseFails(req, actionConfig, policyId, groupId, setId, CaptchaPageCodeSubmit) { + if !CaptchaIncreaseFails(req, actionConfig, policyId, groupId, setId, CaptchaPageCodeSubmit, useLocalFirewall) { return false } @@ -697,7 +743,7 @@ func (this *CaptchaValidator) validateSlideForm(actionConfig *CaptchaAction, pol return true } -func (this *CaptchaValidator) validateGeeTestForm(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, req requests.Request, writer http.ResponseWriter) (allow bool) { +func (this *CaptchaValidator) validateGeeTestForm(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, req requests.Request, writer http.ResponseWriter, useLocalFirewall bool) (allow bool) { var geeTestConfig = actionConfig.GeeTestConfig if geeTestConfig == nil || !geeTestConfig.IsOn { return @@ -719,7 +765,7 @@ func (this *CaptchaValidator) validateGeeTestForm(actionConfig *CaptchaAction, p writer.WriteHeader(http.StatusOK) } else { // 增加计数 - CaptchaIncreaseFails(req, actionConfig, policyId, groupId, setId, CaptchaPageCodeSubmit) + CaptchaIncreaseFails(req, actionConfig, policyId, groupId, setId, CaptchaPageCodeSubmit, useLocalFirewall) writer.WriteHeader(http.StatusBadRequest) } diff --git a/internal/waf/get302_validator.go b/internal/waf/get302_validator.go index 5afcd19..d157d10 100644 --- a/internal/waf/get302_validator.go +++ b/internal/waf/get302_validator.go @@ -24,36 +24,68 @@ func (this *Get302Validator) Run(request requests.Request, writer http.ResponseW if len(info) == 0 { request.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest) - _, _ = writer.Write([]byte("invalid request")) - return - } - m, err := utils.SimpleDecryptMap(info) - if err != nil { - request.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest) - writer.WriteHeader(http.StatusBadRequest) - _, _ = writer.Write([]byte("invalid request")) + _, _ = writer.Write([]byte("invalid request (002)")) + return + } + + var timestamp int64 + var life int64 + var setId int64 + var policyId int64 + var groupId int64 + var scope string + var url string + + var infoArg = &InfoArg{} + decodeErr := infoArg.Decode(info) + var success bool + if decodeErr == nil && infoArg.IsValid() { + success = true + + timestamp = infoArg.Timestamp + life = int64(infoArg.Life) + setId = infoArg.SetId + policyId = infoArg.PolicyId + groupId = infoArg.GroupId + scope = infoArg.Scope + url = infoArg.URL + } else { + // 兼容老版本 + m, decodeMapErr := utils.SimpleDecryptMap(info) + if decodeMapErr == nil { + success = true + + timestamp = m.GetInt64("timestamp") + life = m.GetInt64("life") + setId = m.GetInt64("setId") + policyId = m.GetInt64("policyId") + groupId = m.GetInt64("groupId") + scope = m.GetString("scope") + url = m.GetString("url") + } + } + + if !success { + request.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest) + writer.WriteHeader(http.StatusBadRequest) + _, _ = writer.Write([]byte("invalid request (003)")) return } - var timestamp = m.GetInt64("timestamp") if time.Now().Unix()-timestamp > 5 { // 超过5秒认为失效 request.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest) - _, _ = writer.Write([]byte("invalid request")) + _, _ = writer.Write([]byte("invalid request (004)")) return } // 加入白名单 - var life = m.GetInt64("life") if life <= 0 { life = 600 // 默认10分钟 } - var setId = types.String(m.GetInt64("setId")) - SharedIPWhiteList.RecordIP("set:"+setId, m.GetString("scope"), request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), false, m.GetInt64("groupId"), m.GetInt64("setId"), "") + SharedIPWhiteList.RecordIP("set:"+types.String(setId), scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, policyId, false, groupId, setId, "") // 返回原始URL - var url = m.GetString("url") - request.ProcessResponseHeaders(writer.Header(), http.StatusFound) http.Redirect(writer, request.WAFRaw(), url, http.StatusFound) } diff --git a/internal/waf/info_arg.go b/internal/waf/info_arg.go new file mode 100644 index 0000000..f1ac887 --- /dev/null +++ b/internal/waf/info_arg.go @@ -0,0 +1,46 @@ +// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" + "net/url" +) + +type InfoArg struct { + ActionId int64 `json:"1,omitempty"` + Timestamp int64 `json:"2,omitempty"` + URL string `json:"3,omitempty"` + PolicyId int64 `json:"4,omitempty"` + GroupId int64 `json:"5,omitempty"` + SetId int64 `json:"6,omitempty"` + UseLocalFirewall bool `json:"7,omitempty"` + Life int32 `json:"8,omitempty"` + Scope string `json:"9,omitempty"` + RemoteIP string `json:"10,omitempty"` +} + +func (this *InfoArg) IsValid() bool { + return this.Timestamp > 0 +} + +func (this *InfoArg) Encode() (string, error) { + if this.Timestamp <= 0 { + this.Timestamp = fasttime.Now().Unix() + } + + return utils.SimpleEncryptObject(this) +} + +func (this *InfoArg) URLEncoded() (string, error) { + encodedString, err := this.Encode() + if err != nil { + return "", err + } + return url.QueryEscape(encodedString), nil +} + +func (this *InfoArg) Decode(encodedString string) error { + return utils.SimpleDecryptObjet(encodedString, this) +} diff --git a/internal/waf/info_arg_test.go b/internal/waf/info_arg_test.go new file mode 100644 index 0000000..0bcb000 --- /dev/null +++ b/internal/waf/info_arg_test.go @@ -0,0 +1,44 @@ +// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package waf_test + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf" + "github.com/iwind/TeaGo/types" + "testing" + "time" +) + +func TestInfoArg_Encode(t *testing.T) { + var info = &waf.InfoArg{ + ActionId: 1, + Timestamp: time.Now().Unix(), + URL: "https://example.com/hello", + PolicyId: 2, + GroupId: 3, + SetId: 4, + UseLocalFirewall: true, + Scope: "global", + } + + encodedString, err := info.Encode() + if err != nil { + t.Fatal(err) + } + t.Log("["+types.String(len(encodedString))+"]", encodedString) + + { + urlEncodedString, encodeErr := info.URLEncoded() + if encodeErr != nil { + t.Fatal(encodeErr) + } + t.Log("["+types.String(len(urlEncodedString))+"]", urlEncodedString) + } + + var info2 = &waf.InfoArg{} + err = info2.Decode(encodedString) + if err != nil { + t.Fatal(err) + } + t.Logf("%+v", info2) +} diff --git a/internal/waf/ip_list.go b/internal/waf/ip_list.go index d395199..37103fc 100644 --- a/internal/waf/ip_list.go +++ b/internal/waf/ip_list.go @@ -89,7 +89,7 @@ func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serv switch scope { case firewallconfigs.FirewallScopeGlobal: ip = "*@" + ip + "@" + ipType - case firewallconfigs.FirewallScopeService: + case firewallconfigs.FirewallScopeServer: ip = types.String(serverId) + "@" + ip + "@" + ipType default: ip = "*@" + ip + "@" + ipType @@ -127,7 +127,7 @@ func (this *IPList) RecordIP(ipType string, if this.listType == IPListTypeDeny { // 作用域 var scopeServerId int64 - if scope == firewallconfigs.FirewallScopeService { + if scope == firewallconfigs.FirewallScopeServer { scopeServerId = serverId } @@ -167,7 +167,7 @@ func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope, switch scope { case firewallconfigs.FirewallScopeGlobal: ip = "*@" + ip + "@" + ipType - case firewallconfigs.FirewallScopeService: + case firewallconfigs.FirewallScopeServer: ip = types.String(serverId) + "@" + ip + "@" + ipType default: ip = "*@" + ip + "@" + ipType @@ -184,7 +184,7 @@ func (this *IPList) ContainsExpires(ipType string, scope firewallconfigs.Firewal switch scope { case firewallconfigs.FirewallScopeGlobal: ip = "*@" + ip + "@" + ipType - case firewallconfigs.FirewallScopeService: + case firewallconfigs.FirewallScopeServer: ip = types.String(serverId) + "@" + ip + "@" + ipType default: ip = "*@" + ip + "@" + ipType diff --git a/internal/waf/ip_list_test.go b/internal/waf/ip_list_test.go index da299d8..0530930 100644 --- a/internal/waf/ip_list_test.go +++ b/internal/waf/ip_list_test.go @@ -23,7 +23,7 @@ func TestNewIPList(t *testing.T) { list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()) list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1) list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2) - list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeService, 1, "127.0.0.3", time.Now().Unix()+3) + list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeServer, 1, "127.0.0.3", time.Now().Unix()+3) list.Add(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10) list.RemoveIP("127.0.0.1", 1, false) diff --git a/internal/waf/waf.go b/internal/waf/waf.go index 17badca..4bed603 100644 --- a/internal/waf/waf.go +++ b/internal/waf/waf.go @@ -26,9 +26,12 @@ type WAF struct { UseLocalFirewall bool `yaml:"useLocalFirewall" json:"useLocalFirewall"` SYNFlood *firewallconfigs.SYNFloodConfig `yaml:"synFlood" json:"synFlood"` - DefaultBlockAction *BlockAction - DefaultPageAction *PageAction - DefaultCaptchaAction *CaptchaAction + DefaultBlockAction *BlockAction + DefaultPageAction *PageAction + DefaultCaptchaAction *CaptchaAction + DefaultJSCookieAction *JSCookieAction + DefaultPost307Action *Post307Action + DefaultGet302Action *Get302Action hasInboundRules bool hasOutboundRules bool diff --git a/internal/waf/waf_manager.go b/internal/waf/waf_manager.go index b8d4a57..73df256 100644 --- a/internal/waf/waf_manager.go +++ b/internal/waf/waf_manager.go @@ -176,11 +176,12 @@ func (this *WAFManager) ConvertWAF(policy *firewallconfigs.HTTPFirewallPolicy) ( // block action if policy.BlockOptions != nil { w.DefaultBlockAction = &BlockAction{ - StatusCode: policy.BlockOptions.StatusCode, - Body: policy.BlockOptions.Body, - URL: policy.BlockOptions.URL, - Timeout: policy.BlockOptions.Timeout, - TimeoutMax: policy.BlockOptions.TimeoutMax, + StatusCode: policy.BlockOptions.StatusCode, + Body: policy.BlockOptions.Body, + URL: policy.BlockOptions.URL, + Timeout: policy.BlockOptions.Timeout, + TimeoutMax: policy.BlockOptions.TimeoutMax, + FailBlockScopeAll: policy.BlockOptions.FailBlockScopeAll, } } @@ -214,6 +215,33 @@ func (this *WAFManager) ConvertWAF(policy *firewallconfigs.HTTPFirewallPolicy) ( } } + // get302 + if policy.Get302Options != nil { + w.DefaultGet302Action = &Get302Action{ + Life: policy.Get302Options.Life, + Scope: policy.Get302Options.Scope, + } + } + + // post307 + if policy.Post307Options != nil { + w.DefaultPost307Action = &Post307Action{ + Life: policy.Post307Options.Life, + Scope: policy.Post307Options.Scope, + } + } + + // jscookie + if policy.JSCookieOptions != nil { + w.DefaultJSCookieAction = &JSCookieAction{ + Life: policy.JSCookieOptions.Life, + MaxFails: policy.JSCookieOptions.MaxFails, + FailBlockTimeout: policy.JSCookieOptions.FailBlockTimeout, + Scope: policy.JSCookieOptions.Scope, + FailBlockScopeAll: policy.JSCookieOptions.FailBlockScopeAll, + } + } + errorList := w.Init() if len(errorList) > 0 { return w, errorList[0]