mirror of
https://github.com/TeaOSLab/EdgeAPI.git
synced 2025-11-02 22:10:26 +08:00
实现IP黑白名单、国家|地区封禁、省份封禁
This commit is contained in:
@@ -55,6 +55,8 @@ function build() {
|
||||
rm -f $dist/deploy/.gitignore
|
||||
cp -R $ROOT/installers $DIST/
|
||||
cp -R $ROOT/resources $DIST/
|
||||
rm -f $DIST/resources/ipdata/ip2region/global_region.csv
|
||||
rm -f $DIST/resources/ipdata/ip2region/ip.merge.txt
|
||||
|
||||
# building installer
|
||||
echo "building installer ..."
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -110,4 +111,63 @@ func main() {
|
||||
|
||||
logs.Println("done")
|
||||
}
|
||||
|
||||
// 检查数据
|
||||
if lists.ContainsString(os.Args, "check") {
|
||||
dbs.NotifyReady()
|
||||
|
||||
data, err := ioutil.ReadFile(Tea.Root + "/resources/ipdata/ip2region/ip.merge.txt")
|
||||
if err != nil {
|
||||
logs.Println("[ERROR]" + err.Error())
|
||||
return
|
||||
}
|
||||
if len(data) == 0 {
|
||||
logs.Println("[ERROR]file should not be empty")
|
||||
return
|
||||
}
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
for index, line := range lines {
|
||||
s := string(bytes.TrimSpace(line))
|
||||
if len(s) == 0 {
|
||||
continue
|
||||
}
|
||||
pieces := strings.Split(s, "|")
|
||||
countryName := pieces[2]
|
||||
provinceName := pieces[4]
|
||||
|
||||
if lists.ContainsString([]string{"0", "欧洲", "北美地区", "法国南部领地", "非洲地区", "亚太地区"}, countryName) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查国家
|
||||
countryId, err := models.SharedRegionCountryDAO.FindCountryIdWithCountryName(countryName)
|
||||
if err != nil {
|
||||
logs.Println("[ERROR]" + err.Error())
|
||||
return
|
||||
}
|
||||
if countryId == 0 {
|
||||
logs.Println("[ERROR]can not find country '"+countryName+"', index: ", index, "data: "+s)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查省份
|
||||
if countryName == "中国" {
|
||||
if lists.ContainsString([]string{"0"}, provinceName) {
|
||||
continue
|
||||
}
|
||||
|
||||
provinceId, err := models.SharedRegionProvinceDAO.FindProvinceIdWithProvinceName(provinceName)
|
||||
if err != nil {
|
||||
logs.Println("[ERROR]" + err.Error())
|
||||
return
|
||||
}
|
||||
if provinceId == 0 {
|
||||
logs.Println("[ERROR]can not find province '"+provinceName+"', index: ", index, "data: "+s)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logs.Println("done")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/dbs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -66,7 +67,7 @@ func (this *IPItemDAO) FindEnabledIPItem(id int64) (*IPItem, error) {
|
||||
|
||||
// 创建IP
|
||||
func (this *IPItemDAO) CreateIPItem(listId int64, ipFrom string, ipTo string, expiredAt int64, reason string) (int64, error) {
|
||||
version, err := SharedIPListDAO.IncreaseVersion(listId)
|
||||
version, err := SharedIPListDAO.IncreaseVersion()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -106,7 +107,7 @@ func (this *IPItemDAO) UpdateIPItem(itemId int64, ipFrom string, ipTo string, ex
|
||||
return errors.New("not found")
|
||||
}
|
||||
|
||||
version, err := SharedIPListDAO.IncreaseVersion(listId)
|
||||
version, err := SharedIPListDAO.IncreaseVersion()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -145,3 +146,17 @@ func (this *IPItemDAO) ListIPItemsWithListId(listId int64, offset int64, size in
|
||||
FindAll()
|
||||
return
|
||||
}
|
||||
|
||||
// 根据版本号查找IP列表
|
||||
func (this *IPItemDAO) ListIPItemsAfterVersion(version int64, size int64) (result []*IPItem, err error) {
|
||||
_, err = this.Query().
|
||||
// 这里不要设置状态参数,因为我们要知道哪些是删除的
|
||||
Gt("version", version).
|
||||
Where("(expiredAt=0 OR expiredAt>:expiredAt)").
|
||||
Param("expiredAt", time.Now().Unix()).
|
||||
Asc("version").
|
||||
Limit(size).
|
||||
Slice(&result).
|
||||
FindAll()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package models
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/errors"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
@@ -110,20 +111,20 @@ func (this *IPListDAO) UpdateIPList(listId int64, name string, code string, time
|
||||
}
|
||||
|
||||
// 增加版本
|
||||
func (this *IPListDAO) IncreaseVersion(listId int64) (int64, error) {
|
||||
if listId <= 0 {
|
||||
return 0, errors.New("invalid listId")
|
||||
}
|
||||
op := NewIPListOperator()
|
||||
op.Id = listId
|
||||
op.Version = dbs.SQL("version+1")
|
||||
_, err := this.Save(op)
|
||||
func (this *IPListDAO) IncreaseVersion() (int64, error) {
|
||||
valueJSON, err := SharedSysSettingDAO.ReadSetting(SettingCodeIPListVersion)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(valueJSON) == 0 {
|
||||
err = SharedSysSettingDAO.UpdateSetting(SettingCodeIPListVersion, []byte("1"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
return this.Query().
|
||||
Pk(listId).
|
||||
Result("version").
|
||||
FindInt64Col(0)
|
||||
value := types.Int64(string(valueJSON)) + 1
|
||||
err = SharedSysSettingDAO.UpdateSetting(SettingCodeIPListVersion, []byte(numberutils.FormatInt64(value)))
|
||||
return value, nil
|
||||
}
|
||||
|
||||
@@ -2,13 +2,16 @@ package models
|
||||
|
||||
import (
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/iwind/TeaGo/dbs"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIPListDAO_IncreaseVersion(t *testing.T) {
|
||||
dbs.NotifyReady()
|
||||
|
||||
dao := NewIPListDAO()
|
||||
version, err := dao.IncreaseVersion(1)
|
||||
version, err := dao.IncreaseVersion()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -18,8 +21,10 @@ func TestIPListDAO_IncreaseVersion(t *testing.T) {
|
||||
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
dbs.NotifyReady()
|
||||
|
||||
dao := NewIPListDAO()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = dao.IncreaseVersion(1)
|
||||
_, _ = dao.IncreaseVersion()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ type IPList struct {
|
||||
State uint8 `field:"state"` // 状态
|
||||
CreatedAt uint64 `field:"createdAt"` // 创建时间
|
||||
Timeout string `field:"timeout"` // 默认超时时间
|
||||
Version uint64 `field:"version"` // 版本
|
||||
}
|
||||
|
||||
type IPListOperator struct {
|
||||
@@ -26,7 +25,6 @@ type IPListOperator struct {
|
||||
State interface{} // 状态
|
||||
CreatedAt interface{} // 创建时间
|
||||
Timeout interface{} // 默认超时时间
|
||||
Version interface{} // 版本
|
||||
}
|
||||
|
||||
func NewIPListOperator() *IPListOperator {
|
||||
|
||||
@@ -82,6 +82,15 @@ func (this *RegionCountryDAO) FindCountryIdWithDataId(dataId string) (int64, err
|
||||
FindInt64Col(0)
|
||||
}
|
||||
|
||||
// 根据国家名查找国家ID
|
||||
func (this *RegionCountryDAO) FindCountryIdWithCountryName(countryName string) (int64, error) {
|
||||
return this.Query().
|
||||
Where("JSON_CONTAINS(codes, :countryName)").
|
||||
Param("countryName", "\""+countryName+"\""). // 查询的需要是个JSON字符串,所以这里加双引号
|
||||
ResultPk().
|
||||
FindInt64Col(0)
|
||||
}
|
||||
|
||||
// 根据数据ID创建国家
|
||||
func (this *RegionCountryDAO) CreateCountry(name string, dataId string) (int64, error) {
|
||||
op := NewRegionCountryOperator()
|
||||
|
||||
@@ -1 +1,18 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
)
|
||||
|
||||
func (this *RegionCountry) DecodeCodes() []string {
|
||||
if len(this.Codes) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
result := []string{}
|
||||
err := json.Unmarshal([]byte(this.Codes), &result)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -80,6 +80,15 @@ func (this *RegionProvinceDAO) FindProvinceIdWithDataId(dataId string) (int64, e
|
||||
FindInt64Col(0)
|
||||
}
|
||||
|
||||
// 根据省份名查找省份ID
|
||||
func (this *RegionProvinceDAO) FindProvinceIdWithProvinceName(provinceName string) (int64, error) {
|
||||
return this.Query().
|
||||
Where("JSON_CONTAINS(codes, :provinceName)").
|
||||
Param("provinceName", "\""+provinceName+"\""). // 查询的需要是个JSON字符串,所以这里加双引号
|
||||
ResultPk().
|
||||
FindInt64Col(0)
|
||||
}
|
||||
|
||||
// 创建省份
|
||||
func (this *RegionProvinceDAO) CreateProvince(countryId int64, name string, dataId string) (int64, error) {
|
||||
op := NewRegionProvinceOperator()
|
||||
|
||||
@@ -1 +1,18 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
)
|
||||
|
||||
func (this *RegionProvince) DecodeCodes() []string {
|
||||
if len(this.Codes) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
result := []string{}
|
||||
err := json.Unmarshal([]byte(this.Codes), &result)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ const (
|
||||
SettingCodeServerGlobalConfig SettingCode = "serverGlobalConfig" // 服务相关全局设置
|
||||
SettingCodeNodeMonitor SettingCode = "nodeMonitor" // 监控节点状态
|
||||
SettingCodeClusterHealthCheck SettingCode = "clusterHealthCheck" // 集群健康检查
|
||||
SettingCodeIPListVersion SettingCode = "ipListVersion" // IP名单的版本号
|
||||
)
|
||||
|
||||
func NewSysSettingDAO() *SysSettingDAO {
|
||||
|
||||
@@ -122,3 +122,32 @@ func (this *IPItemService) FindEnabledIPItem(ctx context.Context, req *pb.FindEn
|
||||
Reason: item.Reason,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// 根据版本列出一组IP
|
||||
func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb.ListIPItemsAfterVersionRequest) (*pb.ListIPItemsAfterVersionResponse, error) {
|
||||
// 校验请求
|
||||
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := []*pb.IPItem{}
|
||||
items, err := models.SharedIPItemDAO.ListIPItemsAfterVersion(req.Version, req.Size)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, item := range items {
|
||||
result = append(result, &pb.IPItem{
|
||||
Id: int64(item.Id),
|
||||
IpFrom: item.IpFrom,
|
||||
IpTo: item.IpTo,
|
||||
Version: int64(item.Version),
|
||||
ExpiredAt: int64(item.ExpiredAt),
|
||||
Reason: "", // 这里我们不需要这个数据
|
||||
ListId: int64(item.ListId),
|
||||
IsDeleted: item.State == 0,
|
||||
})
|
||||
}
|
||||
|
||||
return &pb.ListIPItemsAfterVersionResponse{IpItems: result}, nil
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ type RegionCountryService struct {
|
||||
// 查找所有的国家列表
|
||||
func (this *RegionCountryService) FindAllEnabledRegionCountries(ctx context.Context, req *pb.FindAllEnabledRegionCountriesRequest) (*pb.FindAllEnabledRegionCountriesResponse, error) {
|
||||
// 校验请求
|
||||
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
|
||||
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -39,6 +39,7 @@ func (this *RegionCountryService) FindAllEnabledRegionCountries(ctx context.Cont
|
||||
result = append(result, &pb.RegionCountry{
|
||||
Id: int64(country.Id),
|
||||
Name: country.Name,
|
||||
Codes: country.DecodeCodes(),
|
||||
Pinyin: pinyinStrings,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ type RegionProvinceService struct {
|
||||
// 查找所有省份
|
||||
func (this *RegionProvinceService) FindAllEnabledRegionProvincesWithCountryId(ctx context.Context, req *pb.FindAllEnabledRegionProvincesWithCountryIdRequest) (*pb.FindAllEnabledRegionProvincesWithCountryIdResponse, error) {
|
||||
// 校验请求
|
||||
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
|
||||
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -26,8 +26,9 @@ func (this *RegionProvinceService) FindAllEnabledRegionProvincesWithCountryId(ct
|
||||
result := []*pb.RegionProvince{}
|
||||
for _, province := range provinces {
|
||||
result = append(result, &pb.RegionProvince{
|
||||
Id: int64(province.Id),
|
||||
Name: province.Name,
|
||||
Id: int64(province.Id),
|
||||
Name: province.Name,
|
||||
Codes: province.DecodeCodes(),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user