mirror of
https://github.com/TeaOSLab/EdgeAPI.git
synced 2025-11-03 15:00:27 +08:00
实现IP黑白名单、国家|地区封禁、省份封禁
This commit is contained in:
@@ -55,6 +55,8 @@ function build() {
|
|||||||
rm -f $dist/deploy/.gitignore
|
rm -f $dist/deploy/.gitignore
|
||||||
cp -R $ROOT/installers $DIST/
|
cp -R $ROOT/installers $DIST/
|
||||||
cp -R $ROOT/resources $DIST/
|
cp -R $ROOT/resources $DIST/
|
||||||
|
rm -f $DIST/resources/ipdata/ip2region/global_region.csv
|
||||||
|
rm -f $DIST/resources/ipdata/ip2region/ip.merge.txt
|
||||||
|
|
||||||
# building installer
|
# building installer
|
||||||
echo "building installer ..."
|
echo "building installer ..."
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -110,4 +111,63 @@ func main() {
|
|||||||
|
|
||||||
logs.Println("done")
|
logs.Println("done")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查数据
|
||||||
|
if lists.ContainsString(os.Args, "check") {
|
||||||
|
dbs.NotifyReady()
|
||||||
|
|
||||||
|
data, err := ioutil.ReadFile(Tea.Root + "/resources/ipdata/ip2region/ip.merge.txt")
|
||||||
|
if err != nil {
|
||||||
|
logs.Println("[ERROR]" + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
logs.Println("[ERROR]file should not be empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lines := bytes.Split(data, []byte("\n"))
|
||||||
|
for index, line := range lines {
|
||||||
|
s := string(bytes.TrimSpace(line))
|
||||||
|
if len(s) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pieces := strings.Split(s, "|")
|
||||||
|
countryName := pieces[2]
|
||||||
|
provinceName := pieces[4]
|
||||||
|
|
||||||
|
if lists.ContainsString([]string{"0", "欧洲", "北美地区", "法国南部领地", "非洲地区", "亚太地区"}, countryName) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查国家
|
||||||
|
countryId, err := models.SharedRegionCountryDAO.FindCountryIdWithCountryName(countryName)
|
||||||
|
if err != nil {
|
||||||
|
logs.Println("[ERROR]" + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if countryId == 0 {
|
||||||
|
logs.Println("[ERROR]can not find country '"+countryName+"', index: ", index, "data: "+s)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查省份
|
||||||
|
if countryName == "中国" {
|
||||||
|
if lists.ContainsString([]string{"0"}, provinceName) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
provinceId, err := models.SharedRegionProvinceDAO.FindProvinceIdWithProvinceName(provinceName)
|
||||||
|
if err != nil {
|
||||||
|
logs.Println("[ERROR]" + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if provinceId == 0 {
|
||||||
|
logs.Println("[ERROR]can not find province '"+provinceName+"', index: ", index, "data: "+s)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logs.Println("done")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/iwind/TeaGo/Tea"
|
"github.com/iwind/TeaGo/Tea"
|
||||||
"github.com/iwind/TeaGo/dbs"
|
"github.com/iwind/TeaGo/dbs"
|
||||||
"github.com/iwind/TeaGo/types"
|
"github.com/iwind/TeaGo/types"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -66,7 +67,7 @@ func (this *IPItemDAO) FindEnabledIPItem(id int64) (*IPItem, error) {
|
|||||||
|
|
||||||
// 创建IP
|
// 创建IP
|
||||||
func (this *IPItemDAO) CreateIPItem(listId int64, ipFrom string, ipTo string, expiredAt int64, reason string) (int64, error) {
|
func (this *IPItemDAO) CreateIPItem(listId int64, ipFrom string, ipTo string, expiredAt int64, reason string) (int64, error) {
|
||||||
version, err := SharedIPListDAO.IncreaseVersion(listId)
|
version, err := SharedIPListDAO.IncreaseVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -106,7 +107,7 @@ func (this *IPItemDAO) UpdateIPItem(itemId int64, ipFrom string, ipTo string, ex
|
|||||||
return errors.New("not found")
|
return errors.New("not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
version, err := SharedIPListDAO.IncreaseVersion(listId)
|
version, err := SharedIPListDAO.IncreaseVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -145,3 +146,17 @@ func (this *IPItemDAO) ListIPItemsWithListId(listId int64, offset int64, size in
|
|||||||
FindAll()
|
FindAll()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据版本号查找IP列表
|
||||||
|
func (this *IPItemDAO) ListIPItemsAfterVersion(version int64, size int64) (result []*IPItem, err error) {
|
||||||
|
_, err = this.Query().
|
||||||
|
// 这里不要设置状态参数,因为我们要知道哪些是删除的
|
||||||
|
Gt("version", version).
|
||||||
|
Where("(expiredAt=0 OR expiredAt>:expiredAt)").
|
||||||
|
Param("expiredAt", time.Now().Unix()).
|
||||||
|
Asc("version").
|
||||||
|
Limit(size).
|
||||||
|
Slice(&result).
|
||||||
|
FindAll()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package models
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/TeaOSLab/EdgeAPI/internal/errors"
|
"github.com/TeaOSLab/EdgeAPI/internal/errors"
|
||||||
|
"github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils"
|
||||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs"
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs"
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
"github.com/iwind/TeaGo/Tea"
|
"github.com/iwind/TeaGo/Tea"
|
||||||
@@ -110,20 +111,20 @@ func (this *IPListDAO) UpdateIPList(listId int64, name string, code string, time
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 增加版本
|
// 增加版本
|
||||||
func (this *IPListDAO) IncreaseVersion(listId int64) (int64, error) {
|
func (this *IPListDAO) IncreaseVersion() (int64, error) {
|
||||||
if listId <= 0 {
|
valueJSON, err := SharedSysSettingDAO.ReadSetting(SettingCodeIPListVersion)
|
||||||
return 0, errors.New("invalid listId")
|
|
||||||
}
|
|
||||||
op := NewIPListOperator()
|
|
||||||
op.Id = listId
|
|
||||||
op.Version = dbs.SQL("version+1")
|
|
||||||
_, err := this.Save(op)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
if len(valueJSON) == 0 {
|
||||||
|
err = SharedSysSettingDAO.UpdateSetting(SettingCodeIPListVersion, []byte("1"))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
return this.Query().
|
value := types.Int64(string(valueJSON)) + 1
|
||||||
Pk(listId).
|
err = SharedSysSettingDAO.UpdateSetting(SettingCodeIPListVersion, []byte(numberutils.FormatInt64(value)))
|
||||||
Result("version").
|
return value, nil
|
||||||
FindInt64Col(0)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,13 +2,16 @@ package models
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
"github.com/iwind/TeaGo/dbs"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIPListDAO_IncreaseVersion(t *testing.T) {
|
func TestIPListDAO_IncreaseVersion(t *testing.T) {
|
||||||
|
dbs.NotifyReady()
|
||||||
|
|
||||||
dao := NewIPListDAO()
|
dao := NewIPListDAO()
|
||||||
version, err := dao.IncreaseVersion(1)
|
version, err := dao.IncreaseVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -18,8 +21,10 @@ func TestIPListDAO_IncreaseVersion(t *testing.T) {
|
|||||||
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
|
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
|
||||||
runtime.GOMAXPROCS(1)
|
runtime.GOMAXPROCS(1)
|
||||||
|
|
||||||
|
dbs.NotifyReady()
|
||||||
|
|
||||||
dao := NewIPListDAO()
|
dao := NewIPListDAO()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, _ = dao.IncreaseVersion(1)
|
_, _ = dao.IncreaseVersion()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ type IPList struct {
|
|||||||
State uint8 `field:"state"` // 状态
|
State uint8 `field:"state"` // 状态
|
||||||
CreatedAt uint64 `field:"createdAt"` // 创建时间
|
CreatedAt uint64 `field:"createdAt"` // 创建时间
|
||||||
Timeout string `field:"timeout"` // 默认超时时间
|
Timeout string `field:"timeout"` // 默认超时时间
|
||||||
Version uint64 `field:"version"` // 版本
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type IPListOperator struct {
|
type IPListOperator struct {
|
||||||
@@ -26,7 +25,6 @@ type IPListOperator struct {
|
|||||||
State interface{} // 状态
|
State interface{} // 状态
|
||||||
CreatedAt interface{} // 创建时间
|
CreatedAt interface{} // 创建时间
|
||||||
Timeout interface{} // 默认超时时间
|
Timeout interface{} // 默认超时时间
|
||||||
Version interface{} // 版本
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIPListOperator() *IPListOperator {
|
func NewIPListOperator() *IPListOperator {
|
||||||
|
|||||||
@@ -82,6 +82,15 @@ func (this *RegionCountryDAO) FindCountryIdWithDataId(dataId string) (int64, err
|
|||||||
FindInt64Col(0)
|
FindInt64Col(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据国家名查找国家ID
|
||||||
|
func (this *RegionCountryDAO) FindCountryIdWithCountryName(countryName string) (int64, error) {
|
||||||
|
return this.Query().
|
||||||
|
Where("JSON_CONTAINS(codes, :countryName)").
|
||||||
|
Param("countryName", "\""+countryName+"\""). // 查询的需要是个JSON字符串,所以这里加双引号
|
||||||
|
ResultPk().
|
||||||
|
FindInt64Col(0)
|
||||||
|
}
|
||||||
|
|
||||||
// 根据数据ID创建国家
|
// 根据数据ID创建国家
|
||||||
func (this *RegionCountryDAO) CreateCountry(name string, dataId string) (int64, error) {
|
func (this *RegionCountryDAO) CreateCountry(name string, dataId string) (int64, error) {
|
||||||
op := NewRegionCountryOperator()
|
op := NewRegionCountryOperator()
|
||||||
|
|||||||
@@ -1 +1,18 @@
|
|||||||
package models
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (this *RegionCountry) DecodeCodes() []string {
|
||||||
|
if len(this.Codes) == 0 {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
result := []string{}
|
||||||
|
err := json.Unmarshal([]byte(this.Codes), &result)
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|||||||
@@ -80,6 +80,15 @@ func (this *RegionProvinceDAO) FindProvinceIdWithDataId(dataId string) (int64, e
|
|||||||
FindInt64Col(0)
|
FindInt64Col(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据省份名查找省份ID
|
||||||
|
func (this *RegionProvinceDAO) FindProvinceIdWithProvinceName(provinceName string) (int64, error) {
|
||||||
|
return this.Query().
|
||||||
|
Where("JSON_CONTAINS(codes, :provinceName)").
|
||||||
|
Param("provinceName", "\""+provinceName+"\""). // 查询的需要是个JSON字符串,所以这里加双引号
|
||||||
|
ResultPk().
|
||||||
|
FindInt64Col(0)
|
||||||
|
}
|
||||||
|
|
||||||
// 创建省份
|
// 创建省份
|
||||||
func (this *RegionProvinceDAO) CreateProvince(countryId int64, name string, dataId string) (int64, error) {
|
func (this *RegionProvinceDAO) CreateProvince(countryId int64, name string, dataId string) (int64, error) {
|
||||||
op := NewRegionProvinceOperator()
|
op := NewRegionProvinceOperator()
|
||||||
|
|||||||
@@ -1 +1,18 @@
|
|||||||
package models
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (this *RegionProvince) DecodeCodes() []string {
|
||||||
|
if len(this.Codes) == 0 {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
result := []string{}
|
||||||
|
err := json.Unmarshal([]byte(this.Codes), &result)
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ const (
|
|||||||
SettingCodeServerGlobalConfig SettingCode = "serverGlobalConfig" // 服务相关全局设置
|
SettingCodeServerGlobalConfig SettingCode = "serverGlobalConfig" // 服务相关全局设置
|
||||||
SettingCodeNodeMonitor SettingCode = "nodeMonitor" // 监控节点状态
|
SettingCodeNodeMonitor SettingCode = "nodeMonitor" // 监控节点状态
|
||||||
SettingCodeClusterHealthCheck SettingCode = "clusterHealthCheck" // 集群健康检查
|
SettingCodeClusterHealthCheck SettingCode = "clusterHealthCheck" // 集群健康检查
|
||||||
|
SettingCodeIPListVersion SettingCode = "ipListVersion" // IP名单的版本号
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewSysSettingDAO() *SysSettingDAO {
|
func NewSysSettingDAO() *SysSettingDAO {
|
||||||
|
|||||||
@@ -122,3 +122,32 @@ func (this *IPItemService) FindEnabledIPItem(ctx context.Context, req *pb.FindEn
|
|||||||
Reason: item.Reason,
|
Reason: item.Reason,
|
||||||
}}, nil
|
}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据版本列出一组IP
|
||||||
|
func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb.ListIPItemsAfterVersionRequest) (*pb.ListIPItemsAfterVersionResponse, error) {
|
||||||
|
// 校验请求
|
||||||
|
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeNode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := []*pb.IPItem{}
|
||||||
|
items, err := models.SharedIPItemDAO.ListIPItemsAfterVersion(req.Version, req.Size)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, item := range items {
|
||||||
|
result = append(result, &pb.IPItem{
|
||||||
|
Id: int64(item.Id),
|
||||||
|
IpFrom: item.IpFrom,
|
||||||
|
IpTo: item.IpTo,
|
||||||
|
Version: int64(item.Version),
|
||||||
|
ExpiredAt: int64(item.ExpiredAt),
|
||||||
|
Reason: "", // 这里我们不需要这个数据
|
||||||
|
ListId: int64(item.ListId),
|
||||||
|
IsDeleted: item.State == 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pb.ListIPItemsAfterVersionResponse{IpItems: result}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ type RegionCountryService struct {
|
|||||||
// 查找所有的国家列表
|
// 查找所有的国家列表
|
||||||
func (this *RegionCountryService) FindAllEnabledRegionCountries(ctx context.Context, req *pb.FindAllEnabledRegionCountriesRequest) (*pb.FindAllEnabledRegionCountriesResponse, error) {
|
func (this *RegionCountryService) FindAllEnabledRegionCountries(ctx context.Context, req *pb.FindAllEnabledRegionCountriesRequest) (*pb.FindAllEnabledRegionCountriesResponse, error) {
|
||||||
// 校验请求
|
// 校验请求
|
||||||
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
|
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -39,6 +39,7 @@ func (this *RegionCountryService) FindAllEnabledRegionCountries(ctx context.Cont
|
|||||||
result = append(result, &pb.RegionCountry{
|
result = append(result, &pb.RegionCountry{
|
||||||
Id: int64(country.Id),
|
Id: int64(country.Id),
|
||||||
Name: country.Name,
|
Name: country.Name,
|
||||||
|
Codes: country.DecodeCodes(),
|
||||||
Pinyin: pinyinStrings,
|
Pinyin: pinyinStrings,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ type RegionProvinceService struct {
|
|||||||
// 查找所有省份
|
// 查找所有省份
|
||||||
func (this *RegionProvinceService) FindAllEnabledRegionProvincesWithCountryId(ctx context.Context, req *pb.FindAllEnabledRegionProvincesWithCountryIdRequest) (*pb.FindAllEnabledRegionProvincesWithCountryIdResponse, error) {
|
func (this *RegionProvinceService) FindAllEnabledRegionProvincesWithCountryId(ctx context.Context, req *pb.FindAllEnabledRegionProvincesWithCountryIdRequest) (*pb.FindAllEnabledRegionProvincesWithCountryIdResponse, error) {
|
||||||
// 校验请求
|
// 校验请求
|
||||||
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
|
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -26,8 +26,9 @@ func (this *RegionProvinceService) FindAllEnabledRegionProvincesWithCountryId(ct
|
|||||||
result := []*pb.RegionProvince{}
|
result := []*pb.RegionProvince{}
|
||||||
for _, province := range provinces {
|
for _, province := range provinces {
|
||||||
result = append(result, &pb.RegionProvince{
|
result = append(result, &pb.RegionProvince{
|
||||||
Id: int64(province.Id),
|
Id: int64(province.Id),
|
||||||
Name: province.Name,
|
Name: province.Name,
|
||||||
|
Codes: province.DecodeCodes(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user