mirror of
				https://github.com/TeaOSLab/EdgeAPI.git
				synced 2025-11-04 16:00:24 +08:00 
			
		
		
		
	优化代码
This commit is contained in:
		@@ -40,6 +40,7 @@ func (this *Request) Run() (certData []byte, keyData []byte, err error) {
 | 
			
		||||
	}
 | 
			
		||||
	if this.task.Provider.RequireEAB && this.task.Account == nil {
 | 
			
		||||
		err = errors.New("account should not be nil when provider require EAB")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch this.task.AuthType {
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,7 @@ import (
 | 
			
		||||
 | 
			
		||||
func TestDNSTaskDAO_CreateDNSTask(t *testing.T) {
 | 
			
		||||
	dbs.NotifyReady()
 | 
			
		||||
	err := SharedDNSTaskDAO.CreateDNSTask(nil, 1, 2, 3, 0, "taskType")
 | 
			
		||||
	err := SharedDNSTaskDAO.CreateDNSTask(nil, 1, 2, 3, 0, "cdn", "taskType")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -21,7 +21,7 @@ func TestNodeClusterDAO_CheckClusterDNS(t *testing.T) {
 | 
			
		||||
		t.Log("cluster not found, skip the test")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	issues, err := CheckClusterDNS(tx, cluster)
 | 
			
		||||
	issues, err := CheckClusterDNS(tx, cluster, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -805,7 +805,7 @@ func (this *HTTPAccessLogDAO) SetupQueue() {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if bytes.Compare(accessLogConfigJSON, configJSON) == 0 {
 | 
			
		||||
	if bytes.Equal(accessLogConfigJSON, configJSON) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	accessLogConfigJSON = configJSON
 | 
			
		||||
 
 | 
			
		||||
@@ -422,7 +422,7 @@ func (this *HTTPAccessLogManager) checkTableFields(db *dbs.DB, tableName string)
 | 
			
		||||
	}
 | 
			
		||||
	for _, field := range fields {
 | 
			
		||||
		var fieldName = field.GetString("Field")
 | 
			
		||||
		if strings.ToLower(fieldName) == strings.ToLower("remoteAddr") {
 | 
			
		||||
		if strings.EqualFold(fieldName, "remoteAddr") {
 | 
			
		||||
			hasRemoteAddrField = true
 | 
			
		||||
		}
 | 
			
		||||
		if strings.ToLower(fieldName) == "domain" {
 | 
			
		||||
 
 | 
			
		||||
@@ -381,7 +381,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64, cacheMap *util
 | 
			
		||||
 | 
			
		||||
	// 认证
 | 
			
		||||
	if IsNotNull(web.Auth) {
 | 
			
		||||
		authConfig := &serverconfigs.HTTPAuthConfig{}
 | 
			
		||||
		var authConfig = &serverconfigs.HTTPAuthConfig{}
 | 
			
		||||
		err = json.Unmarshal(web.Auth, authConfig)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
@@ -395,6 +395,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64, cacheMap *util
 | 
			
		||||
			if policyConfig != nil {
 | 
			
		||||
				ref.AuthPolicy = policyConfig
 | 
			
		||||
				newRefs = append(newRefs, ref)
 | 
			
		||||
				authConfig.PolicyRefs = newRefs
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		config.Auth = authConfig
 | 
			
		||||
 
 | 
			
		||||
@@ -141,7 +141,7 @@ func TestServerDAO_ExistServerNameInCluster(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	var tx *dbs.Tx
 | 
			
		||||
	{
 | 
			
		||||
		exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "hello.teaos.cn", 0)
 | 
			
		||||
		exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "hello.teaos.cn", 0, true)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
@@ -149,7 +149,7 @@ func TestServerDAO_ExistServerNameInCluster(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "cdn.teaos.cn", 0)
 | 
			
		||||
		exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "cdn.teaos.cn", 0, true)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
@@ -157,7 +157,7 @@ func TestServerDAO_ExistServerNameInCluster(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "cdn.teaos.cn", 23)
 | 
			
		||||
		exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "cdn.teaos.cn", 23, true)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -149,7 +149,7 @@ func (this *SSLCertDAO) UpdateCert(tx *dbs.Tx,
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	var oldCert = oldOne.(*SSLCert)
 | 
			
		||||
	var dataIsChanged = bytes.Compare(certData, oldCert.CertData) != 0 || bytes.Compare(keyData, oldCert.KeyData) != 0
 | 
			
		||||
	var dataIsChanged = !bytes.Equal(certData, oldCert.CertData) || !bytes.Equal(keyData, oldCert.KeyData)
 | 
			
		||||
 | 
			
		||||
	var op = NewSSLCertOperator()
 | 
			
		||||
	op.Id = certId
 | 
			
		||||
 
 | 
			
		||||
@@ -12,7 +12,6 @@ func TestUserDAO_UpdateUserFeatures(t *testing.T) {
 | 
			
		||||
	var dao = NewUserDAO()
 | 
			
		||||
	var tx *dbs.Tx
 | 
			
		||||
	err := dao.UpdateUsersFeatures(tx, []string{
 | 
			
		||||
		userconfigs.UserFeatureCodeFinance,
 | 
			
		||||
		userconfigs.UserFeatureCodeServerACME,
 | 
			
		||||
	}, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -433,17 +433,6 @@ func (this *DNSDomainService) convertDomainToPB(tx *dbs.Tx, domain *dns.DNSDomai
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 转换域名记录信息
 | 
			
		||||
func (this *DNSDomainService) convertRecordToPB(record *dnstypes.Record) *pb.DNSRecord {
 | 
			
		||||
	return &pb.DNSRecord{
 | 
			
		||||
		Id:    record.Id,
 | 
			
		||||
		Name:  record.Name,
 | 
			
		||||
		Value: record.Value,
 | 
			
		||||
		Type:  record.Type,
 | 
			
		||||
		Route: record.Route,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 检查集群节点变化
 | 
			
		||||
func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, records []*dnstypes.Record, domainName string, defaultRoute string) (result []maps.Map, doneNodeRecords []*dnstypes.Record, doneServerRecords []*dnstypes.Record, countAllNodes int64, countAllServers int64, nodesChanged bool, serversChanged bool, err error) {
 | 
			
		||||
	var clusterId = int64(cluster.Id)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,21 +3,21 @@
 | 
			
		||||
package services
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/utils"
 | 
			
		||||
	"github.com/iwind/TeaGo/assert"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestHTTPCacheTaskService_CountHTTPCacheTasks(t *testing.T) {
 | 
			
		||||
func TestHTTPCacheTaskService_ParseDomain(t *testing.T) {
 | 
			
		||||
	var a = assert.NewAssertion(t)
 | 
			
		||||
 | 
			
		||||
	var service = &HTTPCacheTaskService{}
 | 
			
		||||
	a.IsTrue(service.parseDomain("aaa") == "aaa")
 | 
			
		||||
	a.IsTrue(service.parseDomain("AAA") == "aaa")
 | 
			
		||||
	a.IsTrue(service.parseDomain("a.b-c.com") == "a.b-c.com")
 | 
			
		||||
	a.IsTrue(service.parseDomain("a.b-c.com/hello/world") == "a.b-c.com")
 | 
			
		||||
	a.IsTrue(service.parseDomain("https://a.b-c.com") == "a.b-c.com")
 | 
			
		||||
	a.IsTrue(service.parseDomain("http://a.b-c.com/hello/world") == "a.b-c.com")
 | 
			
		||||
	a.IsTrue(service.parseDomain("http://a.B-c.com/hello/world") == "a.b-c.com")
 | 
			
		||||
	a.IsTrue(service.parseDomain("http:/aaaa.com") == "http")
 | 
			
		||||
	a.IsTrue(service.parseDomain("北京") == "")
 | 
			
		||||
	a.IsTrue(utils.ParseDomainFromKey("aaa") == "aaa")
 | 
			
		||||
	a.IsTrue(utils.ParseDomainFromKey("AAA") == "aaa")
 | 
			
		||||
	a.IsTrue(utils.ParseDomainFromKey("a.b-c.com") == "a.b-c.com")
 | 
			
		||||
	a.IsTrue(utils.ParseDomainFromKey("a.b-c.com/hello/world") == "a.b-c.com")
 | 
			
		||||
	a.IsTrue(utils.ParseDomainFromKey("https://a.b-c.com") == "a.b-c.com")
 | 
			
		||||
	a.IsTrue(utils.ParseDomainFromKey("http://a.b-c.com/hello/world") == "a.b-c.com")
 | 
			
		||||
	a.IsTrue(utils.ParseDomainFromKey("http://a.B-c.com/hello/world") == "a.b-c.com")
 | 
			
		||||
	a.IsTrue(utils.ParseDomainFromKey("http:/aaaa.com") == "http")
 | 
			
		||||
	a.IsTrue(utils.ParseDomainFromKey("北京") == "")
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -25,6 +25,8 @@ func init() {
 | 
			
		||||
		goman.New(func() {
 | 
			
		||||
			// 将队列导入数据库
 | 
			
		||||
			var countKeys = 0
 | 
			
		||||
			var useTx = true
 | 
			
		||||
 | 
			
		||||
			for key := range metricStatKeysQueue {
 | 
			
		||||
				err := func(key string) error {
 | 
			
		||||
					metricStatsLocker.Lock()
 | 
			
		||||
@@ -43,18 +45,31 @@ func init() {
 | 
			
		||||
					var itemId = types.Int64(pieces[3])
 | 
			
		||||
 | 
			
		||||
					// 删除旧的数据
 | 
			
		||||
					tx, err := models.SharedMetricStatDAO.Instance.Begin()
 | 
			
		||||
					var tx *dbs.Tx
 | 
			
		||||
					var err error
 | 
			
		||||
					if useTx {
 | 
			
		||||
						var before = time.Now()
 | 
			
		||||
 | 
			
		||||
						tx, err = models.SharedMetricStatDAO.Instance.Begin()
 | 
			
		||||
						if err != nil {
 | 
			
		||||
							return err
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						defer func() {
 | 
			
		||||
							// 失败时不需要rollback
 | 
			
		||||
							if tx != nil {
 | 
			
		||||
								commitErr := tx.Commit()
 | 
			
		||||
								if commitErr != nil {
 | 
			
		||||
									remotelogs.Error("METRIC_STAT", "commit metric stats failed: "+commitErr.Error())
 | 
			
		||||
								}
 | 
			
		||||
							}
 | 
			
		||||
 | 
			
		||||
							// 如果运行时间过长,则不使用事务
 | 
			
		||||
							if time.Since(before) > 1*time.Second {
 | 
			
		||||
								useTx = false
 | 
			
		||||
							}
 | 
			
		||||
						}()
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					err = models.SharedMetricStatDAO.DeleteNodeItemStats(tx, nodeId, serverId, itemId, req.Time)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -213,6 +213,7 @@ func (this *ServerService) UpdateServerGroupIds(ctx context.Context, req *pb.Upd
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查分组IDs
 | 
			
		||||
	var serverGroupIds = []int64{}
 | 
			
		||||
	for _, groupId := range req.ServerGroupIds {
 | 
			
		||||
		if userId > 0 {
 | 
			
		||||
			err = models.SharedServerGroupDAO.CheckUserGroup(tx, userId, groupId)
 | 
			
		||||
@@ -228,18 +229,19 @@ func (this *ServerService) UpdateServerGroupIds(ctx context.Context, req *pb.Upd
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		serverGroupIds = append(serverGroupIds, groupId)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 增加默认分组
 | 
			
		||||
	if userId > 0 {
 | 
			
		||||
		config, err := models.SharedSysSettingDAO.ReadUserServerConfig(tx)
 | 
			
		||||
		if err == nil && config.GroupId > 0 && !lists.ContainsInt64(req.ServerGroupIds, config.GroupId) {
 | 
			
		||||
			req.ServerGroupIds = append(req.ServerGroupIds, config.GroupId)
 | 
			
		||||
		if err == nil && config.GroupId > 0 && !lists.ContainsInt64(serverGroupIds, config.GroupId) {
 | 
			
		||||
			serverGroupIds = append(serverGroupIds, config.GroupId)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 修改
 | 
			
		||||
	err = models.SharedServerDAO.UpdateServerGroupIds(tx, req.ServerId, req.ServerGroupIds)
 | 
			
		||||
	err = models.SharedServerDAO.UpdateServerGroupIds(tx, req.ServerId, serverGroupIds)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -20,6 +20,7 @@ var serverBandwidthStatsLocker = &sync.Mutex{}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	var ticker = time.NewTicker(1 * time.Minute)
 | 
			
		||||
	var useTx = true
 | 
			
		||||
 | 
			
		||||
	dbs.OnReadyDone(func() {
 | 
			
		||||
		goman.New(func() {
 | 
			
		||||
@@ -30,15 +31,32 @@ func init() {
 | 
			
		||||
					serverBandwidthStatsMap = map[string]*pb.ServerBandwidthStat{}
 | 
			
		||||
					serverBandwidthStatsLocker.Unlock()
 | 
			
		||||
 | 
			
		||||
					tx, err := models.SharedServerBandwidthStatDAO.Instance.Begin()
 | 
			
		||||
					var tx *dbs.Tx
 | 
			
		||||
					var err error
 | 
			
		||||
 | 
			
		||||
					if useTx {
 | 
			
		||||
						var before = time.Now()
 | 
			
		||||
 | 
			
		||||
						tx, err = models.SharedServerBandwidthStatDAO.Instance.Begin()
 | 
			
		||||
						if err != nil {
 | 
			
		||||
							remotelogs.Error("ServerBandwidthStatService", "begin transaction failed: "+err.Error())
 | 
			
		||||
							return
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						defer func() {
 | 
			
		||||
						_ = tx.Commit()
 | 
			
		||||
							if tx != nil {
 | 
			
		||||
								commitErr := tx.Commit()
 | 
			
		||||
								if commitErr != nil {
 | 
			
		||||
									remotelogs.Error("METRIC_STAT", "commit bandwidth stats failed: "+commitErr.Error())
 | 
			
		||||
								}
 | 
			
		||||
							}
 | 
			
		||||
 | 
			
		||||
							// 如果运行时间过长,则不使用事务
 | 
			
		||||
							if time.Since(before) > 1*time.Second {
 | 
			
		||||
								useTx = false
 | 
			
		||||
							}
 | 
			
		||||
						}()
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					for _, stat := range m {
 | 
			
		||||
						// 更新服务的带宽峰值
 | 
			
		||||
 
 | 
			
		||||
@@ -70,7 +70,7 @@ func (this *SQLDump) Dump(db *dbs.DB) (result *SQLDumpResult, err error) {
 | 
			
		||||
			Name:       table.Name,
 | 
			
		||||
			Engine:     table.Engine,
 | 
			
		||||
			Charset:    table.Collation,
 | 
			
		||||
			Definition: regexp.MustCompile(" AUTO_INCREMENT=\\d+").ReplaceAllString(table.Code, ""),
 | 
			
		||||
			Definition: regexp.MustCompile(` AUTO_INCREMENT=\d+`).ReplaceAllString(table.Code, ""),
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 字段
 | 
			
		||||
 
 | 
			
		||||
@@ -8,7 +8,7 @@ type SQLDumpResult struct {
 | 
			
		||||
 | 
			
		||||
func (this *SQLDumpResult) FindTable(tableName string) *SQLTable {
 | 
			
		||||
	for _, table := range this.Tables {
 | 
			
		||||
		if strings.ToLower(table.Name) == strings.ToLower(tableName) {
 | 
			
		||||
		if strings.EqualFold(table.Name, tableName) {
 | 
			
		||||
			return table
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -45,7 +45,7 @@ func (this *HealthCheckClusterTask) Reset(config *serverconfigs.HealthCheckConfi
 | 
			
		||||
		this.logErr("HealthCheckClusterTask", err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if bytes.Compare(oldJSON, newJSON) != 0 {
 | 
			
		||||
	if !bytes.Equal(oldJSON, newJSON) {
 | 
			
		||||
		this.config = config
 | 
			
		||||
		this.Run()
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -86,7 +86,7 @@ func (this *HealthCheckTask) Loop() error {
 | 
			
		||||
			// 检查是否有变化
 | 
			
		||||
			newJSON, _ := json.Marshal(config)
 | 
			
		||||
			oldJSON, _ := json.Marshal(task.Config())
 | 
			
		||||
			if bytes.Compare(oldJSON, newJSON) != 0 {
 | 
			
		||||
			if !bytes.Equal(oldJSON, newJSON) {
 | 
			
		||||
				remotelogs.Println("TASK", "[HealthCheckTask]update cluster '"+numberutils.FormatInt64(clusterId)+"'")
 | 
			
		||||
				goman.New(func() {
 | 
			
		||||
					task.Reset(config)
 | 
			
		||||
 
 | 
			
		||||
@@ -122,6 +122,9 @@ func (this *SSLCertExpireCheckExecutor) Loop() error {
 | 
			
		||||
								"certId":     cert.Id,
 | 
			
		||||
								"acmeTaskId": cert.AcmeTaskId,
 | 
			
		||||
							}.AsJSON())
 | 
			
		||||
							if err != nil {
 | 
			
		||||
								return err
 | 
			
		||||
							}
 | 
			
		||||
 | 
			
		||||
							// 更新通知时间
 | 
			
		||||
							err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(nil, int64(cert.Id))
 | 
			
		||||
@@ -136,6 +139,9 @@ func (this *SSLCertExpireCheckExecutor) Loop() error {
 | 
			
		||||
								"certId":     cert.Id,
 | 
			
		||||
								"acmeTaskId": cert.AcmeTaskId,
 | 
			
		||||
							}.AsJSON())
 | 
			
		||||
							if err != nil {
 | 
			
		||||
								return err
 | 
			
		||||
							}
 | 
			
		||||
 | 
			
		||||
							// 更新通知时间
 | 
			
		||||
							err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(nil, int64(cert.Id))
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ import (
 | 
			
		||||
	pb2 "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
	"google.golang.org/grpc"
 | 
			
		||||
	"google.golang.org/grpc/credentials"
 | 
			
		||||
	"google.golang.org/grpc/credentials/insecure"
 | 
			
		||||
	"google.golang.org/grpc/metadata"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net"
 | 
			
		||||
@@ -51,7 +52,7 @@ func TestTCPServer(t *testing.T) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTCPClient(t *testing.T) {
 | 
			
		||||
	conn, err := grpc.Dial("127.0.0.1:8001", grpc.WithInsecure())
 | 
			
		||||
	conn, err := grpc.Dial("127.0.0.1:8001", grpc.WithTransportCredentials(insecure.NewCredentials()))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user