Files
EdgeAPI/internal/setup/sql_executor.go
GoEdgeLab 8962b6a0bb 更新SQL
2021-07-13 15:10:19 +08:00

292 lines
6.2 KiB
Go

package setup
import (
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
_ "github.com/go-sql-driver/mysql"
"github.com/go-yaml/yaml"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"io/ioutil"
"strings"
"time"
)
var LatestSQLResult = &SQLDumpResult{}
// 安装或升级SQL执行器
type SQLExecutor struct {
dbConfig *dbs.DBConfig
}
func NewSQLExecutor(dbConfig *dbs.DBConfig) *SQLExecutor {
return &SQLExecutor{
dbConfig: dbConfig,
}
}
func NewSQLExecutorFromCmd() (*SQLExecutor, error) {
// 执行SQL
config := &dbs.Config{}
configData, err := ioutil.ReadFile(Tea.ConfigFile("db.yaml"))
if err != nil {
return nil, err
}
err = yaml.Unmarshal(configData, config)
if err != nil {
return nil, err
}
return NewSQLExecutor(config.DBs[Tea.Env]), nil
}
func (this *SQLExecutor) Run() error {
db, err := dbs.NewInstanceFromConfig(this.dbConfig)
if err != nil {
return err
}
sqlDump := NewSQLDump()
_, err = sqlDump.Apply(db, LatestSQLResult)
if err != nil {
return err
}
// 检查数据
err = this.checkData(db)
if err != nil {
return err
}
return nil
}
// 检查数据
func (this *SQLExecutor) checkData(db *dbs.DB) error {
// 检查管理员
err := this.checkAdmin(db)
if err != nil {
return err
}
// 检查管理员平台节点
err = this.checkAdminNode(db)
if err != nil {
return err
}
// 检查用户平台节点
err = this.checkUserNode(db)
if err != nil {
return err
}
// 检查集群配置
err = this.checkCluster(db)
if err != nil {
return err
}
// 检查IP名单
err = this.checkIPList(db)
if err != nil {
return err
}
// 更新版本号
err = this.updateVersion(db, teaconst.Version)
if err != nil {
return err
}
return nil
}
// 检查管理员
func (this *SQLExecutor) checkAdmin(db *dbs.DB) error {
stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeAdmins")
if err != nil {
return errors.New("check admin failed: " + err.Error())
}
defer func() {
_ = stmt.Close()
}()
col, err := stmt.FindCol(0)
if err != nil {
return errors.New("check admin failed: " + err.Error())
}
count := types.Int(col)
if count == 0 {
_, err = db.Exec("INSERT INTO edgeAdmins (username, password, fullname, isSuper, createdAt, state) VALUES (?, ?, ?, ?, ?, ?)", "admin", stringutil.Md5("123456"), "管理员", 1, time.Now().Unix(), 1)
if err != nil {
return errors.New("create admin failed: " + err.Error())
}
}
return nil
}
// 检查管理员平台节点
func (this *SQLExecutor) checkAdminNode(db *dbs.DB) error {
stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeAPITokens WHERE role='admin'")
if err != nil {
return err
}
defer func() {
_ = stmt.Close()
}()
col, err := stmt.FindCol(0)
if err != nil {
return err
}
count := types.Int(col)
if count > 0 {
return nil
}
nodeId := rands.HexString(32)
secret := rands.String(32)
_, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role) VALUES (?, ?, ?)", nodeId, secret, "admin")
if err != nil {
return err
}
return nil
}
// 检查用户平台节点
func (this *SQLExecutor) checkUserNode(db *dbs.DB) error {
stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeAPITokens WHERE role='user'")
if err != nil {
return err
}
defer func() {
_ = stmt.Close()
}()
col, err := stmt.FindCol(0)
if err != nil {
return err
}
count := types.Int(col)
if count > 0 {
return nil
}
nodeId := rands.HexString(32)
secret := rands.String(32)
_, err = db.Exec("INSERT INTO edgeAPITokens (nodeId, secret, role) VALUES (?, ?, ?)", nodeId, secret, "user")
if err != nil {
return err
}
return nil
}
// 检查集群配置
func (this *SQLExecutor) checkCluster(db *dbs.DB) error {
/// 检查是否有集群数字
stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeNodeClusters")
if err != nil {
return errors.New("query clusters failed: " + err.Error())
}
defer func() {
_ = stmt.Close()
}()
col, err := stmt.FindCol(0)
if err != nil {
return errors.New("query clusters failed: " + err.Error())
}
count := types.Int(col)
if count > 0 {
return nil
}
// 创建默认集群
_, err = db.Exec("INSERT INTO edgeNodeClusters (name, useAllAPINodes, state, uniqueId, secret) VALUES (?, ?, ?, ?, ?)", "默认集群", 1, 1, rands.HexString(32), rands.String(32))
if err != nil {
return err
}
return nil
}
// 检查IP名单
func (this *SQLExecutor) checkIPList(db *dbs.DB) error {
stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeIPLists")
if err != nil {
return errors.New("query ip lists failed: " + err.Error())
}
defer func() {
_ = stmt.Close()
}()
col, err := stmt.FindCol(0)
if err != nil {
return errors.New("query ip lists failed: " + err.Error())
}
count := types.Int(col)
if count > 0 {
return nil
}
// 创建名单
_, err = db.Exec("INSERT INTO edgeIPLists(name, type, code, isPublic, createdAt) VALUES (?, ?, ?, ?, ?)", "公共黑名单", "black", "black", 1, time.Now().Unix())
if err != nil {
return err
}
_, err = db.Exec("INSERT INTO edgeIPLists(name, type, code, isPublic, createdAt) VALUES (?, ?, ?, ?, ?)", "公共白名单", "white", "white", 1, time.Now().Unix())
if err != nil {
return err
}
return nil
}
// 更新版本号
func (this *SQLExecutor) updateVersion(db *dbs.DB, version string) error {
stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeVersions")
if err != nil {
return errors.New("query version failed: " + err.Error())
}
defer func() {
_ = stmt.Close()
}()
col, err := stmt.FindCol(0)
if err != nil {
return errors.New("query version failed: " + err.Error())
}
count := types.Int(col)
if count > 0 {
_, err = db.Exec("UPDATE edgeVersions SET version=?", version)
if err != nil {
return errors.New("update version failed: " + err.Error())
}
return nil
}
_, err = db.Exec("INSERT edgeVersions (version) VALUES (?)", version)
if err != nil {
return errors.New("create version failed: " + err.Error())
}
return nil
}
// 判断某个错误是否可以忽略
func (this *SQLExecutor) canIgnoreError(err error) bool {
if err == nil {
return true
}
// Error 1050: Table 'xxx' already exists
if strings.Contains(err.Error(), "Error 1050") {
return true
}
return false
}