diff --git a/internal/iplibrary/manager_city.go b/internal/iplibrary/manager_city.go new file mode 100644 index 0000000..a07e701 --- /dev/null +++ b/internal/iplibrary/manager_city.go @@ -0,0 +1,149 @@ +package iplibrary + +import ( + "crypto/md5" + "encoding/json" + "fmt" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/TeaOSLab/EdgeNode/internal/goman" + "github.com/TeaOSLab/EdgeNode/internal/remotelogs" + "github.com/TeaOSLab/EdgeNode/internal/rpc" + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/iwind/TeaGo/Tea" + _ "github.com/iwind/TeaGo/bootstrap" + "github.com/iwind/TeaGo/types" + "io/ioutil" + "os" + "sync" + "time" +) + +var SharedCityManager = NewCityManager() + +func init() { + events.On(events.EventLoaded, func() { + goman.New(func() { + SharedCityManager.Start() + }) + }) +} + +// CityManager 中国省份信息管理 +type CityManager struct { + cacheFile string + + cityMap map[string]int64 // provinceName_cityName => cityName + dataHash string // 国家JSON的md5 + + locker sync.RWMutex + + isUpdated bool +} + +func NewCityManager() *CityManager { + return &CityManager{ + cacheFile: Tea.Root + "/configs/region_city.json.cache", + cityMap: map[string]int64{}, + } +} + +func (this *CityManager) Start() { + // 从缓存中读取 + err := this.load() + if err != nil { + remotelogs.ErrorObject("CITY_MANAGER", err) + } + + // 第一次更新 + err = this.loop() + if err != nil { + remotelogs.ErrorObject("City_MANAGER", err) + } + + // 定时更新 + ticker := utils.NewTicker(4 * time.Hour) + events.On(events.EventQuit, func() { + ticker.Stop() + }) + for ticker.Next() { + err := this.loop() + if err != nil { + remotelogs.ErrorObject("CITY_MANAGER", err) + } + } +} + +func (this *CityManager) Lookup(provinceId int64, cityName string) (cityId int64) { + this.locker.RLock() + cityId, _ = this.cityMap[types.String(provinceId)+"_"+cityName] + this.locker.RUnlock() + return +} + +// 从缓存中读取 +func (this *CityManager) load() error { + data, err := ioutil.ReadFile(this.cacheFile) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + m := map[string]int64{} + err = json.Unmarshal(data, &m) + if err != nil { + return err + } + if m != nil && len(m) > 0 { + this.cityMap = m + } + + return nil +} + +// 更新城市信息 +func (this *CityManager) loop() error { + if this.isUpdated { + return nil + } + + rpcClient, err := rpc.SharedRPC() + if err != nil { + return err + } + resp, err := rpcClient.RegionCityRPC().FindAllEnabledRegionCities(rpcClient.Context(), &pb.FindAllEnabledRegionCitiesRequest{}) + if err != nil { + return err + } + + m := map[string]int64{} + for _, city := range resp.RegionCities { + for _, code := range city.Codes { + m[types.String(city.RegionProvinceId)+"_"+code] = city.Id + } + } + + // 检查是否有更新 + data, err := json.Marshal(m) + if err != nil { + return err + } + hash := md5.New() + hash.Write(data) + dataHash := fmt.Sprintf("%x", hash.Sum(nil)) + if this.dataHash == dataHash { + return nil + } + this.dataHash = dataHash + + this.locker.Lock() + this.cityMap = m + this.isUpdated = true + this.locker.Unlock() + + // 保存到本地缓存 + + err = ioutil.WriteFile(this.cacheFile, data, 0666) + return err +} diff --git a/internal/iplibrary/manager_city_test.go b/internal/iplibrary/manager_city_test.go new file mode 100644 index 0000000..426dee2 --- /dev/null +++ b/internal/iplibrary/manager_city_test.go @@ -0,0 +1,14 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package iplibrary + +import "testing" + +func TestNewCityManager(t *testing.T) { + var manager = NewCityManager() + err := manager.loop() + if err != nil { + t.Fatal(err) + } + t.Log(manager.Lookup(16, "许昌市")) +} diff --git a/internal/iplibrary/manager_country.go b/internal/iplibrary/manager_country.go index f75c10e..9c85b20 100644 --- a/internal/iplibrary/manager_country.go +++ b/internal/iplibrary/manager_country.go @@ -36,6 +36,8 @@ type CountryManager struct { dataHash string // 国家JSON的md5 locker sync.RWMutex + + isUpdated bool } func NewCountryManager() *CountryManager { @@ -59,7 +61,7 @@ func (this *CountryManager) Start() { } // 定时更新 - ticker := utils.NewTicker(1 * time.Hour) + ticker := utils.NewTicker(4 * time.Hour) events.On(events.EventQuit, func() { ticker.Stop() }) @@ -101,6 +103,10 @@ func (this *CountryManager) load() error { // 更新国家信息 func (this *CountryManager) loop() error { + if this.isUpdated { + return nil + } + rpcClient, err := rpc.SharedRPC() if err != nil { return err @@ -111,7 +117,7 @@ func (this *CountryManager) loop() error { } m := map[string]int64{} - for _, country := range resp.Countries { + for _, country := range resp.RegionCountries { for _, code := range country.Codes { m[code] = country.Id } @@ -132,6 +138,7 @@ func (this *CountryManager) loop() error { this.locker.Lock() this.countryMap = m + this.isUpdated = true this.locker.Unlock() // 保存到本地缓存 diff --git a/internal/iplibrary/manager_provider.go b/internal/iplibrary/manager_provider.go new file mode 100644 index 0000000..b6c9a7c --- /dev/null +++ b/internal/iplibrary/manager_provider.go @@ -0,0 +1,148 @@ +package iplibrary + +import ( + "crypto/md5" + "encoding/json" + "fmt" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/TeaOSLab/EdgeNode/internal/goman" + "github.com/TeaOSLab/EdgeNode/internal/remotelogs" + "github.com/TeaOSLab/EdgeNode/internal/rpc" + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/iwind/TeaGo/Tea" + _ "github.com/iwind/TeaGo/bootstrap" + "io/ioutil" + "os" + "sync" + "time" +) + +var SharedProviderManager = NewProviderManager() + +func init() { + events.On(events.EventLoaded, func() { + goman.New(func() { + SharedProviderManager.Start() + }) + }) +} + +// ProviderManager 中国省份信息管理 +type ProviderManager struct { + cacheFile string + + providerMap map[string]int64 // name => id + dataHash string // 国家JSON的md5 + + locker sync.RWMutex + + isUpdated bool +} + +func NewProviderManager() *ProviderManager { + return &ProviderManager{ + cacheFile: Tea.Root + "/configs/region_provider.json.cache", + providerMap: map[string]int64{}, + } +} + +func (this *ProviderManager) Start() { + // 从缓存中读取 + err := this.load() + if err != nil { + remotelogs.ErrorObject("PROVIDER_MANAGER", err) + } + + // 第一次更新 + err = this.loop() + if err != nil { + remotelogs.ErrorObject("PROVIDER_MANAGER", err) + } + + // 定时更新 + ticker := utils.NewTicker(4 * time.Hour) + events.On(events.EventQuit, func() { + ticker.Stop() + }) + for ticker.Next() { + err := this.loop() + if err != nil { + remotelogs.ErrorObject("PROVIDER_MANAGER", err) + } + } +} + +func (this *ProviderManager) Lookup(providerName string) (providerId int64) { + this.locker.RLock() + providerId, _ = this.providerMap[providerName] + this.locker.RUnlock() + return +} + +// 从缓存中读取 +func (this *ProviderManager) load() error { + data, err := ioutil.ReadFile(this.cacheFile) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + m := map[string]int64{} + err = json.Unmarshal(data, &m) + if err != nil { + return err + } + if m != nil && len(m) > 0 { + this.providerMap = m + } + + return nil +} + +// 更新服务商信息 +func (this *ProviderManager) loop() error { + if this.isUpdated { + return nil + } + + rpcClient, err := rpc.SharedRPC() + if err != nil { + return err + } + resp, err := rpcClient.RegionProviderRPC().FindAllEnabledRegionProviders(rpcClient.Context(), &pb.FindAllEnabledRegionProvidersRequest{}) + if err != nil { + return err + } + + m := map[string]int64{} + for _, provider := range resp.RegionProviders { + for _, code := range provider.Codes { + m[code] = provider.Id + } + } + + // 检查是否有更新 + data, err := json.Marshal(m) + if err != nil { + return err + } + hash := md5.New() + hash.Write(data) + dataHash := fmt.Sprintf("%x", hash.Sum(nil)) + if this.dataHash == dataHash { + return nil + } + this.dataHash = dataHash + + this.locker.Lock() + this.providerMap = m + this.isUpdated = true + this.locker.Unlock() + + // 保存到本地缓存 + + err = ioutil.WriteFile(this.cacheFile, data, 0666) + return err +} diff --git a/internal/iplibrary/manager_provider_test.go b/internal/iplibrary/manager_provider_test.go new file mode 100644 index 0000000..5163f7c --- /dev/null +++ b/internal/iplibrary/manager_provider_test.go @@ -0,0 +1,15 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package iplibrary + +import "testing" + +func TestNewProviderManager(t *testing.T) { + var manager = NewProviderManager() + err := manager.loop() + if err != nil { + t.Fatal(err) + } + t.Log(manager.Lookup("阿里云")) + t.Log(manager.Lookup("阿里云2")) +} diff --git a/internal/iplibrary/manager_province.go b/internal/iplibrary/manager_province.go index 32ed1d8..a8923e6 100644 --- a/internal/iplibrary/manager_province.go +++ b/internal/iplibrary/manager_province.go @@ -40,6 +40,8 @@ type ProvinceManager struct { dataHash string // 国家JSON的md5 locker sync.RWMutex + + isUpdated bool } func NewProvinceManager() *ProvinceManager { @@ -63,7 +65,7 @@ func (this *ProvinceManager) Start() { } // 定时更新 - ticker := utils.NewTicker(1 * time.Hour) + ticker := utils.NewTicker(4 * time.Hour) events.On(events.EventQuit, func() { ticker.Stop() }) @@ -103,21 +105,25 @@ func (this *ProvinceManager) load() error { return nil } -// 更新国家信息 +// 更新省份信息 func (this *ProvinceManager) loop() error { + if this.isUpdated { + return nil + } + rpcClient, err := rpc.SharedRPC() if err != nil { return err } resp, err := rpcClient.RegionProvinceRPC().FindAllEnabledRegionProvincesWithCountryId(rpcClient.Context(), &pb.FindAllEnabledRegionProvincesWithCountryIdRequest{ - CountryId: ChinaCountryId, + RegionCountryId: ChinaCountryId, }) if err != nil { return err } m := map[string]int64{} - for _, province := range resp.Provinces { + for _, province := range resp.RegionProvinces { for _, code := range province.Codes { m[code] = province.Id } @@ -138,6 +144,7 @@ func (this *ProvinceManager) loop() error { this.locker.Lock() this.provinceMap = m + this.isUpdated = true this.locker.Unlock() // 保存到本地缓存 diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index 2da70a1..5cf9092 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -8,6 +8,7 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" + "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/metrics" "github.com/TeaOSLab/EdgeNode/internal/stats" "github.com/TeaOSLab/EdgeNode/internal/utils" @@ -65,6 +66,7 @@ type HTTPRequest struct { rewriteRule *serverconfigs.HTTPRewriteRule // 匹配到的重写规则 rewriteReplace string // 重写规则的目标 rewriteIsExternalURL bool // 重写目标是否为外部URL + remoteAddr string // 计算后的RemoteAddr cacheRef *serverconfigs.HTTPCacheRef // 缓存设置 cacheKey string // 缓存使用的Key @@ -865,6 +867,73 @@ func (this *HTTPRequest) Format(source string) string { } } + // geo + if prefix == "geo" { + result, _ := iplibrary.SharedLibrary.Lookup(this.requestRemoteAddr(true)) + + switch suffix { + case "country.name": + if result != nil { + return result.Country + } + return "" + case "country.id": + if result != nil { + return types.String(iplibrary.SharedCountryManager.Lookup(result.Country)) + } + return "0" + case "province.name": + if result != nil { + return result.Province + } + return "" + case "province.id": + if result != nil { + return types.String(iplibrary.SharedProvinceManager.Lookup(result.Province)) + } + return "0" + case "city.name": + if result != nil { + return result.City + } + return "" + case "city.id": + if result != nil { + var provinceId = iplibrary.SharedProvinceManager.Lookup(result.Province) + if provinceId > 0 { + return types.String(iplibrary.SharedCityManager.Lookup(provinceId, result.City)) + } else { + return "0" + } + } + return "0" + } + } + + // ips + if prefix == "isp" { + result, _ := iplibrary.SharedLibrary.Lookup(this.requestRemoteAddr(true)) + + switch suffix { + case "name": + if result != nil { + return result.ISP + } + case "id": + if result != nil { + return types.String(iplibrary.SharedProviderManager.Lookup(result.ISP)) + } + return "0" + } + return "" + } + + // os + // TODO + + // browser + // TODO + return "${" + varName + "}" }) } @@ -878,12 +947,17 @@ func (this *HTTPRequest) addVarMapping(varMapping map[string]string) { // 获取请求的客户端地址 func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string { + if supportVar && len(this.remoteAddr) > 0 { + return this.remoteAddr + } + if supportVar && this.web.RemoteAddr != nil && this.web.RemoteAddr.IsOn && !this.web.RemoteAddr.IsEmpty() { var remoteAddr = this.Format(this.web.RemoteAddr.Value) if net.ParseIP(remoteAddr) != nil { + this.remoteAddr = remoteAddr return remoteAddr } } @@ -896,6 +970,9 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string { forwardedFor = forwardedFor[:commaIndex] } if net.ParseIP(forwardedFor) != nil { + if supportVar { + this.remoteAddr = forwardedFor + } return forwardedFor } } @@ -905,6 +982,9 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string { realIP, ok := this.RawReq.Header["X-Real-IP"] if ok && len(realIP) > 0 { if net.ParseIP(realIP[0]) != nil { + if supportVar { + this.remoteAddr = realIP[0] + } return realIP[0] } } @@ -915,6 +995,9 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string { realIP, ok := this.RawReq.Header["X-Real-Ip"] if ok && len(realIP) > 0 { if net.ParseIP(realIP[0]) != nil { + if supportVar { + this.remoteAddr = realIP[0] + } return realIP[0] } } @@ -924,6 +1007,9 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string { remoteAddr := this.RawReq.RemoteAddr host, _, err := net.SplitHostPort(remoteAddr) if err == nil { + if supportVar { + this.remoteAddr = host + } return host } else { return remoteAddr diff --git a/internal/nodes/waf_manager.go b/internal/nodes/waf_manager.go index 849e769..be285dd 100644 --- a/internal/nodes/waf_manager.go +++ b/internal/nodes/waf_manager.go @@ -32,14 +32,13 @@ func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallP m := map[int64]*waf.WAF{} for _, p := range policies { w, err := this.convertWAF(p) + if w != nil { + m[p.Id] = w + } if err != nil { remotelogs.Error("WAF", "initialize policy '"+strconv.FormatInt(p.Id, 10)+"' failed: "+err.Error()) continue } - if w == nil { - continue - } - m[p.Id] = w } this.mapping = m } @@ -181,9 +180,9 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) ( } } - err := w.Init() - if err != nil { - return nil, err + errorList := w.Init() + if len(errorList) > 0 { + return w, errorList[0] } return w, nil diff --git a/internal/rpc/rpc_client.go b/internal/rpc/rpc_client.go index f34eed7..39db210 100644 --- a/internal/rpc/rpc_client.go +++ b/internal/rpc/rpc_client.go @@ -81,6 +81,14 @@ func (this *RPCClient) RegionProvinceRPC() pb.RegionProvinceServiceClient { return pb.NewRegionProvinceServiceClient(this.pickConn()) } +func (this *RPCClient) RegionCityRPC() pb.RegionCityServiceClient { + return pb.NewRegionCityServiceClient(this.pickConn()) +} + +func (this *RPCClient) RegionProviderRPC() pb.RegionProviderServiceClient { + return pb.NewRegionProviderServiceClient(this.pickConn()) +} + func (this *RPCClient) IPListRPC() pb.IPListServiceClient { return pb.NewIPListServiceClient(this.pickConn()) } diff --git a/internal/waf/checkpoints/checkpoint_definition.go b/internal/waf/checkpoints/checkpoint_definition.go index c65b9a4..0857aef 100644 --- a/internal/waf/checkpoints/checkpoint_definition.go +++ b/internal/waf/checkpoints/checkpoint_definition.go @@ -1,6 +1,6 @@ package checkpoints -// check point definition +// CheckpointDefinition check point definition type CheckpointDefinition struct { Name string Description string diff --git a/internal/waf/checkpoints/request_geo_city_name.go b/internal/waf/checkpoints/request_geo_city_name.go new file mode 100644 index 0000000..e79ab87 --- /dev/null +++ b/internal/waf/checkpoints/request_geo_city_name.go @@ -0,0 +1,25 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/maps" +) + +type RequestGeoCityNameCheckpoint struct { + Checkpoint +} + +func (this *RequestGeoCityNameCheckpoint) IsComposed() bool { + return false +} + +func (this *RequestGeoCityNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.Format("${geo.city.name}") + return +} + +func (this *RequestGeoCityNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + return this.RequestValue(req, param, options) +} diff --git a/internal/waf/checkpoints/request_geo_country_name.go b/internal/waf/checkpoints/request_geo_country_name.go new file mode 100644 index 0000000..438310f --- /dev/null +++ b/internal/waf/checkpoints/request_geo_country_name.go @@ -0,0 +1,25 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/maps" +) + +type RequestGeoCountryNameCheckpoint struct { + Checkpoint +} + +func (this *RequestGeoCountryNameCheckpoint) IsComposed() bool { + return false +} + +func (this *RequestGeoCountryNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.Format("${geo.country.name}") + return +} + +func (this *RequestGeoCountryNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + return this.RequestValue(req, param, options) +} diff --git a/internal/waf/checkpoints/request_geo_province_name.go b/internal/waf/checkpoints/request_geo_province_name.go new file mode 100644 index 0000000..37d7b88 --- /dev/null +++ b/internal/waf/checkpoints/request_geo_province_name.go @@ -0,0 +1,25 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/maps" +) + +type RequestGeoProvinceNameCheckpoint struct { + Checkpoint +} + +func (this *RequestGeoProvinceNameCheckpoint) IsComposed() bool { + return false +} + +func (this *RequestGeoProvinceNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.Format("${geo.province.name}") + return +} + +func (this *RequestGeoProvinceNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + return this.RequestValue(req, param, options) +} diff --git a/internal/waf/checkpoints/request_isp_name.go b/internal/waf/checkpoints/request_isp_name.go new file mode 100644 index 0000000..68c6b7d --- /dev/null +++ b/internal/waf/checkpoints/request_isp_name.go @@ -0,0 +1,25 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/maps" +) + +type RequestISPNameCheckpoint struct { + Checkpoint +} + +func (this *RequestISPNameCheckpoint) IsComposed() bool { + return false +} + +func (this *RequestISPNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + value = req.Format("${isp.name}") + return +} + +func (this *RequestISPNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { + return this.RequestValue(req, param, options) +} diff --git a/internal/waf/checkpoints/utils.go b/internal/waf/checkpoints/utils.go index 945955f..120d1aa 100644 --- a/internal/waf/checkpoints/utils.go +++ b/internal/waf/checkpoints/utils.go @@ -173,7 +173,7 @@ var AllCheckpoints = []*CheckpointDefinition{ { Name: "所有Header信息", Prefix: "headers", - Description: "使用\n隔开的Header信息字符串", + Description: "使用\\n隔开的Header信息字符串", HasParams: false, Instance: new(RequestHeadersCheckpoint), }, @@ -184,6 +184,34 @@ var AllCheckpoints = []*CheckpointDefinition{ HasParams: true, Instance: new(RequestHeaderCheckpoint), }, + { + Name: "国家/地区名称", + Prefix: "geoCountryName", + Description: "国家/地区名称", + HasParams: false, + Instance: new(RequestGeoCountryNameCheckpoint), + }, + { + Name: "省份名称", + Prefix: "geoProvinceName", + Description: "中国省份名称", + HasParams: false, + Instance: new(RequestGeoProvinceNameCheckpoint), + }, + { + Name: "城市名称", + Prefix: "geoCityName", + Description: "中国城市名称", + HasParams: false, + Instance: new(RequestGeoCityNameCheckpoint), + }, + { + Name: "ISP名称", + Prefix: "ispName", + Description: "ISP名称", + HasParams: false, + Instance: new(RequestISPNameCheckpoint), + }, { Name: "CC统计(旧)", Prefix: "cc", @@ -242,7 +270,7 @@ var AllCheckpoints = []*CheckpointDefinition{ }, } -// find a check point +// FindCheckpoint find a check point func FindCheckpoint(prefix string) CheckpointInterface { for _, def := range AllCheckpoints { if def.Prefix == prefix { @@ -252,7 +280,7 @@ func FindCheckpoint(prefix string) CheckpointInterface { return nil } -// find a check point definition +// FindCheckpointDefinition find a check point definition func FindCheckpointDefinition(prefix string) *CheckpointDefinition { for _, def := range AllCheckpoints { if def.Prefix == prefix { diff --git a/internal/waf/waf.go b/internal/waf/waf.go index cfbb96d..1117afb 100644 --- a/internal/waf/waf.go +++ b/internal/waf/waf.go @@ -62,7 +62,7 @@ func NewWAFFromFile(path string) (waf *WAF, err error) { return waf, nil } -func (this *WAF) Init() error { +func (this *WAF) Init() (resultErrors []error) { // checkpoint this.checkpointsMap = map[string]checkpoints.CheckpointInterface{} for _, def := range checkpoints.AllCheckpoints { @@ -86,7 +86,8 @@ func (this *WAF) Init() error { err := group.Init(this) if err != nil { - return err + // 这里我们不阻止其他规则正常加入 + resultErrors = append(resultErrors, err) } } } @@ -102,7 +103,8 @@ func (this *WAF) Init() error { err := group.Init(this) if err != nil { - return err + // 这里我们不阻止其他规则正常加入 + resultErrors = append(resultErrors, err) } } }