实现IP名单管理

This commit is contained in:
GoEdgeLab
2020-11-07 19:40:24 +08:00
parent b4896e03a2
commit 48a3a28f32
11 changed files with 567 additions and 0 deletions

View File

@@ -0,0 +1,147 @@
package models
import (
"github.com/TeaOSLab/EdgeAPI/internal/errors"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
)
const (
IPItemStateEnabled = 1 // 已启用
IPItemStateDisabled = 0 // 已禁用
)
type IPItemDAO dbs.DAO
func NewIPItemDAO() *IPItemDAO {
return dbs.NewDAO(&IPItemDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgeIPItems",
Model: new(IPItem),
PkName: "id",
},
}).(*IPItemDAO)
}
var SharedIPItemDAO *IPItemDAO
func init() {
dbs.OnReady(func() {
SharedIPItemDAO = NewIPItemDAO()
})
}
// 启用条目
func (this *IPItemDAO) EnableIPItem(id int64) error {
_, err := this.Query().
Pk(id).
Set("state", IPItemStateEnabled).
Update()
return err
}
// 禁用条目
func (this *IPItemDAO) DisableIPItem(id int64) error {
_, err := this.Query().
Pk(id).
Set("state", IPItemStateDisabled).
Update()
return err
}
// 查找启用中的条目
func (this *IPItemDAO) FindEnabledIPItem(id int64) (*IPItem, error) {
result, err := this.Query().
Pk(id).
Attr("state", IPItemStateEnabled).
Find()
if result == nil {
return nil, err
}
return result.(*IPItem), err
}
// 创建IP
func (this *IPItemDAO) CreateIPItem(listId int64, ipFrom string, ipTo string, expiredAt int64, reason string) (int64, error) {
version, err := SharedIPListDAO.IncreaseVersion(listId)
if err != nil {
return 0, err
}
op := NewIPItemOperator()
op.ListId = listId
op.IpFrom = ipFrom
op.IpTo = ipTo
op.Reason = reason
op.Version = version
if expiredAt < 0 {
expiredAt = 0
}
op.ExpiredAt = expiredAt
op.State = IPItemStateEnabled
_, err = this.Save(op)
if err != nil {
return 0, err
}
return types.Int64(op.Id), nil
}
// 修改IP
func (this *IPItemDAO) UpdateIPItem(itemId int64, ipFrom string, ipTo string, expiredAt int64, reason string) error {
if itemId <= 0 {
return errors.New("invalid itemId")
}
listId, err := this.Query().
Pk(itemId).
Result("listId").
FindInt64Col(0)
if err != nil {
return err
}
if listId == 0 {
return errors.New("not found")
}
version, err := SharedIPListDAO.IncreaseVersion(listId)
if err != nil {
return err
}
op := NewIPItemOperator()
op.Id = itemId
op.IpFrom = ipFrom
op.IpTo = ipTo
op.Reason = reason
if expiredAt < 0 {
expiredAt = 0
}
op.ExpiredAt = expiredAt
op.Version = version
_, err = this.Save(op)
return err
}
// 计算IP数量
func (this *IPItemDAO) CountIPItemsWithListId(listId int64) (int64, error) {
return this.Query().
State(IPItemStateEnabled).
Attr("listId", listId).
Count()
}
// 查找IP列表
func (this *IPItemDAO) ListIPItemsWithListId(listId int64, offset int64, size int64) (result []*IPItem, err error) {
_, err = this.Query().
State(IPItemStateEnabled).
Attr("listId", listId).
DescPk().
Slice(&result).
Offset(offset).
Limit(size).
FindAll()
return
}

View File

@@ -0,0 +1,5 @@
package models
import (
_ "github.com/go-sql-driver/mysql"
)

View File

@@ -0,0 +1,32 @@
package models
// IP
type IPItem struct {
Id uint64 `field:"id"` // ID
ListId uint32 `field:"listId"` // 所属名单ID
IpFrom string `field:"ipFrom"` // 开始IP
IpTo string `field:"ipTo"` // 结束IP
Version uint64 `field:"version"` // 版本
CreatedAt uint64 `field:"createdAt"` // 创建时间
UpdatedAt uint64 `field:"updatedAt"` // 修改时间
Reason string `field:"reason"` // 加入说明
State uint8 `field:"state"` // 状态
ExpiredAt uint64 `field:"expiredAt"` // 过期时间
}
type IPItemOperator struct {
Id interface{} // ID
ListId interface{} // 所属名单ID
IpFrom interface{} // 开始IP
IpTo interface{} // 结束IP
Version interface{} // 版本
CreatedAt interface{} // 创建时间
UpdatedAt interface{} // 修改时间
Reason interface{} // 加入说明
State interface{} // 状态
ExpiredAt interface{} // 过期时间
}
func NewIPItemOperator() *IPItemOperator {
return &IPItemOperator{}
}

