From 12fd8b80b01620c487df8c6dd858ff0207ce6aa5 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Mon, 29 Jan 2024 10:22:27 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96SQL=E5=8D=87=E7=BA=A7?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/setup/sql_executor.go | 56 ++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/internal/setup/sql_executor.go b/internal/setup/sql_executor.go index ef0f2678..1bbc804a 100644 --- a/internal/setup/sql_executor.go +++ b/internal/setup/sql_executor.go @@ -56,6 +56,12 @@ func (this *SQLExecutor) Run(showLog bool) error { return err } + // prevent default configure loading + var globalConfig = dbs.GlobalConfig() + if globalConfig != nil && len(globalConfig.DBs) == 0 { + globalConfig.DBs = map[string]*dbs.DBConfig{"prod": this.dbConfig} + } + defer func() { _ = db.Close() }() @@ -91,56 +97,56 @@ func (this *SQLExecutor) checkData(db *dbs.DB) error { // 检查管理员平台节点 err := this.checkAdminNode(db) if err != nil { - return err + return fmt.Errorf("check admin node failed: %w", err) } // 检查用户平台节点 err = this.checkUserNode(db) if err != nil { - return err + return fmt.Errorf("check user node failed: %w", err) } // 检查集群配置 err = this.checkCluster(db) if err != nil { - return err + return fmt.Errorf("check cluster failed: %w", err) } // 检查初始化用户 // 需要放在检查集群后面 err = this.checkUser(db) if err != nil { - return err + return fmt.Errorf("check user failed: %w", err) } // 检查IP名单 err = this.checkIPList(db) if err != nil { - return err + return fmt.Errorf("check ip list failed: %w", err) } // 检查指标设置 err = this.checkMetricItems(db) if err != nil { - return err + return fmt.Errorf("check metric items failed: %w", err) } // 检查自建DNS全局设置 err = this.checkNS(db) if err != nil { - return err + return fmt.Errorf("check ns failed: %w", err) } // 更新Agents err = this.checkClientAgents(db) if err != nil { - return err + return fmt.Errorf("check client agents failed: %w", err) } // 更新版本号 err = this.updateVersion(db, ComposeSQLVersion()) if err != nil { - return err + return fmt.Errorf("update version failed: %w", err) } return nil @@ -180,13 +186,13 @@ func (this *SQLExecutor) checkAdminNode(db *dbs.DB) error { if err != nil { return err } - count := types.Int(col) + var count = types.Int(col) if count > 0 { return nil } - nodeId := rands.HexString(32) - secret := rands.String(32) + var nodeId = rands.HexString(32) + var secret = rands.String(32) _, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role) VALUES (?, ?, ?)", nodeId, secret, "admin") if err != nil { return err @@ -208,13 +214,13 @@ func (this *SQLExecutor) checkUserNode(db *dbs.DB) error { if err != nil { return err } - count := types.Int(col) + var count = types.Int(col) if count > 0 { return nil } - nodeId := rands.HexString(32) - secret := rands.String(32) + var nodeId = rands.HexString(32) + var secret = rands.String(32) _, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role) VALUES (?, ?, ?)", nodeId, secret, "user") if err != nil { return err @@ -225,7 +231,7 @@ func (this *SQLExecutor) checkUserNode(db *dbs.DB) error { // 检查集群配置 func (this *SQLExecutor) checkCluster(db *dbs.DB) error { - /// 检查是否有集群数字 + /// 检查是否有集群数据 stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeNodeClusters") if err != nil { return fmt.Errorf("query clusters failed: %w", err) @@ -238,7 +244,7 @@ func (this *SQLExecutor) checkCluster(db *dbs.DB) error { if err != nil { return fmt.Errorf("query clusters failed: %w", err) } - count := types.Int(col) + var count = types.Int(col) if count > 0 { return nil } @@ -281,6 +287,7 @@ func (this *SQLExecutor) checkCluster(db *dbs.DB) error { } // 默认缓存策略 + models.SharedHTTPCachePolicyDAO = models.NewHTTPCachePolicyDAO() models.SharedHTTPCachePolicyDAO.Instance = db policyId, err := models.SharedHTTPCachePolicyDAO.CreateDefaultCachePolicy(nil, "默认集群") @@ -305,6 +312,15 @@ func (this *SQLExecutor) checkCluster(db *dbs.DB) error { models.SharedHTTPFirewallRuleDAO = models.NewHTTPFirewallRuleDAO() models.SharedHTTPFirewallRuleDAO.Instance = db + models.SharedHTTPWebDAO = models.NewHTTPWebDAO() + models.SharedHTTPWebDAO.Instance = db + + models.SharedServerDAO = models.NewServerDAO() + models.SharedServerDAO.Instance = db + + models.SharedNodeClusterDAO = models.NewNodeClusterDAO() + models.SharedNodeClusterDAO.Instance = db + policyId, err = models.SharedHTTPFirewallPolicyDAO.CreateDefaultFirewallPolicy(nil, "默认集群") if err != nil { return err @@ -331,7 +347,7 @@ func (this *SQLExecutor) checkIPList(db *dbs.DB) error { if err != nil { return fmt.Errorf("query ip lists failed: %w", err) } - count := types.Int(col) + var count = types.Int(col) if count > 0 { return nil } @@ -388,7 +404,7 @@ func (this *SQLExecutor) checkMetricItems(db *dbs.DB) error { // chart for _, chartMap := range chartMaps { - chartCode := chartMap.GetString("code") + var chartCode = chartMap.GetString("code") one, err := db.FindOne("SELECT id FROM edgeMetricCharts WHERE itemId=? AND code=? LIMIT 1", itemId, chartCode) if err != nil { return err @@ -528,7 +544,7 @@ func (this *SQLExecutor) updateVersion(db *dbs.DB, version string) error { if err != nil { return fmt.Errorf("query version failed: %w", err) } - count := types.Int(col) + var count = types.Int(col) if count > 0 { _, err = db.Exec("UPDATE edgeVersions SET version=?", version) if err != nil {