mirror of
				https://github.com/TeaOSLab/EdgeAPI.git
				synced 2025-11-04 16:00:24 +08:00 
			
		
		
		
	初步实现安装界面
This commit is contained in:
		
							
								
								
									
										279
									
								
								internal/setup/sql_executor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										279
									
								
								internal/setup/sql_executor.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,279 @@
 | 
			
		||||
package setup
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/setup/sqls"
 | 
			
		||||
	_ "github.com/go-sql-driver/mysql"
 | 
			
		||||
	"github.com/iwind/TeaGo/dbs"
 | 
			
		||||
	"github.com/iwind/TeaGo/lists"
 | 
			
		||||
	"github.com/iwind/TeaGo/rands"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	stringutil "github.com/iwind/TeaGo/utils/string"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type SQLExecutor struct {
 | 
			
		||||
	dbConfig *dbs.DBConfig
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSQLExecutor(dbConfig *dbs.DBConfig) *SQLExecutor {
 | 
			
		||||
	return &SQLExecutor{
 | 
			
		||||
		dbConfig: dbConfig,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *SQLExecutor) Run() error {
 | 
			
		||||
	db, err := dbs.NewInstanceFromConfig(this.dbConfig)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tableNames, err := db.TableNames()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查已有数据
 | 
			
		||||
	if lists.ContainsString(tableNames, "edgeVersions") {
 | 
			
		||||
		stmt, err := db.Prepare("SELECT version FROM " + db.TablePrefix() + "Versions")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		defer func() {
 | 
			
		||||
			_ = stmt.Close()
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		rows, err := stmt.Query()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		defer func() {
 | 
			
		||||
			_ = rows.Close()
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		// 对比版本
 | 
			
		||||
		oldVersion := ""
 | 
			
		||||
		if rows.Next() {
 | 
			
		||||
			err = rows.Scan(&oldVersion)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.New("query version failed: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if oldVersion == teaconst.Version {
 | 
			
		||||
			err = this.checkData(db)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		oldVersion = strings.Replace(oldVersion, "_", ".", -1)
 | 
			
		||||
 | 
			
		||||
		upgradeVersions := []string{}
 | 
			
		||||
		sqlMap := map[string]string{} // version => sql
 | 
			
		||||
 | 
			
		||||
		for _, m := range sqls.SQLVersions {
 | 
			
		||||
			version, _ := m["version"]
 | 
			
		||||
			if version == "full" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			version = strings.Replace(version, "_", ".", -1)
 | 
			
		||||
			if len(oldVersion) == 0 || stringutil.VersionCompare(version, oldVersion) > 0 {
 | 
			
		||||
				upgradeVersions = append(upgradeVersions, version)
 | 
			
		||||
				sql, _ := m["sql"]
 | 
			
		||||
				sqlMap[version] = sql
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 如果没有可以升级的版本,直接返回
 | 
			
		||||
		if len(upgradeVersions) == 0 {
 | 
			
		||||
			err = this.checkData(db)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		sort.Slice(upgradeVersions, func(i, j int) bool {
 | 
			
		||||
			return stringutil.VersionCompare(upgradeVersions[i], upgradeVersions[j]) < 0
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		// 执行升级的SQL
 | 
			
		||||
		for _, version := range upgradeVersions {
 | 
			
		||||
			sql := sqlMap[version]
 | 
			
		||||
			_, err = db.Exec(sql)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				if !this.canIgnoreError(err) {
 | 
			
		||||
					return errors.New("exec upgrade sql for version '" + version + "' failed: " + err.Error())
 | 
			
		||||
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			err = this.updateVersion(db, version)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 检查数据
 | 
			
		||||
		err = this.checkData(db)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 全新安装
 | 
			
		||||
	fullSQL, found := this.findFullSQL()
 | 
			
		||||
	if !found {
 | 
			
		||||
		return errors.New("not found full setup sql")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 执行SQL
 | 
			
		||||
	_, err = db.Exec(fullSQL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("create tables failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查数据
 | 
			
		||||
	err = this.checkData(db)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 查找完整的SQL
 | 
			
		||||
func (this *SQLExecutor) findFullSQL() (sql string, found bool) {
 | 
			
		||||
	for _, m := range sqls.SQLVersions {
 | 
			
		||||
		code, _ := m["version"]
 | 
			
		||||
		if code == "full" {
 | 
			
		||||
			sql, found = m["sql"]
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 检查数据
 | 
			
		||||
func (this *SQLExecutor) checkData(db *dbs.DB) error {
 | 
			
		||||
	// 检查管理员
 | 
			
		||||
	err := this.checkAdmin(db)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查管理员平台节点
 | 
			
		||||
	err = this.checkAdminNode(db)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新版本号
 | 
			
		||||
	err = this.updateVersion(db, teaconst.Version)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 检查管理员
 | 
			
		||||
func (this *SQLExecutor) checkAdmin(db *dbs.DB) error {
 | 
			
		||||
	stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeAdmins")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("check admin failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = stmt.Close()
 | 
			
		||||
	}()
 | 
			
		||||
	col, err := stmt.FindCol(0)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("check admin failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	count := types.Int(col)
 | 
			
		||||
	if count == 0 {
 | 
			
		||||
		_, err = db.Exec("INSERT INTO edgeAdmins (username, password, fullname, isSuper, createdAt, state) VALUES (?, ?, ?, ?, ?, ?)", "admin", stringutil.Md5("123456"), "管理员", 1, time.Now().Unix(), 1)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errors.New("create admin failed: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 检查管理员平台节点
 | 
			
		||||
func (this *SQLExecutor) checkAdminNode(db *dbs.DB) error {
 | 
			
		||||
	stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeAPITokens WHERE role='admin'")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = stmt.Close()
 | 
			
		||||
	}()
 | 
			
		||||
	col, err := stmt.FindCol(0)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	count := types.Int(col)
 | 
			
		||||
	if count > 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nodeId := rands.HexString(32)
 | 
			
		||||
	secret := rands.String(32)
 | 
			
		||||
	_, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role) VALUES (?, ?, ?)", nodeId, secret, "admin")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 更新版本号
 | 
			
		||||
func (this *SQLExecutor) updateVersion(db *dbs.DB, version string) error {
 | 
			
		||||
	stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeVersions")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("query version failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = stmt.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	col, err := stmt.FindCol(0)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("query version failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	count := types.Int(col)
 | 
			
		||||
	if count > 0 {
 | 
			
		||||
		_, err = db.Exec("UPDATE edgeVersions SET version=?", version)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errors.New("update version failed: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = db.Exec("INSERT edgeVersions (version) VALUES (?)", version)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("create version failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 判断某个错误是否可以忽略
 | 
			
		||||
func (this *SQLExecutor) canIgnoreError(err error) bool {
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Error 1050: Table 'xxx' already exists
 | 
			
		||||
	if strings.Contains(err.Error(), "Error 1050") {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user