diff --git a/build/build.sh b/build/build.sh index 49f4babb..6d003ea7 100755 --- a/build/build.sh +++ b/build/build.sh @@ -55,6 +55,8 @@ function build() { rm -f $dist/deploy/.gitignore cp -R $ROOT/installers $DIST/ cp -R $ROOT/resources $DIST/ + rm -f $DIST/resources/ipdata/ip2region/global_region.csv + rm -f $DIST/resources/ipdata/ip2region/ip.merge.txt # building installer echo "building installer ..." diff --git a/cmd/ip2region/main.go b/cmd/ip2region/main.go index 7db20726..fb7a3405 100644 --- a/cmd/ip2region/main.go +++ b/cmd/ip2region/main.go @@ -11,6 +11,7 @@ import ( "io/ioutil" "os" "regexp" + "strings" ) func main() { @@ -110,4 +111,63 @@ func main() { logs.Println("done") } + + // 检查数据 + if lists.ContainsString(os.Args, "check") { + dbs.NotifyReady() + + data, err := ioutil.ReadFile(Tea.Root + "/resources/ipdata/ip2region/ip.merge.txt") + if err != nil { + logs.Println("[ERROR]" + err.Error()) + return + } + if len(data) == 0 { + logs.Println("[ERROR]file should not be empty") + return + } + lines := bytes.Split(data, []byte("\n")) + for index, line := range lines { + s := string(bytes.TrimSpace(line)) + if len(s) == 0 { + continue + } + pieces := strings.Split(s, "|") + countryName := pieces[2] + provinceName := pieces[4] + + if lists.ContainsString([]string{"0", "欧洲", "北美地区", "法国南部领地", "非洲地区", "亚太地区"}, countryName) { + continue + } + + // 检查国家 + countryId, err := models.SharedRegionCountryDAO.FindCountryIdWithCountryName(countryName) + if err != nil { + logs.Println("[ERROR]" + err.Error()) + return + } + if countryId == 0 { + logs.Println("[ERROR]can not find country '"+countryName+"', index: ", index, "data: "+s) + return + } + + // 检查省份 + if countryName == "中国" { + if lists.ContainsString([]string{"0"}, provinceName) { + continue + } + + provinceId, err := models.SharedRegionProvinceDAO.FindProvinceIdWithProvinceName(provinceName) + if err != nil { + logs.Println("[ERROR]" + err.Error()) + return + } + if provinceId == 0 { + logs.Println("[ERROR]can not find province '"+provinceName+"', index: ", index, "data: "+s) + return + } + } + } + + logs.Println("done") + } } diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index 02425485..6d431053 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -6,6 +6,7 @@ import ( "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/types" + "time" ) const ( @@ -66,7 +67,7 @@ func (this *IPItemDAO) FindEnabledIPItem(id int64) (*IPItem, error) { // 创建IP func (this *IPItemDAO) CreateIPItem(listId int64, ipFrom string, ipTo string, expiredAt int64, reason string) (int64, error) { - version, err := SharedIPListDAO.IncreaseVersion(listId) + version, err := SharedIPListDAO.IncreaseVersion() if err != nil { return 0, err } @@ -106,7 +107,7 @@ func (this *IPItemDAO) UpdateIPItem(itemId int64, ipFrom string, ipTo string, ex return errors.New("not found") } - version, err := SharedIPListDAO.IncreaseVersion(listId) + version, err := SharedIPListDAO.IncreaseVersion() if err != nil { return err } @@ -145,3 +146,17 @@ func (this *IPItemDAO) ListIPItemsWithListId(listId int64, offset int64, size in FindAll() return } + +// 根据版本号查找IP列表 +func (this *IPItemDAO) ListIPItemsAfterVersion(version int64, size int64) (result []*IPItem, err error) { + _, err = this.Query(). + // 这里不要设置状态参数,因为我们要知道哪些是删除的 + Gt("version", version). + Where("(expiredAt=0 OR expiredAt>:expiredAt)"). + Param("expiredAt", time.Now().Unix()). + Asc("version"). + Limit(size). + Slice(&result). + FindAll() + return +} diff --git a/internal/db/models/ip_list_dao.go b/internal/db/models/ip_list_dao.go index b744733e..9fbd16a9 100644 --- a/internal/db/models/ip_list_dao.go +++ b/internal/db/models/ip_list_dao.go @@ -2,6 +2,7 @@ package models import ( "github.com/TeaOSLab/EdgeAPI/internal/errors" + "github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" @@ -110,20 +111,20 @@ func (this *IPListDAO) UpdateIPList(listId int64, name string, code string, time } // 增加版本 -func (this *IPListDAO) IncreaseVersion(listId int64) (int64, error) { - if listId <= 0 { - return 0, errors.New("invalid listId") - } - op := NewIPListOperator() - op.Id = listId - op.Version = dbs.SQL("version+1") - _, err := this.Save(op) +func (this *IPListDAO) IncreaseVersion() (int64, error) { + valueJSON, err := SharedSysSettingDAO.ReadSetting(SettingCodeIPListVersion) if err != nil { return 0, err } + if len(valueJSON) == 0 { + err = SharedSysSettingDAO.UpdateSetting(SettingCodeIPListVersion, []byte("1")) + if err != nil { + return 0, err + } + return 1, nil + } - return this.Query(). - Pk(listId). - Result("version"). - FindInt64Col(0) + value := types.Int64(string(valueJSON)) + 1 + err = SharedSysSettingDAO.UpdateSetting(SettingCodeIPListVersion, []byte(numberutils.FormatInt64(value))) + return value, nil } diff --git a/internal/db/models/ip_list_dao_test.go b/internal/db/models/ip_list_dao_test.go index c29c4aab..5f80a2be 100644 --- a/internal/db/models/ip_list_dao_test.go +++ b/internal/db/models/ip_list_dao_test.go @@ -2,13 +2,16 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" "runtime" "testing" ) func TestIPListDAO_IncreaseVersion(t *testing.T) { + dbs.NotifyReady() + dao := NewIPListDAO() - version, err := dao.IncreaseVersion(1) + version, err := dao.IncreaseVersion() if err != nil { t.Fatal(err) } @@ -18,8 +21,10 @@ func TestIPListDAO_IncreaseVersion(t *testing.T) { func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) { runtime.GOMAXPROCS(1) + dbs.NotifyReady() + dao := NewIPListDAO() for i := 0; i < b.N; i++ { - _, _ = dao.IncreaseVersion(1) + _, _ = dao.IncreaseVersion() } } diff --git a/internal/db/models/ip_list_model.go b/internal/db/models/ip_list_model.go index d045c650..9e8e7fef 100644 --- a/internal/db/models/ip_list_model.go +++ b/internal/db/models/ip_list_model.go @@ -12,7 +12,6 @@ type IPList struct { State uint8 `field:"state"` // 状态 CreatedAt uint64 `field:"createdAt"` // 创建时间 Timeout string `field:"timeout"` // 默认超时时间 - Version uint64 `field:"version"` // 版本 } type IPListOperator struct { @@ -26,7 +25,6 @@ type IPListOperator struct { State interface{} // 状态 CreatedAt interface{} // 创建时间 Timeout interface{} // 默认超时时间 - Version interface{} // 版本 } func NewIPListOperator() *IPListOperator { diff --git a/internal/db/models/region_country_dao.go b/internal/db/models/region_country_dao.go index a0973690..b7430c43 100644 --- a/internal/db/models/region_country_dao.go +++ b/internal/db/models/region_country_dao.go @@ -82,6 +82,15 @@ func (this *RegionCountryDAO) FindCountryIdWithDataId(dataId string) (int64, err FindInt64Col(0) } +// 根据国家名查找国家ID +func (this *RegionCountryDAO) FindCountryIdWithCountryName(countryName string) (int64, error) { + return this.Query(). + Where("JSON_CONTAINS(codes, :countryName)"). + Param("countryName", "\""+countryName+"\""). // 查询的需要是个JSON字符串,所以这里加双引号 + ResultPk(). + FindInt64Col(0) +} + // 根据数据ID创建国家 func (this *RegionCountryDAO) CreateCountry(name string, dataId string) (int64, error) { op := NewRegionCountryOperator() diff --git a/internal/db/models/region_country_model_ext.go b/internal/db/models/region_country_model_ext.go index 2640e7f9..0b0d46f2 100644 --- a/internal/db/models/region_country_model_ext.go +++ b/internal/db/models/region_country_model_ext.go @@ -1 +1,18 @@ package models + +import ( + "encoding/json" + "github.com/iwind/TeaGo/logs" +) + +func (this *RegionCountry) DecodeCodes() []string { + if len(this.Codes) == 0 { + return []string{} + } + result := []string{} + err := json.Unmarshal([]byte(this.Codes), &result) + if err != nil { + logs.Error(err) + } + return result +} diff --git a/internal/db/models/region_province_dao.go b/internal/db/models/region_province_dao.go index 8e5914f4..9871f69d 100644 --- a/internal/db/models/region_province_dao.go +++ b/internal/db/models/region_province_dao.go @@ -80,6 +80,15 @@ func (this *RegionProvinceDAO) FindProvinceIdWithDataId(dataId string) (int64, e FindInt64Col(0) } +// 根据省份名查找省份ID +func (this *RegionProvinceDAO) FindProvinceIdWithProvinceName(provinceName string) (int64, error) { + return this.Query(). + Where("JSON_CONTAINS(codes, :provinceName)"). + Param("provinceName", "\""+provinceName+"\""). // 查询的需要是个JSON字符串,所以这里加双引号 + ResultPk(). + FindInt64Col(0) +} + // 创建省份 func (this *RegionProvinceDAO) CreateProvince(countryId int64, name string, dataId string) (int64, error) { op := NewRegionProvinceOperator() diff --git a/internal/db/models/region_province_model_ext.go b/internal/db/models/region_province_model_ext.go index 2640e7f9..a7c5a97a 100644 --- a/internal/db/models/region_province_model_ext.go +++ b/internal/db/models/region_province_model_ext.go @@ -1 +1,18 @@ package models + +import ( + "encoding/json" + "github.com/iwind/TeaGo/logs" +) + +func (this *RegionProvince) DecodeCodes() []string { + if len(this.Codes) == 0 { + return []string{} + } + result := []string{} + err := json.Unmarshal([]byte(this.Codes), &result) + if err != nil { + logs.Error(err) + } + return result +} diff --git a/internal/db/models/sys_setting_dao.go b/internal/db/models/sys_setting_dao.go index b108519d..f28c8df9 100644 --- a/internal/db/models/sys_setting_dao.go +++ b/internal/db/models/sys_setting_dao.go @@ -18,6 +18,7 @@ const ( SettingCodeServerGlobalConfig SettingCode = "serverGlobalConfig" // 服务相关全局设置 SettingCodeNodeMonitor SettingCode = "nodeMonitor" // 监控节点状态 SettingCodeClusterHealthCheck SettingCode = "clusterHealthCheck" // 集群健康检查 + SettingCodeIPListVersion SettingCode = "ipListVersion" // IP名单的版本号 ) func NewSysSettingDAO() *SysSettingDAO { diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go index 3083707f..7a3e824f 100644 --- a/internal/rpc/services/service_ip_item.go +++ b/internal/rpc/services/service_ip_item.go @@ -122,3 +122,32 @@ func (this *IPItemService) FindEnabledIPItem(ctx context.Context, req *pb.FindEn Reason: item.Reason, }}, nil } + +// 根据版本列出一组IP +func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb.ListIPItemsAfterVersionRequest) (*pb.ListIPItemsAfterVersionResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeNode) + if err != nil { + return nil, err + } + + result := []*pb.IPItem{} + items, err := models.SharedIPItemDAO.ListIPItemsAfterVersion(req.Version, req.Size) + if err != nil { + return nil, err + } + for _, item := range items { + result = append(result, &pb.IPItem{ + Id: int64(item.Id), + IpFrom: item.IpFrom, + IpTo: item.IpTo, + Version: int64(item.Version), + ExpiredAt: int64(item.ExpiredAt), + Reason: "", // 这里我们不需要这个数据 + ListId: int64(item.ListId), + IsDeleted: item.State == 0, + }) + } + + return &pb.ListIPItemsAfterVersionResponse{IpItems: result}, nil +} diff --git a/internal/rpc/services/service_region_country.go b/internal/rpc/services/service_region_country.go index 62f05165..fa72527d 100644 --- a/internal/rpc/services/service_region_country.go +++ b/internal/rpc/services/service_region_country.go @@ -15,7 +15,7 @@ type RegionCountryService struct { // 查找所有的国家列表 func (this *RegionCountryService) FindAllEnabledRegionCountries(ctx context.Context, req *pb.FindAllEnabledRegionCountriesRequest) (*pb.FindAllEnabledRegionCountriesResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeNode) if err != nil { return nil, err } @@ -39,6 +39,7 @@ func (this *RegionCountryService) FindAllEnabledRegionCountries(ctx context.Cont result = append(result, &pb.RegionCountry{ Id: int64(country.Id), Name: country.Name, + Codes: country.DecodeCodes(), Pinyin: pinyinStrings, }) } diff --git a/internal/rpc/services/service_region_province.go b/internal/rpc/services/service_region_province.go index 64052d6d..cb1d54bb 100644 --- a/internal/rpc/services/service_region_province.go +++ b/internal/rpc/services/service_region_province.go @@ -14,7 +14,7 @@ type RegionProvinceService struct { // 查找所有省份 func (this *RegionProvinceService) FindAllEnabledRegionProvincesWithCountryId(ctx context.Context, req *pb.FindAllEnabledRegionProvincesWithCountryIdRequest) (*pb.FindAllEnabledRegionProvincesWithCountryIdResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeNode) if err != nil { return nil, err } @@ -26,8 +26,9 @@ func (this *RegionProvinceService) FindAllEnabledRegionProvincesWithCountryId(ct result := []*pb.RegionProvince{} for _, province := range provinces { result = append(result, &pb.RegionProvince{ - Id: int64(province.Id), - Name: province.Name, + Id: int64(province.Id), + Name: province.Name, + Codes: province.DecodeCodes(), }) }