View File

@@ -0,0 +1 @@
package models

View File

@@ -0,0 +1,129 @@
package models
import (
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
)
const (
IPListStateEnabled = 1 // 已启用
IPListStateDisabled = 0 // 已禁用
)
type IPListDAO dbs.DAO
func NewIPListDAO() *IPListDAO {
return dbs.NewDAO(&IPListDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgeIPLists",
Model: new(IPList),
PkName: "id",
},
}).(*IPListDAO)
}
var SharedIPListDAO *IPListDAO
func init() {
dbs.OnReady(func() {
SharedIPListDAO = NewIPListDAO()
})
}
// 启用条目
func (this *IPListDAO) EnableIPList(id int64) error {
_, err := this.Query().
Pk(id).
Set("state", IPListStateEnabled).
Update()
return err
}
// 禁用条目
func (this *IPListDAO) DisableIPList(id int64) error {
_, err := this.Query().
Pk(id).
Set("state", IPListStateDisabled).
Update()
return err
}
// 查找启用中的条目
func (this *IPListDAO) FindEnabledIPList(id int64) (*IPList, error) {
result, err := this.Query().
Pk(id).
Attr("state", IPListStateEnabled).
Find()
if result == nil {
return nil, err
}
return result.(*IPList), err
}
// 根据主键查找名称
func (this *IPListDAO) FindIPListName(id int64) (string, error) {
return this.Query().
Pk(id).
Result("name").
FindStringCol("")
}
// 创建名单
func (this *IPListDAO) CreateIPList(listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte) (int64, error) {
op := NewIPListOperator()
op.IsOn = true
op.State = IPListStateEnabled
op.Type = listType
op.Name = name
op.Code = code
if len(timeoutJSON) > 0 {
op.Timeout = timeoutJSON
}
_, err := this.Save(op)
if err != nil {
return 0, err
}
return types.Int64(op.Id), nil
}
// 修改名单
func (this *IPListDAO) UpdateIPList(listId int64, name string, code string, timeoutJSON []byte) error {
if listId <= 0 {
return errors.New("invalid listId")
}
op := NewIPListOperator()
op.Id = listId
op.Name = name
op.Code = code
if len(timeoutJSON) > 0 {
op.Timeout = timeoutJSON
} else {
op.Timeout = "null"
}
_, err := this.Save(op)
return err
}
// 增加版本
func (this *IPListDAO) IncreaseVersion(listId int64) (int64, error) {
if listId <= 0 {
return 0, errors.New("invalid listId")
}
op := NewIPListOperator()
op.Id = listId
op.Version = dbs.SQL("version+1")
_, err := this.Save(op)
if err != nil {
return 0, err
}
return this.Query().
Pk(listId).
Result("version").
FindInt64Col(0)
}

View File

@@ -0,0 +1,25 @@
package models
import (
_ "github.com/go-sql-driver/mysql"
"runtime"
"testing"
)
func TestIPListDAO_IncreaseVersion(t *testing.T) {
dao := NewIPListDAO()
version, err := dao.IncreaseVersion(1)
if err != nil {
t.Fatal(err)
}
t.Log("version:", version)
}
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
runtime.GOMAXPROCS(1)
dao := NewIPListDAO()
for i := 0; i < b.N; i++ {
_, _ = dao.IncreaseVersion(1)
}
}

View File

@@ -0,0 +1,34 @@
package models
// IP名单
type IPList struct {
Id uint32 `field:"id"` // ID
IsOn uint8 `field:"isOn"` // 是否启用
Type string `field:"type"` // 类型
AdminId uint32 `field:"adminId"` // 用户ID
UserId uint32 `field:"userId"` // 用户ID
Name string `field:"name"` // 列表名
Code string `field:"code"` // 代号
State uint8 `field:"state"` // 状态
CreatedAt uint64 `field:"createdAt"` // 创建时间
Timeout string `field:"timeout"` // 默认超时时间
Version uint64 `field:"version"` // 版本
}
type IPListOperator struct {
Id interface{} // ID
IsOn interface{} // 是否启用
Type interface{} // 类型
AdminId interface{} // 用户ID
UserId interface{} // 用户ID
Name interface{} // 列表名
Code interface{} // 代号
State interface{} // 状态
CreatedAt interface{} // 创建时间
Timeout interface{} // 默认超时时间
Version interface{} // 版本
}
func NewIPListOperator() *IPListOperator {
return &IPListOperator{}
}

View File

@@ -0,0 +1 @@
package models

View File

