mirror of
				https://github.com/TeaOSLab/EdgeAPI.git
				synced 2025-11-04 07:50:25 +08:00 
			
		
		
		
	优化操作IP条目时检查用户ID的相关代码
This commit is contained in:
		@@ -75,13 +75,18 @@ func (this *IPItemDAO) EnableIPItem(tx *dbs.Tx, id int64) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DisableIPItem 禁用条目
 | 
			
		||||
func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error {
 | 
			
		||||
func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64, sourceUserId int64) error {
 | 
			
		||||
	version, err := SharedIPListDAO.IncreaseVersion(tx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = this.Query(tx).
 | 
			
		||||
	var query = this.Query(tx)
 | 
			
		||||
	if sourceUserId > 0 {
 | 
			
		||||
		query.Attr("sourceUserId", sourceUserId)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = query.
 | 
			
		||||
		Pk(id).
 | 
			
		||||
		Set("state", IPItemStateDisabled).
 | 
			
		||||
		Set("version", version).
 | 
			
		||||
@@ -94,7 +99,7 @@ func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DisableIPItemsWithIP 禁用某个IP相关条目
 | 
			
		||||
func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo string, userId int64, listId int64) error {
 | 
			
		||||
func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo string, sourceUserId int64, listId int64) error {
 | 
			
		||||
	if len(ipFrom) == 0 {
 | 
			
		||||
		return errors.New("invalid 'ipFrom'")
 | 
			
		||||
	}
 | 
			
		||||
@@ -106,16 +111,13 @@ func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo stri
 | 
			
		||||
		State(IPItemStateEnabled)
 | 
			
		||||
 | 
			
		||||
	if listId > 0 {
 | 
			
		||||
		if userId > 0 {
 | 
			
		||||
			err := SharedIPListDAO.CheckUserIPList(tx, userId, listId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		query.Attr("listId", listId)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if sourceUserId > 0 {
 | 
			
		||||
		query.Attr("sourceUserId", sourceUserId)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ones, err := query.FindAll()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -125,14 +127,6 @@ func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo stri
 | 
			
		||||
	for _, one := range ones {
 | 
			
		||||
		var item = one.(*IPItem)
 | 
			
		||||
		var itemId = int64(item.Id)
 | 
			
		||||
		var itemListId = int64(item.ListId)
 | 
			
		||||
		if itemListId != listId && userId > 0 {
 | 
			
		||||
			err = SharedIPListDAO.CheckUserIPList(tx, userId, itemListId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				// ignore error
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		itemIds = append(itemIds, itemId)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -366,10 +360,13 @@ func (this *IPItemDAO) UpdateIPItem(tx *dbs.Tx, itemId int64, ipFrom string, ipT
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CountIPItemsWithListId 计算IP数量
 | 
			
		||||
func (this *IPItemDAO) CountIPItemsWithListId(tx *dbs.Tx, listId int64, keyword string, ipFrom string, ipTo string, eventLevel string) (int64, error) {
 | 
			
		||||
func (this *IPItemDAO) CountIPItemsWithListId(tx *dbs.Tx, listId int64, sourceUserId int64, keyword string, ipFrom string, ipTo string, eventLevel string) (int64, error) {
 | 
			
		||||
	var query = this.Query(tx).
 | 
			
		||||
		State(IPItemStateEnabled).
 | 
			
		||||
		Attr("listId", listId)
 | 
			
		||||
	if sourceUserId > 0 {
 | 
			
		||||
		query.Attr("sourceUserId", sourceUserId)
 | 
			
		||||
	}
 | 
			
		||||
	if len(keyword) > 0 {
 | 
			
		||||
		query.Where("(ipFrom LIKE :keyword OR ipTo LIKE :keyword)").
 | 
			
		||||
			Param("keyword", dbutils.QuoteLike(keyword))
 | 
			
		||||
@@ -387,10 +384,13 @@ func (this *IPItemDAO) CountIPItemsWithListId(tx *dbs.Tx, listId int64, keyword
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ListIPItemsWithListId 查找IP列表
 | 
			
		||||
func (this *IPItemDAO) ListIPItemsWithListId(tx *dbs.Tx, listId int64, keyword string, ipFrom string, ipTo string, eventLevel string, offset int64, size int64) (result []*IPItem, err error) {
 | 
			
		||||
func (this *IPItemDAO) ListIPItemsWithListId(tx *dbs.Tx, listId int64, sourceUserId int64, keyword string, ipFrom string, ipTo string, eventLevel string, offset int64, size int64) (result []*IPItem, err error) {
 | 
			
		||||
	var query = this.Query(tx).
 | 
			
		||||
		State(IPItemStateEnabled).
 | 
			
		||||
		Attr("listId", listId)
 | 
			
		||||
	if sourceUserId > 0 {
 | 
			
		||||
		query.Attr("sourceUserId", sourceUserId)
 | 
			
		||||
	}
 | 
			
		||||
	if len(keyword) > 0 {
 | 
			
		||||
		query.Where("(ipFrom LIKE :keyword OR ipTo LIKE :keyword)").
 | 
			
		||||
			Param("keyword", dbutils.QuoteLike(keyword))
 | 
			
		||||
@@ -479,8 +479,12 @@ func (this *IPItemDAO) ExistsEnabledItem(tx *dbs.Tx, itemId int64) (bool, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CountAllEnabledIPItems 计算数量
 | 
			
		||||
func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string) (int64, error) {
 | 
			
		||||
func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string) (int64, error) {
 | 
			
		||||
	var query = this.Query(tx)
 | 
			
		||||
	if sourceUserId > 0 {
 | 
			
		||||
		query.Attr("sourceUserId", sourceUserId)
 | 
			
		||||
		query.UseIndex("sourceUserId")
 | 
			
		||||
	}
 | 
			
		||||
	if len(keyword) > 0 {
 | 
			
		||||
		query.Like("ipFrom", dbutils.QuoteLike(keyword))
 | 
			
		||||
	}
 | 
			
		||||
@@ -512,8 +516,12 @@ func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, keyword string, ip str
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ListAllEnabledIPItems 搜索所有IP
 | 
			
		||||
func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string, offset int64, size int64) (result []*IPItem, err error) {
 | 
			
		||||
func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string, offset int64, size int64) (result []*IPItem, err error) {
 | 
			
		||||
	var query = this.Query(tx)
 | 
			
		||||
	if sourceUserId > 0 {
 | 
			
		||||
		query.Attr("sourceUserId", sourceUserId)
 | 
			
		||||
		query.UseIndex("sourceUserId")
 | 
			
		||||
	}
 | 
			
		||||
	if len(keyword) > 0 {
 | 
			
		||||
		query.Like("ipFrom", dbutils.QuoteLike(keyword))
 | 
			
		||||
	}
 | 
			
		||||
@@ -549,11 +557,17 @@ func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, keyword string, ip stri
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateItemsRead 设置所有未已读
 | 
			
		||||
func (this *IPItemDAO) UpdateItemsRead(tx *dbs.Tx) error {
 | 
			
		||||
	return this.Query(tx).
 | 
			
		||||
func (this *IPItemDAO) UpdateItemsRead(tx *dbs.Tx, sourceUserId int64) error {
 | 
			
		||||
	var query = this.Query(tx).
 | 
			
		||||
		Attr("isRead", 0).
 | 
			
		||||
		Set("isRead", 1).
 | 
			
		||||
		UpdateQuickly()
 | 
			
		||||
		Set("isRead", 1)
 | 
			
		||||
 | 
			
		||||
	if sourceUserId > 0 {
 | 
			
		||||
		query.Attr("sourceUserId", sourceUserId)
 | 
			
		||||
		query.UseIndex("sourceUserId")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return query.UpdateQuickly()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CleanExpiredIPItems 清除过期数据
 | 
			
		||||
 
 | 
			
		||||
@@ -189,19 +189,7 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte
 | 
			
		||||
 | 
			
		||||
	// 如果是使用IPItemId删除
 | 
			
		||||
	if req.IpItemId > 0 {
 | 
			
		||||
		if userId > 0 {
 | 
			
		||||
			listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId)
 | 
			
		||||
		err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId, userId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
@@ -210,7 +198,7 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte
 | 
			
		||||
	// 如果是使用ipFrom+ipTo删除
 | 
			
		||||
	if len(req.IpFrom) > 0 {
 | 
			
		||||
		// 检查IP列表
 | 
			
		||||
		if req.IpListId > 0 && userId > 0 {
 | 
			
		||||
		if req.IpListId > 0 && userId > 0 && req.IpListId != firewallconfigs.GlobalListId {
 | 
			
		||||
			err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
@@ -228,14 +216,14 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte
 | 
			
		||||
 | 
			
		||||
// DeleteIPItems 批量删除IP
 | 
			
		||||
func (this *IPItemService) DeleteIPItems(ctx context.Context, req *pb.DeleteIPItemsRequest) (*pb.RPCSuccess, error) {
 | 
			
		||||
	_, err := this.ValidateAdmin(ctx)
 | 
			
		||||
	_, userId, err := this.ValidateAdminAndUser(ctx, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var tx = this.NullTx()
 | 
			
		||||
	for _, itemId := range req.IpItemIds {
 | 
			
		||||
		err = models.SharedIPItemDAO.DisableIPItem(tx, itemId)
 | 
			
		||||
		err = models.SharedIPItemDAO.DisableIPItem(tx, itemId, userId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
@@ -254,13 +242,16 @@ func (this *IPItemService) CountIPItemsWithListId(ctx context.Context, req *pb.C
 | 
			
		||||
	var tx = this.NullTx()
 | 
			
		||||
 | 
			
		||||
	if userId > 0 {
 | 
			
		||||
		err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		// 检查用户所属名单
 | 
			
		||||
		if req.IpListId != firewallconfigs.GlobalListId {
 | 
			
		||||
			err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	count, err := models.SharedIPItemDAO.CountIPItemsWithListId(tx, req.IpListId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel)
 | 
			
		||||
	count, err := models.SharedIPItemDAO.CountIPItemsWithListId(tx, req.IpListId, userId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -278,13 +269,16 @@ func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.Li
 | 
			
		||||
	var tx = this.NullTx()
 | 
			
		||||
 | 
			
		||||
	if userId > 0 {
 | 
			
		||||
		err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		// 检查用户所属名单
 | 
			
		||||
		if req.IpListId != firewallconfigs.GlobalListId {
 | 
			
		||||
			err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	items, err := models.SharedIPItemDAO.ListIPItemsWithListId(tx, req.IpListId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel, req.Offset, req.Size)
 | 
			
		||||
	items, err := models.SharedIPItemDAO.ListIPItemsWithListId(tx, req.IpListId, userId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel, req.Offset, req.Size)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -564,17 +558,21 @@ func (this *IPItemService) ExistsEnabledIPItem(ctx context.Context, req *pb.Exis
 | 
			
		||||
 | 
			
		||||
// CountAllEnabledIPItems 计算所有IP数量
 | 
			
		||||
func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.CountAllEnabledIPItemsRequest) (*pb.RPCCountResponse, error) {
 | 
			
		||||
	_, err := this.ValidateAdmin(ctx)
 | 
			
		||||
	adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if adminId > 0 {
 | 
			
		||||
		userId = req.UserId
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var tx = this.NullTx()
 | 
			
		||||
	var listId int64 = 0
 | 
			
		||||
	if req.GlobalOnly {
 | 
			
		||||
		listId = firewallconfigs.GlobalListId
 | 
			
		||||
	}
 | 
			
		||||
	count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType)
 | 
			
		||||
	count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -583,18 +581,22 @@ func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.C
 | 
			
		||||
 | 
			
		||||
// ListAllEnabledIPItems 搜索IP
 | 
			
		||||
func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.ListAllEnabledIPItemsRequest) (*pb.ListAllEnabledIPItemsResponse, error) {
 | 
			
		||||
	_, err := this.ValidateAdmin(ctx)
 | 
			
		||||
	adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if adminId > 0 {
 | 
			
		||||
		userId = req.UserId
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var results = []*pb.ListAllEnabledIPItemsResponse_Result{}
 | 
			
		||||
	var tx = this.NullTx()
 | 
			
		||||
	var listId int64 = 0
 | 
			
		||||
	if req.GlobalOnly {
 | 
			
		||||
		listId = firewallconfigs.GlobalListId
 | 
			
		||||
	}
 | 
			
		||||
	items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType, req.Offset, req.Size)
 | 
			
		||||
	items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType, req.Offset, req.Size)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -703,7 +705,7 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if list == nil {
 | 
			
		||||
			err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id))
 | 
			
		||||
			err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id), 0)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
@@ -728,7 +730,7 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
			if policy == nil {
 | 
			
		||||
				err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id))
 | 
			
		||||
				err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id), 0)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return nil, err
 | 
			
		||||
				}
 | 
			
		||||
@@ -768,13 +770,13 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li
 | 
			
		||||
 | 
			
		||||
// UpdateIPItemsRead 设置所有为已读
 | 
			
		||||
func (this *IPItemService) UpdateIPItemsRead(ctx context.Context, req *pb.UpdateIPItemsReadRequest) (*pb.RPCSuccess, error) {
 | 
			
		||||
	_, err := this.ValidateAdmin(ctx)
 | 
			
		||||
	_, userId, err := this.ValidateAdminAndUser(ctx, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var tx = this.NullTx()
 | 
			
		||||
	err = models.SharedIPItemDAO.UpdateItemsRead(tx)
 | 
			
		||||
	err = models.SharedIPItemDAO.UpdateItemsRead(tx, userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/db/models"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/utils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
			
		||||
	"github.com/iwind/TeaGo/lists"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -68,9 +69,12 @@ func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEn
 | 
			
		||||
 | 
			
		||||
	var tx = this.NullTx()
 | 
			
		||||
	if userId > 0 {
 | 
			
		||||
		err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		// 检查用户所属名单
 | 
			
		||||
		if req.IpListId != firewallconfigs.GlobalListId {
 | 
			
		||||
			err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user