From 48a3a28f32fe3815d15c9dacf4a3f821936e9eda Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Sat, 7 Nov 2020 19:40:24 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0IP=E5=90=8D=E5=8D=95=E7=AE=A1?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/db/models/ip_item_dao.go | 147 +++++++++++++++++++++++ internal/db/models/ip_item_dao_test.go | 5 + internal/db/models/ip_item_model.go | 32 +++++ internal/db/models/ip_item_model_ext.go | 1 + internal/db/models/ip_list_dao.go | 129 ++++++++++++++++++++ internal/db/models/ip_list_dao_test.go | 25 ++++ internal/db/models/ip_list_model.go | 34 ++++++ internal/db/models/ip_list_model_ext.go | 1 + internal/nodes/api_node.go | 2 + internal/rpc/services/service_ip_item.go | 124 +++++++++++++++++++ internal/rpc/services/service_ip_list.go | 67 +++++++++++ 11 files changed, 567 insertions(+) create mode 100644 internal/db/models/ip_item_dao.go create mode 100644 internal/db/models/ip_item_dao_test.go create mode 100644 internal/db/models/ip_item_model.go create mode 100644 internal/db/models/ip_item_model_ext.go create mode 100644 internal/db/models/ip_list_dao.go create mode 100644 internal/db/models/ip_list_dao_test.go create mode 100644 internal/db/models/ip_list_model.go create mode 100644 internal/db/models/ip_list_model_ext.go create mode 100644 internal/rpc/services/service_ip_item.go create mode 100644 internal/rpc/services/service_ip_list.go diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go new file mode 100644 index 00000000..02425485 --- /dev/null +++ b/internal/db/models/ip_item_dao.go @@ -0,0 +1,147 @@ +package models + +import ( + "github.com/TeaOSLab/EdgeAPI/internal/errors" + _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/types" +) + +const ( + IPItemStateEnabled = 1 // 已启用 + IPItemStateDisabled = 0 // 已禁用 +) + +type IPItemDAO dbs.DAO + +func NewIPItemDAO() *IPItemDAO { + return dbs.NewDAO(&IPItemDAO{ + DAOObject: dbs.DAOObject{ + DB: Tea.Env, + Table: "edgeIPItems", + Model: new(IPItem), + PkName: "id", + }, + }).(*IPItemDAO) +} + +var SharedIPItemDAO *IPItemDAO + +func init() { + dbs.OnReady(func() { + SharedIPItemDAO = NewIPItemDAO() + }) +} + +// 启用条目 +func (this *IPItemDAO) EnableIPItem(id int64) error { + _, err := this.Query(). + Pk(id). + Set("state", IPItemStateEnabled). + Update() + return err +} + +// 禁用条目 +func (this *IPItemDAO) DisableIPItem(id int64) error { + _, err := this.Query(). + Pk(id). + Set("state", IPItemStateDisabled). + Update() + return err +} + +// 查找启用中的条目 +func (this *IPItemDAO) FindEnabledIPItem(id int64) (*IPItem, error) { + result, err := this.Query(). + Pk(id). + Attr("state", IPItemStateEnabled). + Find() + if result == nil { + return nil, err + } + return result.(*IPItem), err +} + +// 创建IP +func (this *IPItemDAO) CreateIPItem(listId int64, ipFrom string, ipTo string, expiredAt int64, reason string) (int64, error) { + version, err := SharedIPListDAO.IncreaseVersion(listId) + if err != nil { + return 0, err + } + + op := NewIPItemOperator() + op.ListId = listId + op.IpFrom = ipFrom + op.IpTo = ipTo + op.Reason = reason + op.Version = version + if expiredAt < 0 { + expiredAt = 0 + } + op.ExpiredAt = expiredAt + op.State = IPItemStateEnabled + _, err = this.Save(op) + if err != nil { + return 0, err + } + return types.Int64(op.Id), nil +} + +// 修改IP +func (this *IPItemDAO) UpdateIPItem(itemId int64, ipFrom string, ipTo string, expiredAt int64, reason string) error { + if itemId <= 0 { + return errors.New("invalid itemId") + } + + listId, err := this.Query(). + Pk(itemId). + Result("listId"). + FindInt64Col(0) + if err != nil { + return err + } + if listId == 0 { + return errors.New("not found") + } + + version, err := SharedIPListDAO.IncreaseVersion(listId) + if err != nil { + return err + } + + op := NewIPItemOperator() + op.Id = itemId + op.IpFrom = ipFrom + op.IpTo = ipTo + op.Reason = reason + if expiredAt < 0 { + expiredAt = 0 + } + op.ExpiredAt = expiredAt + op.Version = version + _, err = this.Save(op) + return err +} + +// 计算IP数量 +func (this *IPItemDAO) CountIPItemsWithListId(listId int64) (int64, error) { + return this.Query(). + State(IPItemStateEnabled). + Attr("listId", listId). + Count() +} + +// 查找IP列表 +func (this *IPItemDAO) ListIPItemsWithListId(listId int64, offset int64, size int64) (result []*IPItem, err error) { + _, err = this.Query(). + State(IPItemStateEnabled). + Attr("listId", listId). + DescPk(). + Slice(&result). + Offset(offset). + Limit(size). + FindAll() + return +} diff --git a/internal/db/models/ip_item_dao_test.go b/internal/db/models/ip_item_dao_test.go new file mode 100644 index 00000000..97c24b56 --- /dev/null +++ b/internal/db/models/ip_item_dao_test.go @@ -0,0 +1,5 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" +) diff --git a/internal/db/models/ip_item_model.go b/internal/db/models/ip_item_model.go new file mode 100644 index 00000000..c8b44c4b --- /dev/null +++ b/internal/db/models/ip_item_model.go @@ -0,0 +1,32 @@ +package models + +// IP +type IPItem struct { + Id uint64 `field:"id"` // ID + ListId uint32 `field:"listId"` // 所属名单ID + IpFrom string `field:"ipFrom"` // 开始IP + IpTo string `field:"ipTo"` // 结束IP + Version uint64 `field:"version"` // 版本 + CreatedAt uint64 `field:"createdAt"` // 创建时间 + UpdatedAt uint64 `field:"updatedAt"` // 修改时间 + Reason string `field:"reason"` // 加入说明 + State uint8 `field:"state"` // 状态 + ExpiredAt uint64 `field:"expiredAt"` // 过期时间 +} + +type IPItemOperator struct { + Id interface{} // ID + ListId interface{} // 所属名单ID + IpFrom interface{} // 开始IP + IpTo interface{} // 结束IP + Version interface{} // 版本 + CreatedAt interface{} // 创建时间 + UpdatedAt interface{} // 修改时间 + Reason interface{} // 加入说明 + State interface{} // 状态 + ExpiredAt interface{} // 过期时间 +} + +func NewIPItemOperator() *IPItemOperator { + return &IPItemOperator{} +} diff --git a/internal/db/models/ip_item_model_ext.go b/internal/db/models/ip_item_model_ext.go new file mode 100644 index 00000000..2640e7f9 --- /dev/null +++ b/internal/db/models/ip_item_model_ext.go @@ -0,0 +1 @@ +package models diff --git a/internal/db/models/ip_list_dao.go b/internal/db/models/ip_list_dao.go new file mode 100644 index 00000000..b744733e --- /dev/null +++ b/internal/db/models/ip_list_dao.go @@ -0,0 +1,129 @@ +package models + +import ( + "github.com/TeaOSLab/EdgeAPI/internal/errors" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs" + _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/types" +) + +const ( + IPListStateEnabled = 1 // 已启用 + IPListStateDisabled = 0 // 已禁用 +) + +type IPListDAO dbs.DAO + +func NewIPListDAO() *IPListDAO { + return dbs.NewDAO(&IPListDAO{ + DAOObject: dbs.DAOObject{ + DB: Tea.Env, + Table: "edgeIPLists", + Model: new(IPList), + PkName: "id", + }, + }).(*IPListDAO) +} + +var SharedIPListDAO *IPListDAO + +func init() { + dbs.OnReady(func() { + SharedIPListDAO = NewIPListDAO() + }) +} + +// 启用条目 +func (this *IPListDAO) EnableIPList(id int64) error { + _, err := this.Query(). + Pk(id). + Set("state", IPListStateEnabled). + Update() + return err +} + +// 禁用条目 +func (this *IPListDAO) DisableIPList(id int64) error { + _, err := this.Query(). + Pk(id). + Set("state", IPListStateDisabled). + Update() + return err +} + +// 查找启用中的条目 +func (this *IPListDAO) FindEnabledIPList(id int64) (*IPList, error) { + result, err := this.Query(). + Pk(id). + Attr("state", IPListStateEnabled). + Find() + if result == nil { + return nil, err + } + return result.(*IPList), err +} + +// 根据主键查找名称 +func (this *IPListDAO) FindIPListName(id int64) (string, error) { + return this.Query(). + Pk(id). + Result("name"). + FindStringCol("") +} + +// 创建名单 +func (this *IPListDAO) CreateIPList(listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte) (int64, error) { + op := NewIPListOperator() + op.IsOn = true + op.State = IPListStateEnabled + op.Type = listType + op.Name = name + op.Code = code + if len(timeoutJSON) > 0 { + op.Timeout = timeoutJSON + } + _, err := this.Save(op) + if err != nil { + return 0, err + } + return types.Int64(op.Id), nil +} + +// 修改名单 +func (this *IPListDAO) UpdateIPList(listId int64, name string, code string, timeoutJSON []byte) error { + if listId <= 0 { + return errors.New("invalid listId") + } + op := NewIPListOperator() + op.Id = listId + op.Name = name + op.Code = code + if len(timeoutJSON) > 0 { + op.Timeout = timeoutJSON + } else { + op.Timeout = "null" + } + _, err := this.Save(op) + return err +} + +// 增加版本 +func (this *IPListDAO) IncreaseVersion(listId int64) (int64, error) { + if listId <= 0 { + return 0, errors.New("invalid listId") + } + op := NewIPListOperator() + op.Id = listId + op.Version = dbs.SQL("version+1") + _, err := this.Save(op) + if err != nil { + return 0, err + } + + return this.Query(). + Pk(listId). + Result("version"). + FindInt64Col(0) +} diff --git a/internal/db/models/ip_list_dao_test.go b/internal/db/models/ip_list_dao_test.go new file mode 100644 index 00000000..c29c4aab --- /dev/null +++ b/internal/db/models/ip_list_dao_test.go @@ -0,0 +1,25 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" + "runtime" + "testing" +) + +func TestIPListDAO_IncreaseVersion(t *testing.T) { + dao := NewIPListDAO() + version, err := dao.IncreaseVersion(1) + if err != nil { + t.Fatal(err) + } + t.Log("version:", version) +} + +func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) { + runtime.GOMAXPROCS(1) + + dao := NewIPListDAO() + for i := 0; i < b.N; i++ { + _, _ = dao.IncreaseVersion(1) + } +} diff --git a/internal/db/models/ip_list_model.go b/internal/db/models/ip_list_model.go new file mode 100644 index 00000000..d045c650 --- /dev/null +++ b/internal/db/models/ip_list_model.go @@ -0,0 +1,34 @@ +package models + +// IP名单 +type IPList struct { + Id uint32 `field:"id"` // ID + IsOn uint8 `field:"isOn"` // 是否启用 + Type string `field:"type"` // 类型 + AdminId uint32 `field:"adminId"` // 用户ID + UserId uint32 `field:"userId"` // 用户ID + Name string `field:"name"` // 列表名 + Code string `field:"code"` // 代号 + State uint8 `field:"state"` // 状态 + CreatedAt uint64 `field:"createdAt"` // 创建时间 + Timeout string `field:"timeout"` // 默认超时时间 + Version uint64 `field:"version"` // 版本 +} + +type IPListOperator struct { + Id interface{} // ID + IsOn interface{} // 是否启用 + Type interface{} // 类型 + AdminId interface{} // 用户ID + UserId interface{} // 用户ID + Name interface{} // 列表名 + Code interface{} // 代号 + State interface{} // 状态 + CreatedAt interface{} // 创建时间 + Timeout interface{} // 默认超时时间 + Version interface{} // 版本 +} + +func NewIPListOperator() *IPListOperator { + return &IPListOperator{} +} diff --git a/internal/db/models/ip_list_model_ext.go b/internal/db/models/ip_list_model_ext.go new file mode 100644 index 00000000..2640e7f9 --- /dev/null +++ b/internal/db/models/ip_list_model_ext.go @@ -0,0 +1 @@ +package models diff --git a/internal/nodes/api_node.go b/internal/nodes/api_node.go index 65c93fa4..e69a80ce 100644 --- a/internal/nodes/api_node.go +++ b/internal/nodes/api_node.go @@ -178,6 +178,8 @@ func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) err pb.RegisterFileServiceServer(rpcServer, &services.FileService{}) pb.RegisterRegionCountryServiceServer(rpcServer, &services.RegionCountryService{}) pb.RegisterRegionProvinceServiceServer(rpcServer, &services.RegionProvinceService{}) + pb.RegisterIPListServiceServer(rpcServer, &services.IPListService{}) + pb.RegisterIPItemServiceServer(rpcServer, &services.IPItemService{}) err := rpcServer.Serve(listener) if err != nil { return errors.New("[API]start rpc failed: " + err.Error()) diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go new file mode 100644 index 00000000..3083707f --- /dev/null +++ b/internal/rpc/services/service_ip_item.go @@ -0,0 +1,124 @@ +package services + +import ( + "context" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" +) + +// IP条目相关服务 +type IPItemService struct { +} + +// 创建IP +func (this *IPItemService) CreateIPItem(ctx context.Context, req *pb.CreateIPItemRequest) (*pb.CreateIPItemResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + itemId, err := models.SharedIPItemDAO.CreateIPItem(req.IpListId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason) + if err != nil { + return nil, err + } + return &pb.CreateIPItemResponse{IpItemId: itemId}, nil +} + +// 修改IP +func (this *IPItemService) UpdateIPItem(ctx context.Context, req *pb.UpdateIPItemRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedIPItemDAO.UpdateIPItem(req.IpItemId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason) + if err != nil { + return nil, err + } + return rpcutils.RPCUpdateSuccess() +} + +// 删除IP +func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPItemRequest) (*pb.RPCDeleteSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedIPItemDAO.DisableIPItem(req.IpItemId) + if err != nil { + return nil, err + } + return rpcutils.RPCDeleteSuccess() +} + +// 计算IP数量 +func (this *IPItemService) CountIPItemsWithListId(ctx context.Context, req *pb.CountIPItemsWithListIdRequest) (*pb.CountIPItemsWithListIdResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + count, err := models.SharedIPItemDAO.CountIPItemsWithListId(req.IpListId) + if err != nil { + return nil, err + } + return &pb.CountIPItemsWithListIdResponse{Count: count}, nil +} + +// 列出单页的IP +func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.ListIPItemsWithListIdRequest) (*pb.ListIPItemsWithListIdResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + items, err := models.SharedIPItemDAO.ListIPItemsWithListId(req.IpListId, req.Offset, req.Size) + if err != nil { + return nil, err + } + result := []*pb.IPItem{} + for _, item := range items { + result = append(result, &pb.IPItem{ + Id: int64(item.Id), + IpFrom: item.IpFrom, + IpTo: item.IpTo, + Version: int64(item.Version), + ExpiredAt: int64(item.ExpiredAt), + Reason: item.Reason, + }) + } + + return &pb.ListIPItemsWithListIdResponse{IpItems: result}, nil +} + +// 查找单个IP +func (this *IPItemService) FindEnabledIPItem(ctx context.Context, req *pb.FindEnabledIPItemRequest) (*pb.FindEnabledIPItemResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + item, err := models.SharedIPItemDAO.FindEnabledIPItem(req.IpItemId) + if err != nil { + return nil, err + } + if item == nil { + return &pb.FindEnabledIPItemResponse{IpItem: nil}, nil + } + return &pb.FindEnabledIPItemResponse{IpItem: &pb.IPItem{ + Id: int64(item.Id), + IpFrom: item.IpFrom, + IpTo: item.IpTo, + Version: int64(item.Version), + ExpiredAt: int64(item.ExpiredAt), + Reason: item.Reason, + }}, nil +} diff --git a/internal/rpc/services/service_ip_list.go b/internal/rpc/services/service_ip_list.go new file mode 100644 index 00000000..26f18d82 --- /dev/null +++ b/internal/rpc/services/service_ip_list.go @@ -0,0 +1,67 @@ +package services + +import ( + "context" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" +) + +// IP名单相关服务 +type IPListService struct { +} + +// 创建IP列表 +func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPListRequest) (*pb.CreateIPListResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + listId, err := models.SharedIPListDAO.CreateIPList(req.Type, req.Name, req.Code, req.TimeoutJSON) + if err != nil { + return nil, err + } + return &pb.CreateIPListResponse{IpListId: listId}, nil +} + +// 修改IP列表 +func (this *IPListService) UpdateIPList(ctx context.Context, req *pb.UpdateIPListRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedIPListDAO.UpdateIPList(req.IpListId, req.Name, req.Code, req.TimeoutJSON) + if err != nil { + return nil, err + } + return rpcutils.RPCUpdateSuccess() +} + +// 查找IP列表 +func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEnabledIPListRequest) (*pb.FindEnabledIPListResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + list, err := models.SharedIPListDAO.FindEnabledIPList(req.IpListId) + if err != nil { + return nil, err + } + if list == nil { + return &pb.FindEnabledIPListResponse{IpList: nil}, nil + } + return &pb.FindEnabledIPListResponse{IpList: &pb.IPList{ + Id: int64(list.Id), + IsOn: list.IsOn == 1, + Type: list.Type, + Name: list.Name, + Code: list.Code, + TimeoutJSON: []byte(list.Timeout), + }}, nil +}