diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index 0ec2b48b..4d8807bb 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -253,7 +253,8 @@ func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx, sourceServerId int64, sourceHTTPFirewallPolicyId int64, sourceHTTPFirewallRuleGroupId int64, - sourceHTTPFirewallRuleSetId int64) (int64, error) { + sourceHTTPFirewallRuleSetId int64, + shouldNotify bool) (int64, error) { version, err := SharedIPListDAO.IncreaseVersion(tx) if err != nil { return 0, err @@ -282,6 +283,15 @@ func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx, op.SourceHTTPFirewallRuleGroupId = sourceHTTPFirewallRuleGroupId op.SourceHTTPFirewallRuleSetId = sourceHTTPFirewallRuleSetId + // 服务所属用户 + if sourceServerId > 0 { + userId, err := SharedServerDAO.FindServerUserId(tx, sourceServerId) + if err != nil { + return 0, err + } + op.SourceUserId = userId + } + var autoAdded = listId == firewallconfigs.GlobalListId || sourceNodeId > 0 || sourceServerId > 0 || sourceHTTPFirewallPolicyId > 0 if autoAdded { op.IsRead = 0 @@ -301,9 +311,11 @@ func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx, return itemId, nil } - err = this.NotifyUpdate(tx, itemId) - if err != nil { - return 0, err + if shouldNotify { + err = this.NotifyUpdate(tx, itemId) + if err != nil { + return 0, err + } } return itemId, nil } diff --git a/internal/db/models/ip_item_dao_test.go b/internal/db/models/ip_item_dao_test.go index 1ce17501..bde5fbdc 100644 --- a/internal/db/models/ip_item_dao_test.go +++ b/internal/db/models/ip_item_dao_test.go @@ -51,7 +51,7 @@ func TestIPItemDAO_CreateManyIPs(t *testing.T) { var dao = models.NewIPItemDAO() var n = 10 for i := 0; i < n; i++ { - itemId, err := dao.CreateIPItem(tx, firewallconfigs.GlobalListId, "192."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255)), "", time.Now().Unix()+86400, "test", models.IPItemTypeIPv4, "warning", 0, 0, 0, 0, 0, 0, 0) + itemId, err := dao.CreateIPItem(tx, firewallconfigs.GlobalListId, "192."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255)), "", time.Now().Unix()+86400, "test", models.IPItemTypeIPv4, "warning", 0, 0, 0, 0, 0, 0, 0, false) if err != nil { t.Fatal(err) } diff --git a/internal/db/models/ip_item_model.go b/internal/db/models/ip_item_model.go index a67046d6..b050e909 100644 --- a/internal/db/models/ip_item_model.go +++ b/internal/db/models/ip_item_model.go @@ -23,32 +23,34 @@ type IPItem struct { SourceHTTPFirewallPolicyId uint32 `field:"sourceHTTPFirewallPolicyId"` // 来源策略ID SourceHTTPFirewallRuleGroupId uint32 `field:"sourceHTTPFirewallRuleGroupId"` // 来源规则集分组ID SourceHTTPFirewallRuleSetId uint32 `field:"sourceHTTPFirewallRuleSetId"` // 来源规则集ID + SourceUserId uint64 `field:"sourceUserId"` // 用户ID IsRead bool `field:"isRead"` // 是否已读 } type IPItemOperator struct { - Id interface{} // ID - ListId interface{} // 所属名单ID - Type interface{} // 类型 - IpFrom interface{} // 开始IP - IpTo interface{} // 结束IP - IpFromLong interface{} // 开始IP整型 - IpToLong interface{} // 结束IP整型 - Version interface{} // 版本 - CreatedAt interface{} // 创建时间 - UpdatedAt interface{} // 修改时间 - Reason interface{} // 加入说明 - EventLevel interface{} // 事件级别 - State interface{} // 状态 - ExpiredAt interface{} // 过期时间 - ServerId interface{} // 有效范围服务ID - NodeId interface{} // 有效范围节点ID - SourceNodeId interface{} // 来源节点ID - SourceServerId interface{} // 来源服务ID - SourceHTTPFirewallPolicyId interface{} // 来源策略ID - SourceHTTPFirewallRuleGroupId interface{} // 来源规则集分组ID - SourceHTTPFirewallRuleSetId interface{} // 来源规则集ID - IsRead interface{} // 是否已读 + Id any // ID + ListId any // 所属名单ID + Type any // 类型 + IpFrom any // 开始IP + IpTo any // 结束IP + IpFromLong any // 开始IP整型 + IpToLong any // 结束IP整型 + Version any // 版本 + CreatedAt any // 创建时间 + UpdatedAt any // 修改时间 + Reason any // 加入说明 + EventLevel any // 事件级别 + State any // 状态 + ExpiredAt any // 过期时间 + ServerId any // 有效范围服务ID + NodeId any // 有效范围节点ID + SourceNodeId any // 来源节点ID + SourceServerId any // 来源服务ID + SourceHTTPFirewallPolicyId any // 来源策略ID + SourceHTTPFirewallRuleGroupId any // 来源规则集分组ID + SourceHTTPFirewallRuleSetId any // 来源规则集ID + SourceUserId any // 用户ID + IsRead any // 是否已读 } func NewIPItemOperator() *IPItemOperator { diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go index f1c31ead..afa4ec52 100644 --- a/internal/rpc/services/service_ip_item.go +++ b/internal/rpc/services/service_ip_item.go @@ -29,7 +29,7 @@ func (this *IPItemService) CreateIPItem(ctx context.Context, req *pb.CreateIPIte return nil, errors.New("'ipFrom' should not be empty") } - ipFrom := net.ParseIP(req.IpFrom) + var ipFrom = net.ParseIP(req.IpFrom) if ipFrom == nil { return nil, errors.New("invalid 'ipFrom'") } @@ -64,7 +64,7 @@ func (this *IPItemService) CreateIPItem(ctx context.Context, req *pb.CreateIPIte return nil, err } - itemId, err := models.SharedIPItemDAO.CreateIPItem(tx, req.IpListId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason, req.Type, req.EventLevel, req.NodeId, req.ServerId, req.SourceNodeId, req.SourceServerId, req.SourceHTTPFirewallPolicyId, req.SourceHTTPFirewallRuleGroupId, req.SourceHTTPFirewallRuleSetId) + itemId, err := models.SharedIPItemDAO.CreateIPItem(tx, req.IpListId, req.IpFrom, req.IpTo, req.ExpiredAt, req.Reason, req.Type, req.EventLevel, req.NodeId, req.ServerId, req.SourceNodeId, req.SourceServerId, req.SourceHTTPFirewallPolicyId, req.SourceHTTPFirewallRuleGroupId, req.SourceHTTPFirewallRuleSetId, true) if err != nil { return nil, err } @@ -72,6 +72,77 @@ func (this *IPItemService) CreateIPItem(ctx context.Context, req *pb.CreateIPIte return &pb.CreateIPItemResponse{IpItemId: itemId}, nil } +// CreateIPItems 创建一组IP +func (this *IPItemService) CreateIPItems(ctx context.Context, req *pb.CreateIPItemsRequest) (*pb.CreateIPItemsResponse, error) { + // 校验请求 + userType, _, userId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser, rpcutils.UserTypeNode, rpcutils.UserTypeDNS) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + + // 校验 + for _, item := range req.IpItems { + if len(item.IpFrom) == 0 { + return nil, errors.New("'ipFrom' should not be empty") + } + + var ipFrom = net.ParseIP(item.IpFrom) + if ipFrom == nil { + return nil, errors.New("invalid 'ipFrom'") + } + + if len(item.IpTo) > 0 { + ipTo := net.ParseIP(item.IpTo) + if ipTo == nil { + return nil, errors.New("invalid 'ipTo'") + } + } + + if userType == rpcutils.UserTypeUser { + if userId <= 0 { + return nil, errors.New("invalid userId") + } else { + err = models.SharedIPListDAO.CheckUserIPList(tx, userId, item.IpListId) + if err != nil { + return nil, err + } + } + } + + if len(item.Type) == 0 { + item.Type = models.IPItemTypeIPv4 + } + } + + // 创建 + // TODO 需要区分不同的用户 + var ipItemIds = []int64{} + for index, item := range req.IpItems { + var shouldNotify = index == len(req.IpItems)-1 + + // 删除以前的 + err = models.SharedIPItemDAO.DeleteOldItem(tx, item.IpListId, item.IpFrom, item.IpTo) + if err != nil { + return nil, err + } + + itemId, err := models.SharedIPItemDAO.CreateIPItem(tx, item.IpListId, item.IpFrom, item.IpTo, item.ExpiredAt, item.Reason, item.Type, item.EventLevel, item.NodeId, item.ServerId, item.SourceNodeId, item.SourceServerId, item.SourceHTTPFirewallPolicyId, item.SourceHTTPFirewallRuleGroupId, item.SourceHTTPFirewallRuleSetId, shouldNotify) + if err != nil { + return nil, err + } + if err != nil { + return nil, err + } + ipItemIds = append(ipItemIds, itemId) + } + + return &pb.CreateIPItemsResponse{ + IpItemIds: ipItemIds, + }, nil +} + // UpdateIPItem 修改IP func (this *IPItemService) UpdateIPItem(ctx context.Context, req *pb.UpdateIPItemRequest) (*pb.RPCSuccess, error) { // 校验请求