diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index eabd4ce..1e4855c 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -194,7 +194,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir } // 规则测试 - w := sharedWAFManager.FindWAF(firewallPolicy.Id) + w := waf.SharedWAFManager.FindWAF(firewallPolicy.Id) if w == nil { return } @@ -261,7 +261,7 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi return } - w := sharedWAFManager.FindWAF(firewallPolicy.Id) + w := waf.SharedWAFManager.FindWAF(firewallPolicy.Id) if w == nil { return } diff --git a/internal/nodes/node.go b/internal/nodes/node.go index 753519a..594b7f9 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -21,6 +21,7 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/stats" "github.com/TeaOSLab/EdgeNode/internal/trackers" "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/andybalholm/brotli" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/lists" @@ -865,7 +866,7 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig) { } // WAF策略 - sharedWAFManager.UpdatePolicies(config.FindAllFirewallPolicies()) + waf.SharedWAFManager.UpdatePolicies(config.FindAllFirewallPolicies()) iplibrary.SharedActionManager.UpdateActions(config.FirewallActions) // 统计指标 diff --git a/internal/waf/action_allow.go b/internal/waf/action_allow.go index 46fda78..2f55770 100644 --- a/internal/waf/action_allow.go +++ b/internal/waf/action_allow.go @@ -6,6 +6,7 @@ import ( ) type AllowAction struct { + BaseAction } func (this *AllowAction) Init(waf *WAF) error { diff --git a/internal/waf/action_base.go b/internal/waf/action_base.go index 93bf304..16e9fdc 100644 --- a/internal/waf/action_base.go +++ b/internal/waf/action_base.go @@ -7,6 +7,17 @@ import ( ) type BaseAction struct { + currentActionId int64 +} + +// ActionId 读取ActionId +func (this *BaseAction) ActionId() int64 { + return this.currentActionId +} + +// SetActionId 设置Id +func (this *BaseAction) SetActionId(actionId int64) { + this.currentActionId = actionId } // CloseConn 关闭连接 diff --git a/internal/waf/action_block.go b/internal/waf/action_block.go index fffade1..39cc898 100644 --- a/internal/waf/action_block.go +++ b/internal/waf/action_block.go @@ -20,6 +20,8 @@ var urlPrefixReg = regexp.MustCompile("^(?i)(http|https)://") var httpClient = utils.SharedHttpClient(5 * time.Second) type BlockAction struct { + BaseAction + StatusCode int `yaml:"statusCode" json:"statusCode"` Body string `yaml:"body" json:"body"` // supports HTML URL string `yaml:"url" json:"url"` diff --git a/internal/waf/action_captcha.go b/internal/waf/action_captcha.go index 14147f2..048ad0d 100644 --- a/internal/waf/action_captcha.go +++ b/internal/waf/action_captcha.go @@ -18,16 +18,71 @@ const ( ) type CaptchaAction struct { - Life int32 `yaml:"life" json:"life"` - MaxFails int `yaml:"maxFails" json:"maxFails"` // 最大失败次数 - FailBlockTimeout int `yaml:"failBlockTimeout" json:"failBlockTimeout"` // 失败拦截时间 + BaseAction - Language string `yaml:"language" json:"language"` // 语言,zh-CN, en-US ... + Life int32 `yaml:"life" json:"life"` + MaxFails int `yaml:"maxFails" json:"maxFails"` // 最大失败次数 + FailBlockTimeout int `yaml:"failBlockTimeout" json:"failBlockTimeout"` // 失败拦截时间 + FailBlockScopeAll bool `yaml:"failBlockScopeAll" json:"failBlockScopeAll"` // 是否全局有效 + + CountLetters int8 `yaml:"countLetters" json:"countLetters"` + + UIIsOn bool `yaml:"uiIsOn" json:"uiIsOn"` // 是否使用自定义UI + UITitle string `yaml:"uiTitle" json:"uiTitle"` // 消息标题 + UIPrompt string `yaml:"uiPrompt" json:"uiPrompt"` // 消息提示 + UIButtonTitle string `yaml:"uiButtonTitle" json:"uiButtonTitle"` // 按钮标题 + UIShowRequestId bool `yaml:"uiShowRequestId" json:"uiShowRequestId"` // 是否显示请求ID + UICss string `yaml:"uiCss" json:"uiCss"` // CSS样式 + UIFooter string `yaml:"uiFooter" json:"uiFooter"` // 页脚 + UIBody string `yaml:"uiBody" json:"uiBody"` // 内容轮廓 + + Lang string `yaml:"lang" json:"lang"` // 语言,zh-CN, en-US ... AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单 Scope string `yaml:"scope" json:"scope"` } func (this *CaptchaAction) Init(waf *WAF) error { + if waf.DefaultCaptchaAction != nil { + if this.Life <= 0 { + this.Life = waf.DefaultCaptchaAction.Life + } + if this.MaxFails <= 0 { + this.MaxFails = waf.DefaultCaptchaAction.MaxFails + } + if this.FailBlockTimeout <= 0 { + this.FailBlockTimeout = waf.DefaultCaptchaAction.FailBlockTimeout + } + this.FailBlockScopeAll = waf.DefaultCaptchaAction.FailBlockScopeAll + + if this.CountLetters <= 0 { + this.CountLetters = waf.DefaultCaptchaAction.CountLetters + } + + this.UIIsOn = waf.DefaultCaptchaAction.UIIsOn + if len(this.UITitle) == 0 { + this.UITitle = waf.DefaultCaptchaAction.UITitle + } + if len(this.UIPrompt) == 0 { + this.UIPrompt = waf.DefaultCaptchaAction.UIPrompt + } + if len(this.UIButtonTitle) == 0 { + this.UIButtonTitle = waf.DefaultCaptchaAction.UIButtonTitle + } + this.UIShowRequestId = waf.DefaultCaptchaAction.UIShowRequestId + if len(this.UICss) == 0 { + this.UICss = waf.DefaultCaptchaAction.UICss + } + if len(this.UIFooter) == 0 { + this.UIFooter = waf.DefaultCaptchaAction.UIFooter + } + if len(this.UIBody) == 0 { + this.UIBody = waf.DefaultCaptchaAction.UIBody + } + if len(this.Lang) == 0 { + this.Lang = waf.DefaultCaptchaAction.Lang + } + } + return nil } @@ -49,7 +104,7 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req return true } - refURL := request.WAFRaw().URL.String() + var refURL = request.WAFRaw().URL.String() // 覆盖配置 if strings.HasPrefix(refURL, CaptchaPath) { @@ -63,14 +118,12 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req } var captchaConfig = maps.Map{ - "action": this, - "timestamp": time.Now().Unix(), - "maxFails": this.MaxFails, - "failBlockTimeout": this.FailBlockTimeout, - "url": refURL, - "policyId": waf.Id, - "groupId": group.Id, - "setId": set.Id, + "actionId": this.ActionId(), + "timestamp": time.Now().Unix(), + "url": refURL, + "policyId": waf.Id, + "groupId": group.Id, + "setId": set.Id, } info, err := utils.SimpleEncryptMap(captchaConfig) if err != nil { diff --git a/internal/waf/action_go_group.go b/internal/waf/action_go_group.go index 201dfd0..fe8afc3 100644 --- a/internal/waf/action_go_group.go +++ b/internal/waf/action_go_group.go @@ -8,6 +8,8 @@ import ( ) type GoGroupAction struct { + BaseAction + GroupId string `yaml:"groupId" json:"groupId"` } diff --git a/internal/waf/action_go_set.go b/internal/waf/action_go_set.go index 0f515e5..ea998e9 100644 --- a/internal/waf/action_go_set.go +++ b/internal/waf/action_go_set.go @@ -8,6 +8,8 @@ import ( ) type GoSetAction struct { + BaseAction + GroupId string `yaml:"groupId" json:"groupId"` SetId string `yaml:"setId" json:"setId"` } diff --git a/internal/waf/action_interface.go b/internal/waf/action_interface.go index 256b58e..a9de7ac 100644 --- a/internal/waf/action_interface.go +++ b/internal/waf/action_interface.go @@ -11,6 +11,12 @@ type ActionInterface interface { // Init 初始化 Init(waf *WAF) error + // ActionId 读取ActionId + ActionId() int64 + + // SetActionId 设置ID + SetActionId(id int64) + // Code 代号 Code() string @@ -20,6 +26,6 @@ type ActionInterface interface { // WillChange determine if the action will change the request WillChange() bool - // Perform perform the action + // Perform the action Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) } diff --git a/internal/waf/action_log.go b/internal/waf/action_log.go index 74c85ac..5d7f334 100644 --- a/internal/waf/action_log.go +++ b/internal/waf/action_log.go @@ -6,6 +6,7 @@ import ( ) type LogAction struct { + BaseAction } func (this *LogAction) Init(waf *WAF) error { diff --git a/internal/waf/action_notify.go b/internal/waf/action_notify.go index d2bc6b4..33a5485 100644 --- a/internal/waf/action_notify.go +++ b/internal/waf/action_notify.go @@ -50,6 +50,7 @@ func init() { } type NotifyAction struct { + BaseAction } func (this *NotifyAction) Init(waf *WAF) error { @@ -69,7 +70,7 @@ func (this *NotifyAction) WillChange() bool { return false } -// Perform perform the action +// Perform the action func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { select { case notifyChan <- ¬ifyTask{ diff --git a/internal/waf/action_tag.go b/internal/waf/action_tag.go index b39794f..1e84512 100644 --- a/internal/waf/action_tag.go +++ b/internal/waf/action_tag.go @@ -6,6 +6,8 @@ import ( ) type TagAction struct { + BaseAction + Tags []string `yaml:"tags" json:"tags"` } diff --git a/internal/waf/action_utils.go b/internal/waf/action_utils.go index 39f1259..ced1bec 100644 --- a/internal/waf/action_utils.go +++ b/internal/waf/action_utils.go @@ -5,15 +5,19 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/iwind/TeaGo/maps" "reflect" + "sync/atomic" ) +var seedActionId int64 = 1 + func FindActionInstance(action ActionString, options maps.Map) ActionInterface { for _, def := range AllActions { if def.Code == action { if def.Type != nil { // create new instance - ptrValue := reflect.New(def.Type) - instance := ptrValue.Interface().(ActionInterface) + var ptrValue = reflect.New(def.Type) + var instance = ptrValue.Interface().(ActionInterface) + instance.SetActionId(atomic.AddInt64(&seedActionId, 1)) if len(options) > 0 { optionsJSON, err := json.Marshal(options) diff --git a/internal/waf/captcha_validator.go b/internal/waf/captcha_validator.go index 1148004..81786cb 100644 --- a/internal/waf/captcha_validator.go +++ b/internal/waf/captcha_validator.go @@ -6,7 +6,6 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeNode/internal/ttlcache" "github.com/TeaOSLab/EdgeNode/internal/utils" - "github.com/TeaOSLab/EdgeNode/internal/utils/jsonutils" "github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/dchest/captcha" "github.com/iwind/TeaGo/logs" @@ -26,8 +25,8 @@ func NewCaptchaValidator() *CaptchaValidator { return &CaptchaValidator{} } -func (this *CaptchaValidator) Run(request requests.Request, writer http.ResponseWriter) { - var info = request.WAFRaw().URL.Query().Get("info") +func (this *CaptchaValidator) Run(req requests.Request, writer http.ResponseWriter) { + var info = req.WAFRaw().URL.Query().Get("info") if len(info) == 0 { writer.WriteHeader(http.StatusBadRequest) _, _ = writer.Write([]byte("invalid request")) @@ -41,35 +40,48 @@ func (this *CaptchaValidator) Run(request requests.Request, writer http.Response var timestamp = m.GetInt64("timestamp") if timestamp < time.Now().Unix()-600 { // 10分钟之后信息过期 - http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect) - return - } - - var actionConfig = &CaptchaAction{} - err = jsonutils.MapToObject(m.GetMap("action"), actionConfig) - if err != nil { - http.Redirect(writer, request.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect) + http.Redirect(writer, req.WAFRaw(), m.GetString("url"), http.StatusTemporaryRedirect) return } + var actionId = m.GetInt64("actionId") var setId = m.GetInt64("setId") var originURL = m.GetString("url") - var maxFails = m.GetInt("maxFails") - var failBlockTimeout = m.GetInt("failBlockTimeout") var policyId = m.GetInt64("policyId") var groupId = m.GetInt64("groupId") - if request.WAFRaw().Method == http.MethodPost && len(request.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 { - this.validate(actionConfig, maxFails, failBlockTimeout, policyId, groupId, setId, originURL, request, writer) + + var waf = SharedWAFManager.FindWAF(policyId) + if waf == nil { + http.Redirect(writer, req.WAFRaw(), originURL, http.StatusTemporaryRedirect) + return + } + var actionConfig = waf.FindAction(actionId) + if actionConfig == nil { + http.Redirect(writer, req.WAFRaw(), originURL, http.StatusTemporaryRedirect) + return + } + captchaActionConfig, ok := actionConfig.(*CaptchaAction) + if !ok { + http.Redirect(writer, req.WAFRaw(), originURL, http.StatusTemporaryRedirect) + return + } + + if req.WAFRaw().Method == http.MethodPost && len(req.WAFRaw().FormValue("GOEDGE_WAF_CAPTCHA_ID")) > 0 { + this.validate(captchaActionConfig, policyId, groupId, setId, originURL, req, writer) } else { // 增加计数 - this.IncreaseFails(request, maxFails, failBlockTimeout, policyId, groupId, setId) - this.show(actionConfig, request, writer) + this.IncreaseFails(req, captchaActionConfig, policyId, groupId, setId) + this.show(captchaActionConfig, req, writer) } } -func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests.Request, writer http.ResponseWriter) { +func (this *CaptchaValidator) show(actionConfig *CaptchaAction, req requests.Request, writer http.ResponseWriter) { // show captcha - var captchaId = captcha.NewLen(6) + var countLetters = 6 + if actionConfig.CountLetters > 0 && actionConfig.CountLetters <= 10 { + countLetters = int(actionConfig.CountLetters) + } + var captchaId = captcha.NewLen(countLetters) var buf = bytes.NewBuffer([]byte{}) err := captcha.WriteImage(buf, captchaId, 200, 100) if err != nil { @@ -77,9 +89,9 @@ func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests return } - var lang = actionConfig.Language + var lang = actionConfig.Lang if len(lang) == 0 { - acceptLanguage := request.WAFRaw().Header.Get("Accept-Language") + var acceptLanguage = req.WAFRaw().Header.Get("Accept-Language") if len(acceptLanguage) > 0 { langIndex := strings.Index(acceptLanguage, ",") if langIndex > 0 { @@ -114,12 +126,62 @@ func (this *CaptchaValidator) show(actionConfig *CaptchaAction, request requests msgRequestId = "Request ID" } + var msgCss = "" + var requestIdBox = `
` + msgRequestId + `: ` + req.Format("${requestId}") + `
` + var msgFooter = "" + var body = `
+ +
+ ` + ` +
+
+

` + msgPrompt + `

+ +
+
+ +
+
+` + requestIdBox + ` +` + msgFooter + `` + + // 默认设置 + if actionConfig.UIIsOn { + if len(actionConfig.UITitle) > 0 { + msgTitle = actionConfig.UITitle + } + if len(actionConfig.UIPrompt) > 0 { + msgPrompt = actionConfig.UIPrompt + } + if len(actionConfig.UIButtonTitle) > 0 { + msgButtonTitle = actionConfig.UIButtonTitle + } + if len(actionConfig.UICss) > 0 { + msgCss = actionConfig.UICss + } + if !actionConfig.UIShowRequestId { + requestIdBox = "" + } + if len(actionConfig.UIFooter) > 0 { + msgFooter = actionConfig.UIFooter + } + if len(actionConfig.UIBody) > 0 { + var index = strings.Index(actionConfig.UIBody, "${body}") + if index < 0 { + body = actionConfig.UIBody + body + } else { + body = actionConfig.UIBody[:index] + body + actionConfig.UIBody[index+7:] // 7是"${body}"的长度 + } + } + } + writer.Header().Set("Content-Type", "text/html; charset=utf-8") _, _ = writer.Write([]byte(` ` + msgTitle + ` +