mirror of
				https://github.com/TeaOSLab/EdgeAPI.git
				synced 2025-11-04 16:00:24 +08:00 
			
		
		
		
	增加批量上传证书接口、使用域名查询证书接口
This commit is contained in:
		@@ -13,6 +13,8 @@ import (
 | 
			
		||||
	"github.com/iwind/TeaGo/dbs"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	timeutil "github.com/iwind/TeaGo/utils/time"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -283,8 +285,8 @@ func (this *SSLCertDAO) ComposeCertConfig(tx *dbs.Tx, certId int64, ignoreData b
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CountCerts 计算符合条件的证书数量
 | 
			
		||||
func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64) (int64, error) {
 | 
			
		||||
	query := this.Query(tx).
 | 
			
		||||
func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, domains []string) (int64, error) {
 | 
			
		||||
	var query = this.Query(tx).
 | 
			
		||||
		State(SSLCertStateEnabled)
 | 
			
		||||
	if isCA {
 | 
			
		||||
		query.Attr("isCA", true)
 | 
			
		||||
@@ -309,12 +311,19 @@ func (this *SSLCertDAO) CountCerts(tx *dbs.Tx, isCA bool, isAvailable bool, isEx
 | 
			
		||||
		// 只查询管理员上传的
 | 
			
		||||
		query.Attr("userId", 0)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 域名
 | 
			
		||||
	err := this.buildDomainSearchingQuery(query, domains)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return query.Count()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ListCertIds 列出符合条件的证书
 | 
			
		||||
func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, offset int64, size int64) (certIds []int64, err error) {
 | 
			
		||||
	query := this.Query(tx).
 | 
			
		||||
func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, userId int64, domains []string, offset int64, size int64) (certIds []int64, err error) {
 | 
			
		||||
	var query = this.Query(tx).
 | 
			
		||||
		State(SSLCertStateEnabled)
 | 
			
		||||
	if isCA {
 | 
			
		||||
		query.Attr("isCA", true)
 | 
			
		||||
@@ -340,6 +349,12 @@ func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isE
 | 
			
		||||
		query.Attr("userId", 0)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 域名
 | 
			
		||||
	err = this.buildDomainSearchingQuery(query, domains)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ones, err := query.
 | 
			
		||||
		ResultPk().
 | 
			
		||||
		DescPk().
 | 
			
		||||
@@ -350,7 +365,7 @@ func (this *SSLCertDAO) ListCertIds(tx *dbs.Tx, isCA bool, isAvailable bool, isE
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	result := []int64{}
 | 
			
		||||
	var result = []int64{}
 | 
			
		||||
	for _, one := range ones {
 | 
			
		||||
		result = append(result, int64(one.(*SSLCert).Id))
 | 
			
		||||
	}
 | 
			
		||||
@@ -649,3 +664,71 @@ func (this *SSLCertDAO) NotifyUpdate(tx *dbs.Tx, certId int64) error {
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 构造通过域名搜索证书的查询对象
 | 
			
		||||
func (this *SSLCertDAO) buildDomainSearchingQuery(query *dbs.Query, domains []string) error {
 | 
			
		||||
	if len(domains) == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 不要查询太多
 | 
			
		||||
	const maxDomains = 10_000
 | 
			
		||||
	if len(domains) > maxDomains {
 | 
			
		||||
		domains = domains[:maxDomains]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 加入通配符
 | 
			
		||||
	var searchingDomains = []string{}
 | 
			
		||||
	var domainMap = map[string]bool{}
 | 
			
		||||
	for _, domain := range domains {
 | 
			
		||||
		domainMap[domain] = true
 | 
			
		||||
	}
 | 
			
		||||
	var reg = regexp.MustCompile(`^[\w.-]+$`) // 为了下面的SQL语句安全先不支持其他字符
 | 
			
		||||
	for domain := range domainMap {
 | 
			
		||||
		if !reg.MatchString(domain) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		searchingDomains = append(searchingDomains, domain)
 | 
			
		||||
 | 
			
		||||
		if strings.Count(domain, ".") >= 2 && !strings.HasPrefix(domain, "*.") {
 | 
			
		||||
			var wildcardDomain = "*" + domain[strings.Index(domain, "."):]
 | 
			
		||||
			if !domainMap[wildcardDomain] {
 | 
			
		||||
				domainMap[wildcardDomain] = true
 | 
			
		||||
				searchingDomains = append(searchingDomains, wildcardDomain)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检测 JSON_OVERLAPS() 函数是否可用
 | 
			
		||||
	var canJSONOverlaps = false
 | 
			
		||||
	_, funcErr := this.Instance.FindCol(0, "SELECT JSON_OVERLAPS('[1]', '[1]')")
 | 
			
		||||
	canJSONOverlaps = funcErr == nil
 | 
			
		||||
	if canJSONOverlaps {
 | 
			
		||||
		domainsJSON, err := json.Marshal(searchingDomains)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		query.
 | 
			
		||||
			Where("JSON_OVERLAPS(dnsNames, JSON_UNQUOTE(:domainsJSON))").
 | 
			
		||||
			Param("domainsJSON", string(domainsJSON))
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 不支持JSON_OVERLAPS()的情形
 | 
			
		||||
	query.Reuse(false)
 | 
			
		||||
 | 
			
		||||
	// TODO 需要判断是否超出max_allowed_packet
 | 
			
		||||
	var sqlPieces = []string{}
 | 
			
		||||
	for _, domain := range searchingDomains {
 | 
			
		||||
		domainJSON, err := json.Marshal(domain)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		sqlPieces = append(sqlPieces, "JSON_CONTAINS(dnsNames, '"+string(domainJSON)+"')")
 | 
			
		||||
	}
 | 
			
		||||
	query.Where("(" + strings.Join(sqlPieces, " OR ") + ")")
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -238,3 +238,17 @@ func TestSSLCertDAO_Update_JSON(t *testing.T) {
 | 
			
		||||
		t.Log(cert)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSSLCertDAO_FindAllCertsMatchDomains(t *testing.T) {
 | 
			
		||||
	dbs.NotifyReady()
 | 
			
		||||
 | 
			
		||||
	var tx *dbs.Tx
 | 
			
		||||
	var dao = models.NewSSLCertDAO()
 | 
			
		||||
	certs, err := dao.FindAllCertsMatchDomains(tx, 0, []string{"goedge.cn", "teaos.cn", "www.goedge.cn", "'hello\"'", "中文.com", "xn---1.com", "global.dl.goedge.cn"})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	for _, cert := range certs {
 | 
			
		||||
		t.Log("id:", cert.Id, "userId:", cert.UserId, "name:", cert.Name, "dnsNames:", cert.DecodeDNSNames(), "end:", timeutil.FormatTime("Y-m-d H:i:s", int64(cert.TimeEndAt)))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,7 @@ package dbutils
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/iwind/TeaGo/dbs"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
@@ -93,3 +94,36 @@ func IsLocalAddr(addr string) bool {
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MySQLVersion 读取当前MySQL版本
 | 
			
		||||
func MySQLVersion() (version string, err error) {
 | 
			
		||||
	db, err := dbs.Default()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	result, err := db.FindCol(0, "SELECT VERSION()")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	version = types.String(result)
 | 
			
		||||
	var suffixIndex = strings.Index(version, "-")
 | 
			
		||||
	if suffixIndex > 0 {
 | 
			
		||||
		version = version[:suffixIndex]
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func MySQLVersionFrom8() (bool, error) {
 | 
			
		||||
	version, err := MySQLVersion()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	if len(version) == 0 {
 | 
			
		||||
		return false, nil
 | 
			
		||||
	}
 | 
			
		||||
	var dotIndex = strings.Index(version, ".")
 | 
			
		||||
	if dotIndex > 0 {
 | 
			
		||||
		return types.Int(version[:dotIndex]) >= 8, nil
 | 
			
		||||
	}
 | 
			
		||||
	return false, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -23,3 +23,15 @@ func TestIsLocalAddr(t *testing.T) {
 | 
			
		||||
	a.IsFalse(dbutils.IsLocalAddr("192.168.2.200"))
 | 
			
		||||
	a.IsFalse(dbutils.IsLocalAddr("192.168.2.200:3306"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMySQLVersion(t *testing.T) {
 | 
			
		||||
	version, err := dbutils.MySQLVersion()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("version:", version)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMySQLVersionFrom8(t *testing.T) {
 | 
			
		||||
	t.Log(dbutils.MySQLVersionFrom8())
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
 | 
			
		||||
	"github.com/iwind/TeaGo/dbs"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -16,7 +17,7 @@ type SSLCertService struct {
 | 
			
		||||
	BaseService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateSSLCert 创建Cert
 | 
			
		||||
// CreateSSLCert 创建证书
 | 
			
		||||
func (this *SSLCertService) CreateSSLCert(ctx context.Context, req *pb.CreateSSLCertRequest) (*pb.CreateSSLCertResponse, error) {
 | 
			
		||||
	// 校验请求
 | 
			
		||||
	adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
 | 
			
		||||
@@ -41,6 +42,39 @@ func (this *SSLCertService) CreateSSLCert(ctx context.Context, req *pb.CreateSSL
 | 
			
		||||
	return &pb.CreateSSLCertResponse{SslCertId: certId}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateSSLCerts 创建一组证书
 | 
			
		||||
func (this *SSLCertService) CreateSSLCerts(ctx context.Context, req *pb.CreateSSLCertsRequest) (*pb.CreateSSLCertsResponse, error) {
 | 
			
		||||
	// 校验请求
 | 
			
		||||
	adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if adminId > 0 {
 | 
			
		||||
		if req.UserId > 0 {
 | 
			
		||||
			userId = req.UserId
 | 
			
		||||
		} else {
 | 
			
		||||
			userId = 0
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var certIds = []int64{}
 | 
			
		||||
	err = this.RunTx(func(tx *dbs.Tx) error {
 | 
			
		||||
		for _, cert := range req.SSLCerts {
 | 
			
		||||
			certId, err := models.SharedSSLCertDAO.CreateCert(tx, adminId, userId, cert.IsOn, cert.Name, cert.Description, cert.ServerName, cert.IsCA, cert.CertData, cert.KeyData, cert.TimeBeginAt, cert.TimeEndAt, cert.DnsNames, cert.CommonNames)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			certIds = append(certIds, certId)
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &pb.CreateSSLCertsResponse{SslCertIds: certIds}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateSSLCert 修改Cert
 | 
			
		||||
func (this *SSLCertService) UpdateSSLCert(ctx context.Context, req *pb.UpdateSSLCertRequest) (*pb.RPCSuccess, error) {
 | 
			
		||||
	// 校验请求
 | 
			
		||||
@@ -139,7 +173,7 @@ func (this *SSLCertService) DeleteSSLCert(ctx context.Context, req *pb.DeleteSSL
 | 
			
		||||
// CountSSLCerts 计算匹配的Cert数量
 | 
			
		||||
func (this *SSLCertService) CountSSLCerts(ctx context.Context, req *pb.CountSSLCertRequest) (*pb.RPCCountResponse, error) {
 | 
			
		||||
	// 校验请求
 | 
			
		||||
	_, userId, err := this.ValidateAdminAndUser(ctx, true)
 | 
			
		||||
	adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -148,9 +182,11 @@ func (this *SSLCertService) CountSSLCerts(ctx context.Context, req *pb.CountSSLC
 | 
			
		||||
 | 
			
		||||
	if userId > 0 {
 | 
			
		||||
		userId = req.UserId
 | 
			
		||||
	} else if adminId > 0 {
 | 
			
		||||
		userId = req.UserId
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	count, err := models.SharedSSLCertDAO.CountCerts(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, userId)
 | 
			
		||||
	count, err := models.SharedSSLCertDAO.CountCerts(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, userId, req.Domains)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -161,23 +197,25 @@ func (this *SSLCertService) CountSSLCerts(ctx context.Context, req *pb.CountSSLC
 | 
			
		||||
// ListSSLCerts 列出单页匹配的Cert
 | 
			
		||||
func (this *SSLCertService) ListSSLCerts(ctx context.Context, req *pb.ListSSLCertsRequest) (*pb.ListSSLCertsResponse, error) {
 | 
			
		||||
	// 校验请求
 | 
			
		||||
	_, userId, err := this.ValidateAdminAndUser(ctx, true)
 | 
			
		||||
	adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if userId > 0 {
 | 
			
		||||
		userId = req.UserId
 | 
			
		||||
	} else if adminId > 0 {
 | 
			
		||||
		userId = req.UserId
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var tx = this.NullTx()
 | 
			
		||||
 | 
			
		||||
	certIds, err := models.SharedSSLCertDAO.ListCertIds(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, userId, req.Offset, req.Size)
 | 
			
		||||
	certIds, err := models.SharedSSLCertDAO.ListCertIds(tx, req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, userId, req.Domains, req.Offset, req.Size)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	certConfigs := []*sslconfigs.SSLCertConfig{}
 | 
			
		||||
	var certConfigs = []*sslconfigs.SSLCertConfig{}
 | 
			
		||||
	for _, certId := range certIds {
 | 
			
		||||
		certConfig, err := models.SharedSSLCertDAO.ComposeCertConfig(tx, certId, false, nil, nil)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user