WAF checkpoint返回值支持[][]byte

This commit is contained in:
刘祥超
2023-12-05 17:18:53 +08:00
parent facea1ed96
commit 9f77f62308
4 changed files with 157 additions and 70 deletions

View File

@@ -12,11 +12,11 @@ type RequestAllCheckpoint struct {
} }
func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) { func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map, ruleId int64) (value any, hasRequestBody bool, sysErr error, userErr error) {
var valueBytes = []byte{} var valueBytes = [][]byte{}
if len(req.WAFRaw().RequestURI) > 0 { if len(req.WAFRaw().RequestURI) > 0 {
valueBytes = append(valueBytes, req.WAFRaw().RequestURI...) valueBytes = append(valueBytes, []byte(req.WAFRaw().RequestURI))
} else if req.WAFRaw().URL != nil { } else if req.WAFRaw().URL != nil {
valueBytes = append(valueBytes, req.WAFRaw().URL.RequestURI()...) valueBytes = append(valueBytes, []byte(req.WAFRaw().URL.RequestURI()))
} }
if this.RequestBodyIsEmpty(req) { if this.RequestBodyIsEmpty(req) {
@@ -25,8 +25,6 @@ func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param strin
} }
if req.WAFRaw().Body != nil { if req.WAFRaw().Body != nil {
valueBytes = append(valueBytes, ' ')
var bodyData = req.WAFGetCacheBody() var bodyData = req.WAFGetCacheBody()
hasRequestBody = true hasRequestBody = true
if len(bodyData) == 0 { if len(bodyData) == 0 {
@@ -39,7 +37,9 @@ func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param strin
req.WAFSetCacheBody(data) req.WAFSetCacheBody(data)
req.WAFRestoreBody(data) req.WAFRestoreBody(data)
} }
valueBytes = append(valueBytes, bodyData...) if len(bodyData) > 0 {
valueBytes = append(valueBytes, bodyData)
}
} }
value = valueBytes value = valueBytes

View File

@@ -25,8 +25,14 @@ func TestRequestAllCheckpoint_RequestValue(t *testing.T) {
if userErr != nil { if userErr != nil {
t.Fatal(userErr) t.Fatal(userErr)
} }
t.Log(v) if v != nil {
t.Log(types.String(v)) vv, ok := v.([][]byte)
if ok {
for _, v2 := range vv {
t.Log(string(v2), ":", v2)
}
}
}
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
if err != nil { if err != nil {

View File

@@ -29,14 +29,14 @@ var singleParamRegexp = regexp.MustCompile(`^\${[\w.-]+}$`)
type Rule struct { type Rule struct {
Id int64 Id int64
Description string `yaml:"description" json:"description"` Description string `yaml:"description" json:"description"`
Param string `yaml:"param" json:"param"` // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName} Param string `yaml:"param" json:"param"` // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName}
ParamFilters []*ParamFilter `yaml:"paramFilters" json:"paramFilters"` ParamFilters []*ParamFilter `yaml:"paramFilters" json:"paramFilters"`
Operator RuleOperator `yaml:"operator" json:"operator"` // such as contains, gt, ... Operator RuleOperator `yaml:"operator" json:"operator"` // such as contains, gt, ...
Value string `yaml:"value" json:"value"` // compared value Value string `yaml:"value" json:"value"` // compared value
IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"` IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"`
CheckpointOptions map[string]interface{} `yaml:"checkpointOptions" json:"checkpointOptions"` CheckpointOptions map[string]any `yaml:"checkpointOptions" json:"checkpointOptions"`
Priority int `yaml:"priority" json:"priority"` Priority int `yaml:"priority" json:"priority"`
checkpointFinder func(prefix string) checkpoints.CheckpointInterface checkpointFinder func(prefix string) checkpoints.CheckpointInterface
@@ -93,7 +93,7 @@ func (this *Rule) Init() error {
} }
} }
case RuleOperatorMatch: case RuleOperatorMatch:
v := this.Value var v = this.Value
if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") { if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") {
v = "(?i)" + v v = "(?i)" + v
} }
@@ -106,7 +106,7 @@ func (this *Rule) Init() error {
} }
this.reg = reg this.reg = reg
case RuleOperatorNotMatch: case RuleOperatorNotMatch:
v := this.Value var v = this.Value
if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") { if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") {
v = "(?i)" + v v = "(?i)" + v
} }
@@ -239,9 +239,9 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, hasRequestBody boo
return this.Test(value), hasRequestBody, nil return this.Test(value), hasRequestBody, nil
} }
value := configutils.ParseVariables(this.Param, func(varName string) (value string) { var value = configutils.ParseVariables(this.Param, func(varName string) (value string) {
pieces := strings.SplitN(varName, ".", 2) var pieces = strings.SplitN(varName, ".", 2)
prefix := pieces[0] var prefix = pieces[0]
point, ok := this.multipleCheckpoints[prefix] point, ok := this.multipleCheckpoints[prefix]
if !ok { if !ok {
return "" return ""
@@ -255,7 +255,7 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, hasRequestBody boo
if err1 != nil { if err1 != nil {
err = err1 err = err1
} }
return types.String(value1) return this.stringifyValue(value1)
} }
value1, hasCheckRequestBody, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions, this.Id) value1, hasCheckRequestBody, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions, this.Id)
@@ -265,7 +265,7 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, hasRequestBody boo
if err1 != nil { if err1 != nil {
err = err1 err = err1
} }
return types.String(value1) return this.stringifyValue(value1)
}) })
if err != nil { if err != nil {
@@ -312,9 +312,9 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
return this.Test(value), hasRequestBody, nil return this.Test(value), hasRequestBody, nil
} }
value := configutils.ParseVariables(this.Param, func(varName string) (value string) { var value = configutils.ParseVariables(this.Param, func(varName string) (value string) {
pieces := strings.SplitN(varName, ".", 2) var pieces = strings.SplitN(varName, ".", 2)
prefix := pieces[0] var prefix = pieces[0]
point, ok := this.multipleCheckpoints[prefix] point, ok := this.multipleCheckpoints[prefix]
if !ok { if !ok {
return "" return ""
@@ -329,7 +329,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
if err1 != nil { if err1 != nil {
err = err1 err = err1
} }
return types.String(value1) return this.stringifyValue(value1)
} else { } else {
value1, hasCheckRequestBody, err1, _ := point.ResponseValue(req, resp, "", this.CheckpointOptions, this.Id) value1, hasCheckRequestBody, err1, _ := point.ResponseValue(req, resp, "", this.CheckpointOptions, this.Id)
if hasCheckRequestBody { if hasCheckRequestBody {
@@ -338,7 +338,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
if err1 != nil { if err1 != nil {
err = err1 err = err1
} }
return types.String(value1) return this.stringifyValue(value1)
} }
} }
@@ -350,7 +350,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
if err1 != nil { if err1 != nil {
err = err1 err = err1
} }
return types.String(value1) return this.stringifyValue(value1)
} else { } else {
value1, hasCheckRequestBody, err1, _ := point.ResponseValue(req, resp, pieces[1], this.CheckpointOptions, this.Id) value1, hasCheckRequestBody, err1, _ := point.ResponseValue(req, resp, pieces[1], this.CheckpointOptions, this.Id)
if hasCheckRequestBody { if hasCheckRequestBody {
@@ -359,7 +359,7 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (
if err1 != nil { if err1 != nil {
err = err1 err = err1
} }
return types.String(value1) return this.stringifyValue(value1)
} }
}) })
@@ -387,19 +387,19 @@ func (this *Rule) Test(value any) bool {
return types.Float64(value) != this.floatValue return types.Float64(value) != this.floatValue
case RuleOperatorEqString: case RuleOperatorEqString:
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
return strings.EqualFold(types.String(value), this.Value) return strings.EqualFold(this.stringifyValue(value), this.Value)
} else { } else {
return types.String(value) == this.Value return this.stringifyValue(value) == this.Value
} }
case RuleOperatorNeqString: case RuleOperatorNeqString:
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
return !strings.EqualFold(types.String(value), this.Value) return !strings.EqualFold(this.stringifyValue(value), this.Value)
} else { } else {
return types.String(value) != this.Value return this.stringifyValue(value) != this.Value
} }
case RuleOperatorMatch, RuleOperatorWildcardMatch: case RuleOperatorMatch, RuleOperatorWildcardMatch:
if value == nil { if value == nil {
return false value = ""
} }
// strings // strings
@@ -413,6 +413,17 @@ func (this *Rule) Test(value any) bool {
return false return false
} }
// bytes list
byteSlices, ok := value.([][]byte)
if ok {
for _, byteSlice := range byteSlices {
if utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) {
return true
}
}
return false
}
// bytes // bytes
byteSlice, ok := value.([]byte) byteSlice, ok := value.([]byte)
if ok { if ok {
@@ -420,10 +431,10 @@ func (this *Rule) Test(value any) bool {
} }
// string // string
return utils.MatchStringCache(this.reg, types.String(value), this.regCacheLife) return utils.MatchStringCache(this.reg, this.stringifyValue(value), this.regCacheLife)
case RuleOperatorNotMatch, RuleOperatorWildcardNotMatch: case RuleOperatorNotMatch, RuleOperatorWildcardNotMatch:
if value == nil { if value == nil {
return true value = ""
} }
stringList, ok := value.([]string) stringList, ok := value.([]string)
if ok { if ok {
@@ -435,20 +446,31 @@ func (this *Rule) Test(value any) bool {
return true return true
} }
// bytes list
byteSlices, ok := value.([][]byte)
if ok {
for _, byteSlice := range byteSlices {
if utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) {
return false
}
}
return true
}
// bytes // bytes
byteSlice, ok := value.([]byte) byteSlice, ok := value.([]byte)
if ok { if ok {
return !utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) return !utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife)
} }
return !utils.MatchStringCache(this.reg, types.String(value), this.regCacheLife) return !utils.MatchStringCache(this.reg, this.stringifyValue(value), this.regCacheLife)
case RuleOperatorContains: case RuleOperatorContains:
if types.IsSlice(value) { if types.IsSlice(value) {
_, isBytes := value.([]byte) _, isBytes := value.([]byte)
if !isBytes { if !isBytes {
ok := false var ok = false
lists.Each(value, func(k int, v any) { lists.Each(value, func(k int, v any) {
if types.String(v) == this.Value { if this.stringifyValue(v) == this.Value {
ok = true ok = true
} }
}) })
@@ -456,17 +478,17 @@ func (this *Rule) Test(value any) bool {
} }
} }
if types.IsMap(value) { if types.IsMap(value) {
lowerValue := "" var lowerValue = ""
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
lowerValue = strings.ToLower(this.Value) lowerValue = strings.ToLower(this.Value)
} }
for _, v := range maps.NewMap(value) { for _, v := range maps.NewMap(value) {
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
if strings.ToLower(types.String(v)) == lowerValue { if strings.ToLower(this.stringifyValue(v)) == lowerValue {
return true return true
} }
} else { } else {
if types.String(v) == this.Value { if this.stringifyValue(v) == this.Value {
return true return true
} }
} }
@@ -475,30 +497,30 @@ func (this *Rule) Test(value any) bool {
} }
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
return strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value)) return strings.Contains(strings.ToLower(this.stringifyValue(value)), strings.ToLower(this.Value))
} else { } else {
return strings.Contains(types.String(value), this.Value) return strings.Contains(this.stringifyValue(value), this.Value)
} }
case RuleOperatorNotContains: case RuleOperatorNotContains:
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
return !strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value)) return !strings.Contains(strings.ToLower(this.stringifyValue(value)), strings.ToLower(this.Value))
} else { } else {
return !strings.Contains(types.String(value), this.Value) return !strings.Contains(this.stringifyValue(value), this.Value)
} }
case RuleOperatorPrefix: case RuleOperatorPrefix:
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
return strings.HasPrefix(strings.ToLower(types.String(value)), strings.ToLower(this.Value)) return strings.HasPrefix(strings.ToLower(this.stringifyValue(value)), strings.ToLower(this.Value))
} else { } else {
return strings.HasPrefix(types.String(value), this.Value) return strings.HasPrefix(this.stringifyValue(value), this.Value)
} }
case RuleOperatorSuffix: case RuleOperatorSuffix:
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
return strings.HasSuffix(strings.ToLower(types.String(value)), strings.ToLower(this.Value)) return strings.HasSuffix(strings.ToLower(this.stringifyValue(value)), strings.ToLower(this.Value))
} else { } else {
return strings.HasSuffix(types.String(value), this.Value) return strings.HasSuffix(this.stringifyValue(value), this.Value)
} }
case RuleOperatorContainsAny: case RuleOperatorContainsAny:
var stringValue = types.String(value) var stringValue = this.stringifyValue(value)
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
stringValue = strings.ToLower(stringValue) stringValue = strings.ToLower(stringValue)
} }
@@ -511,7 +533,7 @@ func (this *Rule) Test(value any) bool {
} }
return false return false
case RuleOperatorContainsAll: case RuleOperatorContainsAll:
var stringValue = types.String(value) var stringValue = this.stringifyValue(value)
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
stringValue = strings.ToLower(stringValue) stringValue = strings.ToLower(stringValue)
} }
@@ -525,30 +547,30 @@ func (this *Rule) Test(value any) bool {
} }
return false return false
case RuleOperatorContainsBinary: case RuleOperatorContainsBinary:
data, _ := base64.StdEncoding.DecodeString(types.String(this.Value)) data, _ := base64.StdEncoding.DecodeString(this.stringifyValue(this.Value))
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
return bytes.Contains(bytes.ToUpper([]byte(types.String(value))), bytes.ToUpper(data)) return bytes.Contains(bytes.ToUpper([]byte(this.stringifyValue(value))), bytes.ToUpper(data))
} else { } else {
return bytes.Contains([]byte(types.String(value)), data) return bytes.Contains([]byte(this.stringifyValue(value)), data)
} }
case RuleOperatorNotContainsBinary: case RuleOperatorNotContainsBinary:
data, _ := base64.StdEncoding.DecodeString(types.String(this.Value)) data, _ := base64.StdEncoding.DecodeString(this.stringifyValue(this.Value))
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
return !bytes.Contains(bytes.ToUpper([]byte(types.String(value))), bytes.ToUpper(data)) return !bytes.Contains(bytes.ToUpper([]byte(this.stringifyValue(value))), bytes.ToUpper(data))
} else { } else {
return !bytes.Contains([]byte(types.String(value)), data) return !bytes.Contains([]byte(this.stringifyValue(value)), data)
} }
case RuleOperatorHasKey: case RuleOperatorHasKey:
if types.IsSlice(value) { if types.IsSlice(value) {
index := types.Int(this.Value) var index = types.Int(this.Value)
if index < 0 { if index < 0 {
return false return false
} }
return reflect.ValueOf(value).Len() > index return reflect.ValueOf(value).Len() > index
} else if types.IsMap(value) { } else if types.IsMap(value) {
m := maps.NewMap(value) var m = maps.NewMap(value)
if this.IsCaseInsensitive { if this.IsCaseInsensitive {
lowerValue := strings.ToLower(this.Value) var lowerValue = strings.ToLower(this.Value)
for k := range m { for k := range m {
if strings.ToLower(k) == lowerValue { if strings.ToLower(k) == lowerValue {
return true return true
@@ -567,9 +589,9 @@ func (this *Rule) Test(value any) bool {
return stringutil.VersionCompare(this.Value, types.String(value)) < 0 return stringutil.VersionCompare(this.Value, types.String(value)) < 0
case RuleOperatorVersionRange: case RuleOperatorVersionRange:
if strings.Contains(this.Value, ",") { if strings.Contains(this.Value, ",") {
versions := strings.SplitN(this.Value, ",", 2) var versions = strings.SplitN(this.Value, ",", 2)
version1 := strings.TrimSpace(versions[0]) var version1 = strings.TrimSpace(versions[0])
version2 := strings.TrimSpace(versions[1]) var version2 = strings.TrimSpace(versions[1])
if len(version1) > 0 && stringutil.VersionCompare(types.String(value), version1) < 0 { if len(version1) > 0 && stringutil.VersionCompare(types.String(value), version1) < 0 {
return false return false
} }
@@ -587,25 +609,25 @@ func (this *Rule) Test(value any) bool {
} }
return this.isIP && ip.Equal(this.ipValue) return this.isIP && ip.Equal(this.ipValue)
case RuleOperatorGtIP: case RuleOperatorGtIP:
ip := net.ParseIP(types.String(value)) var ip = net.ParseIP(types.String(value))
if ip == nil { if ip == nil {
return false return false
} }
return this.isIP && bytes.Compare(ip, this.ipValue) > 0 return this.isIP && bytes.Compare(ip, this.ipValue) > 0
case RuleOperatorGteIP: case RuleOperatorGteIP:
ip := net.ParseIP(types.String(value)) var ip = net.ParseIP(types.String(value))
if ip == nil { if ip == nil {
return false return false
} }
return this.isIP && bytes.Compare(ip, this.ipValue) >= 0 return this.isIP && bytes.Compare(ip, this.ipValue) >= 0
case RuleOperatorLtIP: case RuleOperatorLtIP:
ip := net.ParseIP(types.String(value)) var ip = net.ParseIP(types.String(value))
if ip == nil { if ip == nil {
return false return false
} }
return this.isIP && bytes.Compare(ip, this.ipValue) < 0 return this.isIP && bytes.Compare(ip, this.ipValue) < 0
case RuleOperatorLteIP: case RuleOperatorLteIP:
ip := net.ParseIP(types.String(value)) var ip = net.ParseIP(types.String(value))
if ip == nil { if ip == nil {
return false return false
} }
@@ -624,7 +646,7 @@ func (this *Rule) Test(value any) bool {
if div == 0 { if div == 0 {
return false return false
} }
rem := types.Int64(pieces[1]) var rem = types.Int64(pieces[1])
return this.ipToInt64(net.ParseIP(types.String(value)))%div == rem return this.ipToInt64(net.ParseIP(types.String(value)))%div == rem
case RuleOperatorIPMod10: case RuleOperatorIPMod10:
return this.ipToInt64(net.ParseIP(types.String(value)))%10 == types.Int64(this.Value) return this.ipToInt64(net.ParseIP(types.String(value)))%10 == types.Int64(this.Value)
@@ -737,3 +759,25 @@ func (this *Rule) execFilter(value any) any {
} }
return value return value
} }
func (this *Rule) stringifyValue(value any) string {
if value == nil {
return ""
}
switch v := value.(type) {
case string:
return v
case []string:
return strings.Join(v, "")
case []byte:
return string(v)
case [][]byte:
var b = &bytes.Buffer{}
for _, vb := range v {
b.Write(vb)
}
return b.String()
default:
return types.String(v)
}
}

View File

@@ -205,6 +205,30 @@ func TestRule_Test(t *testing.T) {
a.IsFalse(rule.Test("abc123")) a.IsFalse(rule.Test("abc123"))
} }
{
var rule = NewRule()
rule.Operator = RuleOperatorMatch
rule.Value = "^\\d+"
err := rule.Init()
if err != nil {
t.Fatal(err)
}
a.IsTrue(rule.Test([]byte("123")))
a.IsFalse(rule.Test([]byte("abc123")))
}
{
var rule = NewRule()
rule.Operator = RuleOperatorMatch
rule.Value = "^\\d+"
err := rule.Init()
if err != nil {
t.Fatal(err)
}
a.IsTrue(rule.Test([][]byte{[]byte("123"), []byte("456")}))
a.IsFalse(rule.Test([][]byte{[]byte("abc123")}))
}
{ {
rule := NewRule() rule := NewRule()
rule.Operator = RuleOperatorMatch rule.Operator = RuleOperatorMatch
@@ -265,6 +289,19 @@ func TestRule_Test(t *testing.T) {
a.IsTrue(rule.Test([]string{"abc123"})) a.IsTrue(rule.Test([]string{"abc123"}))
} }
{
var rule = NewRule()
rule.Operator = RuleOperatorNotMatch
rule.Value = "^\\d+"
err := rule.Init()
if err != nil {
t.Fatal(err)
}
a.IsFalse(rule.Test([][]byte{[]byte("123"), []byte("456")}))
a.IsFalse(rule.Test([][]byte{[]byte("123"), []byte("abc")}))
a.IsTrue(rule.Test([][]byte{[]byte("abc123")}))
}
{ {
rule := NewRule() rule := NewRule()
rule.Operator = RuleOperatorMatch rule.Operator = RuleOperatorMatch