实现IP黑白名单、国家|地区封禁、省份封禁

This commit is contained in:
GoEdgeLab
2020-11-09 10:44:00 +08:00
parent 48a3a28f32
commit 5417280163
14 changed files with 187 additions and 22 deletions

View File

@@ -55,6 +55,8 @@ function build() {
rm -f $dist/deploy/.gitignore
cp -R $ROOT/installers $DIST/
cp -R $ROOT/resources $DIST/
rm -f $DIST/resources/ipdata/ip2region/global_region.csv
rm -f $DIST/resources/ipdata/ip2region/ip.merge.txt
# building installer
echo "building installer ..."

View File

@@ -11,6 +11,7 @@ import (
"io/ioutil"
"os"
"regexp"
"strings"
)
func main() {
@@ -110,4 +111,63 @@ func main() {
logs.Println("done")
}
// 检查数据
if lists.ContainsString(os.Args, "check") {
dbs.NotifyReady()
data, err := ioutil.ReadFile(Tea.Root + "/resources/ipdata/ip2region/ip.merge.txt")
if err != nil {
logs.Println("[ERROR]" + err.Error())
return
}
if len(data) == 0 {
logs.Println("[ERROR]file should not be empty")
return
}
lines := bytes.Split(data, []byte("\n"))
for index, line := range lines {
s := string(bytes.TrimSpace(line))
if len(s) == 0 {
continue
}
pieces := strings.Split(s, "|")
countryName := pieces[2]
provinceName := pieces[4]
if lists.ContainsString([]string{"0", "欧洲", "北美地区", "法国南部领地", "非洲地区", "亚太地区"}, countryName) {
continue
}
// 检查国家
countryId, err := models.SharedRegionCountryDAO.FindCountryIdWithCountryName(countryName)
if err != nil {
logs.Println("[ERROR]" + err.Error())
return
}
if countryId == 0 {
logs.Println("[ERROR]can not find country '"+countryName+"', index: ", index, "data: "+s)
return
}
// 检查省份
if countryName == "中国" {
if lists.ContainsString([]string{"0"}, provinceName) {
continue
}
provinceId, err := models.SharedRegionProvinceDAO.FindProvinceIdWithProvinceName(provinceName)
if err != nil {
logs.Println("[ERROR]" + err.Error())
return
}
if provinceId == 0 {
logs.Println("[ERROR]can not find province '"+provinceName+"', index: ", index, "data: "+s)
return
}
}
}
logs.Println("done")
}
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
"time"
)
const (
@@ -66,7 +67,7 @@ func (this *IPItemDAO) FindEnabledIPItem(id int64) (*IPItem, error) {
// 创建IP
func (this *IPItemDAO) CreateIPItem(listId int64, ipFrom string, ipTo string, expiredAt int64, reason string) (int64, error) {
version, err := SharedIPListDAO.IncreaseVersion(listId)
version, err := SharedIPListDAO.IncreaseVersion()
if err != nil {
return 0, err
}
@@ -106,7 +107,7 @@ func (this *IPItemDAO) UpdateIPItem(itemId int64, ipFrom string, ipTo string, ex
return errors.New("not found")
}
version, err := SharedIPListDAO.IncreaseVersion(listId)
version, err := SharedIPListDAO.IncreaseVersion()
if err != nil {
return err
}
@@ -145,3 +146,17 @@ func (this *IPItemDAO) ListIPItemsWithListId(listId int64, offset int64, size in
FindAll()
return
}
// 根据版本号查找IP列表
func (this *IPItemDAO) ListIPItemsAfterVersion(version int64, size int64) (result []*IPItem, err error) {
_, err = this.Query().
// 这里不要设置状态参数,因为我们要知道哪些是删除的
Gt("version", version).
Where("(expiredAt=0 OR expiredAt>:expiredAt)").
Param("expiredAt", time.Now().Unix()).
Asc("version").
Limit(size).
Slice(&result).
FindAll()
return
}

View File

@@ -2,6 +2,7 @@ package models
import (
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
@@ -110,20 +111,20 @@ func (this *IPListDAO) UpdateIPList(listId int64, name string, code string, time
}
// 增加版本
func (this *IPListDAO) IncreaseVersion(listId int64) (int64, error) {
if listId <= 0 {
return 0, errors.New("invalid listId")
}
op := NewIPListOperator()
op.Id = listId
op.Version = dbs.SQL("version+1")
_, err := this.Save(op)
func (this *IPListDAO) IncreaseVersion() (int64, error) {
valueJSON, err := SharedSysSettingDAO.ReadSetting(SettingCodeIPListVersion)
if err != nil {
return 0, err
}
if len(valueJSON) == 0 {
err = SharedSysSettingDAO.UpdateSetting(SettingCodeIPListVersion, []byte("1"))
if err != nil {
return 0, err
}
return 1, nil
}
return this.Query().
Pk(listId).
Result("version").
FindInt64Col(0)
value := types.Int64(string(valueJSON)) + 1
err = SharedSysSettingDAO.UpdateSetting(SettingCodeIPListVersion, []byte(numberutils.FormatInt64(value)))
return value, nil
}

View File

@@ -2,13 +2,16 @@ package models
import (
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
"runtime"
"testing"
)
func TestIPListDAO_IncreaseVersion(t *testing.T) {
dbs.NotifyReady()
dao := NewIPListDAO()
version, err := dao.IncreaseVersion(1)
version, err := dao.IncreaseVersion()
if err != nil {
t.Fatal(err)
}
@@ -18,8 +21,10 @@ func TestIPListDAO_IncreaseVersion(t *testing.T) {
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
runtime.GOMAXPROCS(1)
dbs.NotifyReady()
dao := NewIPListDAO()
for i := 0; i < b.N; i++ {
_, _ = dao.IncreaseVersion(1)
_, _ = dao.IncreaseVersion()
}
}

View File

@@ -12,7 +12,6 @@ type IPList struct {
State uint8 `field:"state"` // 状态
CreatedAt uint64 `field:"createdAt"` // 创建时间
Timeout string `field:"timeout"` // 默认超时时间
Version uint64 `field:"version"` // 版本
}
type IPListOperator struct {
@@ -26,7 +25,6 @@ type IPListOperator struct {
State interface{} // 状态
CreatedAt interface{} // 创建时间
Timeout interface{} // 默认超时时间
Version interface{} // 版本
}
func NewIPListOperator() *IPListOperator {

View File

@@ -82,6 +82,15 @@ func (this *RegionCountryDAO) FindCountryIdWithDataId(dataId string) (int64, err
FindInt64Col(0)
}
// 根据国家名查找国家ID
func (this *RegionCountryDAO) FindCountryIdWithCountryName(countryName string) (int64, error) {
return this.Query().
Where("JSON_CONTAINS(codes, :countryName)").
Param("countryName", "\""+countryName+"\""). // 查询的需要是个JSON字符串所以这里加双引号
ResultPk().
FindInt64Col(0)
}
// 根据数据ID创建国家
func (this *RegionCountryDAO) CreateCountry(name string, dataId string) (int64, error) {
op := NewRegionCountryOperator()

View File

@@ -1 +1,18 @@
package models
import (
"encoding/json"
"github.com/iwind/TeaGo/logs"
)
func (this *RegionCountry) DecodeCodes() []string {
if len(this.Codes) == 0 {
return []string{}
}
result := []string{}
err := json.Unmarshal([]byte(this.Codes), &result)
if err != nil {
logs.Error(err)
}
return result
}

View File

@@ -80,6 +80,15 @@ func (this *RegionProvinceDAO) FindProvinceIdWithDataId(dataId string) (int64, e
FindInt64Col(0)
}
// 根据省份名查找省份ID
func (this *RegionProvinceDAO) FindProvinceIdWithProvinceName(provinceName string) (int64, error) {
return this.Query().
Where("JSON_CONTAINS(codes, :provinceName)").
Param("provinceName", "\""+provinceName+"\""). // 查询的需要是个JSON字符串所以这里加双引号
ResultPk().
FindInt64Col(0)
}
// 创建省份
func (this *RegionProvinceDAO) CreateProvince(countryId int64, name string, dataId string) (int64, error) {
op := NewRegionProvinceOperator()

View File

@@ -1 +1,18 @@
package models
import (
"encoding/json"
"github.com/iwind/TeaGo/logs"
)
func (this *RegionProvince) DecodeCodes() []string {
if len(this.Codes) == 0 {
return []string{}
}
result := []string{}
err := json.Unmarshal([]byte(this.Codes), &result)
if err != nil {
logs.Error(err)
}
return result
}

View File

@@ -18,6 +18,7 @@ const (
SettingCodeServerGlobalConfig SettingCode = "serverGlobalConfig" // 服务相关全局设置
SettingCodeNodeMonitor SettingCode = "nodeMonitor" // 监控节点状态
SettingCodeClusterHealthCheck SettingCode = "clusterHealthCheck" // 集群健康检查
SettingCodeIPListVersion SettingCode = "ipListVersion" // IP名单的版本号
)
func NewSysSettingDAO() *SysSettingDAO {

View File

@@ -122,3 +122,32 @@ func (this *IPItemService) FindEnabledIPItem(ctx context.Context, req *pb.FindEn
Reason: item.Reason,
}}, nil
}
// 根据版本列出一组IP
func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb.ListIPItemsAfterVersionRequest) (*pb.ListIPItemsAfterVersionResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeNode)
if err != nil {
return nil, err
}
result := []*pb.IPItem{}
items, err := models.SharedIPItemDAO.ListIPItemsAfterVersion(req.Version, req.Size)
if err != nil {
return nil, err
}
for _, item := range items {
result = append(result, &pb.IPItem{
Id: int64(item.Id),
IpFrom: item.IpFrom,
IpTo: item.IpTo,
Version: int64(item.Version),
ExpiredAt: int64(item.ExpiredAt),
Reason: "", // 这里我们不需要这个数据
ListId: int64(item.ListId),
IsDeleted: item.State == 0,
})
}
return &pb.ListIPItemsAfterVersionResponse{IpItems: result}, nil
}

View File

@@ -15,7 +15,7 @@ type RegionCountryService struct {
// 查找所有的国家列表
func (this *RegionCountryService) FindAllEnabledRegionCountries(ctx context.Context, req *pb.FindAllEnabledRegionCountriesRequest) (*pb.FindAllEnabledRegionCountriesResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeNode)
if err != nil {
return nil, err
}
@@ -39,6 +39,7 @@ func (this *RegionCountryService) FindAllEnabledRegionCountries(ctx context.Cont
result = append(result, &pb.RegionCountry{
Id: int64(country.Id),
Name: country.Name,
Codes: country.DecodeCodes(),
Pinyin: pinyinStrings,
})
}

View File

@@ -14,7 +14,7 @@ type RegionProvinceService struct {
// 查找所有省份
func (this *RegionProvinceService) FindAllEnabledRegionProvincesWithCountryId(ctx context.Context, req *pb.FindAllEnabledRegionProvincesWithCountryIdRequest) (*pb.FindAllEnabledRegionProvincesWithCountryIdResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeNode)
if err != nil {
return nil, err
}
@@ -26,8 +26,9 @@ func (this *RegionProvinceService) FindAllEnabledRegionProvincesWithCountryId(ct
result := []*pb.RegionProvince{}
for _, province := range provinces {
result = append(result, &pb.RegionProvince{
Id: int64(province.Id),
Name: province.Name,
Id: int64(province.Id),
Name: province.Name,
Codes: province.DecodeCodes(),
})
}