diff --git a/internal/waf/checkpoints/request_all.go b/internal/waf/checkpoints/request_all.go index cba0849..1290511 100644 --- a/internal/waf/checkpoints/request_all.go +++ b/internal/waf/checkpoints/request_all.go @@ -12,11 +12,11 @@ type RequestAllCheckpoint struct { } func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) { - var valueBytes = []byte{} + var valueBytes = [][]byte{} if len(req.WAFRaw().RequestURI) > 0 { - valueBytes = append(valueBytes, req.WAFRaw().RequestURI...) + valueBytes = append(valueBytes, []byte(req.WAFRaw().RequestURI)) } else if req.WAFRaw().URL != nil { - valueBytes = append(valueBytes, req.WAFRaw().URL.RequestURI()...) + valueBytes = append(valueBytes, []byte(req.WAFRaw().URL.RequestURI())) } if this.RequestBodyIsEmpty(req) { @@ -25,8 +25,6 @@ func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param strin } if req.WAFRaw().Body != nil { - valueBytes = append(valueBytes, ' ') - var bodyData = req.WAFGetCacheBody() hasRequestBody = true if len(bodyData) == 0 { @@ -39,7 +37,9 @@ func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param strin req.WAFSetCacheBody(data) req.WAFRestoreBody(data) } - valueBytes = append(valueBytes, bodyData...) + if len(bodyData) > 0 { + valueBytes = append(valueBytes, bodyData) + } } value = valueBytes diff --git a/internal/waf/checkpoints/request_all_test.go b/internal/waf/checkpoints/request_all_test.go index 9ac8def..5a28d8b 100644 --- a/internal/waf/checkpoints/request_all_test.go +++ b/internal/waf/checkpoints/request_all_test.go @@ -25,8 +25,14 @@ func TestRequestAllCheckpoint_RequestValue(t *testing.T) { if userErr != nil { t.Fatal(userErr) } - t.Log(v) - t.Log(types.String(v)) + if v != nil { + vv, ok := v.([][]byte) + if ok { + for _, v2 := range vv { + t.Log(string(v2), ":", v2) + } + } + } body, err := io.ReadAll(req.Body) if err != nil { diff --git a/internal/waf/rule.go b/internal/waf/rule.go index f68cac4..7f582b8 100644 --- a/internal/waf/rule.go +++ b/internal/waf/rule.go @@ -29,14 +29,14 @@ var singleParamRegexp = regexp.MustCompile(`^\${[\w.-]+}$`) type Rule struct { Id int64 - Description string `yaml:"description" json:"description"` - Param string `yaml:"param" json:"param"` // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName} - ParamFilters []*ParamFilter `yaml:"paramFilters" json:"paramFilters"` - Operator RuleOperator `yaml:"operator" json:"operator"` // such as contains, gt, ... - Value string `yaml:"value" json:"value"` // compared value - IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"` - CheckpointOptions map[string]interface{} `yaml:"checkpointOptions" json:"checkpointOptions"` - Priority int `yaml:"priority" json:"priority"` + Description string `yaml:"description" json:"description"` + Param string `yaml:"param" json:"param"` // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName} + ParamFilters []*ParamFilter `yaml:"paramFilters" json:"paramFilters"` + Operator RuleOperator `yaml:"operator" json:"operator"` // such as contains, gt, ... + Value string `yaml:"value" json:"value"` // compared value + IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"` + CheckpointOptions map[string]any `yaml:"checkpointOptions" json:"checkpointOptions"` + Priority int `yaml:"priority" json:"priority"` checkpointFinder func(prefix string) checkpoints.CheckpointInterface @@ -93,7 +93,7 @@ func (this *Rule) Init() error { } } case RuleOperatorMatch: - v := this.Value + var v = this.Value if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") { v = "(?i)" + v } @@ -106,7 +106,7 @@ func (this *Rule) Init() error { } this.reg = reg case RuleOperatorNotMatch: - v := this.Value + var v = this.Value if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") { v = "(?i)" + v } @@ -239,9 +239,9 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, hasRequestBody boo return this.Test(value), hasRequestBody, nil } - value := configutils.ParseVariables(this.Param, func(varName string) (value string) { - pieces := strings.SplitN(varName, ".", 2) - prefix := pieces[0] + var value = configutils.ParseVariables(this.Param, func(varName string) (value string) { + var pieces = strings.SplitN(varName, ".", 2) + var prefix = pieces[0] point, ok := this.multipleCheckpoints[prefix] if !ok { return "" @@ -255,7 +255,7 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, hasRequestBody boo if err1 != nil { err = err1 } - return types.String(value1) + return this.stringifyValue(value1) } value1, hasCheckRequestBody, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions, this.Id) @@ -265,7 +265,7 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, hasRequestBody boo if err1 != nil { err = err1 } - return types.String(value1) + return this.stringifyValue(value1) }) if err != nil { @@ -312,9 +312,9 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) ( return this.Test(value), hasRequestBody, nil } - value := configutils.ParseVariables(this.Param, func(varName string) (value string) { - pieces := strings.SplitN(varName, ".", 2) - prefix := pieces[0] + var value = configutils.ParseVariables(this.Param, func(varName string) (value string) { + var pieces = strings.SplitN(varName, ".", 2) + var prefix = pieces[0] point, ok := this.multipleCheckpoints[prefix] if !ok { return "" @@ -329,7 +329,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) ( if err1 != nil { err = err1 } - return types.String(value1) + return this.stringifyValue(value1) } else { value1, hasCheckRequestBody, err1, _ := point.ResponseValue(req, resp, "", this.CheckpointOptions, this.Id) if hasCheckRequestBody { @@ -338,7 +338,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) ( if err1 != nil { err = err1 } - return types.String(value1) + return this.stringifyValue(value1) } } @@ -350,7 +350,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) ( if err1 != nil { err = err1 } - return types.String(value1) + return this.stringifyValue(value1) } else { value1, hasCheckRequestBody, err1, _ := point.ResponseValue(req, resp, pieces[1], this.CheckpointOptions, this.Id) if hasCheckRequestBody { @@ -359,7 +359,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) ( if err1 != nil { err = err1 } - return types.String(value1) + return this.stringifyValue(value1) } }) @@ -387,19 +387,19 @@ func (this *Rule) Test(value any) bool { return types.Float64(value) != this.floatValue case RuleOperatorEqString: if this.IsCaseInsensitive { - return strings.EqualFold(types.String(value), this.Value) + return strings.EqualFold(this.stringifyValue(value), this.Value) } else { - return types.String(value) == this.Value + return this.stringifyValue(value) == this.Value } case RuleOperatorNeqString: if this.IsCaseInsensitive { - return !strings.EqualFold(types.String(value), this.Value) + return !strings.EqualFold(this.stringifyValue(value), this.Value) } else { - return types.String(value) != this.Value + return this.stringifyValue(value) != this.Value } case RuleOperatorMatch, RuleOperatorWildcardMatch: if value == nil { - return false + value = "" } // strings @@ -413,6 +413,17 @@ func (this *Rule) Test(value any) bool { return false } + // bytes list + byteSlices, ok := value.([][]byte) + if ok { + for _, byteSlice := range byteSlices { + if utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) { + return true + } + } + return false + } + // bytes byteSlice, ok := value.([]byte) if ok { @@ -420,10 +431,10 @@ func (this *Rule) Test(value any) bool { } // string - return utils.MatchStringCache(this.reg, types.String(value), this.regCacheLife) + return utils.MatchStringCache(this.reg, this.stringifyValue(value), this.regCacheLife) case RuleOperatorNotMatch, RuleOperatorWildcardNotMatch: if value == nil { - return true + value = "" } stringList, ok := value.([]string) if ok { @@ -435,20 +446,31 @@ func (this *Rule) Test(value any) bool { return true } + // bytes list + byteSlices, ok := value.([][]byte) + if ok { + for _, byteSlice := range byteSlices { + if utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) { + return false + } + } + return true + } + // bytes byteSlice, ok := value.([]byte) if ok { return !utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) } - return !utils.MatchStringCache(this.reg, types.String(value), this.regCacheLife) + return !utils.MatchStringCache(this.reg, this.stringifyValue(value), this.regCacheLife) case RuleOperatorContains: if types.IsSlice(value) { _, isBytes := value.([]byte) if !isBytes { - ok := false + var ok = false lists.Each(value, func(k int, v any) { - if types.String(v) == this.Value { + if this.stringifyValue(v) == this.Value { ok = true } }) @@ -456,17 +478,17 @@ func (this *Rule) Test(value any) bool { } } if types.IsMap(value) { - lowerValue := "" + var lowerValue = "" if this.IsCaseInsensitive { lowerValue = strings.ToLower(this.Value) } for _, v := range maps.NewMap(value) { if this.IsCaseInsensitive { - if strings.ToLower(types.String(v)) == lowerValue { + if strings.ToLower(this.stringifyValue(v)) == lowerValue { return true } } else { - if types.String(v) == this.Value { + if this.stringifyValue(v) == this.Value { return true } } @@ -475,30 +497,30 @@ func (this *Rule) Test(value any) bool { } if this.IsCaseInsensitive { - return strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value)) + return strings.Contains(strings.ToLower(this.stringifyValue(value)), strings.ToLower(this.Value)) } else { - return strings.Contains(types.String(value), this.Value) + return strings.Contains(this.stringifyValue(value), this.Value) } case RuleOperatorNotContains: if this.IsCaseInsensitive { - return !strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value)) + return !strings.Contains(strings.ToLower(this.stringifyValue(value)), strings.ToLower(this.Value)) } else { - return !strings.Contains(types.String(value), this.Value) + return !strings.Contains(this.stringifyValue(value), this.Value) } case RuleOperatorPrefix: if this.IsCaseInsensitive { - return strings.HasPrefix(strings.ToLower(types.String(value)), strings.ToLower(this.Value)) + return strings.HasPrefix(strings.ToLower(this.stringifyValue(value)), strings.ToLower(this.Value)) } else { - return strings.HasPrefix(types.String(value), this.Value) + return strings.HasPrefix(this.stringifyValue(value), this.Value) } case RuleOperatorSuffix: if this.IsCaseInsensitive { - return strings.HasSuffix(strings.ToLower(types.String(value)), strings.ToLower(this.Value)) + return strings.HasSuffix(strings.ToLower(this.stringifyValue(value)), strings.ToLower(this.Value)) } else { - return strings.HasSuffix(types.String(value), this.Value) + return strings.HasSuffix(this.stringifyValue(value), this.Value) } case RuleOperatorContainsAny: - var stringValue = types.String(value) + var stringValue = this.stringifyValue(value) if this.IsCaseInsensitive { stringValue = strings.ToLower(stringValue) } @@ -511,7 +533,7 @@ func (this *Rule) Test(value any) bool { } return false case RuleOperatorContainsAll: - var stringValue = types.String(value) + var stringValue = this.stringifyValue(value) if this.IsCaseInsensitive { stringValue = strings.ToLower(stringValue) } @@ -525,30 +547,30 @@ func (this *Rule) Test(value any) bool { } return false case RuleOperatorContainsBinary: - data, _ := base64.StdEncoding.DecodeString(types.String(this.Value)) + data, _ := base64.StdEncoding.DecodeString(this.stringifyValue(this.Value)) if this.IsCaseInsensitive { - return bytes.Contains(bytes.ToUpper([]byte(types.String(value))), bytes.ToUpper(data)) + return bytes.Contains(bytes.ToUpper([]byte(this.stringifyValue(value))), bytes.ToUpper(data)) } else { - return bytes.Contains([]byte(types.String(value)), data) + return bytes.Contains([]byte(this.stringifyValue(value)), data) } case RuleOperatorNotContainsBinary: - data, _ := base64.StdEncoding.DecodeString(types.String(this.Value)) + data, _ := base64.StdEncoding.DecodeString(this.stringifyValue(this.Value)) if this.IsCaseInsensitive { - return !bytes.Contains(bytes.ToUpper([]byte(types.String(value))), bytes.ToUpper(data)) + return !bytes.Contains(bytes.ToUpper([]byte(this.stringifyValue(value))), bytes.ToUpper(data)) } else { - return !bytes.Contains([]byte(types.String(value)), data) + return !bytes.Contains([]byte(this.stringifyValue(value)), data) } case RuleOperatorHasKey: if types.IsSlice(value) { - index := types.Int(this.Value) + var index = types.Int(this.Value) if index < 0 { return false } return reflect.ValueOf(value).Len() > index } else if types.IsMap(value) { - m := maps.NewMap(value) + var m = maps.NewMap(value) if this.IsCaseInsensitive { - lowerValue := strings.ToLower(this.Value) + var lowerValue = strings.ToLower(this.Value) for k := range m { if strings.ToLower(k) == lowerValue { return true @@ -567,9 +589,9 @@ func (this *Rule) Test(value any) bool { return stringutil.VersionCompare(this.Value, types.String(value)) < 0 case RuleOperatorVersionRange: if strings.Contains(this.Value, ",") { - versions := strings.SplitN(this.Value, ",", 2) - version1 := strings.TrimSpace(versions[0]) - version2 := strings.TrimSpace(versions[1]) + var versions = strings.SplitN(this.Value, ",", 2) + var version1 = strings.TrimSpace(versions[0]) + var version2 = strings.TrimSpace(versions[1]) if len(version1) > 0 && stringutil.VersionCompare(types.String(value), version1) < 0 { return false } @@ -587,25 +609,25 @@ func (this *Rule) Test(value any) bool { } return this.isIP && ip.Equal(this.ipValue) case RuleOperatorGtIP: - ip := net.ParseIP(types.String(value)) + var ip = net.ParseIP(types.String(value)) if ip == nil { return false } return this.isIP && bytes.Compare(ip, this.ipValue) > 0 case RuleOperatorGteIP: - ip := net.ParseIP(types.String(value)) + var ip = net.ParseIP(types.String(value)) if ip == nil { return false } return this.isIP && bytes.Compare(ip, this.ipValue) >= 0 case RuleOperatorLtIP: - ip := net.ParseIP(types.String(value)) + var ip = net.ParseIP(types.String(value)) if ip == nil { return false } return this.isIP && bytes.Compare(ip, this.ipValue) < 0 case RuleOperatorLteIP: - ip := net.ParseIP(types.String(value)) + var ip = net.ParseIP(types.String(value)) if ip == nil { return false } @@ -624,7 +646,7 @@ func (this *Rule) Test(value any) bool { if div == 0 { return false } - rem := types.Int64(pieces[1]) + var rem = types.Int64(pieces[1]) return this.ipToInt64(net.ParseIP(types.String(value)))%div == rem case RuleOperatorIPMod10: return this.ipToInt64(net.ParseIP(types.String(value)))%10 == types.Int64(this.Value) @@ -737,3 +759,25 @@ func (this *Rule) execFilter(value any) any { } return value } + +func (this *Rule) stringifyValue(value any) string { + if value == nil { + return "" + } + switch v := value.(type) { + case string: + return v + case []string: + return strings.Join(v, "") + case []byte: + return string(v) + case [][]byte: + var b = &bytes.Buffer{} + for _, vb := range v { + b.Write(vb) + } + return b.String() + default: + return types.String(v) + } +} diff --git a/internal/waf/rule_test.go b/internal/waf/rule_test.go index e9597b3..657bd32 100644 --- a/internal/waf/rule_test.go +++ b/internal/waf/rule_test.go @@ -205,6 +205,30 @@ func TestRule_Test(t *testing.T) { a.IsFalse(rule.Test("abc123")) } + { + var rule = NewRule() + rule.Operator = RuleOperatorMatch + rule.Value = "^\\d+" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test([]byte("123"))) + a.IsFalse(rule.Test([]byte("abc123"))) + } + + { + var rule = NewRule() + rule.Operator = RuleOperatorMatch + rule.Value = "^\\d+" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test([][]byte{[]byte("123"), []byte("456")})) + a.IsFalse(rule.Test([][]byte{[]byte("abc123")})) + } + { rule := NewRule() rule.Operator = RuleOperatorMatch @@ -265,6 +289,19 @@ func TestRule_Test(t *testing.T) { a.IsTrue(rule.Test([]string{"abc123"})) } + { + var rule = NewRule() + rule.Operator = RuleOperatorNotMatch + rule.Value = "^\\d+" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test([][]byte{[]byte("123"), []byte("456")})) + a.IsFalse(rule.Test([][]byte{[]byte("123"), []byte("abc")})) + a.IsTrue(rule.Test([][]byte{[]byte("abc123")})) + } + { rule := NewRule() rule.Operator = RuleOperatorMatch