diff --git a/internal/waf/checkpoints/checkpoint.go b/internal/waf/checkpoints/checkpoint.go index ef3b58a..f2764ba 100644 --- a/internal/waf/checkpoints/checkpoint.go +++ b/internal/waf/checkpoints/checkpoint.go @@ -6,6 +6,7 @@ import ( ) type Checkpoint struct { + priority int } func (this *Checkpoint) Init() { @@ -36,6 +37,14 @@ func (this *Checkpoint) Stop() { } +func (this *Checkpoint) SetPriority(priority int) { + this.priority = priority +} + +func (this *Checkpoint) Priority() int { + return this.priority +} + func (this *Checkpoint) RequestBodyIsEmpty(req requests.Request) bool { if req.WAFRaw().ContentLength == 0 { return true diff --git a/internal/waf/checkpoints/checkpoint_definition.go b/internal/waf/checkpoints/checkpoint_definition.go index 0857aef..1c07222 100644 --- a/internal/waf/checkpoints/checkpoint_definition.go +++ b/internal/waf/checkpoints/checkpoint_definition.go @@ -7,4 +7,5 @@ type CheckpointDefinition struct { Prefix string HasParams bool // has sub params Instance CheckpointInterface + Priority int } diff --git a/internal/waf/checkpoints/checkpoint_interface.go b/internal/waf/checkpoints/checkpoint_interface.go index 6afdd49..5deacbb 100644 --- a/internal/waf/checkpoints/checkpoint_interface.go +++ b/internal/waf/checkpoints/checkpoint_interface.go @@ -33,4 +33,10 @@ type CheckpointInterface interface { // Stop stop Stop() + + // SetPriority set priority + SetPriority(priority int) + + // get priority + Priority() int } diff --git a/internal/waf/checkpoints/utils.go b/internal/waf/checkpoints/utils.go index 23aa7e2..dbf954f 100644 --- a/internal/waf/checkpoints/utils.go +++ b/internal/waf/checkpoints/utils.go @@ -8,6 +8,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "通用Header比如Cache-Control、Accept之类的长度限制,防止缓冲区溢出攻击", HasParams: false, Instance: new(RequestGeneralHeaderLengthCheckpoint), + Priority: 100, }, { Name: "客户端地址(IP)", @@ -15,6 +16,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "试图通过分析X-Forwarded-For等Header获取的客户端地址,比如192.168.1.100", HasParams: false, Instance: new(RequestRemoteAddrCheckpoint), + Priority: 100, }, { Name: "客户端源地址(IP)", @@ -22,6 +24,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "直接连接的客户端地址,比如192.168.1.100", HasParams: false, Instance: new(RequestRawRemoteAddrCheckpoint), + Priority: 100, }, { Name: "客户端端口", @@ -29,6 +32,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "直接连接的客户端地址端口", HasParams: false, Instance: new(RequestRemotePortCheckpoint), + Priority: 100, }, { Name: "客户端用户名", @@ -36,6 +40,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "通过BasicAuth登录的客户端用户名", HasParams: false, Instance: new(RequestRemoteUserCheckpoint), + Priority: 100, }, { Name: "请求URI", @@ -43,6 +48,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "包含URL参数的请求URI,类似于 /hello/world?lang=go", HasParams: false, Instance: new(RequestURICheckpoint), + Priority: 100, }, { Name: "请求路径", @@ -50,6 +56,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "不包含URL参数的请求路径,类似于 /hello/world", HasParams: false, Instance: new(RequestPathCheckpoint), + Priority: 100, }, { Name: "请求URL", @@ -57,6 +64,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "完整的请求URL,包含协议、域名、请求路径、参数等,类似于 https://example.com/hello?name=lily", HasParams: false, Instance: new(RequestURLCheckpoint), + Priority: 100, }, { Name: "请求内容长度", @@ -64,6 +72,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "请求Header中的Content-Length", HasParams: false, Instance: new(RequestLengthCheckpoint), + Priority: 100, }, { Name: "请求体内容", @@ -71,6 +80,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "通常在POST或者PUT等操作时会附带请求体,最大限制32M", HasParams: false, Instance: new(RequestBodyCheckpoint), + Priority: 5, }, { Name: "请求URI和请求体组合", @@ -78,6 +88,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "${requestURI}和${requestBody}组合", HasParams: false, Instance: new(RequestAllCheckpoint), + Priority: 5, }, { Name: "请求表单参数", @@ -85,6 +96,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "获取POST或者其他方法发送的表单参数,最大请求体限制32M", HasParams: true, Instance: new(RequestFormArgCheckpoint), + Priority: 5, }, { Name: "上传文件", @@ -92,6 +104,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "获取POST上传的文件信息,最大请求体限制32M", HasParams: true, Instance: new(RequestUploadCheckpoint), + Priority: 20, }, { Name: "请求JSON参数", @@ -99,6 +112,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "获取POST或者其他方法发送的JSON,最大请求体限制32M,使用点(.)符号表示多级数据", HasParams: true, Instance: new(RequestJSONArgCheckpoint), + Priority: 5, }, { Name: "请求方法", @@ -106,6 +120,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "比如GET、POST", HasParams: false, Instance: new(RequestMethodCheckpoint), + Priority: 100, }, { Name: "请求协议", @@ -113,6 +128,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "比如http或https", HasParams: false, Instance: new(RequestSchemeCheckpoint), + Priority: 100, }, { Name: "HTTP协议版本", @@ -120,6 +136,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "比如HTTP/1.1", HasParams: false, Instance: new(RequestProtoCheckpoint), + Priority: 100, }, { Name: "主机名", @@ -127,6 +144,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "比如teaos.cn", HasParams: false, Instance: new(RequestHostCheckpoint), + Priority: 100, }, { Name: "请求来源URL", @@ -134,6 +152,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "请求Header中的Referer值", HasParams: false, Instance: new(RequestRefererCheckpoint), + Priority: 100, }, { Name: "客户端信息", @@ -141,6 +160,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "比如Mozilla/5.0 AppleWebKit/537.36 (KHTML, like Gecko) Chrome/73.0.3683.103", HasParams: false, Instance: new(RequestUserAgentCheckpoint), + Priority: 100, }, { Name: "内容类型", @@ -148,6 +168,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "请求Header的Content-Type", HasParams: false, Instance: new(RequestContentTypeCheckpoint), + Priority: 100, }, { Name: "所有cookie组合字符串", @@ -155,6 +176,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "比如sid=IxZVPFhE&city=beijing&uid=18237", HasParams: false, Instance: new(RequestCookiesCheckpoint), + Priority: 100, }, { Name: "单个cookie值", @@ -162,6 +184,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "单个cookie值", HasParams: true, Instance: new(RequestCookieCheckpoint), + Priority: 100, }, { Name: "所有URL参数组合", @@ -169,6 +192,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "比如name=lu&age=20", HasParams: false, Instance: new(RequestArgsCheckpoint), + Priority: 100, }, { Name: "单个URL参数值", @@ -176,6 +200,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "单个URL参数值", HasParams: true, Instance: new(RequestArgCheckpoint), + Priority: 100, }, { Name: "所有Header信息", @@ -183,6 +208,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "使用\\n隔开的Header信息字符串", HasParams: false, Instance: new(RequestHeadersCheckpoint), + Priority: 100, }, { Name: "单个Header值", @@ -190,6 +216,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "单个Header值", HasParams: true, Instance: new(RequestHeaderCheckpoint), + Priority: 100, }, { Name: "国家/地区名称", @@ -197,6 +224,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "国家/地区名称", HasParams: false, Instance: new(RequestGeoCountryNameCheckpoint), + Priority: 90, }, { Name: "省份名称", @@ -204,6 +232,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "中国省份名称", HasParams: false, Instance: new(RequestGeoProvinceNameCheckpoint), + Priority: 90, }, { Name: "城市名称", @@ -211,6 +240,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "中国城市名称", HasParams: false, Instance: new(RequestGeoCityNameCheckpoint), + Priority: 90, }, { Name: "ISP名称", @@ -218,6 +248,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "ISP名称", HasParams: false, Instance: new(RequestISPNameCheckpoint), + Priority: 90, }, { Name: "CC统计(旧)", @@ -225,6 +256,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "统计某段时间段内的请求信息", HasParams: true, Instance: new(CCCheckpoint), + Priority: 10, }, { Name: "CC统计(新)", @@ -232,6 +264,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "统计某段时间段内的请求信息", HasParams: true, Instance: new(CC2Checkpoint), + Priority: 10, }, { Name: "防盗链", @@ -239,6 +272,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "阻止一些域名访问引用本站资源", HasParams: true, Instance: new(RequestRefererBlockCheckpoint), + Priority: 20, }, { Name: "通用响应Header长度限制", @@ -246,6 +280,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "通用Header比如Cache-Control、Accept之类的长度限制,防止缓冲区溢出攻击", HasParams: false, Instance: new(ResponseGeneralHeaderLengthCheckpoint), + Priority: 100, }, { Name: "响应状态码", @@ -253,6 +288,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "响应状态码,比如200、404、500", HasParams: false, Instance: new(ResponseStatusCheckpoint), + Priority: 100, }, { Name: "响应Header", @@ -260,6 +296,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "响应Header值", HasParams: true, Instance: new(ResponseHeaderCheckpoint), + Priority: 100, }, { Name: "响应内容", @@ -267,6 +304,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "响应内容字符串", HasParams: false, Instance: new(ResponseBodyCheckpoint), + Priority: 5, }, { Name: "响应内容长度", @@ -274,6 +312,7 @@ var AllCheckpoints = []*CheckpointDefinition{ Description: "响应内容长度,通过响应的Header Content-Length获取", HasParams: false, Instance: new(ResponseBytesSentCheckpoint), + Priority: 100, }, } @@ -281,6 +320,7 @@ var AllCheckpoints = []*CheckpointDefinition{ func FindCheckpoint(prefix string) CheckpointInterface { for _, def := range AllCheckpoints { if def.Prefix == prefix { + def.Instance.SetPriority(def.Priority) return def.Instance } } diff --git a/internal/waf/rule.go b/internal/waf/rule.go index 28c3531..9bff108 100644 --- a/internal/waf/rule.go +++ b/internal/waf/rule.go @@ -35,6 +35,7 @@ type Rule struct { Value string `yaml:"value" json:"value"` // compared value IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"` CheckpointOptions map[string]interface{} `yaml:"checkpointOptions" json:"checkpointOptions"` + Priority int `yaml:"priority" json:"priority"` checkpointFinder func(prefix string) checkpoints.CheckpointInterface @@ -132,9 +133,9 @@ func (this *Rule) Init() error { } if singleParamRegexp.MatchString(this.Param) { - param := this.Param[2 : len(this.Param)-1] - pieces := strings.SplitN(param, ".", 2) - prefix := pieces[0] + var param = this.Param[2 : len(this.Param)-1] + var pieces = strings.SplitN(param, ".", 2) + var prefix = pieces[0] if len(pieces) == 1 { this.singleParam = "" } else { @@ -142,18 +143,20 @@ func (this *Rule) Init() error { } if this.checkpointFinder != nil { - checkpoint := this.checkpointFinder(prefix) + var checkpoint = this.checkpointFinder(prefix) if checkpoint == nil { return errors.New("no check point '" + prefix + "' found") } this.singleCheckpoint = checkpoint + this.Priority = checkpoint.Priority() } else { - checkpoint := checkpoints.FindCheckpoint(prefix) + var checkpoint = checkpoints.FindCheckpoint(prefix) if checkpoint == nil { return errors.New("no check point '" + prefix + "' found") } checkpoint.Init() this.singleCheckpoint = checkpoint + this.Priority = checkpoint.Priority() } return nil @@ -162,22 +165,24 @@ func (this *Rule) Init() error { this.multipleCheckpoints = map[string]checkpoints.CheckpointInterface{} var err error = nil configutils.ParseVariables(this.Param, func(varName string) (value string) { - pieces := strings.SplitN(varName, ".", 2) - prefix := pieces[0] + var pieces = strings.SplitN(varName, ".", 2) + var prefix = pieces[0] if this.checkpointFinder != nil { - checkpoint := this.checkpointFinder(prefix) + var checkpoint = this.checkpointFinder(prefix) if checkpoint == nil { err = errors.New("no check point '" + prefix + "' found") } else { this.multipleCheckpoints[prefix] = checkpoint + this.Priority = checkpoint.Priority() } } else { - checkpoint := checkpoints.FindCheckpoint(prefix) + var checkpoint = checkpoints.FindCheckpoint(prefix) if checkpoint == nil { err = errors.New("no check point '" + prefix + "' found") } else { checkpoint.Init() this.multipleCheckpoints[prefix] = checkpoint + this.Priority = checkpoint.Priority() } } return "" diff --git a/internal/waf/rule_set.go b/internal/waf/rule_set.go index 8d2b200..49cdd00 100644 --- a/internal/waf/rule_set.go +++ b/internal/waf/rule_set.go @@ -52,6 +52,11 @@ func (this *RuleSet) Init(waf *WAF) error { return errors.New("init rule '" + rule.Param + " " + rule.Operator + " " + types.String(rule.Value) + "' failed: " + err.Error()) } } + + // sort by priority + sort.Slice(this.Rules, func(i, j int) bool { + return this.Rules[i].Priority > this.Rules[j].Priority + }) } // action codes diff --git a/internal/waf/waf.go b/internal/waf/waf.go index 5c9e9f3..6d6975c 100644 --- a/internal/waf/waf.go +++ b/internal/waf/waf.go @@ -73,6 +73,7 @@ func (this *WAF) Init() (resultErrors []error) { for _, def := range checkpoints.AllCheckpoints { instance := reflect.New(reflect.Indirect(reflect.ValueOf(def.Instance)).Type()).Interface().(checkpoints.CheckpointInterface) instance.Init() + instance.SetPriority(def.Priority) this.checkpointsMap[def.Prefix] = instance }