优化代码

This commit is contained in:
GoEdgeLab
2022-09-23 09:28:19 +08:00
parent ebe5960731
commit 63a4c4c730
20 changed files with 90 additions and 58 deletions

View File

@@ -40,6 +40,7 @@ func (this *Request) Run() (certData []byte, keyData []byte, err error) {
} }
if this.task.Provider.RequireEAB && this.task.Account == nil { if this.task.Provider.RequireEAB && this.task.Account == nil {
err = errors.New("account should not be nil when provider require EAB") err = errors.New("account should not be nil when provider require EAB")
return
} }
switch this.task.AuthType { switch this.task.AuthType {

View File

@@ -9,7 +9,7 @@ import (
func TestDNSTaskDAO_CreateDNSTask(t *testing.T) { func TestDNSTaskDAO_CreateDNSTask(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
err := SharedDNSTaskDAO.CreateDNSTask(nil, 1, 2, 3, 0, "taskType") err := SharedDNSTaskDAO.CreateDNSTask(nil, 1, 2, 3, 0, "cdn", "taskType")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -21,7 +21,7 @@ func TestNodeClusterDAO_CheckClusterDNS(t *testing.T) {
t.Log("cluster not found, skip the test") t.Log("cluster not found, skip the test")
return return
} }
issues, err := CheckClusterDNS(tx, cluster) issues, err := CheckClusterDNS(tx, cluster, true)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -805,7 +805,7 @@ func (this *HTTPAccessLogDAO) SetupQueue() {
return return
} }
if bytes.Compare(accessLogConfigJSON, configJSON) == 0 { if bytes.Equal(accessLogConfigJSON, configJSON) {
return return
} }
accessLogConfigJSON = configJSON accessLogConfigJSON = configJSON

View File

@@ -422,7 +422,7 @@ func (this *HTTPAccessLogManager) checkTableFields(db *dbs.DB, tableName string)
} }
for _, field := range fields { for _, field := range fields {
var fieldName = field.GetString("Field") var fieldName = field.GetString("Field")
if strings.ToLower(fieldName) == strings.ToLower("remoteAddr") { if strings.EqualFold(fieldName, "remoteAddr") {
hasRemoteAddrField = true hasRemoteAddrField = true
} }
if strings.ToLower(fieldName) == "domain" { if strings.ToLower(fieldName) == "domain" {

View File

@@ -381,7 +381,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64, cacheMap *util
// 认证 // 认证
if IsNotNull(web.Auth) { if IsNotNull(web.Auth) {
authConfig := &serverconfigs.HTTPAuthConfig{} var authConfig = &serverconfigs.HTTPAuthConfig{}
err = json.Unmarshal(web.Auth, authConfig) err = json.Unmarshal(web.Auth, authConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -395,6 +395,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(tx *dbs.Tx, webId int64, cacheMap *util
if policyConfig != nil { if policyConfig != nil {
ref.AuthPolicy = policyConfig ref.AuthPolicy = policyConfig
newRefs = append(newRefs, ref) newRefs = append(newRefs, ref)
authConfig.PolicyRefs = newRefs
} }
} }
config.Auth = authConfig config.Auth = authConfig

View File

@@ -141,7 +141,7 @@ func TestServerDAO_ExistServerNameInCluster(t *testing.T) {
var tx *dbs.Tx var tx *dbs.Tx
{ {
exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "hello.teaos.cn", 0) exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "hello.teaos.cn", 0, true)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -149,7 +149,7 @@ func TestServerDAO_ExistServerNameInCluster(t *testing.T) {
} }
{ {
exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "cdn.teaos.cn", 0) exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "cdn.teaos.cn", 0, true)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -157,7 +157,7 @@ func TestServerDAO_ExistServerNameInCluster(t *testing.T) {
} }
{ {
exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "cdn.teaos.cn", 23) exist, err := models.SharedServerDAO.ExistServerNameInCluster(tx, 18, "cdn.teaos.cn", 23, true)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -149,7 +149,7 @@ func (this *SSLCertDAO) UpdateCert(tx *dbs.Tx,
return nil return nil
} }
var oldCert = oldOne.(*SSLCert) var oldCert = oldOne.(*SSLCert)
var dataIsChanged = bytes.Compare(certData, oldCert.CertData) != 0 || bytes.Compare(keyData, oldCert.KeyData) != 0 var dataIsChanged = !bytes.Equal(certData, oldCert.CertData) || !bytes.Equal(keyData, oldCert.KeyData)
var op = NewSSLCertOperator() var op = NewSSLCertOperator()
op.Id = certId op.Id = certId

View File

@@ -12,7 +12,6 @@ func TestUserDAO_UpdateUserFeatures(t *testing.T) {
var dao = NewUserDAO() var dao = NewUserDAO()
var tx *dbs.Tx var tx *dbs.Tx
err := dao.UpdateUsersFeatures(tx, []string{ err := dao.UpdateUsersFeatures(tx, []string{
userconfigs.UserFeatureCodeFinance,
userconfigs.UserFeatureCodeServerACME, userconfigs.UserFeatureCodeServerACME,
}, false) }, false)
if err != nil { if err != nil {

View File

@@ -433,17 +433,6 @@ func (this *DNSDomainService) convertDomainToPB(tx *dbs.Tx, domain *dns.DNSDomai
}, nil }, nil
} }
// 转换域名记录信息
func (this *DNSDomainService) convertRecordToPB(record *dnstypes.Record) *pb.DNSRecord {
return &pb.DNSRecord{
Id: record.Id,
Name: record.Name,
Value: record.Value,
Type: record.Type,
Route: record.Route,
}
}
// 检查集群节点变化 // 检查集群节点变化
func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, records []*dnstypes.Record, domainName string, defaultRoute string) (result []maps.Map, doneNodeRecords []*dnstypes.Record, doneServerRecords []*dnstypes.Record, countAllNodes int64, countAllServers int64, nodesChanged bool, serversChanged bool, err error) { func (this *DNSDomainService) findClusterDNSChanges(cluster *models.NodeCluster, records []*dnstypes.Record, domainName string, defaultRoute string) (result []maps.Map, doneNodeRecords []*dnstypes.Record, doneServerRecords []*dnstypes.Record, countAllNodes int64, countAllServers int64, nodesChanged bool, serversChanged bool, err error) {
var clusterId = int64(cluster.Id) var clusterId = int64(cluster.Id)

View File

@@ -3,21 +3,21 @@
package services package services
import ( import (
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/assert"
"testing" "testing"
) )
func TestHTTPCacheTaskService_CountHTTPCacheTasks(t *testing.T) { func TestHTTPCacheTaskService_ParseDomain(t *testing.T) {
var a = assert.NewAssertion(t) var a = assert.NewAssertion(t)
var service = &HTTPCacheTaskService{} a.IsTrue(utils.ParseDomainFromKey("aaa") == "aaa")
a.IsTrue(service.parseDomain("aaa") == "aaa") a.IsTrue(utils.ParseDomainFromKey("AAA") == "aaa")
a.IsTrue(service.parseDomain("AAA") == "aaa") a.IsTrue(utils.ParseDomainFromKey("a.b-c.com") == "a.b-c.com")
a.IsTrue(service.parseDomain("a.b-c.com") == "a.b-c.com") a.IsTrue(utils.ParseDomainFromKey("a.b-c.com/hello/world") == "a.b-c.com")
a.IsTrue(service.parseDomain("a.b-c.com/hello/world") == "a.b-c.com") a.IsTrue(utils.ParseDomainFromKey("https://a.b-c.com") == "a.b-c.com")
a.IsTrue(service.parseDomain("https://a.b-c.com") == "a.b-c.com") a.IsTrue(utils.ParseDomainFromKey("http://a.b-c.com/hello/world") == "a.b-c.com")
a.IsTrue(service.parseDomain("http://a.b-c.com/hello/world") == "a.b-c.com") a.IsTrue(utils.ParseDomainFromKey("http://a.B-c.com/hello/world") == "a.b-c.com")
a.IsTrue(service.parseDomain("http://a.B-c.com/hello/world") == "a.b-c.com") a.IsTrue(utils.ParseDomainFromKey("http:/aaaa.com") == "http")
a.IsTrue(service.parseDomain("http:/aaaa.com") == "http") a.IsTrue(utils.ParseDomainFromKey("北京") == "")
a.IsTrue(service.parseDomain("北京") == "")
} }

View File

@@ -25,6 +25,8 @@ func init() {
goman.New(func() { goman.New(func() {
// 将队列导入数据库 // 将队列导入数据库
var countKeys = 0 var countKeys = 0
var useTx = true
for key := range metricStatKeysQueue { for key := range metricStatKeysQueue {
err := func(key string) error { err := func(key string) error {
metricStatsLocker.Lock() metricStatsLocker.Lock()
@@ -43,18 +45,31 @@ func init() {
var itemId = types.Int64(pieces[3]) var itemId = types.Int64(pieces[3])
// 删除旧的数据 // 删除旧的数据
tx, err := models.SharedMetricStatDAO.Instance.Begin() var tx *dbs.Tx
var err error
if useTx {
var before = time.Now()
tx, err = models.SharedMetricStatDAO.Instance.Begin()
if err != nil { if err != nil {
return err return err
} }
defer func() { defer func() {
// 失败时不需要rollback // 失败时不需要rollback
if tx != nil {
commitErr := tx.Commit() commitErr := tx.Commit()
if commitErr != nil { if commitErr != nil {
remotelogs.Error("METRIC_STAT", "commit metric stats failed: "+commitErr.Error()) remotelogs.Error("METRIC_STAT", "commit metric stats failed: "+commitErr.Error())
} }
}
// 如果运行时间过长,则不使用事务
if time.Since(before) > 1*time.Second {
useTx = false
}
}() }()
}
err = models.SharedMetricStatDAO.DeleteNodeItemStats(tx, nodeId, serverId, itemId, req.Time) err = models.SharedMetricStatDAO.DeleteNodeItemStats(tx, nodeId, serverId, itemId, req.Time)
if err != nil { if err != nil {

View File

@@ -213,6 +213,7 @@ func (this *ServerService) UpdateServerGroupIds(ctx context.Context, req *pb.Upd
} }
// 检查分组IDs // 检查分组IDs
var serverGroupIds = []int64{}
for _, groupId := range req.ServerGroupIds { for _, groupId := range req.ServerGroupIds {
if userId > 0 { if userId > 0 {
err = models.SharedServerGroupDAO.CheckUserGroup(tx, userId, groupId) err = models.SharedServerGroupDAO.CheckUserGroup(tx, userId, groupId)
@@ -228,18 +229,19 @@ func (this *ServerService) UpdateServerGroupIds(ctx context.Context, req *pb.Upd
continue continue
} }
} }
serverGroupIds = append(serverGroupIds, groupId)
} }
// 增加默认分组 // 增加默认分组
if userId > 0 { if userId > 0 {
config, err := models.SharedSysSettingDAO.ReadUserServerConfig(tx) config, err := models.SharedSysSettingDAO.ReadUserServerConfig(tx)
if err == nil && config.GroupId > 0 && !lists.ContainsInt64(req.ServerGroupIds, config.GroupId) { if err == nil && config.GroupId > 0 && !lists.ContainsInt64(serverGroupIds, config.GroupId) {
req.ServerGroupIds = append(req.ServerGroupIds, config.GroupId) serverGroupIds = append(serverGroupIds, config.GroupId)
} }
} }
// 修改 // 修改
err = models.SharedServerDAO.UpdateServerGroupIds(tx, req.ServerId, req.ServerGroupIds) err = models.SharedServerDAO.UpdateServerGroupIds(tx, req.ServerId, serverGroupIds)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -20,6 +20,7 @@ var serverBandwidthStatsLocker = &sync.Mutex{}
func init() { func init() {
var ticker = time.NewTicker(1 * time.Minute) var ticker = time.NewTicker(1 * time.Minute)
var useTx = true
dbs.OnReadyDone(func() { dbs.OnReadyDone(func() {
goman.New(func() { goman.New(func() {
@@ -30,15 +31,32 @@ func init() {
serverBandwidthStatsMap = map[string]*pb.ServerBandwidthStat{} serverBandwidthStatsMap = map[string]*pb.ServerBandwidthStat{}
serverBandwidthStatsLocker.Unlock() serverBandwidthStatsLocker.Unlock()
tx, err := models.SharedServerBandwidthStatDAO.Instance.Begin() var tx *dbs.Tx
var err error
if useTx {
var before = time.Now()
tx, err = models.SharedServerBandwidthStatDAO.Instance.Begin()
if err != nil { if err != nil {
remotelogs.Error("ServerBandwidthStatService", "begin transaction failed: "+err.Error()) remotelogs.Error("ServerBandwidthStatService", "begin transaction failed: "+err.Error())
return return
} }
defer func() { defer func() {
_ = tx.Commit() if tx != nil {
commitErr := tx.Commit()
if commitErr != nil {
remotelogs.Error("METRIC_STAT", "commit bandwidth stats failed: "+commitErr.Error())
}
}
// 如果运行时间过长,则不使用事务
if time.Since(before) > 1*time.Second {
useTx = false
}
}() }()
}
for _, stat := range m { for _, stat := range m {
// 更新服务的带宽峰值 // 更新服务的带宽峰值

View File

@@ -70,7 +70,7 @@ func (this *SQLDump) Dump(db *dbs.DB) (result *SQLDumpResult, err error) {
Name: table.Name, Name: table.Name,
Engine: table.Engine, Engine: table.Engine,
Charset: table.Collation, Charset: table.Collation,
Definition: regexp.MustCompile(" AUTO_INCREMENT=\\d+").ReplaceAllString(table.Code, ""), Definition: regexp.MustCompile(` AUTO_INCREMENT=\d+`).ReplaceAllString(table.Code, ""),
} }
// 字段 // 字段

View File

@@ -8,7 +8,7 @@ type SQLDumpResult struct {
func (this *SQLDumpResult) FindTable(tableName string) *SQLTable { func (this *SQLDumpResult) FindTable(tableName string) *SQLTable {
for _, table := range this.Tables { for _, table := range this.Tables {
if strings.ToLower(table.Name) == strings.ToLower(tableName) { if strings.EqualFold(table.Name, tableName) {
return table return table
} }
} }

View File

@@ -45,7 +45,7 @@ func (this *HealthCheckClusterTask) Reset(config *serverconfigs.HealthCheckConfi
this.logErr("HealthCheckClusterTask", err.Error()) this.logErr("HealthCheckClusterTask", err.Error())
return return
} }
if bytes.Compare(oldJSON, newJSON) != 0 { if !bytes.Equal(oldJSON, newJSON) {
this.config = config this.config = config
this.Run() this.Run()
} }

View File

@@ -86,7 +86,7 @@ func (this *HealthCheckTask) Loop() error {
// 检查是否有变化 // 检查是否有变化
newJSON, _ := json.Marshal(config) newJSON, _ := json.Marshal(config)
oldJSON, _ := json.Marshal(task.Config()) oldJSON, _ := json.Marshal(task.Config())
if bytes.Compare(oldJSON, newJSON) != 0 { if !bytes.Equal(oldJSON, newJSON) {
remotelogs.Println("TASK", "[HealthCheckTask]update cluster '"+numberutils.FormatInt64(clusterId)+"'") remotelogs.Println("TASK", "[HealthCheckTask]update cluster '"+numberutils.FormatInt64(clusterId)+"'")
goman.New(func() { goman.New(func() {
task.Reset(config) task.Reset(config)

View File

@@ -122,6 +122,9 @@ func (this *SSLCertExpireCheckExecutor) Loop() error {
"certId": cert.Id, "certId": cert.Id,
"acmeTaskId": cert.AcmeTaskId, "acmeTaskId": cert.AcmeTaskId,
}.AsJSON()) }.AsJSON())
if err != nil {
return err
}
// 更新通知时间 // 更新通知时间
err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(nil, int64(cert.Id)) err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(nil, int64(cert.Id))
@@ -136,6 +139,9 @@ func (this *SSLCertExpireCheckExecutor) Loop() error {
"certId": cert.Id, "certId": cert.Id,
"acmeTaskId": cert.AcmeTaskId, "acmeTaskId": cert.AcmeTaskId,
}.AsJSON()) }.AsJSON())
if err != nil {
return err
}
// 更新通知时间 // 更新通知时间
err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(nil, int64(cert.Id)) err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(nil, int64(cert.Id))

View File

@@ -8,6 +8,7 @@ import (
pb2 "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" pb2 "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"log" "log"
"net" "net"
@@ -51,7 +52,7 @@ func TestTCPServer(t *testing.T) {
} }
func TestTCPClient(t *testing.T) { func TestTCPClient(t *testing.T) {
conn, err := grpc.Dial("127.0.0.1:8001", grpc.WithInsecure()) conn, err := grpc.Dial("127.0.0.1:8001", grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }