diff --git a/internal/db/models/ssl_cert_dao.go b/internal/db/models/ssl_cert_dao.go index cc52734d..e847b489 100644 --- a/internal/db/models/ssl_cert_dao.go +++ b/internal/db/models/ssl_cert_dao.go @@ -13,6 +13,8 @@ import ( "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/types" timeutil "github.com/iwind/TeaGo/utils/time" + "regexp" + "strings" "time" ) @@ -283,8 +285,8 @@ func (this *SSLCertDAO) ComposeCertConfig(tx *dbs.Tx, certId int64, ignoreData b } // CountCerts 计算符合条件的证书数量 -func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64) (int64, error) { - query := this.Query(tx). +func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, domains []string) (int64, error) { + var query = this.Query(tx). State(SSLCertStateEnabled) if isCA { query.Attr("isCA", true) @@ -309,12 +311,19 @@ func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isEx // 只查询管理员上传的 query.Attr("userId", 0) } + + // 域名 + err := this.buildDomainSearchingQuery(query, domains) + if err != nil { + return 0, err + } + return query.Count() } // ListCertIds 列出符合条件的证书 -func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, offset int64, size int64) (certIds []int64, err error) { - query := this.Query(tx). +func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, domains []string, offset int64, size int64) (certIds []int64, err error) { + var query = this.Query(tx). State(SSLCertStateEnabled) if isCA { query.Attr("isCA", true) @@ -340,6 +349,12 @@ func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isE query.Attr("userId", 0) } + // 域名 + err = this.buildDomainSearchingQuery(query, domains) + if err != nil { + return nil, err + } + ones, err := query. ResultPk(). DescPk(). @@ -350,7 +365,7 @@ func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isE return nil, err } - result := []int64{} + var result = []int64{} for _, one := range ones { result = append(result, int64(one.(*SSLCert).Id)) } @@ -649,3 +664,71 @@ func (this *SSLCertDAO) NotifyUpdate(tx *dbs.Tx, certId int64) error { return nil } + +// 构造通过域名搜索证书的查询对象 +func (this *SSLCertDAO) buildDomainSearchingQuery(query *dbs.Query, domains []string) error { + if len(domains) == 0 { + return nil + } + + // 不要查询太多 + const maxDomains = 10_000 + if len(domains) > maxDomains { + domains = domains[:maxDomains] + } + + // 加入通配符 + var searchingDomains = []string{} + var domainMap = map[string]bool{} + for _, domain := range domains { + domainMap[domain] = true + } + var reg = regexp.MustCompile(`^[\w.-]+$`) // 为了下面的SQL语句安全先不支持其他字符 + for domain := range domainMap { + if !reg.MatchString(domain) { + continue + } + searchingDomains = append(searchingDomains, domain) + + if strings.Count(domain, ".") >= 2 && !strings.HasPrefix(domain, "*.") { + var wildcardDomain = "*" + domain[strings.Index(domain, "."):] + if !domainMap[wildcardDomain] { + domainMap[wildcardDomain] = true + searchingDomains = append(searchingDomains, wildcardDomain) + } + } + } + + // 检测 JSON_OVERLAPS() 函数是否可用 + var canJSONOverlaps = false + _, funcErr := this.Instance.FindCol(0, "SELECT JSON_OVERLAPS('[1]', '[1]')") + canJSONOverlaps = funcErr == nil + if canJSONOverlaps { + domainsJSON, err := json.Marshal(searchingDomains) + if err != nil { + return err + } + + query. + Where("JSON_OVERLAPS(dnsNames, JSON_UNQUOTE(:domainsJSON))"). + Param("domainsJSON", string(domainsJSON)) + return nil + } + + // 不支持JSON_OVERLAPS()的情形 + query.Reuse(false) + + // TODO 需要判断是否超出max_allowed_packet + var sqlPieces = []string{} + for _, domain := range searchingDomains { + domainJSON, err := json.Marshal(domain) + if err != nil { + return err + } + + sqlPieces = append(sqlPieces, "JSON_CONTAINS(dnsNames, '"+string(domainJSON)+"')") + } + query.Where("(" + strings.Join(sqlPieces, " OR ") + ")") + + return nil +} diff --git a/internal/db/models/ssl_cert_dao_test.go b/internal/db/models/ssl_cert_dao_test.go index 5e0f0756..f41504b9 100644 --- a/internal/db/models/ssl_cert_dao_test.go +++ b/internal/db/models/ssl_cert_dao_test.go @@ -238,3 +238,17 @@ func TestSSLCertDAO_Update_JSON(t *testing.T) { t.Log(cert) } } + +func TestSSLCertDAO_FindAllCertsMatchDomains(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + var dao = models.NewSSLCertDAO() + certs, err := dao.FindAllCertsMatchDomains(tx, 0, []string{"goedge.cn", "teaos.cn", "www.goedge.cn", "'hello\"'", "中文.com", "xn---1.com", "global.dl.goedge.cn"}) + if err != nil { + t.Fatal(err) + } + for _, cert := range certs { + t.Log("id:", cert.Id, "userId:", cert.UserId, "name:", cert.Name, "dnsNames:", cert.DecodeDNSNames(), "end:", timeutil.FormatTime("Y-m-d H:i:s", int64(cert.TimeEndAt))) + } +} diff --git a/internal/db/utils/utils.go b/internal/db/utils/utils.go index 686dc21a..aee0dc31 100644 --- a/internal/db/utils/utils.go +++ b/internal/db/utils/utils.go @@ -2,6 +2,7 @@ package dbutils import ( "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/types" "net" "strings" ) @@ -93,3 +94,36 @@ func IsLocalAddr(addr string) bool { } return false } + +// MySQLVersion 读取当前MySQL版本 +func MySQLVersion() (version string, err error) { + db, err := dbs.Default() + if err != nil { + return "", err + } + result, err := db.FindCol(0, "SELECT VERSION()") + if err != nil { + return "", err + } + version = types.String(result) + var suffixIndex = strings.Index(version, "-") + if suffixIndex > 0 { + version = version[:suffixIndex] + } + return +} + +func MySQLVersionFrom8() (bool, error) { + version, err := MySQLVersion() + if err != nil { + return false, err + } + if len(version) == 0 { + return false, nil + } + var dotIndex = strings.Index(version, ".") + if dotIndex > 0 { + return types.Int(version[:dotIndex]) >= 8, nil + } + return false, nil +} diff --git a/internal/db/utils/utils_test.go b/internal/db/utils/utils_test.go index db1d1b2b..35d62083 100644 --- a/internal/db/utils/utils_test.go +++ b/internal/db/utils/utils_test.go @@ -23,3 +23,15 @@ func TestIsLocalAddr(t *testing.T) { a.IsFalse(dbutils.IsLocalAddr("192.168.2.200")) a.IsFalse(dbutils.IsLocalAddr("192.168.2.200:3306")) } + +func TestMySQLVersion(t *testing.T) { + version, err := dbutils.MySQLVersion() + if err != nil { + t.Fatal(err) + } + t.Log("version:", version) +} + +func TestMySQLVersionFrom8(t *testing.T) { + t.Log(dbutils.MySQLVersionFrom8()) +} diff --git a/internal/rpc/services/service_ssl_cert.go b/internal/rpc/services/service_ssl_cert.go index 836e8f88..a8d6d7ab 100644 --- a/internal/rpc/services/service_ssl_cert.go +++ b/internal/rpc/services/service_ssl_cert.go @@ -8,6 +8,7 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" + "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/types" ) @@ -16,7 +17,7 @@ type SSLCertService struct { BaseService } -// CreateSSLCert 创建Cert +// CreateSSLCert 创建证书 func (this *SSLCertService) CreateSSLCert(ctx context.Context, req *pb.CreateSSLCertRequest) (*pb.CreateSSLCertResponse, error) { // 校验请求 adminId, userId, err := this.ValidateAdminAndUser(ctx, true) @@ -41,6 +42,39 @@ func (this *SSLCertService) CreateSSLCert(ctx context.Context, req *pb.CreateSSL return &pb.CreateSSLCertResponse{SslCertId: certId}, nil } +// CreateSSLCerts 创建一组证书 +func (this *SSLCertService) CreateSSLCerts(ctx context.Context, req *pb.CreateSSLCertsRequest) (*pb.CreateSSLCertsResponse, error) { + // 校验请求 + adminId, userId, err := this.ValidateAdminAndUser(ctx, true) + if err != nil { + return nil, err + } + + if adminId > 0 { + if req.UserId > 0 { + userId = req.UserId + } else { + userId = 0 + } + } + + var certIds = []int64{} + err = this.RunTx(func(tx *dbs.Tx) error { + for _, cert := range req.SSLCerts { + certId, err := models.SharedSSLCertDAO.CreateCert(tx, adminId, userId, cert.IsOn, cert.Name, cert.Description, cert.ServerName, cert.IsCA, cert.CertData, cert.KeyData, cert.TimeBeginAt, cert.TimeEndAt, cert.DnsNames, cert.CommonNames) + if err != nil { + return err + } + certIds = append(certIds, certId) + } + return nil + }) + if err != nil { + return nil, err + } + return &pb.CreateSSLCertsResponse{SslCertIds: certIds}, nil +} + // UpdateSSLCert 修改Cert func (this *SSLCertService) UpdateSSLCert(ctx context.Context, req *pb.UpdateSSLCertRequest) (*pb.RPCSuccess, error) { // 校验请求 @@ -139,7 +173,7 @@ func (this *SSLCertService) DeleteSSLCert(ctx context.Context, req *pb.DeleteSSL // CountSSLCerts 计算匹配的Cert数量 func (this *SSLCertService) CountSSLCerts(ctx context.Context, req *pb.CountSSLCertRequest) (*pb.RPCCountResponse, error) { // 校验请求 - _, userId, err := this.ValidateAdminAndUser(ctx, true) + adminId, userId, err := this.ValidateAdminAndUser(ctx, true) if err != nil { return nil, err } @@ -148,9 +182,11 @@ func (this *SSLCertService) CountSSLCerts(ctx context.Context, req *pb.CountSSLC if userId > 0 { userId = req.UserId + } else if adminId > 0 { + userId = req.UserId } - count, err := models.SharedSSLCertDAO.CountCerts(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, userId) + count, err := models.SharedSSLCertDAO.CountCerts(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, userId, req.Domains) if err != nil { return nil, err } @@ -161,23 +197,25 @@ func (this *SSLCertService) CountSSLCerts(ctx context.Context, req *pb.CountSSLC // ListSSLCerts 列出单页匹配的Cert func (this *SSLCertService) ListSSLCerts(ctx context.Context, req *pb.ListSSLCertsRequest) (*pb.ListSSLCertsResponse, error) { // 校验请求 - _, userId, err := this.ValidateAdminAndUser(ctx, true) + adminId, userId, err := this.ValidateAdminAndUser(ctx, true) if err != nil { return nil, err } if userId > 0 { userId = req.UserId + } else if adminId > 0 { + userId = req.UserId } var tx = this.NullTx() - certIds, err := models.SharedSSLCertDAO.ListCertIds(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, userId, req.Offset, req.Size) + certIds, err := models.SharedSSLCertDAO.ListCertIds(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, userId, req.Domains, req.Offset, req.Size) if err != nil { return nil, err } - certConfigs := []*sslconfigs.SSLCertConfig{} + var certConfigs = []*sslconfigs.SSLCertConfig{} for _, certId := range certIds { certConfig, err := models.SharedSSLCertDAO.ComposeCertConfig(tx, certId, false, nil, nil) if err != nil {