国家、省份数据不再每个小时更新一次;WAF增加国家/地区、省份、城市、ISP等参数

This commit is contained in:
GoEdgeLab
2022-01-06 16:27:39 +08:00
parent da03455d10
commit 18c20deee5
16 changed files with 583 additions and 20 deletions

View File

@@ -0,0 +1,149 @@
package iplibrary
import (
"crypto/md5"
"encoding/json"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/types"
"io/ioutil"
"os"
"sync"
"time"
)
var SharedCityManager = NewCityManager()
func init() {
events.On(events.EventLoaded, func() {
goman.New(func() {
SharedCityManager.Start()
})
})
}
// CityManager 中国省份信息管理
type CityManager struct {
cacheFile string
cityMap map[string]int64 // provinceName_cityName => cityName
dataHash string // 国家JSON的md5
locker sync.RWMutex
isUpdated bool
}
func NewCityManager() *CityManager {
return &CityManager{
cacheFile: Tea.Root + "/configs/region_city.json.cache",
cityMap: map[string]int64{},
}
}
func (this *CityManager) Start() {
// 从缓存中读取
err := this.load()
if err != nil {
remotelogs.ErrorObject("CITY_MANAGER", err)
}
// 第一次更新
err = this.loop()
if err != nil {
remotelogs.ErrorObject("City_MANAGER", err)
}
// 定时更新
ticker := utils.NewTicker(4 * time.Hour)
events.On(events.EventQuit, func() {
ticker.Stop()
})
for ticker.Next() {
err := this.loop()
if err != nil {
remotelogs.ErrorObject("CITY_MANAGER", err)
}
}
}
func (this *CityManager) Lookup(provinceId int64, cityName string) (cityId int64) {
this.locker.RLock()
cityId, _ = this.cityMap[types.String(provinceId)+"_"+cityName]
this.locker.RUnlock()
return
}
// 从缓存中读取
func (this *CityManager) load() error {
data, err := ioutil.ReadFile(this.cacheFile)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
m := map[string]int64{}
err = json.Unmarshal(data, &m)
if err != nil {
return err
}
if m != nil && len(m) > 0 {
this.cityMap = m
}
return nil
}
// 更新城市信息
func (this *CityManager) loop() error {
if this.isUpdated {
return nil
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.RegionCityRPC().FindAllEnabledRegionCities(rpcClient.Context(), &pb.FindAllEnabledRegionCitiesRequest{})
if err != nil {
return err
}
m := map[string]int64{}
for _, city := range resp.RegionCities {
for _, code := range city.Codes {
m[types.String(city.RegionProvinceId)+"_"+code] = city.Id
}
}
// 检查是否有更新
data, err := json.Marshal(m)
if err != nil {
return err
}
hash := md5.New()
hash.Write(data)
dataHash := fmt.Sprintf("%x", hash.Sum(nil))
if this.dataHash == dataHash {
return nil
}
this.dataHash = dataHash
this.locker.Lock()
this.cityMap = m
this.isUpdated = true
this.locker.Unlock()
// 保存到本地缓存
err = ioutil.WriteFile(this.cacheFile, data, 0666)
return err
}

View File

@@ -0,0 +1,14 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package iplibrary
import "testing"
func TestNewCityManager(t *testing.T) {
var manager = NewCityManager()
err := manager.loop()
if err != nil {
t.Fatal(err)
}
t.Log(manager.Lookup(16, "许昌市"))
}

View File

@@ -36,6 +36,8 @@ type CountryManager struct {
dataHash string // 国家JSON的md5
locker sync.RWMutex
isUpdated bool
}
func NewCountryManager() *CountryManager {
@@ -59,7 +61,7 @@ func (this *CountryManager) Start() {
}
// 定时更新
ticker := utils.NewTicker(1 * time.Hour)
ticker := utils.NewTicker(4 * time.Hour)
events.On(events.EventQuit, func() {
ticker.Stop()
})
@@ -101,6 +103,10 @@ func (this *CountryManager) load() error {
// 更新国家信息
func (this *CountryManager) loop() error {
if this.isUpdated {
return nil
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
@@ -111,7 +117,7 @@ func (this *CountryManager) loop() error {
}
m := map[string]int64{}
for _, country := range resp.Countries {
for _, country := range resp.RegionCountries {
for _, code := range country.Codes {
m[code] = country.Id
}
@@ -132,6 +138,7 @@ func (this *CountryManager) loop() error {
this.locker.Lock()
this.countryMap = m
this.isUpdated = true
this.locker.Unlock()
// 保存到本地缓存

View File

@@ -0,0 +1,148 @@
package iplibrary
import (
"crypto/md5"
"encoding/json"
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/Tea"
_ "github.com/iwind/TeaGo/bootstrap"
"io/ioutil"
"os"
"sync"
"time"
)
var SharedProviderManager = NewProviderManager()
func init() {
events.On(events.EventLoaded, func() {
goman.New(func() {
SharedProviderManager.Start()
})
})
}
// ProviderManager 中国省份信息管理
type ProviderManager struct {
cacheFile string
providerMap map[string]int64 // name => id
dataHash string // 国家JSON的md5
locker sync.RWMutex
isUpdated bool
}
func NewProviderManager() *ProviderManager {
return &ProviderManager{
cacheFile: Tea.Root + "/configs/region_provider.json.cache",
providerMap: map[string]int64{},
}
}
func (this *ProviderManager) Start() {
// 从缓存中读取
err := this.load()
if err != nil {
remotelogs.ErrorObject("PROVIDER_MANAGER", err)
}
// 第一次更新
err = this.loop()
if err != nil {
remotelogs.ErrorObject("PROVIDER_MANAGER", err)
}
// 定时更新
ticker := utils.NewTicker(4 * time.Hour)
events.On(events.EventQuit, func() {
ticker.Stop()
})
for ticker.Next() {
err := this.loop()
if err != nil {
remotelogs.ErrorObject("PROVIDER_MANAGER", err)
}
}
}
func (this *ProviderManager) Lookup(providerName string) (providerId int64) {
this.locker.RLock()
providerId, _ = this.providerMap[providerName]
this.locker.RUnlock()
return
}
// 从缓存中读取
func (this *ProviderManager) load() error {
data, err := ioutil.ReadFile(this.cacheFile)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
m := map[string]int64{}
err = json.Unmarshal(data, &m)
if err != nil {
return err
}
if m != nil && len(m) > 0 {
this.providerMap = m
}
return nil
}
// 更新服务商信息
func (this *ProviderManager) loop() error {
if this.isUpdated {
return nil
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.RegionProviderRPC().FindAllEnabledRegionProviders(rpcClient.Context(), &pb.FindAllEnabledRegionProvidersRequest{})
if err != nil {
return err
}
m := map[string]int64{}
for _, provider := range resp.RegionProviders {
for _, code := range provider.Codes {
m[code] = provider.Id
}
}
// 检查是否有更新
data, err := json.Marshal(m)
if err != nil {
return err
}
hash := md5.New()
hash.Write(data)
dataHash := fmt.Sprintf("%x", hash.Sum(nil))
if this.dataHash == dataHash {
return nil
}
this.dataHash = dataHash
this.locker.Lock()
this.providerMap = m
this.isUpdated = true
this.locker.Unlock()
// 保存到本地缓存
err = ioutil.WriteFile(this.cacheFile, data, 0666)
return err
}

View File

@@ -0,0 +1,15 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package iplibrary
import "testing"
func TestNewProviderManager(t *testing.T) {
var manager = NewProviderManager()
err := manager.loop()
if err != nil {
t.Fatal(err)
}
t.Log(manager.Lookup("阿里云"))
t.Log(manager.Lookup("阿里云2"))
}

View File

@@ -40,6 +40,8 @@ type ProvinceManager struct {
dataHash string // 国家JSON的md5
locker sync.RWMutex
isUpdated bool
}
func NewProvinceManager() *ProvinceManager {
@@ -63,7 +65,7 @@ func (this *ProvinceManager) Start() {
}
// 定时更新
ticker := utils.NewTicker(1 * time.Hour)
ticker := utils.NewTicker(4 * time.Hour)
events.On(events.EventQuit, func() {
ticker.Stop()
})
@@ -103,21 +105,25 @@ func (this *ProvinceManager) load() error {
return nil
}
// 更新国家信息
// 更新省份信息
func (this *ProvinceManager) loop() error {
if this.isUpdated {
return nil
}
rpcClient, err := rpc.SharedRPC()
if err != nil {
return err
}
resp, err := rpcClient.RegionProvinceRPC().FindAllEnabledRegionProvincesWithCountryId(rpcClient.Context(), &pb.FindAllEnabledRegionProvincesWithCountryIdRequest{
CountryId: ChinaCountryId,
RegionCountryId: ChinaCountryId,
})
if err != nil {
return err
}
m := map[string]int64{}
for _, province := range resp.Provinces {
for _, province := range resp.RegionProvinces {
for _, code := range province.Codes {
m[code] = province.Id
}
@@ -138,6 +144,7 @@ func (this *ProvinceManager) loop() error {
this.locker.Lock()
this.provinceMap = m
this.isUpdated = true
this.locker.Unlock()
// 保存到本地缓存

View File

@@ -8,6 +8,7 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/metrics"
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils"
@@ -65,6 +66,7 @@ type HTTPRequest struct {
rewriteRule *serverconfigs.HTTPRewriteRule // 匹配到的重写规则
rewriteReplace string // 重写规则的目标
rewriteIsExternalURL bool // 重写目标是否为外部URL
remoteAddr string // 计算后的RemoteAddr
cacheRef *serverconfigs.HTTPCacheRef // 缓存设置
cacheKey string // 缓存使用的Key
@@ -865,6 +867,73 @@ func (this *HTTPRequest) Format(source string) string {
}
}
// geo
if prefix == "geo" {
result, _ := iplibrary.SharedLibrary.Lookup(this.requestRemoteAddr(true))
switch suffix {
case "country.name":
if result != nil {
return result.Country
}
return ""
case "country.id":
if result != nil {
return types.String(iplibrary.SharedCountryManager.Lookup(result.Country))
}
return "0"
case "province.name":
if result != nil {
return result.Province
}
return ""
case "province.id":
if result != nil {
return types.String(iplibrary.SharedProvinceManager.Lookup(result.Province))
}
return "0"
case "city.name":
if result != nil {
return result.City
}
return ""
case "city.id":
if result != nil {
var provinceId = iplibrary.SharedProvinceManager.Lookup(result.Province)
if provinceId > 0 {
return types.String(iplibrary.SharedCityManager.Lookup(provinceId, result.City))
} else {
return "0"
}
}
return "0"
}
}
// ips
if prefix == "isp" {
result, _ := iplibrary.SharedLibrary.Lookup(this.requestRemoteAddr(true))
switch suffix {
case "name":
if result != nil {
return result.ISP
}
case "id":
if result != nil {
return types.String(iplibrary.SharedProviderManager.Lookup(result.ISP))
}
return "0"
}
return ""
}
// os
// TODO
// browser
// TODO
return "${" + varName + "}"
})
}
@@ -878,12 +947,17 @@ func (this *HTTPRequest) addVarMapping(varMapping map[string]string) {
// 获取请求的客户端地址
func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
if supportVar && len(this.remoteAddr) > 0 {
return this.remoteAddr
}
if supportVar &&
this.web.RemoteAddr != nil &&
this.web.RemoteAddr.IsOn &&
!this.web.RemoteAddr.IsEmpty() {
var remoteAddr = this.Format(this.web.RemoteAddr.Value)
if net.ParseIP(remoteAddr) != nil {
this.remoteAddr = remoteAddr
return remoteAddr
}
}
@@ -896,6 +970,9 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
forwardedFor = forwardedFor[:commaIndex]
}
if net.ParseIP(forwardedFor) != nil {
if supportVar {
this.remoteAddr = forwardedFor
}
return forwardedFor
}
}
@@ -905,6 +982,9 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
realIP, ok := this.RawReq.Header["X-Real-IP"]
if ok && len(realIP) > 0 {
if net.ParseIP(realIP[0]) != nil {
if supportVar {
this.remoteAddr = realIP[0]
}
return realIP[0]
}
}
@@ -915,6 +995,9 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
realIP, ok := this.RawReq.Header["X-Real-Ip"]
if ok && len(realIP) > 0 {
if net.ParseIP(realIP[0]) != nil {
if supportVar {
this.remoteAddr = realIP[0]
}
return realIP[0]
}
}
@@ -924,6 +1007,9 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
remoteAddr := this.RawReq.RemoteAddr
host, _, err := net.SplitHostPort(remoteAddr)
if err == nil {
if supportVar {
this.remoteAddr = host
}
return host
} else {
return remoteAddr

View File

@@ -32,14 +32,13 @@ func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallP
m := map[int64]*waf.WAF{}
for _, p := range policies {
w, err := this.convertWAF(p)
if w != nil {
m[p.Id] = w
}
if err != nil {
remotelogs.Error("WAF", "initialize policy '"+strconv.FormatInt(p.Id, 10)+"' failed: "+err.Error())
continue
}
if w == nil {
continue
}
m[p.Id] = w
}
this.mapping = m
}
@@ -181,9 +180,9 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (
}
}
err := w.Init()
if err != nil {
return nil, err
errorList := w.Init()
if len(errorList) > 0 {
return w, errorList[0]
}
return w, nil

View File

@@ -81,6 +81,14 @@ func (this *RPCClient) RegionProvinceRPC() pb.RegionProvinceServiceClient {
return pb.NewRegionProvinceServiceClient(this.pickConn())
}
func (this *RPCClient) RegionCityRPC() pb.RegionCityServiceClient {
return pb.NewRegionCityServiceClient(this.pickConn())
}
func (this *RPCClient) RegionProviderRPC() pb.RegionProviderServiceClient {
return pb.NewRegionProviderServiceClient(this.pickConn())
}
func (this *RPCClient) IPListRPC() pb.IPListServiceClient {
return pb.NewIPListServiceClient(this.pickConn())
}

View File

@@ -1,6 +1,6 @@
package checkpoints
// check point definition
// CheckpointDefinition check point definition
type CheckpointDefinition struct {
Name string
Description string

View File

@@ -0,0 +1,25 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
)
type RequestGeoCityNameCheckpoint struct {
Checkpoint
}
func (this *RequestGeoCityNameCheckpoint) IsComposed() bool {
return false
}
func (this *RequestGeoCityNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.Format("${geo.city.name}")
return
}
func (this *RequestGeoCityNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return this.RequestValue(req, param, options)
}

View File

@@ -0,0 +1,25 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
)
type RequestGeoCountryNameCheckpoint struct {
Checkpoint
}
func (this *RequestGeoCountryNameCheckpoint) IsComposed() bool {
return false
}
func (this *RequestGeoCountryNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.Format("${geo.country.name}")
return
}
func (this *RequestGeoCountryNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return this.RequestValue(req, param, options)
}

View File

@@ -0,0 +1,25 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
)
type RequestGeoProvinceNameCheckpoint struct {
Checkpoint
}
func (this *RequestGeoProvinceNameCheckpoint) IsComposed() bool {
return false
}
func (this *RequestGeoProvinceNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.Format("${geo.province.name}")
return
}
func (this *RequestGeoProvinceNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return this.RequestValue(req, param, options)
}

View File

@@ -0,0 +1,25 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
)
type RequestISPNameCheckpoint struct {
Checkpoint
}
func (this *RequestISPNameCheckpoint) IsComposed() bool {
return false
}
func (this *RequestISPNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
value = req.Format("${isp.name}")
return
}
func (this *RequestISPNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) {
return this.RequestValue(req, param, options)
}

View File

@@ -173,7 +173,7 @@ var AllCheckpoints = []*CheckpointDefinition{
{
Name: "所有Header信息",
Prefix: "headers",
Description: "使用\n隔开的Header信息字符串",
Description: "使用\\n隔开的Header信息字符串",
HasParams: false,
Instance: new(RequestHeadersCheckpoint),
},
@@ -184,6 +184,34 @@ var AllCheckpoints = []*CheckpointDefinition{
HasParams: true,
Instance: new(RequestHeaderCheckpoint),
},
{
Name: "国家/地区名称",
Prefix: "geoCountryName",
Description: "国家/地区名称",
HasParams: false,
Instance: new(RequestGeoCountryNameCheckpoint),
},
{
Name: "省份名称",
Prefix: "geoProvinceName",
Description: "中国省份名称",
HasParams: false,
Instance: new(RequestGeoProvinceNameCheckpoint),
},
{
Name: "城市名称",
Prefix: "geoCityName",
Description: "中国城市名称",
HasParams: false,
Instance: new(RequestGeoCityNameCheckpoint),
},
{
Name: "ISP名称",
Prefix: "ispName",
Description: "ISP名称",
HasParams: false,
Instance: new(RequestISPNameCheckpoint),
},
{
Name: "CC统计",
Prefix: "cc",
@@ -242,7 +270,7 @@ var AllCheckpoints = []*CheckpointDefinition{
},
}
// find a check point
// FindCheckpoint find a check point
func FindCheckpoint(prefix string) CheckpointInterface {
for _, def := range AllCheckpoints {
if def.Prefix == prefix {
@@ -252,7 +280,7 @@ func FindCheckpoint(prefix string) CheckpointInterface {
return nil
}
// find a check point definition
// FindCheckpointDefinition find a check point definition
func FindCheckpointDefinition(prefix string) *CheckpointDefinition {
for _, def := range AllCheckpoints {
if def.Prefix == prefix {

View File

@@ -62,7 +62,7 @@ func NewWAFFromFile(path string) (waf *WAF, err error) {
return waf, nil
}
func (this *WAF) Init() error {
func (this *WAF) Init() (resultErrors []error) {
// checkpoint
this.checkpointsMap = map[string]checkpoints.CheckpointInterface{}
for _, def := range checkpoints.AllCheckpoints {
@@ -86,7 +86,8 @@ func (this *WAF) Init() error {
err := group.Init(this)
if err != nil {
return err
// 这里我们不阻止其他规则正常加入
resultErrors = append(resultErrors, err)
}
}
}
@@ -102,7 +103,8 @@ func (this *WAF) Init() error {
err := group.Init(this)
if err != nil {
return err
// 这里我们不阻止其他规则正常加入
resultErrors = append(resultErrors, err)
}
}
}