diff --git a/internal/db/models/http_web_dao.go b/internal/db/models/http_web_dao.go index 24cdf545..2f75f92d 100644 --- a/internal/db/models/http_web_dao.go +++ b/internal/db/models/http_web_dao.go @@ -315,10 +315,14 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon } // 创建Web配置 -func (this *HTTPWebDAO) CreateWeb(rootJSON []byte) (int64, error) { +func (this *HTTPWebDAO) CreateWeb(adminId int64, userId int64, rootJSON []byte) (int64, error) { op := NewHTTPWebOperator() op.State = HTTPWebStateEnabled - op.Root = JSONBytes(rootJSON) + op.AdminId = adminId + op.UserId = userId + if len(rootJSON) > 0 { + op.Root = JSONBytes(rootJSON) + } err := this.Save(op) if err != nil { return 0, err diff --git a/internal/db/models/origin_dao.go b/internal/db/models/origin_dao.go index f47e728b..639c0373 100644 --- a/internal/db/models/origin_dao.go +++ b/internal/db/models/origin_dao.go @@ -91,8 +91,10 @@ func (this *OriginDAO) FindOriginName(id int64) (string, error) { } // 创建源站 -func (this *OriginDAO) CreateOrigin(name string, addrJSON string, description string, weight int32, isOn bool) (originId int64, err error) { +func (this *OriginDAO) CreateOrigin(adminId int64, userId int64, name string, addrJSON string, description string, weight int32, isOn bool) (originId int64, err error) { op := NewOriginOperator() + op.AdminId = adminId + op.UserId = userId op.IsOn = isOn op.Name = name op.Addr = addrJSON diff --git a/internal/db/models/reverse_proxy_dao.go b/internal/db/models/reverse_proxy_dao.go index f30a8467..6f9832a8 100644 --- a/internal/db/models/reverse_proxy_dao.go +++ b/internal/db/models/reverse_proxy_dao.go @@ -151,10 +151,13 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(reverseProxyId int64) (*s } // 创建反向代理 -func (this *ReverseProxyDAO) CreateReverseProxy(schedulingJSON []byte, primaryOriginsJSON []byte, backupOriginsJSON []byte) (int64, error) { +func (this *ReverseProxyDAO) CreateReverseProxy(adminId int64, userId int64, schedulingJSON []byte, primaryOriginsJSON []byte, backupOriginsJSON []byte) (int64, error) { op := NewReverseProxyOperator() op.IsOn = true op.State = ReverseProxyStateEnabled + op.AdminId = adminId + op.UserId = userId + if len(schedulingJSON) > 0 { op.Scheduling = string(schedulingJSON) } diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index 08dfd4cf..89b49498 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -392,7 +392,12 @@ func (this *ServerDAO) InitServerWeb(serverId int64) (int64, error) { return 0, errors.New("serverId should not be smaller than 0") } - webId, err := SharedHTTPWebDAO.CreateWeb(nil) + adminId, userId, err := this.FindServerAdminIdAndUserId(serverId) + if err != nil { + return 0, err + } + + webId, err := SharedHTTPWebDAO.CreateWeb(adminId, userId, nil) if err != nil { return 0, err } @@ -475,14 +480,14 @@ func (this *ServerDAO) CountAllEnabledServersMatch(groupId int64, keyword string query.Where("(name LIKE :keyword OR serverNames LIKE :keyword)"). Param("keyword", "%"+keyword+"%") } - if userId > 0{ + if userId > 0 { query.Attr("userId", userId) } return query.Count() } // 列出单页的服务 -func (this *ServerDAO) ListEnabledServersMatch(offset int64, size int64, groupId int64, keyword string) (result []*Server, err error) { +func (this *ServerDAO) ListEnabledServersMatch(offset int64, size int64, groupId int64, keyword string, userId int64) (result []*Server, err error) { query := this.Query(). State(ServerStateEnabled). Offset(offset). @@ -498,6 +503,9 @@ func (this *ServerDAO) ListEnabledServersMatch(offset int64, size int64, groupId query.Where("(name LIKE :keyword OR serverNames LIKE :keyword)"). Param("keyword", "%"+keyword+"%") } + if userId > 0 { + query.Attr("userId", userId) + } _, err = query.FindAll() return @@ -914,6 +922,21 @@ func (this *ServerDAO) FindServerDNSName(serverId int64) (string, error) { FindStringCol("") } +// 获取当前服务的管理员ID和用户ID +func (this *ServerDAO) FindServerAdminIdAndUserId(serverId int64) (adminId int64, userId int64, err error) { + one, err := this.Query(). + Pk(serverId). + Result("adminId", "userId"). + Find() + if err != nil { + return 0, 0, err + } + if one == nil { + return 0, 0, nil + } + return int64(one.(*Server).AdminId), int64(one.(*Server).UserId), nil +} + // 生成DNS Name func (this *ServerDAO) genDNSName() (string, error) { for { diff --git a/internal/db/models/ssl_cert_dao.go b/internal/db/models/ssl_cert_dao.go index c6e178a3..ed50a173 100644 --- a/internal/db/models/ssl_cert_dao.go +++ b/internal/db/models/ssl_cert_dao.go @@ -210,7 +210,7 @@ func (this *SSLCertDAO) ComposeCertConfig(certId int64) (*sslconfigs.SSLCertConf } // 计算符合条件的证书数量 -func (this *SSLCertDAO) CountCerts(isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string) (int64, error) { +func (this *SSLCertDAO) CountCerts(isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64) (int64, error) { query := this.Query(). State(SSLCertStateEnabled) if isCA { @@ -230,11 +230,17 @@ func (this *SSLCertDAO) CountCerts(isCA bool, isAvailable bool, isExpired bool, query.Where("(name LIKE :keyword OR description LIKE :keyword OR dnsNames LIKE :keyword OR commonNames LIKE :keyword)"). Param("keyword", "%"+keyword+"%") } + if userId > 0 { + query.Attr("userId", userId) + } else { + // 只查询管理员上传的 + query.Attr("userId", 0) + } return query.Count() } // 列出符合条件的证书 -func (this *SSLCertDAO) ListCertIds(isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, offset int64, size int64) (certIds []int64, err error) { +func (this *SSLCertDAO) ListCertIds(isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, offset int64, size int64) (certIds []int64, err error) { query := this.Query(). State(SSLCertStateEnabled) if isCA { @@ -254,6 +260,12 @@ func (this *SSLCertDAO) ListCertIds(isCA bool, isAvailable bool, isExpired bool, query.Where("(name LIKE :keyword OR description LIKE :keyword OR dnsNames LIKE :keyword OR commonNames LIKE :keyword)"). Param("keyword", "%"+keyword+"%") } + if userId > 0 { + query.Attr("userId", userId) + } else { + // 只查询管理员上传的 + query.Attr("userId", 0) + } ones, err := query. ResultPk(). @@ -313,3 +325,22 @@ func (this *SSLCertDAO) UpdateCertNotifiedAt(certId int64) error { Update() return err } + +// 检查用户权限 +func (this *SSLCertDAO) CheckUserCert(certId int64, userId int64) error { + if certId <= 0 || userId <= 0 { + return errors.New("not found") + } + ok, err := this.Query(). + Pk(certId). + Attr("userId", userId). + State(SSLCertStateEnabled). + Exist() + if err != nil { + return err + } + if !ok { + return errors.New("not found") + } + return nil +} diff --git a/internal/db/models/ssl_policy_dao.go b/internal/db/models/ssl_policy_dao.go index 60eedd76..e653d853 100644 --- a/internal/db/models/ssl_policy_dao.go +++ b/internal/db/models/ssl_policy_dao.go @@ -187,10 +187,13 @@ func (this *SSLPolicyDAO) FindAllEnabledPolicyIdsWithCertId(certId int64) (polic } // 创建Policy -func (this *SSLPolicyDAO) CreatePolicy(http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) (int64, error) { +func (this *SSLPolicyDAO) CreatePolicy(adminId int64, userId int64, http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) (int64, error) { op := NewSSLPolicyOperator() op.State = SSLPolicyStateEnabled op.IsOn = true + op.AdminId = adminId + op.UserId = userId + op.Http2Enabled = http2Enabled op.MinVersion = minVersion @@ -258,3 +261,22 @@ func (this *SSLPolicyDAO) UpdatePolicy(policyId int64, http2Enabled bool, minVer err := this.Save(op) return err } + +// 检查是否为用户所属策略 +func (this *SSLPolicyDAO) CheckUserPolicy(policyId int64, userId int64) error { + if policyId <= 0 || userId <= 0 { + return errors.New("not found") + } + ok, err := this.Query(). + State(SSLPolicyStateEnabled). + Pk(policyId). + Attr("userId", userId). + Exist() + if err != nil { + return err + } + if !ok { + return errors.New("not found") + } + return nil +} diff --git a/internal/db/models/user_dao.go b/internal/db/models/user_dao.go index 12a35bde..aac41442 100644 --- a/internal/db/models/user_dao.go +++ b/internal/db/models/user_dao.go @@ -221,3 +221,11 @@ func (this *UserDAO) CheckUserPassword(username string, encryptedPassword string ResultPk(). FindInt64Col(0) } + +// 查找用户所在集群 +func (this *UserDAO) FindUserClusterId(userId int64) (int64, error) { + return this.Query(). + Pk(userId). + Result("clusterId"). + FindInt64Col(0) +} diff --git a/internal/rpc/services/service_acme_task.go b/internal/rpc/services/service_acme_task.go index 060aa575..7bc48a9b 100644 --- a/internal/rpc/services/service_acme_task.go +++ b/internal/rpc/services/service_acme_task.go @@ -15,7 +15,7 @@ type ACMETaskService struct { // 计算某个ACME用户相关的任务数量 func (this *ACMETaskService) CountAllEnabledACMETasksWithACMEUserId(ctx context.Context, req *pb.CountAllEnabledACMETasksWithACMEUserIdRequest) (*pb.RPCCountResponse, error) { - _, _, err := this.ValidateAdminAndUser(ctx, 0) + _, _, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -31,7 +31,7 @@ func (this *ACMETaskService) CountAllEnabledACMETasksWithACMEUserId(ctx context. // 计算跟某个DNS服务商相关的任务数量 func (this *ACMETaskService) CountEnabledACMETasksWithDNSProviderId(ctx context.Context, req *pb.CountEnabledACMETasksWithDNSProviderIdRequest) (*pb.RPCCountResponse, error) { - _, _, err := this.ValidateAdminAndUser(ctx, 0) + _, _, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -47,7 +47,7 @@ func (this *ACMETaskService) CountEnabledACMETasksWithDNSProviderId(ctx context. // 计算所有任务数量 func (this *ACMETaskService) CountAllEnabledACMETasks(ctx context.Context, req *pb.CountAllEnabledACMETasksRequest) (*pb.RPCCountResponse, error) { - _, _, err := this.ValidateAdminAndUser(ctx, req.UserId) + _, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } @@ -61,7 +61,7 @@ func (this *ACMETaskService) CountAllEnabledACMETasks(ctx context.Context, req * // 列出单页任务 func (this *ACMETaskService) ListEnabledACMETasks(ctx context.Context, req *pb.ListEnabledACMETasksRequest) (*pb.ListEnabledACMETasksResponse, error) { - _, _, err := this.ValidateAdminAndUser(ctx, req.UserId) + _, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } @@ -160,7 +160,7 @@ func (this *ACMETaskService) ListEnabledACMETasks(ctx context.Context, req *pb.L // 创建任务 func (this *ACMETaskService) CreateACMETask(ctx context.Context, req *pb.CreateACMETaskRequest) (*pb.CreateACMETaskResponse, error) { - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -178,7 +178,7 @@ func (this *ACMETaskService) CreateACMETask(ctx context.Context, req *pb.CreateA // 修改任务 func (this *ACMETaskService) UpdateACMETask(ctx context.Context, req *pb.UpdateACMETaskRequest) (*pb.RPCSuccess, error) { - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -200,7 +200,7 @@ func (this *ACMETaskService) UpdateACMETask(ctx context.Context, req *pb.UpdateA // 删除任务 func (this *ACMETaskService) DeleteACMETask(ctx context.Context, req *pb.DeleteACMETaskRequest) (*pb.RPCSuccess, error) { - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -222,7 +222,7 @@ func (this *ACMETaskService) DeleteACMETask(ctx context.Context, req *pb.DeleteA // 运行某个任务 func (this *ACMETaskService) RunACMETask(ctx context.Context, req *pb.RunACMETaskRequest) (*pb.RunACMETaskResponse, error) { - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -246,7 +246,7 @@ func (this *ACMETaskService) RunACMETask(ctx context.Context, req *pb.RunACMETas // 查找单个任务信息 func (this *ACMETaskService) FindEnabledACMETask(ctx context.Context, req *pb.FindEnabledACMETaskRequest) (*pb.FindEnabledACMETaskResponse, error) { - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_acme_user.go b/internal/rpc/services/service_acme_user.go index f49d9ab2..b21cd66f 100644 --- a/internal/rpc/services/service_acme_user.go +++ b/internal/rpc/services/service_acme_user.go @@ -14,7 +14,7 @@ type ACMEUserService struct { // 创建用户 func (this *ACMEUserService) CreateACMEUser(ctx context.Context, req *pb.CreateACMEUserRequest) (*pb.CreateACMEUserResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -29,7 +29,7 @@ func (this *ACMEUserService) CreateACMEUser(ctx context.Context, req *pb.CreateA // 修改用户 func (this *ACMEUserService) UpdateACMEUser(ctx context.Context, req *pb.UpdateACMEUserRequest) (*pb.RPCSuccess, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -53,7 +53,7 @@ func (this *ACMEUserService) UpdateACMEUser(ctx context.Context, req *pb.UpdateA // 删除用户 func (this *ACMEUserService) DeleteACMEUser(ctx context.Context, req *pb.DeleteACMEUserRequest) (*pb.RPCSuccess, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -77,7 +77,7 @@ func (this *ACMEUserService) DeleteACMEUser(ctx context.Context, req *pb.DeleteA // 计算用户数量 func (this *ACMEUserService) CountACMEUsers(ctx context.Context, req *pb.CountAcmeUsersRequest) (*pb.RPCCountResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, req.UserId) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } @@ -92,7 +92,7 @@ func (this *ACMEUserService) CountACMEUsers(ctx context.Context, req *pb.CountAc // 列出单页用户 func (this *ACMEUserService) ListACMEUsers(ctx context.Context, req *pb.ListACMEUsersRequest) (*pb.ListACMEUsersResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, req.UserId) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } @@ -116,7 +116,7 @@ func (this *ACMEUserService) ListACMEUsers(ctx context.Context, req *pb.ListACME // 查找单个用户 func (this *ACMEUserService) FindEnabledACMEUser(ctx context.Context, req *pb.FindEnabledACMEUserRequest) (*pb.FindEnabledACMEUserResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -148,7 +148,7 @@ func (this *ACMEUserService) FindEnabledACMEUser(ctx context.Context, req *pb.Fi // 查找所有用户 func (this *ACMEUserService) FindAllACMEUsers(ctx context.Context, req *pb.FindAllACMEUsersRequest) (*pb.FindAllACMEUsersResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, req.UserId) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_base.go b/internal/rpc/services/service_base.go index 5265b94f..05443e86 100644 --- a/internal/rpc/services/service_base.go +++ b/internal/rpc/services/service_base.go @@ -23,7 +23,7 @@ func (this *BaseService) ValidateAdmin(ctx context.Context, reqAdminId int64) (a } // 校验管理员和用户 -func (this *BaseService) ValidateAdminAndUser(ctx context.Context, reqUserId int64) (adminId int64, userId int64, err error) { +func (this *BaseService) ValidateAdminAndUser(ctx context.Context, requireAdminId int64, requireUserId int64) (adminId int64, userId int64, err error) { reqUserType, reqUserId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser) if err != nil { return @@ -38,15 +38,17 @@ func (this *BaseService) ValidateAdminAndUser(ctx context.Context, reqUserId int err = errors.New("invalid 'adminId'") return } + if requireAdminId > 0 && adminId != requireAdminId { + err = this.PermissionError() + return + } case rpcutils.UserTypeUser: userId = reqUserId if userId <= 0 { err = errors.New("invalid 'userId'") return } - - // 校验权限 - if reqUserId > 0 && reqUserId != userId { + if requireUserId > 0 && userId != requireUserId { err = this.PermissionError() return } diff --git a/internal/rpc/services/service_dns_domain.go b/internal/rpc/services/service_dns_domain.go index dc53415f..3422222d 100644 --- a/internal/rpc/services/service_dns_domain.go +++ b/internal/rpc/services/service_dns_domain.go @@ -21,7 +21,7 @@ type DNSDomainService struct { // 创建域名 func (this *DNSDomainService) CreateDNSDomain(ctx context.Context, req *pb.CreateDNSDomainRequest) (*pb.CreateDNSDomainResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_dns_provider.go b/internal/rpc/services/service_dns_provider.go index 62bcd04b..20af9ef6 100644 --- a/internal/rpc/services/service_dns_provider.go +++ b/internal/rpc/services/service_dns_provider.go @@ -16,7 +16,7 @@ type DNSProviderService struct { // 创建服务商 func (this *DNSProviderService) CreateDNSProvider(ctx context.Context, req *pb.CreateDNSProviderRequest) (*pb.CreateDNSProviderResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -32,7 +32,7 @@ func (this *DNSProviderService) CreateDNSProvider(ctx context.Context, req *pb.C // 修改服务商 func (this *DNSProviderService) UpdateDNSProvider(ctx context.Context, req *pb.UpdateDNSProviderRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := this.ValidateAdminAndUser(ctx, 0) + _, _, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -49,7 +49,7 @@ func (this *DNSProviderService) UpdateDNSProvider(ctx context.Context, req *pb.U // 计算服务商数量 func (this *DNSProviderService) CountAllEnabledDNSProviders(ctx context.Context, req *pb.CountAllEnabledDNSProvidersRequest) (*pb.RPCCountResponse, error) { // 校验请求 - _, _, err := this.ValidateAdminAndUser(ctx, req.UserId) + _, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } @@ -64,7 +64,7 @@ func (this *DNSProviderService) CountAllEnabledDNSProviders(ctx context.Context, // 列出单页服务商信息 func (this *DNSProviderService) ListEnabledDNSProviders(ctx context.Context, req *pb.ListEnabledDNSProvidersRequest) (*pb.ListEnabledDNSProvidersResponse, error) { // 校验请求 - _, _, err := this.ValidateAdminAndUser(ctx, req.UserId) + _, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } @@ -92,7 +92,7 @@ func (this *DNSProviderService) ListEnabledDNSProviders(ctx context.Context, req // 查找所有的DNS服务商 func (this *DNSProviderService) FindAllEnabledDNSProviders(ctx context.Context, req *pb.FindAllEnabledDNSProvidersRequest) (*pb.FindAllEnabledDNSProvidersResponse, error) { // 校验请求 - _, _, err := this.ValidateAdminAndUser(ctx, req.UserId) + _, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } @@ -120,7 +120,7 @@ func (this *DNSProviderService) FindAllEnabledDNSProviders(ctx context.Context, // 删除服务商 func (this *DNSProviderService) DeleteDNSProvider(ctx context.Context, req *pb.DeleteDNSProviderRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := this.ValidateAdminAndUser(ctx, 0) + _, _, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_location.go b/internal/rpc/services/service_http_location.go index c99f343c..b81fd635 100644 --- a/internal/rpc/services/service_http_location.go +++ b/internal/rpc/services/service_http_location.go @@ -82,7 +82,7 @@ func (this *HTTPLocationService) DeleteHTTPLocation(ctx context.Context, req *pb // 查找反向代理设置 func (this *HTTPLocationService) FindAndInitHTTPLocationReverseProxyConfig(ctx context.Context, req *pb.FindAndInitHTTPLocationReverseProxyConfigRequest) (*pb.FindAndInitHTTPLocationReverseProxyConfigResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -92,7 +92,7 @@ func (this *HTTPLocationService) FindAndInitHTTPLocationReverseProxyConfig(ctx c return nil, err } if reverseProxyRef == nil || reverseProxyRef.ReverseProxyId <= 0 { - reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(nil, nil, nil) + reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(adminId, userId, nil, nil, nil) if err != nil { return nil, err } @@ -133,7 +133,7 @@ func (this *HTTPLocationService) FindAndInitHTTPLocationReverseProxyConfig(ctx c // 初始化Web设置 func (this *HTTPLocationService) FindAndInitHTTPLocationWebConfig(ctx context.Context, req *pb.FindAndInitHTTPLocationWebConfigRequest) (*pb.FindAndInitHTTPLocationWebConfigResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, rpcutils.Wrap("ValidateRequest()", err) } @@ -144,7 +144,7 @@ func (this *HTTPLocationService) FindAndInitHTTPLocationWebConfig(ctx context.Co } if webId <= 0 { - webId, err = models.SharedHTTPWebDAO.CreateWeb(nil) + webId, err = models.SharedHTTPWebDAO.CreateWeb(adminId, userId, nil) if err != nil { return nil, rpcutils.Wrap("CreateWeb()", err) } diff --git a/internal/rpc/services/service_http_web.go b/internal/rpc/services/service_http_web.go index ad0cf7b9..f617a4f8 100644 --- a/internal/rpc/services/service_http_web.go +++ b/internal/rpc/services/service_http_web.go @@ -15,12 +15,12 @@ type HTTPWebService struct { // 创建Web配置 func (this *HTTPWebService) CreateHTTPWeb(ctx context.Context, req *pb.CreateHTTPWebRequest) (*pb.CreateHTTPWebResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } - webId, err := models.SharedHTTPWebDAO.CreateWeb(req.RootJSON) + webId, err := models.SharedHTTPWebDAO.CreateWeb(adminId, userId, req.RootJSON) if err != nil { return nil, err } @@ -213,11 +213,15 @@ func (this *HTTPWebService) UpdateHTTPWebStat(ctx context.Context, req *pb.Updat // 更改缓存配置 func (this *HTTPWebService) UpdateHTTPWebCache(ctx context.Context, req *pb.UpdateHTTPWebCacheRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } + if userId > 0 { + // TODO 检查权限 + } + err = models.SharedHTTPWebDAO.UpdateWebCache(req.WebId, req.CacheJSON) if err != nil { return nil, err diff --git a/internal/rpc/services/service_message.go b/internal/rpc/services/service_message.go index 120b81e8..aa81cf73 100644 --- a/internal/rpc/services/service_message.go +++ b/internal/rpc/services/service_message.go @@ -14,7 +14,7 @@ type MessageService struct { // 计算未读消息数 func (this *MessageService) CountUnreadMessages(ctx context.Context, req *pb.CountUnreadMessagesRequest) (*pb.RPCCountResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -29,7 +29,7 @@ func (this *MessageService) CountUnreadMessages(ctx context.Context, req *pb.Cou // 列出单页未读消息 func (this *MessageService) ListUnreadMessages(ctx context.Context, req *pb.ListUnreadMessagesRequest) (*pb.ListUnreadMessagesResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -88,7 +88,7 @@ func (this *MessageService) ListUnreadMessages(ctx context.Context, req *pb.List // 设置消息已读状态 func (this *MessageService) UpdateMessageRead(ctx context.Context, req *pb.UpdateMessageReadRequest) (*pb.RPCSuccess, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -112,7 +112,7 @@ func (this *MessageService) UpdateMessageRead(ctx context.Context, req *pb.Updat // 设置一组消息已读状态 func (this *MessageService) UpdateMessagesRead(ctx context.Context, req *pb.UpdateMessagesReadRequest) (*pb.RPCSuccess, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -139,7 +139,7 @@ func (this *MessageService) UpdateMessagesRead(ctx context.Context, req *pb.Upda func (this *MessageService) UpdateAllMessagesRead(ctx context.Context, req *pb.UpdateAllMessagesReadRequest) (*pb.RPCSuccess, error) { // 校验请求 // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_origin.go b/internal/rpc/services/service_origin.go index 3e6020bf..76f19ef7 100644 --- a/internal/rpc/services/service_origin.go +++ b/internal/rpc/services/service_origin.go @@ -17,7 +17,7 @@ type OriginService struct { // 创建源站 func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOriginRequest) (*pb.CreateOriginResponse, error) { - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -30,7 +30,7 @@ func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOrigi "portRange": req.Addr.PortRange, "host": req.Addr.Host, } - originId, err := models.SharedOriginDAO.CreateOrigin(req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn) + originId, err := models.SharedOriginDAO.CreateOrigin(adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_reverse_proxy.go b/internal/rpc/services/service_reverse_proxy.go index 7a166bcf..00f88e0e 100644 --- a/internal/rpc/services/service_reverse_proxy.go +++ b/internal/rpc/services/service_reverse_proxy.go @@ -16,12 +16,16 @@ type ReverseProxyService struct { // 创建反向代理 func (this *ReverseProxyService) CreateReverseProxy(ctx context.Context, req *pb.CreateReverseProxyRequest) (*pb.CreateReverseProxyResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } - reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(req.SchedulingJSON, req.PrimaryOriginsJSON, req.BackupOriginsJSON) + if userId > 0 { + // TODO 校验源站 + } + + reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(adminId, userId, req.SchedulingJSON, req.PrimaryOriginsJSON, req.BackupOriginsJSON) if err != nil { return nil, err } @@ -126,11 +130,15 @@ func (this *ReverseProxyService) UpdateReverseProxyBackupOrigins(ctx context.Con // 修改是否启用 func (this *ReverseProxyService) UpdateReverseProxy(ctx context.Context, req *pb.UpdateReverseProxyRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } + if userId > 0 { + // TODO 检查权限 + } + err = models.SharedReverseProxyDAO.UpdateReverseProxy(req.ReverseProxyId, types.Int8(req.RequestHostType), req.RequestHost, req.RequestURI, req.StripPrefix, req.AutoFlush) if err != nil { return nil, err diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index e98f2ca5..65fe8e40 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -18,10 +18,29 @@ type ServerService struct { // 创建服务 func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServerRequest) (*pb.CreateServerResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } + + // 校验用户相关数据 + if userId > 0 { + // HTTPS + if len(req.HttpsJSON) > 0 { + httpsConfig := &serverconfigs.HTTPSProtocolConfig{} + err = json.Unmarshal(req.HttpsJSON, httpsConfig) + if err != nil { + return nil, err + } + if httpsConfig.SSLPolicyRef != nil && httpsConfig.SSLPolicyRef.SSLPolicyId > 0 { + err := models.SharedSSLPolicyDAO.CheckUserPolicy(httpsConfig.SSLPolicyRef.SSLPolicyId, userId) + if err != nil { + return nil, err + } + } + } + } + serverId, err := models.SharedServerDAO.CreateServer(req.AdminId, req.UserId, req.Type, req.Name, req.Description, string(req.ServerNamesJON), string(req.HttpJSON), string(req.HttpsJSON), string(req.TcpJSON), string(req.TlsJSON), string(req.UnixJSON), string(req.UdpJSON), req.WebId, req.ReverseProxyJSON, req.NodeClusterId, string(req.IncludeNodesJSON), string(req.ExcludeNodesJSON), req.GroupIds) if err != nil { return nil, err @@ -273,11 +292,15 @@ func (this *ServerService) UpdateServerUDP(ctx context.Context, req *pb.UpdateSe // 修改Web服务 func (this *ServerService) UpdateServerWeb(ctx context.Context, req *pb.UpdateServerWebRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } + if userId > 0 { + // TODO 检查权限 + } + if req.ServerId <= 0 { return nil, errors.New("invalid serverId") } @@ -377,11 +400,11 @@ func (this *ServerService) UpdateServerNames(ctx context.Context, req *pb.Update // 计算服务数量 func (this *ServerService) CountAllEnabledServersMatch(ctx context.Context, req *pb.CountAllEnabledServersMatchRequest) (*pb.RPCCountResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } - count, err := models.SharedServerDAO.CountAllEnabledServersMatch(req.GroupId, req.Keyword, 0) + count, err := models.SharedServerDAO.CountAllEnabledServersMatch(req.GroupId, req.Keyword, req.UserId) if err != nil { return nil, err } @@ -392,11 +415,11 @@ func (this *ServerService) CountAllEnabledServersMatch(ctx context.Context, req // 列出单页服务 func (this *ServerService) ListEnabledServersMatch(ctx context.Context, req *pb.ListEnabledServersMatchRequest) (*pb.ListEnabledServersMatchResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } - servers, err := models.SharedServerDAO.ListEnabledServersMatch(req.Offset, req.Size, req.GroupId, req.Keyword) + servers, err := models.SharedServerDAO.ListEnabledServersMatch(req.Offset, req.Size, req.GroupId, req.Keyword, req.UserId) if err != nil { return nil, err } @@ -599,7 +622,7 @@ func (this *ServerService) FindEnabledServerType(ctx context.Context, req *pb.Fi // 查找反向代理设置 func (this *ServerService) FindAndInitServerReverseProxyConfig(ctx context.Context, req *pb.FindAndInitServerReverseProxyConfigRequest) (*pb.FindAndInitServerReverseProxyConfigResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -610,7 +633,7 @@ func (this *ServerService) FindAndInitServerReverseProxyConfig(ctx context.Conte } if reverseProxyRef == nil { - reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(nil, nil, nil) + reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(adminId, userId, nil, nil, nil) if err != nil { return nil, err } @@ -682,12 +705,18 @@ func (this *ServerService) FindAndInitServerWebConfig(ctx context.Context, req * // 计算使用某个SSL证书的服务数量 func (this *ServerService) CountAllEnabledServersWithSSLCertId(ctx context.Context, req *pb.CountAllEnabledServersWithSSLCertIdRequest) (*pb.RPCCountResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } + if userId > 0 { + err = models.SharedSSLCertDAO.CheckUserCert(req.SslCertId, userId) + if err != nil { + return nil, err + } + } - policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(req.CertId) + policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(req.SslCertId) if err != nil { return nil, err } @@ -712,7 +741,7 @@ func (this *ServerService) FindAllEnabledServersWithSSLCertId(ctx context.Contex return nil, err } - policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(req.CertId) + policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(req.SslCertId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_ssl_cert.go b/internal/rpc/services/service_ssl_cert.go index e77ab001..9347279d 100644 --- a/internal/rpc/services/service_ssl_cert.go +++ b/internal/rpc/services/service_ssl_cert.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "github.com/TeaOSLab/EdgeAPI/internal/db/models" - rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" ) @@ -17,30 +16,36 @@ type SSLCertService struct { // 创建Cert func (this *SSLCertService) CreateSSLCert(ctx context.Context, req *pb.CreateSSLCertRequest) (*pb.CreateSSLCertResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } - // TODO 校验权限 - certId, err := models.SharedSSLCertDAO.CreateCert(adminId, userId, req.IsOn, req.Name, req.Description, req.ServerName, req.IsCA, req.CertData, req.KeyData, req.TimeBeginAt, req.TimeEndAt, req.DnsNames, req.CommonNames) if err != nil { return nil, err } - return &pb.CreateSSLCertResponse{CertId: certId}, nil + return &pb.CreateSSLCertResponse{SslCertId: certId}, nil } // 修改Cert func (this *SSLCertService) UpdateSSLCert(ctx context.Context, req *pb.UpdateSSLCertRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } - err = models.SharedSSLCertDAO.UpdateCert(req.CertId, req.IsOn, req.Name, req.Description, req.ServerName, req.IsCA, req.CertData, req.KeyData, req.TimeBeginAt, req.TimeEndAt, req.DnsNames, req.CommonNames) + // 检查权限 + if userId > 0 { + err := models.SharedSSLCertDAO.CheckUserCert(req.SslCertId, userId) + if err != nil { + return nil, err + } + } + + err = models.SharedSSLCertDAO.UpdateCert(req.SslCertId, req.IsOn, req.Name, req.Description, req.ServerName, req.IsCA, req.CertData, req.KeyData, req.TimeBeginAt, req.TimeEndAt, req.DnsNames, req.CommonNames) if err != nil { return nil, err } @@ -51,12 +56,20 @@ func (this *SSLCertService) UpdateSSLCert(ctx context.Context, req *pb.UpdateSSL // 查找证书配置 func (this *SSLCertService) FindEnabledSSLCertConfig(ctx context.Context, req *pb.FindEnabledSSLCertConfigRequest) (*pb.FindEnabledSSLCertConfigResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } - config, err := models.SharedSSLCertDAO.ComposeCertConfig(req.CertId) + // 检查权限 + if userId > 0 { + err := models.SharedSSLCertDAO.CheckUserCert(req.SslCertId, userId) + if err != nil { + return nil, err + } + } + + config, err := models.SharedSSLCertDAO.ComposeCertConfig(req.SslCertId) if err != nil { return nil, err } @@ -65,24 +78,32 @@ func (this *SSLCertService) FindEnabledSSLCertConfig(ctx context.Context, req *p if err != nil { return nil, err } - return &pb.FindEnabledSSLCertConfigResponse{CertJSON: configJSON}, nil + return &pb.FindEnabledSSLCertConfigResponse{SslCertJSON: configJSON}, nil } // 删除证书 func (this *SSLCertService) DeleteSSLCert(ctx context.Context, req *pb.DeleteSSLCertRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } - err = models.SharedSSLCertDAO.DisableSSLCert(req.CertId) + // 检查权限 + if userId > 0 { + err := models.SharedSSLCertDAO.CheckUserCert(req.SslCertId, userId) + if err != nil { + return nil, err + } + } + + err = models.SharedSSLCertDAO.DisableSSLCert(req.SslCertId) if err != nil { return nil, err } // 停止相关ACME任务 - err = models.SharedACMETaskDAO.DisableAllTasksWithCertId(req.CertId) + err = models.SharedACMETaskDAO.DisableAllTasksWithCertId(req.SslCertId) if err != nil { return nil, err } @@ -93,12 +114,12 @@ func (this *SSLCertService) DeleteSSLCert(ctx context.Context, req *pb.DeleteSSL // 计算匹配的Cert数量 func (this *SSLCertService) CountSSLCerts(ctx context.Context, req *pb.CountSSLCertRequest) (*pb.RPCCountResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } - count, err := models.SharedSSLCertDAO.CountCerts(req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword) + count, err := models.SharedSSLCertDAO.CountCerts(req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, req.UserId) if err != nil { return nil, err } @@ -109,12 +130,12 @@ func (this *SSLCertService) CountSSLCerts(ctx context.Context, req *pb.CountSSLC // 列出单页匹配的Cert func (this *SSLCertService) ListSSLCerts(ctx context.Context, req *pb.ListSSLCertsRequest) (*pb.ListSSLCertsResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } - certIds, err := models.SharedSSLCertDAO.ListCertIds(req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, req.Offset, req.Size) + certIds, err := models.SharedSSLCertDAO.ListCertIds(req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, req.UserId, req.Offset, req.Size) if err != nil { return nil, err } @@ -136,5 +157,5 @@ func (this *SSLCertService) ListSSLCerts(ctx context.Context, req *pb.ListSSLCer if err != nil { return nil, err } - return &pb.ListSSLCertsResponse{CertsJSON: certConfigsJSON}, nil + return &pb.ListSSLCertsResponse{SslCertsJSON: certConfigsJSON}, nil } diff --git a/internal/rpc/services/service_ssl_policy.go b/internal/rpc/services/service_ssl_policy.go index d55b7197..b1f09aaa 100644 --- a/internal/rpc/services/service_ssl_policy.go +++ b/internal/rpc/services/service_ssl_policy.go @@ -6,6 +6,7 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/db/models" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" ) type SSLPolicyService struct { @@ -15,12 +16,32 @@ type SSLPolicyService struct { // 创建Policy func (this *SSLPolicyService) CreateSSLPolicy(ctx context.Context, req *pb.CreateSSLPolicyRequest) (*pb.CreateSSLPolicyResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } - policyId, err := models.SharedSSLPolicyDAO.CreatePolicy(req.Http2Enabled, req.MinVersion, req.CertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites) + if userId > 0 { + // 检查证书 + if len(req.SslCertsJSON) > 0 { + certRefs := []*sslconfigs.SSLCertRef{} + err = json.Unmarshal(req.SslCertsJSON, &certRefs) + if err != nil { + return nil, err + } + for _, certRef := range certRefs { + err = models.SharedSSLCertDAO.CheckUserCert(certRef.CertId, userId) + if err != nil { + return nil, err + } + } + } + + // 检查CA证书 + // TODO + } + + policyId, err := models.SharedSSLPolicyDAO.CreatePolicy(adminId, userId, req.Http2Enabled, req.MinVersion, req.SslCertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites) if err != nil { return nil, err } @@ -31,12 +52,18 @@ func (this *SSLPolicyService) CreateSSLPolicy(ctx context.Context, req *pb.Creat // 修改Policy func (this *SSLPolicyService) UpdateSSLPolicy(ctx context.Context, req *pb.UpdateSSLPolicyRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } + if userId > 0 { + err := models.SharedSSLPolicyDAO.CheckUserPolicy(req.SslPolicyId, userId) + if err != nil { + return nil, err + } + } - err = models.SharedSSLPolicyDAO.UpdatePolicy(req.SslPolicyId, req.Http2Enabled, req.MinVersion, req.CertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites) + err = models.SharedSSLPolicyDAO.UpdatePolicy(req.SslPolicyId, req.Http2Enabled, req.MinVersion, req.SslCertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_user.go b/internal/rpc/services/service_user.go index 98ee72b3..88f6f751 100644 --- a/internal/rpc/services/service_user.go +++ b/internal/rpc/services/service_user.go @@ -117,7 +117,7 @@ func (this *UserService) ListEnabledUsers(ctx context.Context, req *pb.ListEnabl // 查询单个用户信息 func (this *UserService) FindEnabledUser(ctx context.Context, req *pb.FindEnabledUserRequest) (*pb.FindEnabledUserResponse, error) { - _, _, err := this.ValidateAdminAndUser(ctx, 0) + _, _, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } @@ -321,3 +321,17 @@ func (this *UserService) ComposeUserDashboard(ctx context.Context, req *pb.Compo DailyPeekTrafficStats: dailyPeekTrafficStats, }, nil } + +// 获取用户所在的集群ID +func (this *UserService) FindUserNodeClusterId(ctx context.Context, req *pb.FindUserNodeClusterIdRequest) (*pb.FindUserNodeClusterIdResponse, error) { + _, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) + if err != nil { + return nil, err + } + + clusterId, err := models.SharedUserDAO.FindUserClusterId(req.UserId) + if err != nil { + return nil, err + } + return &pb.FindUserNodeClusterIdResponse{NodeClusterId: clusterId}, nil +} diff --git a/internal/rpc/services/service_user_bill.go b/internal/rpc/services/service_user_bill.go index fa7d445d..f963df6c 100644 --- a/internal/rpc/services/service_user_bill.go +++ b/internal/rpc/services/service_user_bill.go @@ -39,7 +39,7 @@ func (this *UserBillService) GenerateAllUserBills(ctx context.Context, req *pb.G // 计算所有账单数量 func (this *UserBillService) CountAllUserBills(ctx context.Context, req *pb.CountAllUserBillsRequest) (*pb.RPCCountResponse, error) { - _, _, err := this.ValidateAdminAndUser(ctx, 0) + _, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err } @@ -53,7 +53,7 @@ func (this *UserBillService) CountAllUserBills(ctx context.Context, req *pb.Coun // 列出单页账单 func (this *UserBillService) ListUserBills(ctx context.Context, req *pb.ListUserBillsRequest) (*pb.ListUserBillsResponse, error) { - _, _, err := this.ValidateAdminAndUser(ctx, 0) + _, _, err := this.ValidateAdminAndUser(ctx, 0, req.UserId) if err != nil { return nil, err }