兼容用户节点

This commit is contained in:
GoEdgeLab
2020-12-18 21:18:53 +08:00
parent 83d7528b88
commit 8dcaab40af
22 changed files with 288 additions and 90 deletions

View File

@@ -315,10 +315,14 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon
}
// 创建Web配置
func (this *HTTPWebDAO) CreateWeb(rootJSON []byte) (int64, error) {
func (this *HTTPWebDAO) CreateWeb(adminId int64, userId int64, rootJSON []byte) (int64, error) {
op := NewHTTPWebOperator()
op.State = HTTPWebStateEnabled
op.Root = JSONBytes(rootJSON)
op.AdminId = adminId
op.UserId = userId
if len(rootJSON) > 0 {
op.Root = JSONBytes(rootJSON)
}
err := this.Save(op)
if err != nil {
return 0, err

View File

@@ -91,8 +91,10 @@ func (this *OriginDAO) FindOriginName(id int64) (string, error) {
}
// 创建源站
func (this *OriginDAO) CreateOrigin(name string, addrJSON string, description string, weight int32, isOn bool) (originId int64, err error) {
func (this *OriginDAO) CreateOrigin(adminId int64, userId int64, name string, addrJSON string, description string, weight int32, isOn bool) (originId int64, err error) {
op := NewOriginOperator()
op.AdminId = adminId
op.UserId = userId
op.IsOn = isOn
op.Name = name
op.Addr = addrJSON

View File

@@ -151,10 +151,13 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(reverseProxyId int64) (*s
}
// 创建反向代理
func (this *ReverseProxyDAO) CreateReverseProxy(schedulingJSON []byte, primaryOriginsJSON []byte, backupOriginsJSON []byte) (int64, error) {
func (this *ReverseProxyDAO) CreateReverseProxy(adminId int64, userId int64, schedulingJSON []byte, primaryOriginsJSON []byte, backupOriginsJSON []byte) (int64, error) {
op := NewReverseProxyOperator()
op.IsOn = true
op.State = ReverseProxyStateEnabled
op.AdminId = adminId
op.UserId = userId
if len(schedulingJSON) > 0 {
op.Scheduling = string(schedulingJSON)
}

View File

@@ -392,7 +392,12 @@ func (this *ServerDAO) InitServerWeb(serverId int64) (int64, error) {
return 0, errors.New("serverId should not be smaller than 0")
}
webId, err := SharedHTTPWebDAO.CreateWeb(nil)
adminId, userId, err := this.FindServerAdminIdAndUserId(serverId)
if err != nil {
return 0, err
}
webId, err := SharedHTTPWebDAO.CreateWeb(adminId, userId, nil)
if err != nil {
return 0, err
}
@@ -475,14 +480,14 @@ func (this *ServerDAO) CountAllEnabledServersMatch(groupId int64, keyword string
query.Where("(name LIKE :keyword OR serverNames LIKE :keyword)").
Param("keyword", "%"+keyword+"%")
}
if userId > 0{
if userId > 0 {
query.Attr("userId", userId)
}
return query.Count()
}
// 列出单页的服务
func (this *ServerDAO) ListEnabledServersMatch(offset int64, size int64, groupId int64, keyword string) (result []*Server, err error) {
func (this *ServerDAO) ListEnabledServersMatch(offset int64, size int64, groupId int64, keyword string, userId int64) (result []*Server, err error) {
query := this.Query().
State(ServerStateEnabled).
Offset(offset).
@@ -498,6 +503,9 @@ func (this *ServerDAO) ListEnabledServersMatch(offset int64, size int64, groupId
query.Where("(name LIKE :keyword OR serverNames LIKE :keyword)").
Param("keyword", "%"+keyword+"%")
}
if userId > 0 {
query.Attr("userId", userId)
}
_, err = query.FindAll()
return
@@ -914,6 +922,21 @@ func (this *ServerDAO) FindServerDNSName(serverId int64) (string, error) {
FindStringCol("")
}
// 获取当前服务的管理员ID和用户ID
func (this *ServerDAO) FindServerAdminIdAndUserId(serverId int64) (adminId int64, userId int64, err error) {
one, err := this.Query().
Pk(serverId).
Result("adminId", "userId").
Find()
if err != nil {
return 0, 0, err
}
if one == nil {
return 0, 0, nil
}
return int64(one.(*Server).AdminId), int64(one.(*Server).UserId), nil
}
// 生成DNS Name
func (this *ServerDAO) genDNSName() (string, error) {
for {

View File

@@ -210,7 +210,7 @@ func (this *SSLCertDAO) ComposeCertConfig(certId int64) (*sslconfigs.SSLCertConf
}
// 计算符合条件的证书数量
func (this *SSLCertDAO) CountCerts(isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string) (int64, error) {
func (this *SSLCertDAO) CountCerts(isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64) (int64, error) {
query := this.Query().
State(SSLCertStateEnabled)
if isCA {
@@ -230,11 +230,17 @@ func (this *SSLCertDAO) CountCerts(isCA bool, isAvailable bool, isExpired bool,
query.Where("(name LIKE :keyword OR description LIKE :keyword OR dnsNames LIKE :keyword OR commonNames LIKE :keyword)").
Param("keyword", "%"+keyword+"%")
}
if userId > 0 {
query.Attr("userId", userId)
} else {
// 只查询管理员上传的
query.Attr("userId", 0)
}
return query.Count()
}
// 列出符合条件的证书
func (this *SSLCertDAO) ListCertIds(isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, offset int64, size int64) (certIds []int64, err error) {
func (this *SSLCertDAO) ListCertIds(isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, offset int64, size int64) (certIds []int64, err error) {
query := this.Query().
State(SSLCertStateEnabled)
if isCA {
@@ -254,6 +260,12 @@ func (this *SSLCertDAO) ListCertIds(isCA bool, isAvailable bool, isExpired bool,
query.Where("(name LIKE :keyword OR description LIKE :keyword OR dnsNames LIKE :keyword OR commonNames LIKE :keyword)").
Param("keyword", "%"+keyword+"%")
}
if userId > 0 {
query.Attr("userId", userId)
} else {
// 只查询管理员上传的
query.Attr("userId", 0)
}
ones, err := query.
ResultPk().
@@ -313,3 +325,22 @@ func (this *SSLCertDAO) UpdateCertNotifiedAt(certId int64) error {
Update()
return err
}
// 检查用户权限
func (this *SSLCertDAO) CheckUserCert(certId int64, userId int64) error {
if certId <= 0 || userId <= 0 {
return errors.New("not found")
}
ok, err := this.Query().
Pk(certId).
Attr("userId", userId).
State(SSLCertStateEnabled).
Exist()
if err != nil {
return err
}
if !ok {
return errors.New("not found")
}
return nil
}

View File

@@ -187,10 +187,13 @@ func (this *SSLPolicyDAO) FindAllEnabledPolicyIdsWithCertId(certId int64) (polic
}
// 创建Policy
func (this *SSLPolicyDAO) CreatePolicy(http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) (int64, error) {
func (this *SSLPolicyDAO) CreatePolicy(adminId int64, userId int64, http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) (int64, error) {
op := NewSSLPolicyOperator()
op.State = SSLPolicyStateEnabled
op.IsOn = true
op.AdminId = adminId
op.UserId = userId
op.Http2Enabled = http2Enabled
op.MinVersion = minVersion
@@ -258,3 +261,22 @@ func (this *SSLPolicyDAO) UpdatePolicy(policyId int64, http2Enabled bool, minVer
err := this.Save(op)
return err
}
// 检查是否为用户所属策略
func (this *SSLPolicyDAO) CheckUserPolicy(policyId int64, userId int64) error {
if policyId <= 0 || userId <= 0 {
return errors.New("not found")
}
ok, err := this.Query().
State(SSLPolicyStateEnabled).
Pk(policyId).
Attr("userId", userId).
Exist()
if err != nil {
return err
}
if !ok {
return errors.New("not found")
}
return nil
}

View File

@@ -221,3 +221,11 @@ func (this *UserDAO) CheckUserPassword(username string, encryptedPassword string
ResultPk().
FindInt64Col(0)
}
// 查找用户所在集群
func (this *UserDAO) FindUserClusterId(userId int64) (int64, error) {
return this.Query().
Pk(userId).
Result("clusterId").
FindInt64Col(0)
}

View File

@@ -15,7 +15,7 @@ type ACMETaskService struct {
// 计算某个ACME用户相关的任务数量
func (this *ACMETaskService) CountAllEnabledACMETasksWithACMEUserId(ctx context.Context, req *pb.CountAllEnabledACMETasksWithACMEUserIdRequest) (*pb.RPCCountResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, 0)
_, _, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -31,7 +31,7 @@ func (this *ACMETaskService) CountAllEnabledACMETasksWithACMEUserId(ctx context.
// 计算跟某个DNS服务商相关的任务数量
func (this *ACMETaskService) CountEnabledACMETasksWithDNSProviderId(ctx context.Context, req *pb.CountEnabledACMETasksWithDNSProviderIdRequest) (*pb.RPCCountResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, 0)
_, _, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -47,7 +47,7 @@ func (this *ACMETaskService) CountEnabledACMETasksWithDNSProviderId(ctx context.
// 计算所有任务数量
func (this *ACMETaskService) CountAllEnabledACMETasks(ctx context.Context, req *pb.CountAllEnabledACMETasksRequest) (*pb.RPCCountResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, req.UserId)
_, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
@@ -61,7 +61,7 @@ func (this *ACMETaskService) CountAllEnabledACMETasks(ctx context.Context, req *
// 列出单页任务
func (this *ACMETaskService) ListEnabledACMETasks(ctx context.Context, req *pb.ListEnabledACMETasksRequest) (*pb.ListEnabledACMETasksResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, req.UserId)
_, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
@@ -160,7 +160,7 @@ func (this *ACMETaskService) ListEnabledACMETasks(ctx context.Context, req *pb.L
// 创建任务
func (this *ACMETaskService) CreateACMETask(ctx context.Context, req *pb.CreateACMETaskRequest) (*pb.CreateACMETaskResponse, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -178,7 +178,7 @@ func (this *ACMETaskService) CreateACMETask(ctx context.Context, req *pb.CreateA
// 修改任务
func (this *ACMETaskService) UpdateACMETask(ctx context.Context, req *pb.UpdateACMETaskRequest) (*pb.RPCSuccess, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -200,7 +200,7 @@ func (this *ACMETaskService) UpdateACMETask(ctx context.Context, req *pb.UpdateA
// 删除任务
func (this *ACMETaskService) DeleteACMETask(ctx context.Context, req *pb.DeleteACMETaskRequest) (*pb.RPCSuccess, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -222,7 +222,7 @@ func (this *ACMETaskService) DeleteACMETask(ctx context.Context, req *pb.DeleteA
// 运行某个任务
func (this *ACMETaskService) RunACMETask(ctx context.Context, req *pb.RunACMETaskRequest) (*pb.RunACMETaskResponse, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -246,7 +246,7 @@ func (this *ACMETaskService) RunACMETask(ctx context.Context, req *pb.RunACMETas
// 查找单个任务信息
func (this *ACMETaskService) FindEnabledACMETask(ctx context.Context, req *pb.FindEnabledACMETaskRequest) (*pb.FindEnabledACMETaskResponse, error) {
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}

View File

@@ -14,7 +14,7 @@ type ACMEUserService struct {
// 创建用户
func (this *ACMEUserService) CreateACMEUser(ctx context.Context, req *pb.CreateACMEUserRequest) (*pb.CreateACMEUserResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -29,7 +29,7 @@ func (this *ACMEUserService) CreateACMEUser(ctx context.Context, req *pb.CreateA
// 修改用户
func (this *ACMEUserService) UpdateACMEUser(ctx context.Context, req *pb.UpdateACMEUserRequest) (*pb.RPCSuccess, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -53,7 +53,7 @@ func (this *ACMEUserService) UpdateACMEUser(ctx context.Context, req *pb.UpdateA
// 删除用户
func (this *ACMEUserService) DeleteACMEUser(ctx context.Context, req *pb.DeleteACMEUserRequest) (*pb.RPCSuccess, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -77,7 +77,7 @@ func (this *ACMEUserService) DeleteACMEUser(ctx context.Context, req *pb.DeleteA
// 计算用户数量
func (this *ACMEUserService) CountACMEUsers(ctx context.Context, req *pb.CountAcmeUsersRequest) (*pb.RPCCountResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, req.UserId)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
@@ -92,7 +92,7 @@ func (this *ACMEUserService) CountACMEUsers(ctx context.Context, req *pb.CountAc
// 列出单页用户
func (this *ACMEUserService) ListACMEUsers(ctx context.Context, req *pb.ListACMEUsersRequest) (*pb.ListACMEUsersResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, req.UserId)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
@@ -116,7 +116,7 @@ func (this *ACMEUserService) ListACMEUsers(ctx context.Context, req *pb.ListACME
// 查找单个用户
func (this *ACMEUserService) FindEnabledACMEUser(ctx context.Context, req *pb.FindEnabledACMEUserRequest) (*pb.FindEnabledACMEUserResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -148,7 +148,7 @@ func (this *ACMEUserService) FindEnabledACMEUser(ctx context.Context, req *pb.Fi
// 查找所有用户
func (this *ACMEUserService) FindAllACMEUsers(ctx context.Context, req *pb.FindAllACMEUsersRequest) (*pb.FindAllACMEUsersResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, req.UserId)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}

View File

@@ -23,7 +23,7 @@ func (this *BaseService) ValidateAdmin(ctx context.Context, reqAdminId int64) (a
}
// 校验管理员和用户
func (this *BaseService) ValidateAdminAndUser(ctx context.Context, reqUserId int64) (adminId int64, userId int64, err error) {
func (this *BaseService) ValidateAdminAndUser(ctx context.Context, requireAdminId int64, requireUserId int64) (adminId int64, userId int64, err error) {
reqUserType, reqUserId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser)
if err != nil {
return
@@ -38,15 +38,17 @@ func (this *BaseService) ValidateAdminAndUser(ctx context.Context, reqUserId int
err = errors.New("invalid 'adminId'")
return
}
if requireAdminId > 0 && adminId != requireAdminId {
err = this.PermissionError()
return
}
case rpcutils.UserTypeUser:
userId = reqUserId
if userId <= 0 {
err = errors.New("invalid 'userId'")
return
}
// 校验权限
if reqUserId > 0 && reqUserId != userId {
if requireUserId > 0 && userId != requireUserId {
err = this.PermissionError()
return
}

View File

@@ -21,7 +21,7 @@ type DNSDomainService struct {
// 创建域名
func (this *DNSDomainService) CreateDNSDomain(ctx context.Context, req *pb.CreateDNSDomainRequest) (*pb.CreateDNSDomainResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}

View File

@@ -16,7 +16,7 @@ type DNSProviderService struct {
// 创建服务商
func (this *DNSProviderService) CreateDNSProvider(ctx context.Context, req *pb.CreateDNSProviderRequest) (*pb.CreateDNSProviderResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -32,7 +32,7 @@ func (this *DNSProviderService) CreateDNSProvider(ctx context.Context, req *pb.C
// 修改服务商
func (this *DNSProviderService) UpdateDNSProvider(ctx context.Context, req *pb.UpdateDNSProviderRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, 0)
_, _, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -49,7 +49,7 @@ func (this *DNSProviderService) UpdateDNSProvider(ctx context.Context, req *pb.U
// 计算服务商数量
func (this *DNSProviderService) CountAllEnabledDNSProviders(ctx context.Context, req *pb.CountAllEnabledDNSProvidersRequest) (*pb.RPCCountResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, req.UserId)
_, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
@@ -64,7 +64,7 @@ func (this *DNSProviderService) CountAllEnabledDNSProviders(ctx context.Context,
// 列出单页服务商信息
func (this *DNSProviderService) ListEnabledDNSProviders(ctx context.Context, req *pb.ListEnabledDNSProvidersRequest) (*pb.ListEnabledDNSProvidersResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, req.UserId)
_, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
@@ -92,7 +92,7 @@ func (this *DNSProviderService) ListEnabledDNSProviders(ctx context.Context, req
// 查找所有的DNS服务商
func (this *DNSProviderService) FindAllEnabledDNSProviders(ctx context.Context, req *pb.FindAllEnabledDNSProvidersRequest) (*pb.FindAllEnabledDNSProvidersResponse, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, req.UserId)
_, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
@@ -120,7 +120,7 @@ func (this *DNSProviderService) FindAllEnabledDNSProviders(ctx context.Context,
// 删除服务商
func (this *DNSProviderService) DeleteDNSProvider(ctx context.Context, req *pb.DeleteDNSProviderRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := this.ValidateAdminAndUser(ctx, 0)
_, _, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}

View File

@@ -82,7 +82,7 @@ func (this *HTTPLocationService) DeleteHTTPLocation(ctx context.Context, req *pb
// 查找反向代理设置
func (this *HTTPLocationService) FindAndInitHTTPLocationReverseProxyConfig(ctx context.Context, req *pb.FindAndInitHTTPLocationReverseProxyConfigRequest) (*pb.FindAndInitHTTPLocationReverseProxyConfigResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -92,7 +92,7 @@ func (this *HTTPLocationService) FindAndInitHTTPLocationReverseProxyConfig(ctx c
return nil, err
}
if reverseProxyRef == nil || reverseProxyRef.ReverseProxyId <= 0 {
reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(nil, nil, nil)
reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(adminId, userId, nil, nil, nil)
if err != nil {
return nil, err
}
@@ -133,7 +133,7 @@ func (this *HTTPLocationService) FindAndInitHTTPLocationReverseProxyConfig(ctx c
// 初始化Web设置
func (this *HTTPLocationService) FindAndInitHTTPLocationWebConfig(ctx context.Context, req *pb.FindAndInitHTTPLocationWebConfigRequest) (*pb.FindAndInitHTTPLocationWebConfigResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, rpcutils.Wrap("ValidateRequest()", err)
}
@@ -144,7 +144,7 @@ func (this *HTTPLocationService) FindAndInitHTTPLocationWebConfig(ctx context.Co
}
if webId <= 0 {
webId, err = models.SharedHTTPWebDAO.CreateWeb(nil)
webId, err = models.SharedHTTPWebDAO.CreateWeb(adminId, userId, nil)
if err != nil {
return nil, rpcutils.Wrap("CreateWeb()", err)
}

View File

@@ -15,12 +15,12 @@ type HTTPWebService struct {
// 创建Web配置
func (this *HTTPWebService) CreateHTTPWeb(ctx context.Context, req *pb.CreateHTTPWebRequest) (*pb.CreateHTTPWebResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
webId, err := models.SharedHTTPWebDAO.CreateWeb(req.RootJSON)
webId, err := models.SharedHTTPWebDAO.CreateWeb(adminId, userId, req.RootJSON)
if err != nil {
return nil, err
}
@@ -213,11 +213,15 @@ func (this *HTTPWebService) UpdateHTTPWebStat(ctx context.Context, req *pb.Updat
// 更改缓存配置
func (this *HTTPWebService) UpdateHTTPWebCache(ctx context.Context, req *pb.UpdateHTTPWebCacheRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
if userId > 0 {
// TODO 检查权限
}
err = models.SharedHTTPWebDAO.UpdateWebCache(req.WebId, req.CacheJSON)
if err != nil {
return nil, err

View File

@@ -14,7 +14,7 @@ type MessageService struct {
// 计算未读消息数
func (this *MessageService) CountUnreadMessages(ctx context.Context, req *pb.CountUnreadMessagesRequest) (*pb.RPCCountResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -29,7 +29,7 @@ func (this *MessageService) CountUnreadMessages(ctx context.Context, req *pb.Cou
// 列出单页未读消息
func (this *MessageService) ListUnreadMessages(ctx context.Context, req *pb.ListUnreadMessagesRequest) (*pb.ListUnreadMessagesResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -88,7 +88,7 @@ func (this *MessageService) ListUnreadMessages(ctx context.Context, req *pb.List
// 设置消息已读状态
func (this *MessageService) UpdateMessageRead(ctx context.Context, req *pb.UpdateMessageReadRequest) (*pb.RPCSuccess, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -112,7 +112,7 @@ func (this *MessageService) UpdateMessageRead(ctx context.Context, req *pb.Updat
// 设置一组消息已读状态
func (this *MessageService) UpdateMessagesRead(ctx context.Context, req *pb.UpdateMessagesReadRequest) (*pb.RPCSuccess, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -139,7 +139,7 @@ func (this *MessageService) UpdateMessagesRead(ctx context.Context, req *pb.Upda
func (this *MessageService) UpdateAllMessagesRead(ctx context.Context, req *pb.UpdateAllMessagesReadRequest) (*pb.RPCSuccess, error) {
// 校验请求
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}

View File

@@ -17,7 +17,7 @@ type OriginService struct {
// 创建源站
func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOriginRequest) (*pb.CreateOriginResponse, error) {
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -30,7 +30,7 @@ func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOrigi
"portRange": req.Addr.PortRange,
"host": req.Addr.Host,
}
originId, err := models.SharedOriginDAO.CreateOrigin(req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn)
originId, err := models.SharedOriginDAO.CreateOrigin(adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn)
if err != nil {
return nil, err
}

View File

@@ -16,12 +16,16 @@ type ReverseProxyService struct {
// 创建反向代理
func (this *ReverseProxyService) CreateReverseProxy(ctx context.Context, req *pb.CreateReverseProxyRequest) (*pb.CreateReverseProxyResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(req.SchedulingJSON, req.PrimaryOriginsJSON, req.BackupOriginsJSON)
if userId > 0 {
// TODO 校验源站
}
reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(adminId, userId, req.SchedulingJSON, req.PrimaryOriginsJSON, req.BackupOriginsJSON)
if err != nil {
return nil, err
}
@@ -126,11 +130,15 @@ func (this *ReverseProxyService) UpdateReverseProxyBackupOrigins(ctx context.Con
// 修改是否启用
func (this *ReverseProxyService) UpdateReverseProxy(ctx context.Context, req *pb.UpdateReverseProxyRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
if userId > 0 {
// TODO 检查权限
}
err = models.SharedReverseProxyDAO.UpdateReverseProxy(req.ReverseProxyId, types.Int8(req.RequestHostType), req.RequestHost, req.RequestURI, req.StripPrefix, req.AutoFlush)
if err != nil {
return nil, err

View File

@@ -18,10 +18,29 @@ type ServerService struct {
// 创建服务
func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServerRequest) (*pb.CreateServerResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
// 校验用户相关数据
if userId > 0 {
// HTTPS
if len(req.HttpsJSON) > 0 {
httpsConfig := &serverconfigs.HTTPSProtocolConfig{}
err = json.Unmarshal(req.HttpsJSON, httpsConfig)
if err != nil {
return nil, err
}
if httpsConfig.SSLPolicyRef != nil && httpsConfig.SSLPolicyRef.SSLPolicyId > 0 {
err := models.SharedSSLPolicyDAO.CheckUserPolicy(httpsConfig.SSLPolicyRef.SSLPolicyId, userId)
if err != nil {
return nil, err
}
}
}
}
serverId, err := models.SharedServerDAO.CreateServer(req.AdminId, req.UserId, req.Type, req.Name, req.Description, string(req.ServerNamesJON), string(req.HttpJSON), string(req.HttpsJSON), string(req.TcpJSON), string(req.TlsJSON), string(req.UnixJSON), string(req.UdpJSON), req.WebId, req.ReverseProxyJSON, req.NodeClusterId, string(req.IncludeNodesJSON), string(req.ExcludeNodesJSON), req.GroupIds)
if err != nil {
return nil, err
@@ -273,11 +292,15 @@ func (this *ServerService) UpdateServerUDP(ctx context.Context, req *pb.UpdateSe
// 修改Web服务
func (this *ServerService) UpdateServerWeb(ctx context.Context, req *pb.UpdateServerWebRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
if userId > 0 {
// TODO 检查权限
}
if req.ServerId <= 0 {
return nil, errors.New("invalid serverId")
}
@@ -377,11 +400,11 @@ func (this *ServerService) UpdateServerNames(ctx context.Context, req *pb.Update
// 计算服务数量
func (this *ServerService) CountAllEnabledServersMatch(ctx context.Context, req *pb.CountAllEnabledServersMatchRequest) (*pb.RPCCountResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
count, err := models.SharedServerDAO.CountAllEnabledServersMatch(req.GroupId, req.Keyword, 0)
count, err := models.SharedServerDAO.CountAllEnabledServersMatch(req.GroupId, req.Keyword, req.UserId)
if err != nil {
return nil, err
}
@@ -392,11 +415,11 @@ func (this *ServerService) CountAllEnabledServersMatch(ctx context.Context, req
// 列出单页服务
func (this *ServerService) ListEnabledServersMatch(ctx context.Context, req *pb.ListEnabledServersMatchRequest) (*pb.ListEnabledServersMatchResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
servers, err := models.SharedServerDAO.ListEnabledServersMatch(req.Offset, req.Size, req.GroupId, req.Keyword)
servers, err := models.SharedServerDAO.ListEnabledServersMatch(req.Offset, req.Size, req.GroupId, req.Keyword, req.UserId)
if err != nil {
return nil, err
}
@@ -599,7 +622,7 @@ func (this *ServerService) FindEnabledServerType(ctx context.Context, req *pb.Fi
// 查找反向代理设置
func (this *ServerService) FindAndInitServerReverseProxyConfig(ctx context.Context, req *pb.FindAndInitServerReverseProxyConfigRequest) (*pb.FindAndInitServerReverseProxyConfigResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -610,7 +633,7 @@ func (this *ServerService) FindAndInitServerReverseProxyConfig(ctx context.Conte
}
if reverseProxyRef == nil {
reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(nil, nil, nil)
reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(adminId, userId, nil, nil, nil)
if err != nil {
return nil, err
}
@@ -682,12 +705,18 @@ func (this *ServerService) FindAndInitServerWebConfig(ctx context.Context, req *
// 计算使用某个SSL证书的服务数量
func (this *ServerService) CountAllEnabledServersWithSSLCertId(ctx context.Context, req *pb.CountAllEnabledServersWithSSLCertIdRequest) (*pb.RPCCountResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
if userId > 0 {
err = models.SharedSSLCertDAO.CheckUserCert(req.SslCertId, userId)
if err != nil {
return nil, err
}
}
policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(req.CertId)
policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(req.SslCertId)
if err != nil {
return nil, err
}
@@ -712,7 +741,7 @@ func (this *ServerService) FindAllEnabledServersWithSSLCertId(ctx context.Contex
return nil, err
}
policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(req.CertId)
policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(req.SslCertId)
if err != nil {
return nil, err
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
)
@@ -17,30 +16,36 @@ type SSLCertService struct {
// 创建Cert
func (this *SSLCertService) CreateSSLCert(ctx context.Context, req *pb.CreateSSLCertRequest) (*pb.CreateSSLCertResponse, error) {
// 校验请求
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
// TODO 校验权限
certId, err := models.SharedSSLCertDAO.CreateCert(adminId, userId, req.IsOn, req.Name, req.Description, req.ServerName, req.IsCA, req.CertData, req.KeyData, req.TimeBeginAt, req.TimeEndAt, req.DnsNames, req.CommonNames)
if err != nil {
return nil, err
}
return &pb.CreateSSLCertResponse{CertId: certId}, nil
return &pb.CreateSSLCertResponse{SslCertId: certId}, nil
}
// 修改Cert
func (this *SSLCertService) UpdateSSLCert(ctx context.Context, req *pb.UpdateSSLCertRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
err = models.SharedSSLCertDAO.UpdateCert(req.CertId, req.IsOn, req.Name, req.Description, req.ServerName, req.IsCA, req.CertData, req.KeyData, req.TimeBeginAt, req.TimeEndAt, req.DnsNames, req.CommonNames)
// 检查权限
if userId > 0 {
err := models.SharedSSLCertDAO.CheckUserCert(req.SslCertId, userId)
if err != nil {
return nil, err
}
}
err = models.SharedSSLCertDAO.UpdateCert(req.SslCertId, req.IsOn, req.Name, req.Description, req.ServerName, req.IsCA, req.CertData, req.KeyData, req.TimeBeginAt, req.TimeEndAt, req.DnsNames, req.CommonNames)
if err != nil {
return nil, err
}
@@ -51,12 +56,20 @@ func (this *SSLCertService) UpdateSSLCert(ctx context.Context, req *pb.UpdateSSL
// 查找证书配置
func (this *SSLCertService) FindEnabledSSLCertConfig(ctx context.Context, req *pb.FindEnabledSSLCertConfigRequest) (*pb.FindEnabledSSLCertConfigResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
config, err := models.SharedSSLCertDAO.ComposeCertConfig(req.CertId)
// 检查权限
if userId > 0 {
err := models.SharedSSLCertDAO.CheckUserCert(req.SslCertId, userId)
if err != nil {
return nil, err
}
}
config, err := models.SharedSSLCertDAO.ComposeCertConfig(req.SslCertId)
if err != nil {
return nil, err
}
@@ -65,24 +78,32 @@ func (this *SSLCertService) FindEnabledSSLCertConfig(ctx context.Context, req *p
if err != nil {
return nil, err
}
return &pb.FindEnabledSSLCertConfigResponse{CertJSON: configJSON}, nil
return &pb.FindEnabledSSLCertConfigResponse{SslCertJSON: configJSON}, nil
}
// 删除证书
func (this *SSLCertService) DeleteSSLCert(ctx context.Context, req *pb.DeleteSSLCertRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
err = models.SharedSSLCertDAO.DisableSSLCert(req.CertId)
// 检查权限
if userId > 0 {
err := models.SharedSSLCertDAO.CheckUserCert(req.SslCertId, userId)
if err != nil {
return nil, err
}
}
err = models.SharedSSLCertDAO.DisableSSLCert(req.SslCertId)
if err != nil {
return nil, err
}
// 停止相关ACME任务
err = models.SharedACMETaskDAO.DisableAllTasksWithCertId(req.CertId)
err = models.SharedACMETaskDAO.DisableAllTasksWithCertId(req.SslCertId)
if err != nil {
return nil, err
}
@@ -93,12 +114,12 @@ func (this *SSLCertService) DeleteSSLCert(ctx context.Context, req *pb.DeleteSSL
// 计算匹配的Cert数量
func (this *SSLCertService) CountSSLCerts(ctx context.Context, req *pb.CountSSLCertRequest) (*pb.RPCCountResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
count, err := models.SharedSSLCertDAO.CountCerts(req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword)
count, err := models.SharedSSLCertDAO.CountCerts(req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, req.UserId)
if err != nil {
return nil, err
}
@@ -109,12 +130,12 @@ func (this *SSLCertService) CountSSLCerts(ctx context.Context, req *pb.CountSSLC
// 列出单页匹配的Cert
func (this *SSLCertService) ListSSLCerts(ctx context.Context, req *pb.ListSSLCertsRequest) (*pb.ListSSLCertsResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
certIds, err := models.SharedSSLCertDAO.ListCertIds(req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, req.Offset, req.Size)
certIds, err := models.SharedSSLCertDAO.ListCertIds(req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, req.UserId, req.Offset, req.Size)
if err != nil {
return nil, err
}
@@ -136,5 +157,5 @@ func (this *SSLCertService) ListSSLCerts(ctx context.Context, req *pb.ListSSLCer
if err != nil {
return nil, err
}
return &pb.ListSSLCertsResponse{CertsJSON: certConfigsJSON}, nil
return &pb.ListSSLCertsResponse{SslCertsJSON: certConfigsJSON}, nil
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
)
type SSLPolicyService struct {
@@ -15,12 +16,32 @@ type SSLPolicyService struct {
// 创建Policy
func (this *SSLPolicyService) CreateSSLPolicy(ctx context.Context, req *pb.CreateSSLPolicyRequest) (*pb.CreateSSLPolicyResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
policyId, err := models.SharedSSLPolicyDAO.CreatePolicy(req.Http2Enabled, req.MinVersion, req.CertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites)
if userId > 0 {
// 检查证书
if len(req.SslCertsJSON) > 0 {
certRefs := []*sslconfigs.SSLCertRef{}
err = json.Unmarshal(req.SslCertsJSON, &certRefs)
if err != nil {
return nil, err
}
for _, certRef := range certRefs {
err = models.SharedSSLCertDAO.CheckUserCert(certRef.CertId, userId)
if err != nil {
return nil, err
}
}
}
// 检查CA证书
// TODO
}
policyId, err := models.SharedSSLPolicyDAO.CreatePolicy(adminId, userId, req.Http2Enabled, req.MinVersion, req.SslCertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites)
if err != nil {
return nil, err
}
@@ -31,12 +52,18 @@ func (this *SSLPolicyService) CreateSSLPolicy(ctx context.Context, req *pb.Creat
// 修改Policy
func (this *SSLPolicyService) UpdateSSLPolicy(ctx context.Context, req *pb.UpdateSSLPolicyRequest) (*pb.RPCSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
if userId > 0 {
err := models.SharedSSLPolicyDAO.CheckUserPolicy(req.SslPolicyId, userId)
if err != nil {
return nil, err
}
}
err = models.SharedSSLPolicyDAO.UpdatePolicy(req.SslPolicyId, req.Http2Enabled, req.MinVersion, req.CertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites)
err = models.SharedSSLPolicyDAO.UpdatePolicy(req.SslPolicyId, req.Http2Enabled, req.MinVersion, req.SslCertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites)
if err != nil {
return nil, err
}

View File

@@ -117,7 +117,7 @@ func (this *UserService) ListEnabledUsers(ctx context.Context, req *pb.ListEnabl
// 查询单个用户信息
func (this *UserService) FindEnabledUser(ctx context.Context, req *pb.FindEnabledUserRequest) (*pb.FindEnabledUserResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, 0)
_, _, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil {
return nil, err
}
@@ -321,3 +321,17 @@ func (this *UserService) ComposeUserDashboard(ctx context.Context, req *pb.Compo
DailyPeekTrafficStats: dailyPeekTrafficStats,
}, nil
}
// 获取用户所在的集群ID
func (this *UserService) FindUserNodeClusterId(ctx context.Context, req *pb.FindUserNodeClusterIdRequest) (*pb.FindUserNodeClusterIdResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
clusterId, err := models.SharedUserDAO.FindUserClusterId(req.UserId)
if err != nil {
return nil, err
}
return &pb.FindUserNodeClusterIdResponse{NodeClusterId: clusterId}, nil
}

View File

@@ -39,7 +39,7 @@ func (this *UserBillService) GenerateAllUserBills(ctx context.Context, req *pb.G
// 计算所有账单数量
func (this *UserBillService) CountAllUserBills(ctx context.Context, req *pb.CountAllUserBillsRequest) (*pb.RPCCountResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, 0)
_, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}
@@ -53,7 +53,7 @@ func (this *UserBillService) CountAllUserBills(ctx context.Context, req *pb.Coun
// 列出单页账单
func (this *UserBillService) ListUserBills(ctx context.Context, req *pb.ListUserBillsRequest) (*pb.ListUserBillsResponse, error) {
_, _, err := this.ValidateAdminAndUser(ctx, 0)
_, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId)
if err != nil {
return nil, err
}