@@ -178,6 +178,8 @@ func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) err
pb.RegisterFileServiceServer(rpcServer, &services.FileService{})
pb.RegisterRegionCountryServiceServer(rpcServer, &services.RegionCountryService{})
pb.RegisterRegionProvinceServiceServer(rpcServer, &services.RegionProvinceService{})
pb.RegisterIPListServiceServer(rpcServer, &services.IPListService{})
pb.RegisterIPItemServiceServer(rpcServer, &services.IPItemService{})
err := rpcServer.Serve(listener)
if err != nil {
return errors.New("[API]start rpc failed: " + err.Error())

View File

@@ -0,0 +1,124 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// IP条目相关服务
type IPItemService struct {
}
// 创建IP
func (this *IPItemService) CreateIPItem(ctx context.Context, req *pb.CreateIPItemRequest) (*pb.CreateIPItemResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
itemId, err := models.SharedIPItemDAO.CreateIPItem(req.IpListId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason)
if err != nil {
return nil, err
}
return &pb.CreateIPItemResponse{IpItemId: itemId}, nil
}
// 修改IP
func (this *IPItemService) UpdateIPItem(ctx context.Context, req *pb.UpdateIPItemRequest) (*pb.RPCUpdateSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
err = models.SharedIPItemDAO.UpdateIPItem(req.IpItemId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason)
if err != nil {
return nil, err
}
return rpcutils.RPCUpdateSuccess()
}
// 删除IP
func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPItemRequest) (*pb.RPCDeleteSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
err = models.SharedIPItemDAO.DisableIPItem(req.IpItemId)
if err != nil {
return nil, err
}
return rpcutils.RPCDeleteSuccess()
}
// 计算IP数量
func (this *IPItemService) CountIPItemsWithListId(ctx context.Context, req *pb.CountIPItemsWithListIdRequest) (*pb.CountIPItemsWithListIdResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
count, err := models.SharedIPItemDAO.CountIPItemsWithListId(req.IpListId)
if err != nil {
return nil, err
}
return &pb.CountIPItemsWithListIdResponse{Count: count}, nil
}
// 列出单页的IP
func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.ListIPItemsWithListIdRequest) (*pb.ListIPItemsWithListIdResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
items, err := models.SharedIPItemDAO.ListIPItemsWithListId(req.IpListId, req.Offset, req.Size)
if err != nil {
return nil, err
}
result := []*pb.IPItem{}
for _, item := range items {
result = append(result, &pb.IPItem{
Id: int64(item.Id),
IpFrom: item.IpFrom,
IpTo: item.IpTo,
Version: int64(item.Version),
ExpiredAt: int64(item.ExpiredAt),
Reason: item.Reason,
})
}
return &pb.ListIPItemsWithListIdResponse{IpItems: result}, nil
}
// 查找单个IP
func (this *IPItemService) FindEnabledIPItem(ctx context.Context, req *pb.FindEnabledIPItemRequest) (*pb.FindEnabledIPItemResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
item, err := models.SharedIPItemDAO.FindEnabledIPItem(req.IpItemId)
if err != nil {
return nil, err
}
if item == nil {
return &pb.FindEnabledIPItemResponse{IpItem: nil}, nil
}
return &pb.FindEnabledIPItemResponse{IpItem: &pb.IPItem{
Id: int64(item.Id),
IpFrom: item.IpFrom,
IpTo: item.IpTo,
Version: int64(item.Version),
ExpiredAt: int64(item.ExpiredAt),
Reason: item.Reason,
}}, nil
}

View File

@@ -0,0 +1,67 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// IP名单相关服务
type IPListService struct {
}
// 创建IP列表
func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPListRequest) (*pb.CreateIPListResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
listId, err := models.SharedIPListDAO.CreateIPList(req.Type, req.Name, req.Code, req.TimeoutJSON)
if err != nil {
return nil, err
}
return &pb.CreateIPListResponse{IpListId: listId}, nil
}
// 修改IP列表
func (this *IPListService) UpdateIPList(ctx context.Context, req *pb.UpdateIPListRequest) (*pb.RPCUpdateSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
err = models.SharedIPListDAO.UpdateIPList(req.IpListId, req.Name, req.Code, req.TimeoutJSON)
if err != nil {
return nil, err
}
return rpcutils.RPCUpdateSuccess()
}
// 查找IP列表
func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEnabledIPListRequest) (*pb.FindEnabledIPListResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
list, err := models.SharedIPListDAO.FindEnabledIPList(req.IpListId)
if err != nil {
return nil, err
}
if list == nil {
return &pb.FindEnabledIPListResponse{IpList: nil}, nil
}
return &pb.FindEnabledIPListResponse{IpList: &pb.IPList{
Id: int64(list.Id),
IsOn: list.IsOn == 1,
Type: list.Type,
Name: list.Name,
Code: list.Code,
TimeoutJSON: []byte(list.Timeout),
}}, nil
}