mirror of
				https://github.com/TeaOSLab/EdgeAdmin.git
				synced 2025-11-04 05:00:25 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			228 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			228 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
 | 
						||
 | 
						||
package certs
 | 
						||
 | 
						||
import (
 | 
						||
	"bytes"
 | 
						||
	"context"
 | 
						||
	"crypto/tls"
 | 
						||
	"github.com/TeaOSLab/EdgeAdmin/internal/web/actions/actionutils"
 | 
						||
	"github.com/TeaOSLab/EdgeCommon/pkg/langs/codes"
 | 
						||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
						||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
 | 
						||
	"github.com/iwind/TeaGo/actions"
 | 
						||
	"github.com/iwind/TeaGo/types"
 | 
						||
	"io"
 | 
						||
	"mime/multipart"
 | 
						||
	"strings"
 | 
						||
)
 | 
						||
 | 
						||
// UploadBatchPopupAction 批量上传证书
 | 
						||
type UploadBatchPopupAction struct {
 | 
						||
	actionutils.ParentAction
 | 
						||
}
 | 
						||
 | 
						||
func (this *UploadBatchPopupAction) Init() {
 | 
						||
	this.Nav("", "", "")
 | 
						||
}
 | 
						||
 | 
						||
func (this *UploadBatchPopupAction) RunGet(params struct {
 | 
						||
	ServerId int64
 | 
						||
	UserId   int64
 | 
						||
}) {
 | 
						||
	// 读取服务用户
 | 
						||
	if params.ServerId > 0 {
 | 
						||
		serverResp, err := this.RPC().ServerRPC().FindEnabledUserServerBasic(this.AdminContext(), &pb.FindEnabledUserServerBasicRequest{ServerId: params.ServerId})
 | 
						||
		if err != nil {
 | 
						||
			this.ErrorPage(err)
 | 
						||
			return
 | 
						||
		}
 | 
						||
		var server = serverResp.Server
 | 
						||
		if server != nil {
 | 
						||
			params.UserId = server.UserId
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	this.Data["userId"] = params.UserId
 | 
						||
	this.Data["maxFiles"] = this.maxFiles()
 | 
						||
 | 
						||
	this.Show()
 | 
						||
}
 | 
						||
 | 
						||
func (this *UploadBatchPopupAction) RunPost(params struct {
 | 
						||
	UserId int64
 | 
						||
 | 
						||
	Must *actions.Must
 | 
						||
	CSRF *actionutils.CSRF
 | 
						||
}) {
 | 
						||
	defer this.CreateLogInfo(codes.SSLCert_LogUploadSSLCertBatch)
 | 
						||
 | 
						||
	var files = this.Request.MultipartForm.File["certFiles"]
 | 
						||
	if len(files) == 0 {
 | 
						||
		this.Fail("请选择要上传的证书和私钥文件")
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	// 限制每次上传的文件数量
 | 
						||
	var maxFiles = this.maxFiles()
 | 
						||
	if len(files) > maxFiles {
 | 
						||
		this.Fail("每次上传最多不能超过" + types.String(maxFiles) + "个文件")
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	type certInfo struct {
 | 
						||
		filename string
 | 
						||
		data     []byte
 | 
						||
	}
 | 
						||
 | 
						||
	var certDataList = []*certInfo{}
 | 
						||
	var keyDataList = [][]byte{}
 | 
						||
 | 
						||
	var failMessages = []string{}
 | 
						||
	for _, file := range files {
 | 
						||
		func(file *multipart.FileHeader) {
 | 
						||
			fp, err := file.Open()
 | 
						||
			if err != nil {
 | 
						||
				failMessages = append(failMessages, "文件"+file.Filename+"读取失败:"+err.Error())
 | 
						||
				return
 | 
						||
			}
 | 
						||
 | 
						||
			defer func() {
 | 
						||
				_ = fp.Close()
 | 
						||
			}()
 | 
						||
 | 
						||
			data, err := io.ReadAll(fp)
 | 
						||
			if err != nil {
 | 
						||
				failMessages = append(failMessages, "文件"+file.Filename+"读取失败:"+err.Error())
 | 
						||
				return
 | 
						||
			}
 | 
						||
 | 
						||
			if bytes.Contains(data, []byte("CERTIFICATE-")) {
 | 
						||
				certDataList = append(certDataList, &certInfo{
 | 
						||
					filename: file.Filename,
 | 
						||
					data:     data,
 | 
						||
				})
 | 
						||
			} else if bytes.Contains(data, []byte("PRIVATE KEY-")) {
 | 
						||
				keyDataList = append(keyDataList, data)
 | 
						||
			} else {
 | 
						||
				failMessages = append(failMessages, "文件"+file.Filename+"读取失败:文件格式错误,无法识别是证书还是私钥")
 | 
						||
				return
 | 
						||
			}
 | 
						||
		}(file)
 | 
						||
	}
 | 
						||
 | 
						||
	if len(failMessages) > 0 {
 | 
						||
		this.Fail("发生了错误:" + strings.Join(failMessages, ";"))
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	// 对比证书和私钥数量是否一致
 | 
						||
	if len(certDataList) != len(keyDataList) {
 | 
						||
		this.Fail("证书文件数量(" + types.String(len(certDataList)) + ")和私钥文件数量(" + types.String(len(keyDataList)) + ")不一致")
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	// 自动匹配
 | 
						||
	var pairs = [][2][]byte{}        // [] { cert, key }
 | 
						||
	var keyIndexMap = map[int]bool{} // 方便下面跳过已匹配的Key
 | 
						||
	for _, cert := range certDataList {
 | 
						||
		var found = false
 | 
						||
		for keyIndex, keyData := range keyDataList {
 | 
						||
			if keyIndexMap[keyIndex] {
 | 
						||
				continue
 | 
						||
			}
 | 
						||
 | 
						||
			_, err := tls.X509KeyPair(cert.data, keyData)
 | 
						||
			if err == nil {
 | 
						||
				found = true
 | 
						||
				pairs = append(pairs, [2][]byte{cert.data, keyData})
 | 
						||
				keyIndexMap[keyIndex] = true
 | 
						||
				break
 | 
						||
			}
 | 
						||
		}
 | 
						||
		if !found {
 | 
						||
			this.Fail("找不到" + cert.filename + "对应的私钥")
 | 
						||
			return
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	// 组织 CertConfig
 | 
						||
	var pbCerts = []*pb.CreateSSLCertsRequestCert{}
 | 
						||
	var certConfigs = []*sslconfigs.SSLCertConfig{}
 | 
						||
	for _, pair := range pairs {
 | 
						||
		certData, keyData := pair[0], pair[1]
 | 
						||
 | 
						||
		var certConfig = &sslconfigs.SSLCertConfig{
 | 
						||
			IsCA:     false,
 | 
						||
			CertData: certData,
 | 
						||
			KeyData:  keyData,
 | 
						||
		}
 | 
						||
		err := certConfig.Init(context.TODO())
 | 
						||
		if err != nil {
 | 
						||
			this.Fail("证书验证失败:" + err.Error())
 | 
						||
			return
 | 
						||
		}
 | 
						||
 | 
						||
		certConfigs = append(certConfigs, certConfig)
 | 
						||
 | 
						||
		var certName = ""
 | 
						||
		if len(certConfig.DNSNames) > 0 {
 | 
						||
			certName = certConfig.DNSNames[0]
 | 
						||
			if len(certConfig.DNSNames) > 1 {
 | 
						||
				certName += "等" + types.String(len(certConfig.DNSNames)) + "个域名"
 | 
						||
			}
 | 
						||
		}
 | 
						||
		certConfig.Name = certName
 | 
						||
 | 
						||
		pbCerts = append(pbCerts, &pb.CreateSSLCertsRequestCert{
 | 
						||
			IsOn:        true,
 | 
						||
			Name:        certName,
 | 
						||
			Description: "",
 | 
						||
			ServerName:  "",
 | 
						||
			IsCA:        false,
 | 
						||
			CertData:    certData,
 | 
						||
			KeyData:     keyData,
 | 
						||
			TimeBeginAt: certConfig.TimeBeginAt,
 | 
						||
			TimeEndAt:   certConfig.TimeEndAt,
 | 
						||
			DnsNames:    certConfig.DNSNames,
 | 
						||
			CommonNames: certConfig.CommonNames,
 | 
						||
		})
 | 
						||
	}
 | 
						||
 | 
						||
	createResp, err := this.RPC().SSLCertRPC().CreateSSLCerts(this.AdminContext(), &pb.CreateSSLCertsRequest{
 | 
						||
		UserId:   params.UserId,
 | 
						||
		SSLCerts: pbCerts,
 | 
						||
	})
 | 
						||
	if err != nil {
 | 
						||
		this.ErrorPage(err)
 | 
						||
		return
 | 
						||
	}
 | 
						||
	var certIds = createResp.SslCertIds
 | 
						||
	if len(certIds) != len(certConfigs) {
 | 
						||
		this.Fail("上传成功但API返回的证书ID数量错误,请反馈给开发者")
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	// 返回数据
 | 
						||
	this.Data["count"] = len(pbCerts)
 | 
						||
 | 
						||
	var certRefs = []*sslconfigs.SSLCertRef{}
 | 
						||
	for index, cert := range certConfigs {
 | 
						||
		// ID
 | 
						||
		cert.Id = certIds[index]
 | 
						||
 | 
						||
		// 减少不必要的数据
 | 
						||
		cert.CertData = nil
 | 
						||
		cert.KeyData = nil
 | 
						||
 | 
						||
		certRefs = append(certRefs, &sslconfigs.SSLCertRef{
 | 
						||
			IsOn:   true,
 | 
						||
			CertId: cert.Id,
 | 
						||
		})
 | 
						||
	}
 | 
						||
	this.Data["certs"] = certConfigs
 | 
						||
	this.Data["certRefs"] = certRefs
 | 
						||
 | 
						||
	this.Success()
 | 
						||
}
 |