From 5655f89ba63df39cdeae58497993911a9603f9d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=A5=A5=E8=B6=85?= Date: Fri, 23 Sep 2022 09:28:19 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/acme/request.go | 1 + internal/db/models/dns/dns_task_dao_test.go | 2 +- .../db/models/dns/dnsutils/dns_utils_test.go | 2 +- internal/db/models/http_access_log_dao.go | 2 +- internal/db/models/http_access_log_manager.go | 2 +- internal/db/models/http_web_dao.go | 3 +- internal/db/models/server_dao_test.go | 6 ++-- internal/db/models/ssl_cert_dao.go | 2 +- internal/db/models/user_dao_test.go | 1 - internal/rpc/services/service_dns_domain.go | 11 ------ .../services/service_http_cache_task_test.go | 22 ++++++------ internal/rpc/services/service_metric_stat.go | 35 +++++++++++++------ internal/rpc/services/service_server.go | 8 +++-- .../services/service_server_bandwidth_stat.go | 34 +++++++++++++----- internal/setup/sql_dump.go | 2 +- internal/setup/sql_dump_result.go | 2 +- internal/tasks/health_check_cluster_task.go | 2 +- internal/tasks/health_check_task.go | 2 +- .../tasks/ssl_cert_expire_check_executor.go | 6 ++++ internal/tests/grpc_test.go | 3 +- 20 files changed, 90 insertions(+), 58 deletions(-) diff --git a/internal/acme/request.go b/internal/acme/request.go index 5b419c52..58569b99 100644 --- a/internal/acme/request.go +++ b/internal/acme/request.go @@ -40,6 +40,7 @@ func (this *Request) Run() (certData []byte, keyData []byte, err error) { } if this.task.Provider.RequireEAB && this.task.Account == nil { err = errors.New("account should not be nil when provider require EAB") + return } switch this.task.AuthType { diff --git a/internal/db/models/dns/dns_task_dao_test.go b/internal/db/models/dns/dns_task_dao_test.go index 9c71349a..17449080 100644 --- a/internal/db/models/dns/dns_task_dao_test.go +++ b/internal/db/models/dns/dns_task_dao_test.go @@ -9,7 +9,7 @@ import ( func TestDNSTaskDAO_CreateDNSTask(t *testing.T) { dbs.NotifyReady() - err := SharedDNSTaskDAO.CreateDNSTask(nil, 1, 2, 3, 0, "taskType") + err := SharedDNSTaskDAO.CreateDNSTask(nil, 1, 2, 3, 0, "cdn", "taskType") if err != nil { t.Fatal(err) } diff --git a/internal/db/models/dns/dnsutils/dns_utils_test.go b/internal/db/models/dns/dnsutils/dns_utils_test.go index a923d4a1..7da64254 100644 --- a/internal/db/models/dns/dnsutils/dns_utils_test.go +++ b/internal/db/models/dns/dnsutils/dns_utils_test.go @@ -21,7 +21,7 @@ func TestNodeClusterDAO_CheckClusterDNS(t *testing.T) { t.Log("cluster not found, skip the test") return } - issues, err := CheckClusterDNS(tx, cluster) + issues, err := CheckClusterDNS(tx, cluster, true) if err != nil { t.Fatal(err) } diff --git a/internal/db/models/http_access_log_dao.go b/internal/db/models/http_access_log_dao.go index 4fc1109c..520a6416 100644 --- a/internal/db/models/http_access_log_dao.go +++ b/internal/db/models/http_access_log_dao.go @@ -805,7 +805,7 @@ func (this *HTTPAccessLogDAO) SetupQueue() { return } - if bytes.Compare(accessLogConfigJSON, configJSON) == 0 { + if bytes.Equal(accessLogConfigJSON, configJSON) { return } accessLogConfigJSON = configJSON diff --git a/internal/db/models/http_access_log_manager.go b/internal/db/models/http_access_log_manager.go index 5fafedaa..72c4798a 100644 --- a/internal/db/models/http_access_log_manager.go +++ b/internal/db/models/http_access_log_manager.go @@ -422,7 +422,7 @@ func (this *HTTPAccessLogManager) checkTableFields(db *dbs.DB, tableName string) } for _, field := range fields { var fieldName = field.GetString("Field") - if strings.ToLower(fieldName) == strings.ToLower("remoteAddr") { + if strings.EqualFold(fieldName, "remoteAddr") { hasRemoteAddrField = true } if strings.ToLower(fieldName) == "domain" { diff --git a/internal/db/models/http_web_dao.go b/internal/db/models/http_web_dao.go index 2342a166..d1431b3d 100644 --- a/internal/db/models/http_web_dao.go +++ b/internal/db/models/http_web_dao.go @@ -381,7 +381,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64, cacheMap *util // 认证 if IsNotNull(web.Auth) { - authConfig := &serverconfigs.HTTPAuthConfig{} + var authConfig = &serverconfigs.HTTPAuthConfig{} err = json.Unmarshal(web.Auth, authConfig) if err != nil { return nil, err @@ -395,6 +395,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64, cacheMap *util if policyConfig != nil { ref.AuthPolicy = policyConfig newRefs = append(newRefs, ref) + authConfig.PolicyRefs = newRefs } } config.Auth = authConfig diff --git a/internal/db/models/server_dao_test.go b/internal/db/models/server_dao_test.go index 4f16d9ee..d9366ec2 100644 --- a/internal/db/models/server_dao_test.go +++ b/internal/db/models/server_dao_test.go @@ -141,7 +141,7 @@ func TestServerDAO_ExistServerNameInCluster(t *testing.T) { var tx *dbs.Tx { - exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "hello.teaos.cn", 0) + exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "hello.teaos.cn", 0, true) if err != nil { t.Fatal(err) } @@ -149,7 +149,7 @@ func TestServerDAO_ExistServerNameInCluster(t *testing.T) { } { - exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "cdn.teaos.cn", 0) + exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "cdn.teaos.cn", 0, true) if err != nil { t.Fatal(err) } @@ -157,7 +157,7 @@ func TestServerDAO_ExistServerNameInCluster(t *testing.T) { } { - exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "cdn.teaos.cn", 23) + exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "cdn.teaos.cn", 23, true) if err != nil { t.Fatal(err) } diff --git a/internal/db/models/ssl_cert_dao.go b/internal/db/models/ssl_cert_dao.go index 2a035f24..d8fdea7f 100644 --- a/internal/db/models/ssl_cert_dao.go +++ b/internal/db/models/ssl_cert_dao.go @@ -149,7 +149,7 @@ func (this *SSLCertDAO) UpdateCert(tx *dbs.Tx, return nil } var oldCert = oldOne.(*SSLCert) - var dataIsChanged = bytes.Compare(certData, oldCert.CertData) != 0 || bytes.Compare(keyData, oldCert.KeyData) != 0 + var dataIsChanged = !bytes.Equal(certData, oldCert.CertData) || !bytes.Equal(keyData, oldCert.KeyData) var op = NewSSLCertOperator() op.Id = certId diff --git a/internal/db/models/user_dao_test.go b/internal/db/models/user_dao_test.go index 95838d5a..9bde4fae 100644 --- a/internal/db/models/user_dao_test.go +++ b/internal/db/models/user_dao_test.go @@ -12,7 +12,6 @@ func TestUserDAO_UpdateUserFeatures(t *testing.T) { var dao = NewUserDAO() var tx *dbs.Tx err := dao.UpdateUsersFeatures(tx, []string{ - userconfigs.UserFeatureCodeFinance, userconfigs.UserFeatureCodeServerACME, }, false) if err != nil { diff --git a/internal/rpc/services/service_dns_domain.go b/internal/rpc/services/service_dns_domain.go index 0db6c9fc..cda7b7d3 100644 --- a/internal/rpc/services/service_dns_domain.go +++ b/internal/rpc/services/service_dns_domain.go @@ -433,17 +433,6 @@ func (this *DNSDomainService) convertDomainToPB(tx *dbs.Tx, domain *dns.DNSDomai }, nil } -// 转换域名记录信息 -func (this *DNSDomainService) convertRecordToPB(record *dnstypes.Record) *pb.DNSRecord { - return &pb.DNSRecord{ - Id: record.Id, - Name: record.Name, - Value: record.Value, - Type: record.Type, - Route: record.Route, - } -} - // 检查集群节点变化 func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, records []*dnstypes.Record, domainName string, defaultRoute string) (result []maps.Map, doneNodeRecords []*dnstypes.Record, doneServerRecords []*dnstypes.Record, countAllNodes int64, countAllServers int64, nodesChanged bool, serversChanged bool, err error) { var clusterId = int64(cluster.Id) diff --git a/internal/rpc/services/service_http_cache_task_test.go b/internal/rpc/services/service_http_cache_task_test.go index 89be398b..73ad29f1 100644 --- a/internal/rpc/services/service_http_cache_task_test.go +++ b/internal/rpc/services/service_http_cache_task_test.go @@ -3,21 +3,21 @@ package services import ( + "github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/iwind/TeaGo/assert" "testing" ) -func TestHTTPCacheTaskService_CountHTTPCacheTasks(t *testing.T) { +func TestHTTPCacheTaskService_ParseDomain(t *testing.T) { var a = assert.NewAssertion(t) - var service = &HTTPCacheTaskService{} - a.IsTrue(service.parseDomain("aaa") == "aaa") - a.IsTrue(service.parseDomain("AAA") == "aaa") - a.IsTrue(service.parseDomain("a.b-c.com") == "a.b-c.com") - a.IsTrue(service.parseDomain("a.b-c.com/hello/world") == "a.b-c.com") - a.IsTrue(service.parseDomain("https://a.b-c.com") == "a.b-c.com") - a.IsTrue(service.parseDomain("http://a.b-c.com/hello/world") == "a.b-c.com") - a.IsTrue(service.parseDomain("http://a.B-c.com/hello/world") == "a.b-c.com") - a.IsTrue(service.parseDomain("http:/aaaa.com") == "http") - a.IsTrue(service.parseDomain("北京") == "") + a.IsTrue(utils.ParseDomainFromKey("aaa") == "aaa") + a.IsTrue(utils.ParseDomainFromKey("AAA") == "aaa") + a.IsTrue(utils.ParseDomainFromKey("a.b-c.com") == "a.b-c.com") + a.IsTrue(utils.ParseDomainFromKey("a.b-c.com/hello/world") == "a.b-c.com") + a.IsTrue(utils.ParseDomainFromKey("https://a.b-c.com") == "a.b-c.com") + a.IsTrue(utils.ParseDomainFromKey("http://a.b-c.com/hello/world") == "a.b-c.com") + a.IsTrue(utils.ParseDomainFromKey("http://a.B-c.com/hello/world") == "a.b-c.com") + a.IsTrue(utils.ParseDomainFromKey("http:/aaaa.com") == "http") + a.IsTrue(utils.ParseDomainFromKey("北京") == "") } diff --git a/internal/rpc/services/service_metric_stat.go b/internal/rpc/services/service_metric_stat.go index 597dfa6d..5fad9c46 100644 --- a/internal/rpc/services/service_metric_stat.go +++ b/internal/rpc/services/service_metric_stat.go @@ -25,6 +25,8 @@ func init() { goman.New(func() { // 将队列导入数据库 var countKeys = 0 + var useTx = true + for key := range metricStatKeysQueue { err := func(key string) error { metricStatsLocker.Lock() @@ -43,18 +45,31 @@ func init() { var itemId = types.Int64(pieces[3]) // 删除旧的数据 - tx, err := models.SharedMetricStatDAO.Instance.Begin() - if err != nil { - return err - } + var tx *dbs.Tx + var err error + if useTx { + var before = time.Now() - defer func() { - // 失败时不需要rollback - commitErr := tx.Commit() - if commitErr != nil { - remotelogs.Error("METRIC_STAT", "commit metric stats failed: "+commitErr.Error()) + tx, err = models.SharedMetricStatDAO.Instance.Begin() + if err != nil { + return err } - }() + + defer func() { + // 失败时不需要rollback + if tx != nil { + commitErr := tx.Commit() + if commitErr != nil { + remotelogs.Error("METRIC_STAT", "commit metric stats failed: "+commitErr.Error()) + } + } + + // 如果运行时间过长,则不使用事务 + if time.Since(before) > 1*time.Second { + useTx = false + } + }() + } err = models.SharedMetricStatDAO.DeleteNodeItemStats(tx, nodeId, serverId, itemId, req.Time) if err != nil { diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index 85dbe826..f72d18cc 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -213,6 +213,7 @@ func (this *ServerService) UpdateServerGroupIds(ctx context.Context, req *pb.Upd } // 检查分组IDs + var serverGroupIds = []int64{} for _, groupId := range req.ServerGroupIds { if userId > 0 { err = models.SharedServerGroupDAO.CheckUserGroup(tx, userId, groupId) @@ -228,18 +229,19 @@ func (this *ServerService) UpdateServerGroupIds(ctx context.Context, req *pb.Upd continue } } + serverGroupIds = append(serverGroupIds, groupId) } // 增加默认分组 if userId > 0 { config, err := models.SharedSysSettingDAO.ReadUserServerConfig(tx) - if err == nil && config.GroupId > 0 && !lists.ContainsInt64(req.ServerGroupIds, config.GroupId) { - req.ServerGroupIds = append(req.ServerGroupIds, config.GroupId) + if err == nil && config.GroupId > 0 && !lists.ContainsInt64(serverGroupIds, config.GroupId) { + serverGroupIds = append(serverGroupIds, config.GroupId) } } // 修改 - err = models.SharedServerDAO.UpdateServerGroupIds(tx, req.ServerId, req.ServerGroupIds) + err = models.SharedServerDAO.UpdateServerGroupIds(tx, req.ServerId, serverGroupIds) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_server_bandwidth_stat.go b/internal/rpc/services/service_server_bandwidth_stat.go index 47ce9d44..3df1d538 100644 --- a/internal/rpc/services/service_server_bandwidth_stat.go +++ b/internal/rpc/services/service_server_bandwidth_stat.go @@ -20,6 +20,7 @@ var serverBandwidthStatsLocker = &sync.Mutex{} func init() { var ticker = time.NewTicker(1 * time.Minute) + var useTx = true dbs.OnReadyDone(func() { goman.New(func() { @@ -30,15 +31,32 @@ func init() { serverBandwidthStatsMap = map[string]*pb.ServerBandwidthStat{} serverBandwidthStatsLocker.Unlock() - tx, err := models.SharedServerBandwidthStatDAO.Instance.Begin() - if err != nil { - remotelogs.Error("ServerBandwidthStatService", "begin transaction failed: "+err.Error()) - return - } + var tx *dbs.Tx + var err error - defer func() { - _ = tx.Commit() - }() + if useTx { + var before = time.Now() + + tx, err = models.SharedServerBandwidthStatDAO.Instance.Begin() + if err != nil { + remotelogs.Error("ServerBandwidthStatService", "begin transaction failed: "+err.Error()) + return + } + + defer func() { + if tx != nil { + commitErr := tx.Commit() + if commitErr != nil { + remotelogs.Error("METRIC_STAT", "commit bandwidth stats failed: "+commitErr.Error()) + } + } + + // 如果运行时间过长,则不使用事务 + if time.Since(before) > 1*time.Second { + useTx = false + } + }() + } for _, stat := range m { // 更新服务的带宽峰值 diff --git a/internal/setup/sql_dump.go b/internal/setup/sql_dump.go index e6a1ccec..4c0af150 100644 --- a/internal/setup/sql_dump.go +++ b/internal/setup/sql_dump.go @@ -70,7 +70,7 @@ func (this *SQLDump) Dump(db *dbs.DB) (result *SQLDumpResult, err error) { Name: table.Name, Engine: table.Engine, Charset: table.Collation, - Definition: regexp.MustCompile(" AUTO_INCREMENT=\\d+").ReplaceAllString(table.Code, ""), + Definition: regexp.MustCompile(` AUTO_INCREMENT=\d+`).ReplaceAllString(table.Code, ""), } // 字段 diff --git a/internal/setup/sql_dump_result.go b/internal/setup/sql_dump_result.go index fad42c79..2a7b82b0 100644 --- a/internal/setup/sql_dump_result.go +++ b/internal/setup/sql_dump_result.go @@ -8,7 +8,7 @@ type SQLDumpResult struct { func (this *SQLDumpResult) FindTable(tableName string) *SQLTable { for _, table := range this.Tables { - if strings.ToLower(table.Name) == strings.ToLower(tableName) { + if strings.EqualFold(table.Name, tableName) { return table } } diff --git a/internal/tasks/health_check_cluster_task.go b/internal/tasks/health_check_cluster_task.go index 4ed775ca..f1128b22 100644 --- a/internal/tasks/health_check_cluster_task.go +++ b/internal/tasks/health_check_cluster_task.go @@ -45,7 +45,7 @@ func (this *HealthCheckClusterTask) Reset(config *serverconfigs.HealthCheckConfi this.logErr("HealthCheckClusterTask", err.Error()) return } - if bytes.Compare(oldJSON, newJSON) != 0 { + if !bytes.Equal(oldJSON, newJSON) { this.config = config this.Run() } diff --git a/internal/tasks/health_check_task.go b/internal/tasks/health_check_task.go index 699603c7..58a2ff21 100644 --- a/internal/tasks/health_check_task.go +++ b/internal/tasks/health_check_task.go @@ -86,7 +86,7 @@ func (this *HealthCheckTask) Loop() error { // 检查是否有变化 newJSON, _ := json.Marshal(config) oldJSON, _ := json.Marshal(task.Config()) - if bytes.Compare(oldJSON, newJSON) != 0 { + if !bytes.Equal(oldJSON, newJSON) { remotelogs.Println("TASK", "[HealthCheckTask]update cluster '"+numberutils.FormatInt64(clusterId)+"'") goman.New(func() { task.Reset(config) diff --git a/internal/tasks/ssl_cert_expire_check_executor.go b/internal/tasks/ssl_cert_expire_check_executor.go index 533ee1e9..ed82e954 100644 --- a/internal/tasks/ssl_cert_expire_check_executor.go +++ b/internal/tasks/ssl_cert_expire_check_executor.go @@ -122,6 +122,9 @@ func (this *SSLCertExpireCheckExecutor) Loop() error { "certId": cert.Id, "acmeTaskId": cert.AcmeTaskId, }.AsJSON()) + if err != nil { + return err + } // 更新通知时间 err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(nil, int64(cert.Id)) @@ -136,6 +139,9 @@ func (this *SSLCertExpireCheckExecutor) Loop() error { "certId": cert.Id, "acmeTaskId": cert.AcmeTaskId, }.AsJSON()) + if err != nil { + return err + } // 更新通知时间 err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(nil, int64(cert.Id)) diff --git a/internal/tests/grpc_test.go b/internal/tests/grpc_test.go index e0eb68c0..1a80e5c2 100644 --- a/internal/tests/grpc_test.go +++ b/internal/tests/grpc_test.go @@ -8,6 +8,7 @@ import ( pb2 "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" "log" "net" @@ -51,7 +52,7 @@ func TestTCPServer(t *testing.T) { } func TestTCPClient(t *testing.T) { - conn, err := grpc.Dial("127.0.0.1:8001", grpc.WithInsecure()) + conn, err := grpc.Dial("127.0.0.1:8001", grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { t.Fatal(err) }