优化操作IP条目时检查用户ID的相关代码

This commit is contained in:
GoEdgeLab
2023-04-03 10:02:17 +08:00
parent 61cb729305
commit 3a3854d443
3 changed files with 81 additions and 61 deletions

View File

@@ -75,13 +75,18 @@ func (this *IPItemDAO) EnableIPItem(tx *dbs.Tx, id int64) error {
} }
// DisableIPItem 禁用条目 // DisableIPItem 禁用条目
func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error { func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64, sourceUserId int64) error {
version, err := SharedIPListDAO.IncreaseVersion(tx) version, err := SharedIPListDAO.IncreaseVersion(tx)
if err != nil { if err != nil {
return err return err
} }
_, err = this.Query(tx). var query = this.Query(tx)
if sourceUserId > 0 {
query.Attr("sourceUserId", sourceUserId)
}
_, err = query.
Pk(id). Pk(id).
Set("state", IPItemStateDisabled). Set("state", IPItemStateDisabled).
Set("version", version). Set("version", version).
@@ -94,7 +99,7 @@ func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error {
} }
// DisableIPItemsWithIP 禁用某个IP相关条目 // DisableIPItemsWithIP 禁用某个IP相关条目
func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo string, userId int64, listId int64) error { func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo string, sourceUserId int64, listId int64) error {
if len(ipFrom) == 0 { if len(ipFrom) == 0 {
return errors.New("invalid 'ipFrom'") return errors.New("invalid 'ipFrom'")
} }
@@ -106,14 +111,11 @@ func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo stri
State(IPItemStateEnabled) State(IPItemStateEnabled)
if listId > 0 { if listId > 0 {
if userId > 0 { query.Attr("listId", listId)
err := SharedIPListDAO.CheckUserIPList(tx, userId, listId)
if err != nil {
return err
}
} }
query.Attr("listId", listId) if sourceUserId > 0 {
query.Attr("sourceUserId", sourceUserId)
} }
ones, err := query.FindAll() ones, err := query.FindAll()
@@ -125,14 +127,6 @@ func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo stri
for _, one := range ones { for _, one := range ones {
var item = one.(*IPItem) var item = one.(*IPItem)
var itemId = int64(item.Id) var itemId = int64(item.Id)
var itemListId = int64(item.ListId)
if itemListId != listId && userId > 0 {
err = SharedIPListDAO.CheckUserIPList(tx, userId, itemListId)
if err != nil {
// ignore error
continue
}
}
itemIds = append(itemIds, itemId) itemIds = append(itemIds, itemId)
} }
@@ -366,10 +360,13 @@ func (this *IPItemDAO) UpdateIPItem(tx *dbs.Tx, itemId int64, ipFrom string, ipT
} }
// CountIPItemsWithListId 计算IP数量 // CountIPItemsWithListId 计算IP数量
func (this *IPItemDAO) CountIPItemsWithListId(tx *dbs.Tx, listId int64, keyword string, ipFrom string, ipTo string, eventLevel string) (int64, error) { func (this *IPItemDAO) CountIPItemsWithListId(tx *dbs.Tx, listId int64, sourceUserId int64, keyword string, ipFrom string, ipTo string, eventLevel string) (int64, error) {
var query = this.Query(tx). var query = this.Query(tx).
State(IPItemStateEnabled). State(IPItemStateEnabled).
Attr("listId", listId) Attr("listId", listId)
if sourceUserId > 0 {
query.Attr("sourceUserId", sourceUserId)
}
if len(keyword) > 0 { if len(keyword) > 0 {
query.Where("(ipFrom LIKE :keyword OR ipTo LIKE :keyword)"). query.Where("(ipFrom LIKE :keyword OR ipTo LIKE :keyword)").
Param("keyword", dbutils.QuoteLike(keyword)) Param("keyword", dbutils.QuoteLike(keyword))
@@ -387,10 +384,13 @@ func (this *IPItemDAO) CountIPItemsWithListId(tx *dbs.Tx, listId int64, keyword
} }
// ListIPItemsWithListId 查找IP列表 // ListIPItemsWithListId 查找IP列表
func (this *IPItemDAO) ListIPItemsWithListId(tx *dbs.Tx, listId int64, keyword string, ipFrom string, ipTo string, eventLevel string, offset int64, size int64) (result []*IPItem, err error) { func (this *IPItemDAO) ListIPItemsWithListId(tx *dbs.Tx, listId int64, sourceUserId int64, keyword string, ipFrom string, ipTo string, eventLevel string, offset int64, size int64) (result []*IPItem, err error) {
var query = this.Query(tx). var query = this.Query(tx).
State(IPItemStateEnabled). State(IPItemStateEnabled).
Attr("listId", listId) Attr("listId", listId)
if sourceUserId > 0 {
query.Attr("sourceUserId", sourceUserId)
}
if len(keyword) > 0 { if len(keyword) > 0 {
query.Where("(ipFrom LIKE :keyword OR ipTo LIKE :keyword)"). query.Where("(ipFrom LIKE :keyword OR ipTo LIKE :keyword)").
Param("keyword", dbutils.QuoteLike(keyword)) Param("keyword", dbutils.QuoteLike(keyword))
@@ -479,8 +479,12 @@ func (this *IPItemDAO) ExistsEnabledItem(tx *dbs.Tx, itemId int64) (bool, error)
} }
// CountAllEnabledIPItems 计算数量 // CountAllEnabledIPItems 计算数量
func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string) (int64, error) { func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string) (int64, error) {
var query = this.Query(tx) var query = this.Query(tx)
if sourceUserId > 0 {
query.Attr("sourceUserId", sourceUserId)
query.UseIndex("sourceUserId")
}
if len(keyword) > 0 { if len(keyword) > 0 {
query.Like("ipFrom", dbutils.QuoteLike(keyword)) query.Like("ipFrom", dbutils.QuoteLike(keyword))
} }
@@ -512,8 +516,12 @@ func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, keyword string, ip str
} }
// ListAllEnabledIPItems 搜索所有IP // ListAllEnabledIPItems 搜索所有IP
func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string, offset int64, size int64) (result []*IPItem, err error) { func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string, offset int64, size int64) (result []*IPItem, err error) {
var query = this.Query(tx) var query = this.Query(tx)
if sourceUserId > 0 {
query.Attr("sourceUserId", sourceUserId)
query.UseIndex("sourceUserId")
}
if len(keyword) > 0 { if len(keyword) > 0 {
query.Like("ipFrom", dbutils.QuoteLike(keyword)) query.Like("ipFrom", dbutils.QuoteLike(keyword))
} }
@@ -549,11 +557,17 @@ func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, keyword string, ip stri
} }
// UpdateItemsRead 设置所有未已读 // UpdateItemsRead 设置所有未已读
func (this *IPItemDAO) UpdateItemsRead(tx *dbs.Tx) error { func (this *IPItemDAO) UpdateItemsRead(tx *dbs.Tx, sourceUserId int64) error {
return this.Query(tx). var query = this.Query(tx).
Attr("isRead", 0). Attr("isRead", 0).
Set("isRead", 1). Set("isRead", 1)
UpdateQuickly()
if sourceUserId > 0 {
query.Attr("sourceUserId", sourceUserId)
query.UseIndex("sourceUserId")
}
return query.UpdateQuickly()
} }
// CleanExpiredIPItems 清除过期数据 // CleanExpiredIPItems 清除过期数据

View File

@@ -189,19 +189,7 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte
// 如果是使用IPItemId删除 // 如果是使用IPItemId删除
if req.IpItemId > 0 { if req.IpItemId > 0 {
if userId > 0 { err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId, userId)
listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId)
if err != nil {
return nil, err
}
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId)
if err != nil {
return nil, err
}
}
err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -210,7 +198,7 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte
// 如果是使用ipFrom+ipTo删除 // 如果是使用ipFrom+ipTo删除
if len(req.IpFrom) > 0 { if len(req.IpFrom) > 0 {
// 检查IP列表 // 检查IP列表
if req.IpListId > 0 && userId > 0 { if req.IpListId > 0 && userId > 0 && req.IpListId != firewallconfigs.GlobalListId {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -228,14 +216,14 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte
// DeleteIPItems 批量删除IP // DeleteIPItems 批量删除IP
func (this *IPItemService) DeleteIPItems(ctx context.Context, req *pb.DeleteIPItemsRequest) (*pb.RPCSuccess, error) { func (this *IPItemService) DeleteIPItems(ctx context.Context, req *pb.DeleteIPItemsRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx) _, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var tx = this.NullTx() var tx = this.NullTx()
for _, itemId := range req.IpItemIds { for _, itemId := range req.IpItemIds {
err = models.SharedIPItemDAO.DisableIPItem(tx, itemId) err = models.SharedIPItemDAO.DisableIPItem(tx, itemId, userId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -254,13 +242,16 @@ func (this *IPItemService) CountIPItemsWithListId(ctx context.Context, req *pb.C
var tx = this.NullTx() var tx = this.NullTx()
if userId > 0 { if userId > 0 {
// 检查用户所属名单
if req.IpListId != firewallconfigs.GlobalListId {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
}
count, err := models.SharedIPItemDAO.CountIPItemsWithListId(tx, req.IpListId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel) count, err := models.SharedIPItemDAO.CountIPItemsWithListId(tx, req.IpListId, userId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -278,13 +269,16 @@ func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.Li
var tx = this.NullTx() var tx = this.NullTx()
if userId > 0 { if userId > 0 {
// 检查用户所属名单
if req.IpListId != firewallconfigs.GlobalListId {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
}
items, err := models.SharedIPItemDAO.ListIPItemsWithListId(tx, req.IpListId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel, req.Offset, req.Size) items, err := models.SharedIPItemDAO.ListIPItemsWithListId(tx, req.IpListId, userId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel, req.Offset, req.Size)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -564,17 +558,21 @@ func (this *IPItemService) ExistsEnabledIPItem(ctx context.Context, req *pb.Exis
// CountAllEnabledIPItems 计算所有IP数量 // CountAllEnabledIPItems 计算所有IP数量
func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.CountAllEnabledIPItemsRequest) (*pb.RPCCountResponse, error) { func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.CountAllEnabledIPItemsRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx) adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if adminId > 0 {
userId = req.UserId
}
var tx = this.NullTx() var tx = this.NullTx()
var listId int64 = 0 var listId int64 = 0
if req.GlobalOnly { if req.GlobalOnly {
listId = firewallconfigs.GlobalListId listId = firewallconfigs.GlobalListId
} }
count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType) count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -583,18 +581,22 @@ func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.C
// ListAllEnabledIPItems 搜索IP // ListAllEnabledIPItems 搜索IP
func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.ListAllEnabledIPItemsRequest) (*pb.ListAllEnabledIPItemsResponse, error) { func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.ListAllEnabledIPItemsRequest) (*pb.ListAllEnabledIPItemsResponse, error) {
_, err := this.ValidateAdmin(ctx) adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if adminId > 0 {
userId = req.UserId
}
var results = []*pb.ListAllEnabledIPItemsResponse_Result{} var results = []*pb.ListAllEnabledIPItemsResponse_Result{}
var tx = this.NullTx() var tx = this.NullTx()
var listId int64 = 0 var listId int64 = 0
if req.GlobalOnly { if req.GlobalOnly {
listId = firewallconfigs.GlobalListId listId = firewallconfigs.GlobalListId
} }
items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType, req.Offset, req.Size) items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType, req.Offset, req.Size)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -703,7 +705,7 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li
return nil, err return nil, err
} }
if list == nil { if list == nil {
err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id)) err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id), 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -728,7 +730,7 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li
return nil, err return nil, err
} }
if policy == nil { if policy == nil {
err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id)) err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id), 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -768,13 +770,13 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li
// UpdateIPItemsRead 设置所有为已读 // UpdateIPItemsRead 设置所有为已读
func (this *IPItemService) UpdateIPItemsRead(ctx context.Context, req *pb.UpdateIPItemsReadRequest) (*pb.RPCSuccess, error) { func (this *IPItemService) UpdateIPItemsRead(ctx context.Context, req *pb.UpdateIPItemsReadRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx) _, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var tx = this.NullTx() var tx = this.NullTx()
err = models.SharedIPItemDAO.UpdateItemsRead(tx) err = models.SharedIPItemDAO.UpdateItemsRead(tx, userId)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -5,6 +5,7 @@ import (
"github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/lists"
) )
@@ -68,11 +69,14 @@ func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEn
var tx = this.NullTx() var tx = this.NullTx()
if userId > 0 { if userId > 0 {
// 检查用户所属名单
if req.IpListId != firewallconfigs.GlobalListId {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
}
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId, nil) list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId, nil)
if err != nil { if err != nil {