diff --git a/internal/db/models/ssl_cert_dao.go b/internal/db/models/ssl_cert_dao.go index f716b2ba..02587c9a 100644 --- a/internal/db/models/ssl_cert_dao.go +++ b/internal/db/models/ssl_cert_dao.go @@ -285,7 +285,7 @@ func (this *SSLCertDAO) ComposeCertConfig(tx *dbs.Tx, certId int64, ignoreData b } // CountCerts 计算符合条件的证书数量 -func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, domains []string) (int64, error) { +func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, domains []string, userOnly bool) (int64, error) { var query = this.Query(tx). State(SSLCertStateEnabled) if isCA { @@ -308,8 +308,12 @@ func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isEx if userId > 0 { query.Attr("userId", userId) } else { - // 只查询管理员上传的 - query.Attr("userId", 0) + if userOnly { + query.Gt("userId", 0) + } else { + // 只查询管理员上传的 + query.Attr("userId", 0) + } } // 域名 @@ -322,7 +326,7 @@ func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isEx } // ListCertIds 列出符合条件的证书 -func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, domains []string, offset int64, size int64) (certIds []int64, err error) { +func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, domains []string, userOnly bool, offset int64, size int64) (certIds []int64, err error) { var query = this.Query(tx). State(SSLCertStateEnabled) if isCA { @@ -345,8 +349,12 @@ func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isE if userId > 0 { query.Attr("userId", userId) } else { - // 只查询管理员上传的 - query.Attr("userId", 0) + if userOnly { + query.Gt("userId", 0) + } else { + // 只查询管理员上传的 + query.Attr("userId", 0) + } } // 域名 @@ -434,6 +442,14 @@ func (this *SSLCertDAO) CheckUserCert(tx *dbs.Tx, certId int64, userId int64) er return nil } +// FindCertUserId 查找证书所属用户ID +func (this *SSLCertDAO) FindCertUserId(tx *dbs.Tx, certId int64) (userId int64, err error) { + return this.Query(tx). + Pk(certId). + Result("userId"). + FindInt64Col(0) +} + // UpdateCertUser 修改证书所属用户 func (this *SSLCertDAO) UpdateCertUser(tx *dbs.Tx, certId int64, userId int64) error { if certId <= 0 || userId <= 0 { diff --git a/internal/rpc/services/service_ssl_cert.go b/internal/rpc/services/service_ssl_cert.go index d474e7d0..7e58c5bb 100644 --- a/internal/rpc/services/service_ssl_cert.go +++ b/internal/rpc/services/service_ssl_cert.go @@ -191,7 +191,7 @@ func (this *SSLCertService) CountSSLCerts(ctx context.Context, req *pb.CountSSLC return nil, errors.New("invalid user") } - count, err := models.SharedSSLCertDAO.CountCerts(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, userId, req.Domains) + count, err := models.SharedSSLCertDAO.CountCerts(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, userId, req.Domains, req.UserOnly) if err != nil { return nil, err } @@ -215,7 +215,7 @@ func (this *SSLCertService) ListSSLCerts(ctx context.Context, req *pb.ListSSLCer var tx = this.NullTx() - certIds, err := models.SharedSSLCertDAO.ListCertIds(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, userId, req.Domains, req.Offset, req.Size) + certIds, err := models.SharedSSLCertDAO.ListCertIds(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, userId, req.Domains, req.UserOnly, req.Offset, req.Size) if err != nil { return nil, err } @@ -368,3 +368,40 @@ func (this *SSLCertService) ListUpdatedSSLCertOCSP(ctx context.Context, req *pb. SslCertOCSP: result, }, nil } + +// FindSSLCertUser 查找证书所属用户 +func (this *SSLCertService) FindSSLCertUser(ctx context.Context, req *pb.FindSSLCertUserRequest) (*pb.FindSSLCertUserResponse, error) { + _, err := this.ValidateAdmin(ctx) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + userId, err := models.SharedSSLCertDAO.FindCertUserId(tx, req.SslCertId) + if err != nil { + return nil, err + } + if userId <= 0 { + return &pb.FindSSLCertUserResponse{User: nil}, nil + } + + user, err := models.SharedUserDAO.FindEnabledBasicUser(tx, userId) + if err != nil { + return nil, err + } + if user == nil { + return &pb.FindSSLCertUserResponse{ + User: &pb.User{ + Id: userId, + }, + }, nil + } + + return &pb.FindSSLCertUserResponse{ + User: &pb.User{ + Id: userId, + Username: user.Username, + Fullname: user.Fullname, + }, + }, nil +}