diff --git a/build/build.sh b/build/build.sh index a367731..31e0d24 100755 --- a/build/build.sh +++ b/build/build.sh @@ -38,6 +38,7 @@ function build() { cp $ROOT/configs/api.template.yaml $DIST/configs cp -R $ROOT/www $DIST/ cp -R $ROOT/pages $DIST/ + cp -R $ROOT/resources $DIST/ echo "building ..." env GOOS=${OS} GOARCH=${ARCH} go build -o $DIST/bin/${NAME} -ldflags="-s -w" $ROOT/../cmd/edge-node/main.go diff --git a/build/configs/.gitignore b/build/configs/.gitignore index b227765..d10de96 100644 --- a/build/configs/.gitignore +++ b/build/configs/.gitignore @@ -1,3 +1,4 @@ node.json api.yaml -cluster.yaml \ No newline at end of file +cluster.yaml +*.cache \ No newline at end of file diff --git a/build/resources/ipdata/ip2region/ip2region.db b/build/resources/ipdata/ip2region/ip2region.db new file mode 100644 index 0000000..b688511 Binary files /dev/null and b/build/resources/ipdata/ip2region/ip2region.db differ diff --git a/go.mod b/go.mod index ee1ce1d..335fea3 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/go-yaml/yaml v2.1.0+incompatible github.com/golang/protobuf v1.4.2 github.com/iwind/TeaGo v0.0.0-20201020081413-7cf62d6f420f + github.com/lionsoul2014/ip2region v2.2.0-release+incompatible github.com/shirou/gopsutil v2.20.9+incompatible golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7 google.golang.org/grpc v1.32.0 diff --git a/go.sum b/go.sum index 33fdc9d..b9a88f3 100644 --- a/go.sum +++ b/go.sum @@ -29,6 +29,7 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2 github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI= github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= github.com/go-redis/redis/v8 v8.0.0-beta.7/go.mod h1:FGJAWDWFht1sQ4qxyJHZZbVyvnVcKQN0E3u5/5lRz+g= +github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o= github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= @@ -64,6 +65,9 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lionsoul2014/ip2region v1.9.0-release h1:b4FxevWljlOb+Z3qtAMQIvel6az21p7OeZ0K1wn/3mI= +github.com/lionsoul2014/ip2region v2.2.0-release+incompatible h1:1qp9iks+69h7IGLazAplzS9Ca14HAxuD5c0rbFdPGy4= +github.com/lionsoul2014/ip2region v2.2.0-release+incompatible/go.mod h1:+ZBN7PBoh5gG6/y0ZQ85vJDBe21WnfbRrQQwTfliJJI= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= diff --git a/internal/events/events.go b/internal/events/events.go index 7b57e2d..2b2f827 100644 --- a/internal/events/events.go +++ b/internal/events/events.go @@ -3,5 +3,6 @@ package events type Event = string const ( - EventQuit Event = "quit" // quit node gracefully + EventStart Event = "start" // start loading + EventQuit Event = "quit" // quit node gracefully ) diff --git a/internal/grids/grid.go b/internal/grids/grid.go index 37fa5c3..03a62ef 100644 --- a/internal/grids/grid.go +++ b/internal/grids/grid.go @@ -84,7 +84,7 @@ func (this *Grid) WriteInt64(key []byte, value int64, lifeSeconds int64) { Key: key, Type: ItemInt64, ValueInt64: value, - ExpireAt: time.Now().Unix() + lifeSeconds, + ExpireAt: UnixTime() + lifeSeconds, }) } diff --git a/internal/grids/grid_test.go b/internal/grids/grid_test.go index 1aa1bea..41fb100 100644 --- a/internal/grids/grid_test.go +++ b/internal/grids/grid_test.go @@ -61,9 +61,12 @@ func TestMemoryGrid_Compress(t *testing.T) { } func BenchmarkMemoryGrid_Performance(b *testing.B) { + runtime.GOMAXPROCS(1) + grid := NewGrid(1024) for i := 0; i < b.N; i++ { - grid.WriteInt64([]byte("key:"+strconv.Itoa(i)), int64(i), 3600) + key := "key:" + strconv.Itoa(i) + grid.WriteInt64([]byte(key), int64(i), 3600) } } diff --git a/internal/grids/time.go b/internal/grids/time.go new file mode 100644 index 0000000..b1fc21e --- /dev/null +++ b/internal/grids/time.go @@ -0,0 +1,26 @@ +package grids + +import ( + "time" +) + +var unixTime = time.Now().Unix() +var unixTimerIsReady = false + +func init() { + ticker := time.NewTicker(500 * time.Millisecond) + go func() { + for range ticker.C { + unixTimerIsReady = true + unixTime = time.Now().Unix() + } + }() +} + +// 最快获取时间戳的方式,通常用在不需要特别精确时间戳的场景 +func UnixTime() int64 { + if unixTimerIsReady { + return unixTime + } + return time.Now().Unix() +} diff --git a/internal/grids/time_test.go b/internal/grids/time_test.go new file mode 100644 index 0000000..50c2fa7 --- /dev/null +++ b/internal/grids/time_test.go @@ -0,0 +1,13 @@ +package grids + +import ( + "testing" + "time" +) + +func TestUnixTime(t *testing.T) { + for i := 0; i < 5; i++ { + t.Log(UnixTime(), "real:", time.Now().Unix()) + time.Sleep(1 * time.Second) + } +} diff --git a/internal/iplibrary/init.go b/internal/iplibrary/init.go new file mode 100644 index 0000000..660c3f7 --- /dev/null +++ b/internal/iplibrary/init.go @@ -0,0 +1,5 @@ +package iplibrary + +func init() { + +} diff --git a/internal/iplibrary/ip_item.go b/internal/iplibrary/ip_item.go new file mode 100644 index 0000000..1b7c6bd --- /dev/null +++ b/internal/iplibrary/ip_item.go @@ -0,0 +1,26 @@ +package iplibrary + +import "github.com/TeaOSLab/EdgeNode/internal/utils" + +type IPItem struct { + Id int64 + IPFrom uint32 + IPTo uint32 + ExpiredAt int64 +} + +func (this *IPItem) Contains(ip uint32) bool { + if this.IPTo == 0 { + if this.IPFrom != ip { + return false + } + } else { + if this.IPFrom > ip || this.IPTo < ip { + return false + } + } + if this.ExpiredAt > 0 && this.ExpiredAt < utils.UnixTime() { + return false + } + return true +} diff --git a/internal/iplibrary/ip_item_test.go b/internal/iplibrary/ip_item_test.go new file mode 100644 index 0000000..632cb0a --- /dev/null +++ b/internal/iplibrary/ip_item_test.go @@ -0,0 +1,73 @@ +package iplibrary + +import ( + "github.com/iwind/TeaGo/assert" + "testing" + "time" +) + +func TestIPItem_Contains(t *testing.T) { + a := assert.NewAssertion(t) + + { + item := &IPItem{ + IPFrom: IP2Long("192.168.1.100"), + IPTo: 0, + ExpiredAt: 0, + } + a.IsTrue(item.Contains(IP2Long("192.168.1.100"))) + } + + { + item := &IPItem{ + IPFrom: IP2Long("192.168.1.100"), + IPTo: 0, + ExpiredAt: time.Now().Unix() + 1, + } + a.IsTrue(item.Contains(IP2Long("192.168.1.100"))) + } + + { + item := &IPItem{ + IPFrom: IP2Long("192.168.1.100"), + IPTo: 0, + ExpiredAt: time.Now().Unix() - 1, + } + a.IsFalse(item.Contains(IP2Long("192.168.1.100"))) + } + { + item := &IPItem{ + IPFrom: IP2Long("192.168.1.100"), + IPTo: 0, + ExpiredAt: 0, + } + a.IsFalse(item.Contains(IP2Long("192.168.1.101"))) + } + + { + item := &IPItem{ + IPFrom: IP2Long("192.168.1.1"), + IPTo: IP2Long("192.168.1.101"), + ExpiredAt: 0, + } + a.IsTrue(item.Contains(IP2Long("192.168.1.100"))) + } + + { + item := &IPItem{ + IPFrom: IP2Long("192.168.1.1"), + IPTo: IP2Long("192.168.1.100"), + ExpiredAt: 0, + } + a.IsTrue(item.Contains(IP2Long("192.168.1.100"))) + } + + { + item := &IPItem{ + IPFrom: IP2Long("192.168.1.1"), + IPTo: IP2Long("192.168.1.101"), + ExpiredAt: 0, + } + a.IsTrue(item.Contains(IP2Long("192.168.1.1"))) + } +} diff --git a/internal/iplibrary/ip_list.go b/internal/iplibrary/ip_list.go new file mode 100644 index 0000000..8daeef3 --- /dev/null +++ b/internal/iplibrary/ip_list.go @@ -0,0 +1,45 @@ +package iplibrary + +import ( + "sync" +) + +// IP名单 +type IPList struct { + itemsMap map[int64]*IPItem // id => item + + locker sync.RWMutex +} + +func NewIPList() *IPList { + return &IPList{ + itemsMap: map[int64]*IPItem{}, + } +} + +func (this *IPList) Add(item *IPItem) { + this.locker.Lock() + this.itemsMap[item.Id] = item + this.locker.Unlock() +} + +func (this *IPList) Delete(itemId int64) { + this.locker.Lock() + delete(this.itemsMap, itemId) + this.locker.Unlock() +} + +// 判断是否包含某个IP +func (this *IPList) Contains(ip uint32) bool { + // TODO 优化查询速度,可能需要把items分成两组,一组是单个的,一组是按照范围的,按照范围的再进行二分法查找 + this.locker.RLock() + for _, item := range this.itemsMap { + if item.Contains(ip) { + this.locker.RUnlock() + return true + } + } + this.locker.RUnlock() + + return false +} diff --git a/internal/iplibrary/ip_list_test.go b/internal/iplibrary/ip_list_test.go new file mode 100644 index 0000000..f5a7315 --- /dev/null +++ b/internal/iplibrary/ip_list_test.go @@ -0,0 +1,53 @@ +package iplibrary + +import ( + "runtime" + "strconv" + "testing" + "time" +) + +func TestNewIPList_Memory(t *testing.T) { + list := NewIPList() + + for i := 0; i < 200_0000; i++ { + list.Add(&IPItem{ + IPFrom: 1, + IPTo: 2, + ExpiredAt: time.Now().Unix(), + }) + } + + t.Log("ok") +} + +func TestIPList_Contains(t *testing.T) { + list := NewIPList() + for i := 0; i < 255; i++ { + list.Add(&IPItem{ + Id: int64(i), + IPFrom: IP2Long("192.168.1." + strconv.Itoa(i)), + IPTo: 0, + ExpiredAt: 0, + }) + } + t.Log(list.Contains(IP2Long("192.168.1.100"))) + t.Log(list.Contains(IP2Long("192.168.2.100"))) +} + +func BenchmarkIPList_Contains(b *testing.B) { + runtime.GOMAXPROCS(1) + + list := NewIPList() + for i := 0; i < 10_000; i++ { + list.Add(&IPItem{ + Id: int64(i), + IPFrom: IP2Long("192.168.1." + strconv.Itoa(i)), + IPTo: 0, + ExpiredAt: time.Now().Unix() + 60, + }) + } + for i := 0; i < b.N; i++ { + _ = list.Contains(IP2Long("192.168.1.100")) + } +} diff --git a/internal/iplibrary/ip_utils.go b/internal/iplibrary/ip_utils.go new file mode 100644 index 0000000..16c7dd6 --- /dev/null +++ b/internal/iplibrary/ip_utils.go @@ -0,0 +1,19 @@ +package iplibrary + +import ( + "encoding/binary" + "net" +) + +// 将IP转换为整型 +func IP2Long(ip string) uint32 { + s := net.ParseIP(ip) + if s == nil { + return 0 + } + + if len(s) == 16 { + return binary.BigEndian.Uint32(s[12:16]) + } + return binary.BigEndian.Uint32(s) +} diff --git a/internal/iplibrary/ip_utils_test.go b/internal/iplibrary/ip_utils_test.go new file mode 100644 index 0000000..da7007b --- /dev/null +++ b/internal/iplibrary/ip_utils_test.go @@ -0,0 +1,21 @@ +package iplibrary + +import ( + "runtime" + "testing" +) + +func TestIP2Long(t *testing.T) { + t.Log(IP2Long("192.168.1.100")) + t.Log(IP2Long("192.168.1.101")) + t.Log(IP2Long("202.106.0.20")) + t.Log(IP2Long("192.168.1")) // wrong ip, should return 0 +} + +func BenchmarkIP2Long(b *testing.B) { + runtime.GOMAXPROCS(1) + + for i := 0; i < b.N; i++ { + _ = IP2Long("192.168.1.100") + } +} diff --git a/internal/iplibrary/library_interface.go b/internal/iplibrary/library_interface.go new file mode 100644 index 0000000..df0ef7c --- /dev/null +++ b/internal/iplibrary/library_interface.go @@ -0,0 +1,12 @@ +package iplibrary + +type LibraryInterface interface { + // 加载数据库文件 + Load(dbPath string) error + + // 查询IP + Lookup(ip string) (*Result, error) + + // 关闭数据库文件 + Close() +} diff --git a/internal/iplibrary/library_ip2region.go b/internal/iplibrary/library_ip2region.go new file mode 100644 index 0000000..2f20f98 --- /dev/null +++ b/internal/iplibrary/library_ip2region.go @@ -0,0 +1,72 @@ +package iplibrary + +import ( + "fmt" + "github.com/TeaOSLab/EdgeNode/internal/errors" + "github.com/iwind/TeaGo/logs" + "github.com/lionsoul2014/ip2region/binding/golang/ip2region" +) + +type IP2RegionLibrary struct { + db *ip2region.Ip2Region +} + +func (this *IP2RegionLibrary) Load(dbPath string) error { + db, err := ip2region.New(dbPath) + if err != nil { + return err + } + this.db = db + + return nil +} + +func (this *IP2RegionLibrary) Lookup(ip string) (*Result, error) { + if this.db == nil { + return nil, errors.New("library has not been loaded") + } + + defer func() { + // 防止panic发生 + err := recover() + if err != nil { + logs.Println("[IP2RegionLibrary]panic: " + fmt.Sprintf("%#v", err)) + } + }() + + info, err := this.db.MemorySearch(ip) + if err != nil { + return nil, err + } + + if info.Country == "0" { + info.Country = "" + } + if info.Region == "0" { + info.Region = "" + } + if info.Province == "0" { + info.Province = "" + } + if info.City == "0" { + info.City = "" + } + if info.ISP == "0" { + info.ISP = "" + } + + return &Result{ + CityId: info.CityId, + Country: info.Country, + Region: info.Region, + Province: info.Province, + City: info.City, + ISP: info.ISP, + }, nil +} + +func (this *IP2RegionLibrary) Close() { + if this.db != nil { + this.db.Close() + } +} diff --git a/internal/iplibrary/library_ip2region_test.go b/internal/iplibrary/library_ip2region_test.go new file mode 100644 index 0000000..fef27ef --- /dev/null +++ b/internal/iplibrary/library_ip2region_test.go @@ -0,0 +1,55 @@ +package iplibrary + +import ( + "github.com/iwind/TeaGo/Tea" + _ "github.com/iwind/TeaGo/bootstrap" + "github.com/iwind/TeaGo/logs" + "github.com/iwind/TeaGo/rands" + "runtime" + "strconv" + "testing" + "time" +) + +func TestIP2RegionLibrary_Lookup(t *testing.T) { + library := &IP2RegionLibrary{} + err := library.Load(Tea.Root + "/resources/ipdata/ip2region/ip2region.db") + if err != nil { + t.Fatal(err) + } + result, err := library.Lookup("114.240.223.47") + if err != nil { + t.Fatal(err) + } + logs.PrintAsJSON(result, t) +} + +func TestIP2RegionLibrary_Memory(t *testing.T) { + library := &IP2RegionLibrary{} + err := library.Load(Tea.Root + "/resources/ipdata/ip2region/ip2region.db") + if err != nil { + t.Fatal(err) + } + + before := time.Now() + + for i := 0; i < 1_000_000; i++ { + _, _ = library.Lookup(strconv.Itoa(rands.Int(0, 254)) + "." + strconv.Itoa(rands.Int(0, 254)) + "." + strconv.Itoa(rands.Int(0, 254)) + "." + strconv.Itoa(rands.Int(0, 254))) + } + + t.Log("cost:", time.Since(before).Seconds()*1000, "ms") +} + +func BenchmarkIP2RegionLibrary_Lookup(b *testing.B) { + runtime.GOMAXPROCS(1) + + library := &IP2RegionLibrary{} + err := library.Load(Tea.Root + "/resources/ipdata/ip2region/ip2region.db") + if err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + _, _ = library.Lookup(strconv.Itoa(rands.Int(0, 254)) + "." + strconv.Itoa(rands.Int(0, 254)) + "." + strconv.Itoa(rands.Int(0, 254)) + "." + strconv.Itoa(rands.Int(0, 254))) + } +} diff --git a/internal/iplibrary/manager.go b/internal/iplibrary/manager.go new file mode 100644 index 0000000..0e7e797 --- /dev/null +++ b/internal/iplibrary/manager.go @@ -0,0 +1,95 @@ +package iplibrary + +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/TeaOSLab/EdgeNode/internal/errors" + "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/files" + "github.com/iwind/TeaGo/logs" + "github.com/iwind/TeaGo/types" + "regexp" + "strings" +) + +var SharedManager = NewManager() +var SharedLibrary LibraryInterface + +func init() { + events.On(events.EventStart, func() { + // 初始化 + library, err := SharedManager.Load() + if err != nil { + logs.Println("[IP_LIBRARY]" + err.Error()) + return + } + SharedLibrary = library + }) +} + +type Manager struct { + code string +} + +func NewManager() *Manager { + return &Manager{} +} + +func (this *Manager) Load() (LibraryInterface, error) { + nodeConfig, err := nodeconfigs.SharedNodeConfig() + if err != nil { + return nil, err + } + config := nodeConfig.GlobalConfig + if config == nil { + config = &serverconfigs.GlobalConfig{} + } + + // 当前正在使用的IP库代号 + code := config.IPLibrary.Code + if len(code) == 0 { + code = serverconfigs.DefaultIPLibraryType + } + + dir := Tea.Root + "/resources/ipdata/" + code + var lastVersion int64 = -1 + lastFilename := "" + for _, file := range files.NewFile(dir).List() { + filename := file.Name() + + reg := regexp.MustCompile(`^` + regexp.QuoteMeta(code) + `.(\d+)\.`) + if reg.MatchString(filename) { // 先查找有版本号的 + result := reg.FindStringSubmatch(filename) + version := types.Int64(result[1]) + if version > lastVersion { + lastVersion = version + lastFilename = filename + } + } else if strings.HasPrefix(filename, code+".") { // 后查找默认的 + if lastVersion == -1 { + lastFilename = filename + lastVersion = 0 + } + } + } + + if len(lastFilename) == 0 { + return nil, errors.New("ip library file not found") + } + + var libraryPtr LibraryInterface + switch code { + case serverconfigs.IPLibraryTypeIP2Region: + libraryPtr = &IP2RegionLibrary{} + default: + return nil, errors.New("invalid ip library code '" + code + "'") + } + + err = libraryPtr.Load(dir + "/" + lastFilename) + if err != nil { + return nil, err + } + + return libraryPtr, nil +} diff --git a/internal/iplibrary/manager_country.go b/internal/iplibrary/manager_country.go new file mode 100644 index 0000000..7dd792e --- /dev/null +++ b/internal/iplibrary/manager_country.go @@ -0,0 +1,137 @@ +package iplibrary + +import ( + "crypto/md5" + "encoding/json" + "fmt" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/TeaOSLab/EdgeNode/internal/logs" + "github.com/TeaOSLab/EdgeNode/internal/rpc" + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/iwind/TeaGo/Tea" + _ "github.com/iwind/TeaGo/bootstrap" + "io/ioutil" + "os" + "sync" + "time" +) + +var SharedCountryManager = NewCountryManager() + +func init() { + events.On(events.EventStart, func() { + go SharedCountryManager.Start() + }) +} + +// 国家信息管理 +type CountryManager struct { + cacheFile string + + countryMap map[string]int64 // countryName => countryId + dataHash string // 国家JSON的md5 + + locker sync.RWMutex +} + +func NewCountryManager() *CountryManager { + return &CountryManager{ + cacheFile: Tea.Root + "/configs/region_country.json.cache", + countryMap: map[string]int64{}, + } +} + +func (this *CountryManager) Start() { + // 从缓存中读取 + err := this.load() + if err != nil { + logs.Error("COUNTRY_MANAGER", err.Error()) + } + + // 第一次更新 + err = this.loop() + if err != nil { + logs.Error("COUNTRY_MANAGER", err.Error()) + } + + // 定时更新 + ticker := utils.NewTicker(1 * time.Hour) + events.On(events.EventQuit, func() { + ticker.Stop() + }) + for range ticker.C { + err := this.loop() + if err != nil { + logs.Error("COUNTRY_MANAGER", err.Error()) + } + } +} + +func (this *CountryManager) Lookup(countryName string) (countryId int64) { + this.locker.RLock() + countryId, _ = this.countryMap[countryName] + this.locker.RUnlock() + return countryId +} + +// 从缓存中读取 +func (this *CountryManager) load() error { + data, err := ioutil.ReadFile(this.cacheFile) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + m := map[string]int64{} + err = json.Unmarshal(data, &m) + if err != nil { + return err + } + if m != nil && len(m) > 0 { + this.countryMap = m + } + + return nil +} + +// 更新国家信息 +func (this *CountryManager) loop() error { + rpcClient, err := rpc.SharedRPC() + if err != nil { + return err + } + resp, err := rpcClient.RegionCountryRPC().FindAllEnabledRegionCountries(rpcClient.Context(), &pb.FindAllEnabledRegionCountriesRequest{}) + if err != nil { + return err + } + + m := map[string]int64{} + for _, country := range resp.Countries { + for _, code := range country.Codes { + m[code] = country.Id + } + } + + // 检查是否有更新 + data, err := json.Marshal(m) + if err != nil { + return err + } + hash := md5.New() + hash.Write(data) + dataHash := fmt.Sprintf("%x", hash.Sum(nil)) + if this.dataHash == dataHash { + return nil + } + this.dataHash = dataHash + + this.locker.Lock() + this.countryMap = m + this.locker.Unlock() + + // 保存到本地缓存 + err = ioutil.WriteFile(this.cacheFile, data, 0666) + return err +} diff --git a/internal/iplibrary/manager_country_test.go b/internal/iplibrary/manager_country_test.go new file mode 100644 index 0000000..ab9dd08 --- /dev/null +++ b/internal/iplibrary/manager_country_test.go @@ -0,0 +1,57 @@ +package iplibrary + +import ( + "runtime" + "testing" +) + +func TestCountryManager_load(t *testing.T) { + manager := NewCountryManager() + err := manager.load() + if err != nil { + t.Fatal(err) + } + t.Log("ok", manager.countryMap) +} + +func TestCountryManager_loop(t *testing.T) { + manager := NewCountryManager() + err := manager.loop() + if err != nil { + t.Fatal(err) + } + t.Log("ok", manager.countryMap) +} + +func TestCountryManager_loop_skip(t *testing.T) { + manager := NewCountryManager() + for i := 0; i < 10; i++ { + err := manager.loop() + if err != nil { + t.Fatal(err) + } + } +} + +func TestCountryManager_Lookup(t *testing.T) { + manager := NewCountryManager() + err := manager.load() + if err != nil { + t.Fatal(err) + } + t.Log(manager.Lookup("中国"), manager.Lookup("美国 ")) +} + +func BenchmarkCountryManager_Lookup(b *testing.B) { + runtime.GOMAXPROCS(1) + + manager := NewCountryManager() + err := manager.load() + if err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + _ = manager.Lookup("中国") + } +} diff --git a/internal/iplibrary/manager_ip_list.go b/internal/iplibrary/manager_ip_list.go new file mode 100644 index 0000000..78d4adb --- /dev/null +++ b/internal/iplibrary/manager_ip_list.go @@ -0,0 +1,124 @@ +package iplibrary + +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/TeaOSLab/EdgeNode/internal/logs" + "github.com/TeaOSLab/EdgeNode/internal/rpc" + "github.com/iwind/TeaGo/Tea" + "sync" + "time" +) + +var SharedIPListManager = NewIPListManager() + +func init() { + events.On(events.EventStart, func() { + go SharedIPListManager.Start() + }) +} + +// IP名单管理 +type IPListManager struct { + // 缓存文件 + // 每行一个数据:id|from|to|expiredAt + cacheFile string + + version int64 + pageSize int64 + + listMap map[int64]*IPList + locker sync.Mutex +} + +func NewIPListManager() *IPListManager { + return &IPListManager{ + cacheFile: Tea.Root + "/configs/ip_list.cache", + pageSize: 1000, + listMap: map[int64]*IPList{}, + } +} + +func (this *IPListManager) Start() { + // TODO 从缓存当中读取数据 + + // 第一次读取 + err := this.loop() + if err != nil { + logs.Println("IP_LIST_MANAGER", err.Error()) + } + + ticker := time.NewTicker(60 * time.Second) // TODO 未来改成可以手动触发IP变更事件 + events.On(events.EventQuit, func() { + ticker.Stop() + }) + for range ticker.C { + err := this.loop() + if err != nil { + logs.Println("IP_LIST_MANAGER", err.Error()) + } + } +} + +func (this *IPListManager) loop() error { + for { + hasNext, err := this.fetch() + if err != nil { + return err + } + if !hasNext { + break + } + } + + // TODO 写入到缓存当中 + + return nil +} + +func (this *IPListManager) fetch() (hasNext bool, err error) { + rpcClient, err := rpc.SharedRPC() + if err != nil { + return false, err + } + itemsResp, err := rpcClient.IPItemRPC().ListIPItemsAfterVersion(rpcClient.Context(), &pb.ListIPItemsAfterVersionRequest{ + Version: this.version, + Size: this.pageSize, + }) + if err != nil { + return false, err + } + items := itemsResp.IpItems + if len(items) == 0 { + return false, nil + } + this.locker.Lock() + for _, item := range items { + list, ok := this.listMap[item.ListId] + if !ok { + list = NewIPList() + this.listMap[item.ListId] = list + } + if item.IsDeleted { + list.Delete(item.Id) + continue + } + list.Add(&IPItem{ + Id: item.Id, + IPFrom: IP2Long(item.IpFrom), + IPTo: IP2Long(item.IpTo), + ExpiredAt: item.ExpiredAt, + }) + } + this.locker.Unlock() + this.version = items[len(items)-1].Version + + return true, nil +} + +func (this *IPListManager) FindList(listId int64) *IPList { + this.locker.Lock() + list, _ := this.listMap[listId] + this.locker.Unlock() + return list +} diff --git a/internal/iplibrary/manager_ip_list_test.go b/internal/iplibrary/manager_ip_list_test.go new file mode 100644 index 0000000..82f3faa --- /dev/null +++ b/internal/iplibrary/manager_ip_list_test.go @@ -0,0 +1,12 @@ +package iplibrary + +import "testing" + +func TestIPListManager_loop(t *testing.T) { + manager := NewIPListManager() + manager.pageSize = 2 + err := manager.loop() + if err != nil { + t.Fatal(err) + } +} diff --git a/internal/iplibrary/manager_province.go b/internal/iplibrary/manager_province.go new file mode 100644 index 0000000..1570a23 --- /dev/null +++ b/internal/iplibrary/manager_province.go @@ -0,0 +1,144 @@ +package iplibrary + +import ( + "crypto/md5" + "encoding/json" + "fmt" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/TeaOSLab/EdgeNode/internal/logs" + "github.com/TeaOSLab/EdgeNode/internal/rpc" + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/iwind/TeaGo/Tea" + _ "github.com/iwind/TeaGo/bootstrap" + "io/ioutil" + "os" + "sync" + "time" +) + +const ( + ChinaCountryId int64 = 1 +) + +var SharedProvinceManager = NewProvinceManager() + +func init() { + events.On(events.EventStart, func() { + go SharedProvinceManager.Start() + }) +} + +// 国家信息管理 +type ProvinceManager struct { + cacheFile string + + provinceMap map[string]int64 // provinceName => provinceId + dataHash string // 国家JSON的md5 + + locker sync.RWMutex +} + +func NewProvinceManager() *ProvinceManager { + return &ProvinceManager{ + cacheFile: Tea.Root + "/configs/region_province.json.cache", + provinceMap: map[string]int64{}, + } +} + +func (this *ProvinceManager) Start() { + // 从缓存中读取 + err := this.load() + if err != nil { + logs.Error("PROVINCE_MANAGER", err.Error()) + } + + // 第一次更新 + err = this.loop() + if err != nil { + logs.Error("PROVINCE_MANAGER", err.Error()) + } + + // 定时更新 + ticker := utils.NewTicker(1 * time.Hour) + events.On(events.EventQuit, func() { + ticker.Stop() + }) + for range ticker.C { + err := this.loop() + if err != nil { + logs.Error("PROVINCE_MANAGER", err.Error()) + } + } +} + +func (this *ProvinceManager) Lookup(provinceName string) (provinceId int64) { + this.locker.RLock() + provinceId, _ = this.provinceMap[provinceName] + this.locker.RUnlock() + return provinceId +} + +// 从缓存中读取 +func (this *ProvinceManager) load() error { + data, err := ioutil.ReadFile(this.cacheFile) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + m := map[string]int64{} + err = json.Unmarshal(data, &m) + if err != nil { + return err + } + if m != nil && len(m) > 0 { + this.provinceMap = m + } + + return nil +} + +// 更新国家信息 +func (this *ProvinceManager) loop() error { + rpcClient, err := rpc.SharedRPC() + if err != nil { + return err + } + resp, err := rpcClient.RegionProvinceRPC().FindAllEnabledRegionProvincesWithCountryId(rpcClient.Context(), &pb.FindAllEnabledRegionProvincesWithCountryIdRequest{ + CountryId: ChinaCountryId, + }) + if err != nil { + return err + } + + m := map[string]int64{} + for _, province := range resp.Provinces { + for _, code := range province.Codes { + m[code] = province.Id + } + } + + // 检查是否有更新 + data, err := json.Marshal(m) + if err != nil { + return err + } + hash := md5.New() + hash.Write(data) + dataHash := fmt.Sprintf("%x", hash.Sum(nil)) + if this.dataHash == dataHash { + return nil + } + this.dataHash = dataHash + + this.locker.Lock() + this.provinceMap = m + this.locker.Unlock() + + // 保存到本地缓存 + + err = ioutil.WriteFile(this.cacheFile, data, 0666) + return err +} diff --git a/internal/iplibrary/manager_province_test.go b/internal/iplibrary/manager_province_test.go new file mode 100644 index 0000000..1cc93f9 --- /dev/null +++ b/internal/iplibrary/manager_province_test.go @@ -0,0 +1,57 @@ +package iplibrary + +import ( + "runtime" + "testing" +) + +func TestProvinceManager_load(t *testing.T) { + manager := NewProvinceManager() + err := manager.load() + if err != nil { + t.Fatal(err) + } + t.Log("ok", manager.provinceMap) +} + +func TestProvinceManager_loop(t *testing.T) { + manager := NewProvinceManager() + err := manager.loop() + if err != nil { + t.Fatal(err) + } + t.Log("ok", manager.provinceMap) +} + +func TestProvinceManager_loop_skip(t *testing.T) { + manager := NewProvinceManager() + for i := 0; i < 10; i++ { + err := manager.loop() + if err != nil { + t.Fatal(err) + } + } +} + +func TestProvinceManager_Lookup(t *testing.T) { + manager := NewProvinceManager() + err := manager.load() + if err != nil { + t.Fatal(err) + } + t.Log(manager.Lookup("安徽省"), manager.Lookup("北京市")) +} + +func BenchmarkProvinceManager_Lookup(b *testing.B) { + runtime.GOMAXPROCS(1) + + manager := NewProvinceManager() + err := manager.load() + if err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + _ = manager.Lookup("安徽省") + } +} diff --git a/internal/iplibrary/manager_test.go b/internal/iplibrary/manager_test.go new file mode 100644 index 0000000..a4faa9c --- /dev/null +++ b/internal/iplibrary/manager_test.go @@ -0,0 +1,26 @@ +package iplibrary + +import ( + _ "github.com/iwind/TeaGo/bootstrap" + "github.com/iwind/TeaGo/dbs" + "testing" +) + +func TestManager_Load(t *testing.T) { + dbs.NotifyReady() + + manager := NewManager() + lib, err := manager.Load() + if err != nil { + t.Fatal(err) + } + t.Log(lib.Lookup("1.2.3.4")) + t.Log(lib.Lookup("2.3.4.5")) + t.Log(lib.Lookup("200.200.200.200")) + t.Log(lib.Lookup("202.106.0.20")) +} + +func TestNewManager(t *testing.T) { + dbs.NotifyReady() + t.Log(SharedLibrary) +} diff --git a/internal/iplibrary/result.go b/internal/iplibrary/result.go new file mode 100644 index 0000000..c0a62ec --- /dev/null +++ b/internal/iplibrary/result.go @@ -0,0 +1,10 @@ +package iplibrary + +type Result struct { + CityId int64 + Country string + Region string + Province string + City string + ISP string +} diff --git a/internal/iplibrary/updater.go b/internal/iplibrary/updater.go new file mode 100644 index 0000000..3ac1d25 --- /dev/null +++ b/internal/iplibrary/updater.go @@ -0,0 +1,141 @@ +package iplibrary + +import ( + "fmt" + "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/TeaOSLab/EdgeNode/internal/errors" + "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/TeaOSLab/EdgeNode/internal/rpc" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/logs" + "os" + "time" +) + +func init() { + events.On(events.EventStart, func() { + updater := NewUpdater() + updater.Start() + }) +} + +// IP库更新程序 +type Updater struct { +} + +// 获取新对象 +func NewUpdater() *Updater { + return &Updater{} +} + +// 开始更新 +func (this *Updater) Start() { + // 这里不需要太频繁检查更新,因为通常不需要更新IP库 + ticker := time.NewTicker(1 * time.Hour) + go func() { + for range ticker.C { + err := this.loop() + if err != nil { + logs.Println("[IP_LIBRARY]" + err.Error()) + } + } + }() +} + +// 单次任务 +func (this *Updater) loop() error { + nodeConfig, err := nodeconfigs.SharedNodeConfig() + if err != nil { + return err + } + if nodeConfig.GlobalConfig == nil { + return nil + } + code := nodeConfig.GlobalConfig.IPLibrary.Code + if len(code) == 0 { + code = serverconfigs.DefaultIPLibraryType + } + + rpcClient, err := rpc.SharedRPC() + if err != nil { + return err + } + libraryResp, err := rpcClient.IPLibraryRPC().FindLatestIPLibraryWithType(rpcClient.Context(), &pb.FindLatestIPLibraryWithTypeRequest{Type: code}) + if err != nil { + return err + } + lib := libraryResp.IpLibrary + if lib == nil || lib.File == nil { + return nil + } + + typeInfo := serverconfigs.FindIPLibraryWithType(code) + if typeInfo == nil { + return errors.New("invalid ip library code '" + code + "'") + } + + path := Tea.Root + "/resources/ipdata/" + code + "/" + code + "." + fmt.Sprintf("%d", lib.CreatedAt) + typeInfo.GetString("ext") + + // 是否已经存在 + _, err = os.Stat(path) + if err == nil { + return nil + } + + // 开始下载 + fileChunkIdsResp, err := rpcClient.FileChunkRPC().FindAllFileChunkIds(rpcClient.Context(), &pb.FindAllFileChunkIdsRequest{FileId: lib.File.Id}) + if err != nil { + return err + } + chunkIds := fileChunkIdsResp.FileChunkIds + if len(chunkIds) == 0 { + return nil + } + isOk := false + + fp, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0666) + if err != nil { + return err + } + + defer func() { + // 如果保存不成功就直接删除 + if !isOk { + _ = fp.Close() + _ = os.Remove(path) + } + }() + for _, chunkId := range chunkIds { + chunkResp, err := rpcClient.FileChunkRPC().DownloadFileChunk(rpcClient.Context(), &pb.DownloadFileChunkRequest{FileChunkId: chunkId}) + if err != nil { + return err + } + chunk := chunkResp.FileChunk + + if chunk == nil { + continue + } + _, err = fp.Write(chunk.Data) + if err != nil { + return err + } + } + + err = fp.Close() + if err != nil { + return err + } + + // 重新加载 + library, err := SharedManager.Load() + if err != nil { + return err + } + SharedLibrary = library + + isOk = true + + return nil +} diff --git a/internal/iplibrary/updater_test.go b/internal/iplibrary/updater_test.go new file mode 100644 index 0000000..17a98f1 --- /dev/null +++ b/internal/iplibrary/updater_test.go @@ -0,0 +1,18 @@ +package iplibrary + +import ( + _ "github.com/iwind/TeaGo/bootstrap" + "github.com/iwind/TeaGo/dbs" + "testing" +) + +func TestUpdater_loop(t *testing.T) { + dbs.NotifyReady() + + updater := NewUpdater() + err := updater.loop() + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index 96dcfb4..aec3816 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -1,22 +1,85 @@ package nodes import ( + "github.com/TeaOSLab/EdgeNode/internal/iplibrary" + "github.com/TeaOSLab/EdgeNode/internal/logs" "github.com/TeaOSLab/EdgeNode/internal/waf" - "github.com/iwind/TeaGo/logs" + "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" "net/http" ) // 调用WAF func (this *HTTPRequest) doWAFRequest() (blocked bool) { + // 检查配置是否为空 + if this.web.FirewallPolicy == nil || this.web.FirewallPolicy.Inbound == nil || !this.web.FirewallPolicy.Inbound.IsOn { + return + } + + // 检查IP白名单 + remoteAddr := this.requestRemoteAddr() + inbound := this.web.FirewallPolicy.Inbound + if inbound.WhiteListRef != nil && inbound.WhiteListRef.IsOn && inbound.WhiteListRef.ListId > 0 { + list := iplibrary.SharedIPListManager.FindList(inbound.WhiteListRef.ListId) + if list != nil && list.Contains(iplibrary.IP2Long(remoteAddr)) { + return + } + } + + // 检查IP黑名单 + if inbound.BlackListRef != nil && inbound.BlackListRef.IsOn && inbound.BlackListRef.ListId > 0 { + list := iplibrary.SharedIPListManager.FindList(inbound.BlackListRef.ListId) + if list != nil && list.Contains(iplibrary.IP2Long(remoteAddr)) { + // TODO 可以配置对封禁的处理方式等 + this.writer.WriteHeader(http.StatusForbidden) + this.writer.Close() + + return true + } + } + + // 检查地区封禁 + if iplibrary.SharedLibrary != nil { + if this.web.FirewallPolicy.Inbound.Region != nil && this.web.FirewallPolicy.Inbound.Region.IsOn { + regionConfig := this.web.FirewallPolicy.Inbound.Region + if regionConfig.IsNotEmpty() { + result, err := iplibrary.SharedLibrary.Lookup(remoteAddr) + if err != nil { + logs.Error("REQUEST", "iplibrary lookup failed: "+err.Error()) + } else if result != nil { + // 检查国家级别封禁 + if len(regionConfig.DenyCountryIds) > 0 && len(result.Country) > 0 { + countryId := iplibrary.SharedCountryManager.Lookup(result.Country) + if countryId > 0 && lists.ContainsInt64(regionConfig.DenyCountryIds, countryId) { + // TODO 可以配置对封禁的处理方式等 + this.writer.WriteHeader(http.StatusForbidden) + this.writer.Close() + return true + } + } + + // 检查省份封禁 + if len(regionConfig.DenyProvinceIds) > 0 && len(result.Province) > 0 { + provinceId := iplibrary.SharedProvinceManager.Lookup(result.Province) + if provinceId > 0 && lists.ContainsInt64(regionConfig.DenyProvinceIds, provinceId) { + // TODO 可以配置对封禁的处理方式等 + this.writer.WriteHeader(http.StatusForbidden) + this.writer.Close() + return true + } + } + } + } + } + } + w := sharedWAFManager.FindWAF(this.web.FirewallPolicy.Id) if w == nil { return } - goNext, ruleGroup, ruleSet, err := w.MatchRequest(this.RawReq, this.writer) if err != nil { - logs.Error(err) + logs.Error("REQUEST", this.rawURI+": "+err.Error()) return } @@ -42,7 +105,7 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) { goNext, ruleGroup, ruleSet, err := w.MatchResponse(this.RawReq, resp, this.writer) if err != nil { - logs.Error(err) + logs.Error("REQUEST", this.rawURI+": "+err.Error()) return } diff --git a/internal/nodes/node.go b/internal/nodes/node.go index c7520f1..808f4db 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -53,6 +53,9 @@ func (this *Node) Test() error { // 启动 func (this *Node) Start() { + // 启动事件 + events.Notify(events.EventStart) + // 处理信号 this.listenSignals() diff --git a/internal/rpc/rpc_client.go b/internal/rpc/rpc_client.go index 980f8f6..a595192 100644 --- a/internal/rpc/rpc_client.go +++ b/internal/rpc/rpc_client.go @@ -61,6 +61,34 @@ func (this *RPCClient) APINodeRPC() pb.APINodeServiceClient { return pb.NewAPINodeServiceClient(this.pickConn()) } +func (this *RPCClient) IPLibraryRPC() pb.IPLibraryServiceClient { + return pb.NewIPLibraryServiceClient(this.pickConn()) +} + +func (this *RPCClient) RegionCountryRPC() pb.RegionCountryServiceClient { + return pb.NewRegionCountryServiceClient(this.pickConn()) +} + +func (this *RPCClient) RegionProvinceRPC() pb.RegionProvinceServiceClient { + return pb.NewRegionProvinceServiceClient(this.pickConn()) +} + +func (this *RPCClient) IPListRPC() pb.IPListServiceClient { + return pb.NewIPListServiceClient(this.pickConn()) +} + +func (this *RPCClient) IPItemRPC() pb.IPItemServiceClient { + return pb.NewIPItemServiceClient(this.pickConn()) +} + +func (this *RPCClient) FileRPC() pb.FileServiceClient { + return pb.NewFileServiceClient(this.pickConn()) +} + +func (this *RPCClient) FileChunkRPC() pb.FileChunkServiceClient { + return pb.NewFileChunkServiceClient(this.pickConn()) +} + // 节点上下文信息 func (this *RPCClient) Context() context.Context { ctx := context.Background()