diff --git a/internal/db/models/http_access_log_dao.go b/internal/db/models/http_access_log_dao.go index 867b96ac..d050d34f 100644 --- a/internal/db/models/http_access_log_dao.go +++ b/internal/db/models/http_access_log_dao.go @@ -116,7 +116,7 @@ func (this *HTTPAccessLogDAO) CreateHTTPAccessLogsWithDAO(tx *dbs.Tx, daoWrapper } // 读取往前的 单页访问日志 -func (this *HTTPAccessLogDAO) ListAccessLogs(tx *dbs.Tx, lastRequestId string, size int64, day string, serverId int64, reverse bool, hasError bool, firewallPolicyId int64, firewallRuleGroupId int64, firewallRuleSetId int64) (result []*HTTPAccessLog, nextLastRequestId string, hasMore bool, err error) { +func (this *HTTPAccessLogDAO) ListAccessLogs(tx *dbs.Tx, lastRequestId string, size int64, day string, serverId int64, reverse bool, hasError bool, firewallPolicyId int64, firewallRuleGroupId int64, firewallRuleSetId int64, hasFirewallPolicy bool, userId int64) (result []*HTTPAccessLog, nextLastRequestId string, hasMore bool, err error) { if len(day) != 8 { return } @@ -126,22 +126,33 @@ func (this *HTTPAccessLogDAO) ListAccessLogs(tx *dbs.Tx, lastRequestId string, s size = 1000 } - result, nextLastRequestId, err = this.listAccessLogs(tx, lastRequestId, size, day, serverId, reverse, hasError, firewallPolicyId, firewallRuleGroupId, firewallRuleSetId) + result, nextLastRequestId, err = this.listAccessLogs(tx, lastRequestId, size, day, serverId, reverse, hasError, firewallPolicyId, firewallRuleGroupId, firewallRuleSetId, hasFirewallPolicy, userId) if err != nil || int64(len(result)) < size { return } - moreResult, _, _ := this.listAccessLogs(tx, nextLastRequestId, 1, day, serverId, reverse, hasError, firewallPolicyId, firewallRuleGroupId, firewallRuleSetId) + moreResult, _, _ := this.listAccessLogs(tx, nextLastRequestId, 1, day, serverId, reverse, hasError, firewallPolicyId, firewallRuleGroupId, firewallRuleSetId, hasFirewallPolicy, userId) hasMore = len(moreResult) > 0 return } // 读取往前的单页访问日志 -func (this *HTTPAccessLogDAO) listAccessLogs(tx *dbs.Tx, lastRequestId string, size int64, day string, serverId int64, reverse bool, hasError bool, firewallPolicyId int64, firewallRuleGroupId int64, firewallRuleSetId int64) (result []*HTTPAccessLog, nextLastRequestId string, err error) { +func (this *HTTPAccessLogDAO) listAccessLogs(tx *dbs.Tx, lastRequestId string, size int64, day string, serverId int64, reverse bool, hasError bool, firewallPolicyId int64, firewallRuleGroupId int64, firewallRuleSetId int64, hasFirewallPolicy bool, userId int64) (result []*HTTPAccessLog, nextLastRequestId string, err error) { if size <= 0 { return nil, lastRequestId, nil } + serverIds := []int64{} + if userId > 0 { + serverIds, err = SharedServerDAO.FindAllEnabledServerIdsWithUserId(tx, userId) + if err != nil { + return + } + if len(serverIds) == 0 { + return + } + } + accessLogLocker.RLock() daoList := []*HTTPAccessLogDAOWrapper{} for _, daoWrapper := range accessLogDAOMapping { @@ -182,6 +193,9 @@ func (this *HTTPAccessLogDAO) listAccessLogs(tx *dbs.Tx, lastRequestId string, s // 条件 if serverId > 0 { query.Attr("serverId", serverId) + } else if userId > 0 && len(serverIds) > 0 { + query.Attr("serverId", serverIds). + Reuse(false) } if hasError { query.Where("status>=400") @@ -195,6 +209,9 @@ func (this *HTTPAccessLogDAO) listAccessLogs(tx *dbs.Tx, lastRequestId string, s if firewallRuleSetId > 0 { query.Attr("firewallRuleSetId", firewallRuleSetId) } + if hasFirewallPolicy { + query.Where("firewallPolicyId>0") + } // offset if len(lastRequestId) > 0 { diff --git a/internal/rpc/services/service_http_access_log.go b/internal/rpc/services/service_http_access_log.go index b10710ae..d73aa846 100644 --- a/internal/rpc/services/service_http_access_log.go +++ b/internal/rpc/services/service_http_access_log.go @@ -3,7 +3,6 @@ package services import ( "context" "github.com/TeaOSLab/EdgeAPI/internal/db/models" - "github.com/TeaOSLab/EdgeAPI/internal/errors" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" ) @@ -47,17 +46,20 @@ func (this *HTTPAccessLogService) ListHTTPAccessLogs(ctx context.Context, req *p // 检查服务ID if userId > 0 { - if req.ServerId <= 0 { - return nil, errors.New("invalid serverId") + if req.UserId > 0 && userId != req.UserId { + return nil, this.PermissionError() } - err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) - if err != nil { - return nil, err + // 这里不用担心serverId <= 0 的情况,因为如果userId>0,则只会查询当前用户下的服务,不会产生安全问题 + if req.ServerId > 0 { + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) + if err != nil { + return nil, err + } } } - accessLogs, requestId, hasMore, err := models.SharedHTTPAccessLogDAO.ListAccessLogs(tx, req.RequestId, req.Size, req.Day, req.ServerId, req.Reverse, req.HasError, req.FirewallPolicyId, req.FirewallRuleGroupId, req.FirewallRuleSetId) + accessLogs, requestId, hasMore, err := models.SharedHTTPAccessLogDAO.ListAccessLogs(tx, req.RequestId, req.Size, req.Day, req.ServerId, req.Reverse, req.HasError, req.FirewallPolicyId, req.FirewallRuleGroupId, req.FirewallRuleSetId, req.HasFirewallPolicy, req.UserId) if err != nil { return nil, err }