mirror of
				https://github.com/TeaOSLab/EdgeAPI.git
				synced 2025-11-04 16:00:24 +08:00 
			
		
		
		
	修改SQL对比算法
This commit is contained in:
		@@ -3,21 +3,20 @@ package setup
 | 
			
		||||
import (
 | 
			
		||||
	teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/setup/sqls"
 | 
			
		||||
	_ "github.com/go-sql-driver/mysql"
 | 
			
		||||
	"github.com/go-yaml/yaml"
 | 
			
		||||
	"github.com/iwind/TeaGo/Tea"
 | 
			
		||||
	"github.com/iwind/TeaGo/dbs"
 | 
			
		||||
	"github.com/iwind/TeaGo/lists"
 | 
			
		||||
	"github.com/iwind/TeaGo/rands"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	stringutil "github.com/iwind/TeaGo/utils/string"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var LatestSQLResult = &SQLDumpResult{}
 | 
			
		||||
 | 
			
		||||
// 安装或升级SQL执行器
 | 
			
		||||
type SQLExecutor struct {
 | 
			
		||||
	dbConfig *dbs.DBConfig
 | 
			
		||||
@@ -49,113 +48,12 @@ func (this *SQLExecutor) Run() error {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tableNames, err := db.TableNames()
 | 
			
		||||
	sqlDump := NewSQLDump()
 | 
			
		||||
	_, err = sqlDump.Apply(db, LatestSQLResult)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查已有数据
 | 
			
		||||
	if lists.ContainsString(tableNames, "edgeVersions") {
 | 
			
		||||
		stmt, err := db.Prepare("SELECT version FROM " + db.TablePrefix() + "Versions")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		defer func() {
 | 
			
		||||
			_ = stmt.Close()
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		rows, err := stmt.Query()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		defer func() {
 | 
			
		||||
			_ = rows.Close()
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		// 对比版本
 | 
			
		||||
		oldVersion := ""
 | 
			
		||||
		if rows.Next() {
 | 
			
		||||
			err = rows.Scan(&oldVersion)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.New("query version failed: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if oldVersion == teaconst.Version {
 | 
			
		||||
			err = this.checkData(db)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		oldVersion = strings.Replace(oldVersion, "_", ".", -1)
 | 
			
		||||
 | 
			
		||||
		upgradeVersions := []string{}
 | 
			
		||||
		sqlMap := map[string]string{} // version => sql
 | 
			
		||||
 | 
			
		||||
		for _, m := range sqls.SQLVersions {
 | 
			
		||||
			version, _ := m["version"]
 | 
			
		||||
			if version == "full" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			version = strings.Replace(version, "_", ".", -1)
 | 
			
		||||
			if len(oldVersion) == 0 || stringutil.VersionCompare(version, oldVersion) > 0 {
 | 
			
		||||
				upgradeVersions = append(upgradeVersions, version)
 | 
			
		||||
				sql, _ := m["sql"]
 | 
			
		||||
				sqlMap[version] = sql
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 如果没有可以升级的版本,直接返回
 | 
			
		||||
		if len(upgradeVersions) == 0 {
 | 
			
		||||
			err = this.checkData(db)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		sort.Slice(upgradeVersions, func(i, j int) bool {
 | 
			
		||||
			return stringutil.VersionCompare(upgradeVersions[i], upgradeVersions[j]) < 0
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		// 执行升级的SQL
 | 
			
		||||
		for _, version := range upgradeVersions {
 | 
			
		||||
			sql := sqlMap[version]
 | 
			
		||||
			_, err = db.Exec(sql)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				if !this.canIgnoreError(err) {
 | 
			
		||||
					return errors.New("exec upgrade sql for version '" + version + "' failed: " + err.Error())
 | 
			
		||||
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			err = this.updateVersion(db, version)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 检查数据
 | 
			
		||||
		err = this.checkData(db)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 全新安装
 | 
			
		||||
	fullSQL, found := this.findFullSQL()
 | 
			
		||||
	if !found {
 | 
			
		||||
		return errors.New("not found full setup sql")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 执行SQL
 | 
			
		||||
	_, err = db.Exec(fullSQL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("create tables failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查数据
 | 
			
		||||
	err = this.checkData(db)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -165,18 +63,6 @@ func (this *SQLExecutor) Run() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 查找完整的SQL
 | 
			
		||||
func (this *SQLExecutor) findFullSQL() (sql string, found bool) {
 | 
			
		||||
	for _, m := range sqls.SQLVersions {
 | 
			
		||||
		code, _ := m["version"]
 | 
			
		||||
		if code == "full" {
 | 
			
		||||
			sql, found = m["sql"]
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 检查数据
 | 
			
		||||
func (this *SQLExecutor) checkData(db *dbs.DB) error {
 | 
			
		||||
	// 检查管理员
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user