增加根据IP名单代号查找IP名单ID的接口

This commit is contained in:
GoEdgeLab
2024-05-05 14:10:20 +08:00
parent d82af79e08
commit ec3d36de39
3 changed files with 94 additions and 2 deletions

View File

@@ -3,10 +3,12 @@ package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/rands"
)
// IPListService IP名单相关服务
@@ -24,9 +26,14 @@ func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPLis
var tx = this.NullTx()
// 修正默认的代号
if req.Code == "white" || req.Code == "black" {
req.Code = req.Code + "-" + rands.HexString(8)
}
// 检查用户相关信息
if userId > 0 {
// 检查服务ID
// 检查网站ID
if req.ServerId > 0 {
err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId)
if err != nil {
@@ -35,6 +42,21 @@ func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPLis
}
}
// 检查代号
if len(req.Code) > 0 {
if !models.SharedIPListDAO.ValidateIPListCode(req.Code) {
return nil, errors.New("invalid 'code' format")
}
oldListId, findErr := models.SharedIPListDAO.FindIPListIdWithCode(tx, req.Code)
if findErr != nil {
return nil, findErr
}
if oldListId > 0 {
return nil, errors.New("the code '" + req.Code + "' has been used")
}
}
listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.ServerId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic, req.IsGlobal)
if err != nil {
return nil, err
@@ -52,6 +74,21 @@ func (this *IPListService) UpdateIPList(ctx context.Context, req *pb.UpdateIPLis
var tx = this.NullTx()
// 检查代号
if len(req.Code) > 0 {
if !models.SharedIPListDAO.ValidateIPListCode(req.Code) {
return nil, errors.New("invalid 'code' format")
}
oldListId, findErr := models.SharedIPListDAO.FindIPListIdWithCode(tx, req.Code)
if findErr != nil {
return nil, findErr
}
if oldListId > 0 && oldListId != req.IpListId {
return nil, errors.New("the code '" + req.Code + "' has been used")
}
}
err = models.SharedIPListDAO.UpdateIPList(tx, req.IpListId, req.Name, req.Code, req.TimeoutJSON, req.Description)
if err != nil {
return nil, err
@@ -250,3 +287,34 @@ func (this *IPListService) FindServerIdWithIPListId(ctx context.Context, req *pb
ServerId: serverId,
}, nil
}
// FindIPListIdWithCode 根据IP名单代号获取IP名单ID
func (this *IPListService) FindIPListIdWithCode(ctx context.Context, req *pb.FindIPListIdWithCodeRequest) (*pb.FindIPListIdWithCodeResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if len(req.Code) == 0 {
return nil, errors.New("require 'code'")
}
var tx = this.NullTx()
listId, err := models.SharedIPListDAO.FindIPListIdWithCode(tx, req.Code)
if err != nil {
return nil, err
}
if listId > 0 {
if userId > 0 {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId)
if err != nil {
return nil, err
}
}
}
return &pb.FindIPListIdWithCodeResponse{
IpListId: listId,
}, nil
}