mirror of
https://github.com/TeaOSLab/EdgeAPI.git
synced 2026-01-01 19:36:36 +08:00
实现IP黑白名单、国家|地区封禁、省份封禁
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/dbs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -66,7 +67,7 @@ func (this *IPItemDAO) FindEnabledIPItem(id int64) (*IPItem, error) {
|
||||
|
||||
// 创建IP
|
||||
func (this *IPItemDAO) CreateIPItem(listId int64, ipFrom string, ipTo string, expiredAt int64, reason string) (int64, error) {
|
||||
version, err := SharedIPListDAO.IncreaseVersion(listId)
|
||||
version, err := SharedIPListDAO.IncreaseVersion()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -106,7 +107,7 @@ func (this *IPItemDAO) UpdateIPItem(itemId int64, ipFrom string, ipTo string, ex
|
||||
return errors.New("not found")
|
||||
}
|
||||
|
||||
version, err := SharedIPListDAO.IncreaseVersion(listId)
|
||||
version, err := SharedIPListDAO.IncreaseVersion()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -145,3 +146,17 @@ func (this *IPItemDAO) ListIPItemsWithListId(listId int64, offset int64, size in
|
||||
FindAll()
|
||||
return
|
||||
}
|
||||
|
||||
// 根据版本号查找IP列表
|
||||
func (this *IPItemDAO) ListIPItemsAfterVersion(version int64, size int64) (result []*IPItem, err error) {
|
||||
_, err = this.Query().
|
||||
// 这里不要设置状态参数,因为我们要知道哪些是删除的
|
||||
Gt("version", version).
|
||||
Where("(expiredAt=0 OR expiredAt>:expiredAt)").
|
||||
Param("expiredAt", time.Now().Unix()).
|
||||
Asc("version").
|
||||
Limit(size).
|
||||
Slice(&result).
|
||||
FindAll()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package models
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/errors"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
@@ -110,20 +111,20 @@ func (this *IPListDAO) UpdateIPList(listId int64, name string, code string, time
|
||||
}
|
||||
|
||||
// 增加版本
|
||||
func (this *IPListDAO) IncreaseVersion(listId int64) (int64, error) {
|
||||
if listId <= 0 {
|
||||
return 0, errors.New("invalid listId")
|
||||
}
|
||||
op := NewIPListOperator()
|
||||
op.Id = listId
|
||||
op.Version = dbs.SQL("version+1")
|
||||
_, err := this.Save(op)
|
||||
func (this *IPListDAO) IncreaseVersion() (int64, error) {
|
||||
valueJSON, err := SharedSysSettingDAO.ReadSetting(SettingCodeIPListVersion)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(valueJSON) == 0 {
|
||||
err = SharedSysSettingDAO.UpdateSetting(SettingCodeIPListVersion, []byte("1"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
return this.Query().
|
||||
Pk(listId).
|
||||
Result("version").
|
||||
FindInt64Col(0)
|
||||
value := types.Int64(string(valueJSON)) + 1
|
||||
err = SharedSysSettingDAO.UpdateSetting(SettingCodeIPListVersion, []byte(numberutils.FormatInt64(value)))
|
||||
return value, nil
|
||||
}
|
||||
|
||||
@@ -2,13 +2,16 @@ package models
|
||||
|
||||
import (
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/iwind/TeaGo/dbs"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIPListDAO_IncreaseVersion(t *testing.T) {
|
||||
dbs.NotifyReady()
|
||||
|
||||
dao := NewIPListDAO()
|
||||
version, err := dao.IncreaseVersion(1)
|
||||
version, err := dao.IncreaseVersion()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -18,8 +21,10 @@ func TestIPListDAO_IncreaseVersion(t *testing.T) {
|
||||
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
dbs.NotifyReady()
|
||||
|
||||
dao := NewIPListDAO()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = dao.IncreaseVersion(1)
|
||||
_, _ = dao.IncreaseVersion()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ type IPList struct {
|
||||
State uint8 `field:"state"` // 状态
|
||||
CreatedAt uint64 `field:"createdAt"` // 创建时间
|
||||
Timeout string `field:"timeout"` // 默认超时时间
|
||||
Version uint64 `field:"version"` // 版本
|
||||
}
|
||||
|
||||
type IPListOperator struct {
|
||||
@@ -26,7 +25,6 @@ type IPListOperator struct {
|
||||
State interface{} // 状态
|
||||
CreatedAt interface{} // 创建时间
|
||||
Timeout interface{} // 默认超时时间
|
||||
Version interface{} // 版本
|
||||
}
|
||||
|
||||
func NewIPListOperator() *IPListOperator {
|
||||
|
||||
@@ -82,6 +82,15 @@ func (this *RegionCountryDAO) FindCountryIdWithDataId(dataId string) (int64, err
|
||||
FindInt64Col(0)
|
||||
}
|
||||
|
||||
// 根据国家名查找国家ID
|
||||
func (this *RegionCountryDAO) FindCountryIdWithCountryName(countryName string) (int64, error) {
|
||||
return this.Query().
|
||||
Where("JSON_CONTAINS(codes, :countryName)").
|
||||
Param("countryName", "\""+countryName+"\""). // 查询的需要是个JSON字符串,所以这里加双引号
|
||||
ResultPk().
|
||||
FindInt64Col(0)
|
||||
}
|
||||
|
||||
// 根据数据ID创建国家
|
||||
func (this *RegionCountryDAO) CreateCountry(name string, dataId string) (int64, error) {
|
||||
op := NewRegionCountryOperator()
|
||||
|
||||
@@ -1 +1,18 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
)
|
||||
|
||||
func (this *RegionCountry) DecodeCodes() []string {
|
||||
if len(this.Codes) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
result := []string{}
|
||||
err := json.Unmarshal([]byte(this.Codes), &result)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -80,6 +80,15 @@ func (this *RegionProvinceDAO) FindProvinceIdWithDataId(dataId string) (int64, e
|
||||
FindInt64Col(0)
|
||||
}
|
||||
|
||||
// 根据省份名查找省份ID
|
||||
func (this *RegionProvinceDAO) FindProvinceIdWithProvinceName(provinceName string) (int64, error) {
|
||||
return this.Query().
|
||||
Where("JSON_CONTAINS(codes, :provinceName)").
|
||||
Param("provinceName", "\""+provinceName+"\""). // 查询的需要是个JSON字符串,所以这里加双引号
|
||||
ResultPk().
|
||||
FindInt64Col(0)
|
||||
}
|
||||
|
||||
// 创建省份
|
||||
func (this *RegionProvinceDAO) CreateProvince(countryId int64, name string, dataId string) (int64, error) {
|
||||
op := NewRegionProvinceOperator()
|
||||
|
||||
@@ -1 +1,18 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
)
|
||||
|
||||
func (this *RegionProvince) DecodeCodes() []string {
|
||||
if len(this.Codes) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
result := []string{}
|
||||
err := json.Unmarshal([]byte(this.Codes), &result)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ const (
|
||||
SettingCodeServerGlobalConfig SettingCode = "serverGlobalConfig" // 服务相关全局设置
|
||||
SettingCodeNodeMonitor SettingCode = "nodeMonitor" // 监控节点状态
|
||||
SettingCodeClusterHealthCheck SettingCode = "clusterHealthCheck" // 集群健康检查
|
||||
SettingCodeIPListVersion SettingCode = "ipListVersion" // IP名单的版本号
|
||||
)
|
||||
|
||||
func NewSysSettingDAO() *SysSettingDAO {
|
||||
|
||||
Reference in New Issue
Block a user