优化SQL升级代码

This commit is contained in:
GoEdgeLab
2024-01-29 10:22:27 +08:00
parent cc2184796b
commit 12fd8b80b0

View File

@@ -56,6 +56,12 @@ func (this *SQLExecutor) Run(showLog bool) error {
return err return err
} }
// prevent default configure loading
var globalConfig = dbs.GlobalConfig()
if globalConfig != nil && len(globalConfig.DBs) == 0 {
globalConfig.DBs = map[string]*dbs.DBConfig{"prod": this.dbConfig}
}
defer func() { defer func() {
_ = db.Close() _ = db.Close()
}() }()
@@ -91,56 +97,56 @@ func (this *SQLExecutor) checkData(db *dbs.DB) error {
// 检查管理员平台节点 // 检查管理员平台节点
err := this.checkAdminNode(db) err := this.checkAdminNode(db)
if err != nil { if err != nil {
return err return fmt.Errorf("check admin node failed: %w", err)
} }
// 检查用户平台节点 // 检查用户平台节点
err = this.checkUserNode(db) err = this.checkUserNode(db)
if err != nil { if err != nil {
return err return fmt.Errorf("check user node failed: %w", err)
} }
// 检查集群配置 // 检查集群配置
err = this.checkCluster(db) err = this.checkCluster(db)
if err != nil { if err != nil {
return err return fmt.Errorf("check cluster failed: %w", err)
} }
// 检查初始化用户 // 检查初始化用户
// 需要放在检查集群后面 // 需要放在检查集群后面
err = this.checkUser(db) err = this.checkUser(db)
if err != nil { if err != nil {
return err return fmt.Errorf("check user failed: %w", err)
} }
// 检查IP名单 // 检查IP名单
err = this.checkIPList(db) err = this.checkIPList(db)
if err != nil { if err != nil {
return err return fmt.Errorf("check ip list failed: %w", err)
} }
// 检查指标设置 // 检查指标设置
err = this.checkMetricItems(db) err = this.checkMetricItems(db)
if err != nil { if err != nil {
return err return fmt.Errorf("check metric items failed: %w", err)
} }
// 检查自建DNS全局设置 // 检查自建DNS全局设置
err = this.checkNS(db) err = this.checkNS(db)
if err != nil { if err != nil {
return err return fmt.Errorf("check ns failed: %w", err)
} }
// 更新Agents // 更新Agents
err = this.checkClientAgents(db) err = this.checkClientAgents(db)
if err != nil { if err != nil {
return err return fmt.Errorf("check client agents failed: %w", err)
} }
// 更新版本号 // 更新版本号
err = this.updateVersion(db, ComposeSQLVersion()) err = this.updateVersion(db, ComposeSQLVersion())
if err != nil { if err != nil {
return err return fmt.Errorf("update version failed: %w", err)
} }
return nil return nil
@@ -180,13 +186,13 @@ func (this *SQLExecutor) checkAdminNode(db *dbs.DB) error {
if err != nil { if err != nil {
return err return err
} }
count := types.Int(col) var count = types.Int(col)
if count > 0 { if count > 0 {
return nil return nil
} }
nodeId := rands.HexString(32) var nodeId = rands.HexString(32)
secret := rands.String(32) var secret = rands.String(32)
_, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role) VALUES (?, ?, ?)", nodeId, secret, "admin") _, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role) VALUES (?, ?, ?)", nodeId, secret, "admin")
if err != nil { if err != nil {
return err return err
@@ -208,13 +214,13 @@ func (this *SQLExecutor) checkUserNode(db *dbs.DB) error {
if err != nil { if err != nil {
return err return err
} }
count := types.Int(col) var count = types.Int(col)
if count > 0 { if count > 0 {
return nil return nil
} }
nodeId := rands.HexString(32) var nodeId = rands.HexString(32)
secret := rands.String(32) var secret = rands.String(32)
_, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role) VALUES (?, ?, ?)", nodeId, secret, "user") _, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role) VALUES (?, ?, ?)", nodeId, secret, "user")
if err != nil { if err != nil {
return err return err
@@ -225,7 +231,7 @@ func (this *SQLExecutor) checkUserNode(db *dbs.DB) error {
// 检查集群配置 // 检查集群配置
func (this *SQLExecutor) checkCluster(db *dbs.DB) error { func (this *SQLExecutor) checkCluster(db *dbs.DB) error {
/// 检查是否有集群数 /// 检查是否有集群数
stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeNodeClusters") stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeNodeClusters")
if err != nil { if err != nil {
return fmt.Errorf("query clusters failed: %w", err) return fmt.Errorf("query clusters failed: %w", err)
@@ -238,7 +244,7 @@ func (this *SQLExecutor) checkCluster(db *dbs.DB) error {
if err != nil { if err != nil {
return fmt.Errorf("query clusters failed: %w", err) return fmt.Errorf("query clusters failed: %w", err)
} }
count := types.Int(col) var count = types.Int(col)
if count > 0 { if count > 0 {
return nil return nil
} }
@@ -281,6 +287,7 @@ func (this *SQLExecutor) checkCluster(db *dbs.DB) error {
} }
// 默认缓存策略 // 默认缓存策略
models.SharedHTTPCachePolicyDAO = models.NewHTTPCachePolicyDAO() models.SharedHTTPCachePolicyDAO = models.NewHTTPCachePolicyDAO()
models.SharedHTTPCachePolicyDAO.Instance = db models.SharedHTTPCachePolicyDAO.Instance = db
policyId, err := models.SharedHTTPCachePolicyDAO.CreateDefaultCachePolicy(nil, "默认集群") policyId, err := models.SharedHTTPCachePolicyDAO.CreateDefaultCachePolicy(nil, "默认集群")
@@ -305,6 +312,15 @@ func (this *SQLExecutor) checkCluster(db *dbs.DB) error {
models.SharedHTTPFirewallRuleDAO = models.NewHTTPFirewallRuleDAO() models.SharedHTTPFirewallRuleDAO = models.NewHTTPFirewallRuleDAO()
models.SharedHTTPFirewallRuleDAO.Instance = db models.SharedHTTPFirewallRuleDAO.Instance = db
models.SharedHTTPWebDAO = models.NewHTTPWebDAO()
models.SharedHTTPWebDAO.Instance = db
models.SharedServerDAO = models.NewServerDAO()
models.SharedServerDAO.Instance = db
models.SharedNodeClusterDAO = models.NewNodeClusterDAO()
models.SharedNodeClusterDAO.Instance = db
policyId, err = models.SharedHTTPFirewallPolicyDAO.CreateDefaultFirewallPolicy(nil, "默认集群") policyId, err = models.SharedHTTPFirewallPolicyDAO.CreateDefaultFirewallPolicy(nil, "默认集群")
if err != nil { if err != nil {
return err return err
@@ -331,7 +347,7 @@ func (this *SQLExecutor) checkIPList(db *dbs.DB) error {
if err != nil { if err != nil {
return fmt.Errorf("query ip lists failed: %w", err) return fmt.Errorf("query ip lists failed: %w", err)
} }
count := types.Int(col) var count = types.Int(col)
if count > 0 { if count > 0 {
return nil return nil
} }
@@ -388,7 +404,7 @@ func (this *SQLExecutor) checkMetricItems(db *dbs.DB) error {
// chart // chart
for _, chartMap := range chartMaps { for _, chartMap := range chartMaps {
chartCode := chartMap.GetString("code") var chartCode = chartMap.GetString("code")
one, err := db.FindOne("SELECT id FROM edgeMetricCharts WHERE itemId=? AND code=? LIMIT 1", itemId, chartCode) one, err := db.FindOne("SELECT id FROM edgeMetricCharts WHERE itemId=? AND code=? LIMIT 1", itemId, chartCode)
if err != nil { if err != nil {
return err return err
@@ -528,7 +544,7 @@ func (this *SQLExecutor) updateVersion(db *dbs.DB, version string) error {
if err != nil { if err != nil {
return fmt.Errorf("query version failed: %w", err) return fmt.Errorf("query version failed: %w", err)
} }
count := types.Int(col) var count = types.Int(col)
if count > 0 { if count > 0 {
_, err = db.Exec("UPDATE edgeVersions SET version=?", version) _, err = db.Exec("UPDATE edgeVersions SET version=?", version)
if err != nil { if err != nil {