支持在API节点启动的时候自动升级数据库

This commit is contained in:
GoEdgeLab
2020-11-17 10:26:31 +08:00
parent 59cc185a3f
commit a3449f028e
3 changed files with 59 additions and 20 deletions

View File

@@ -4,12 +4,15 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"github.com/TeaOSLab/EdgeAPI/internal/configs" "github.com/TeaOSLab/EdgeAPI/internal/configs"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services" "github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
"github.com/TeaOSLab/EdgeAPI/internal/setup"
"github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/logs"
stringutil "github.com/iwind/TeaGo/utils/string"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"net" "net"
@@ -27,7 +30,14 @@ func NewAPINode() *APINode {
} }
func (this *APINode) Start() { func (this *APINode) Start() {
logs.Println("[API]start api node, pid: " + strconv.Itoa(os.Getpid())) logs.Println("[API_NODE]start api node, pid: " + strconv.Itoa(os.Getpid()))
// 自动升级
err := this.autoUpgrade()
if err != nil {
logs.Println("[API_NODE]auto upgrade failed: " + err.Error())
return
}
// 数据库通知启动 // 数据库通知启动
dbs.NotifyReady() dbs.NotifyReady()
@@ -35,7 +45,7 @@ func (this *APINode) Start() {
// 读取配置 // 读取配置
config, err := configs.SharedAPIConfig() config, err := configs.SharedAPIConfig()
if err != nil { if err != nil {
logs.Println("[API]start failed: " + err.Error()) logs.Println("[API_NODE]start failed: " + err.Error())
return return
} }
sharedAPIConfig = config sharedAPIConfig = config
@@ -43,11 +53,11 @@ func (this *APINode) Start() {
// 校验 // 校验
apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINodeWithUniqueIdAndSecret(config.NodeId, config.Secret) apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINodeWithUniqueIdAndSecret(config.NodeId, config.Secret)
if err != nil { if err != nil {
logs.Println("[API]start failed: read api node from database failed: " + err.Error()) logs.Println("[API_NODE]start failed: read api node from database failed: " + err.Error())
return return
} }
if apiNode == nil { if apiNode == nil {
logs.Println("[API]can not start node, wrong 'nodeId' or 'secret'") logs.Println("[API_NODE]can not start node, wrong 'nodeId' or 'secret'")
return return
} }
config.SetNumberId(int64(apiNode.Id)) config.SetNumberId(int64(apiNode.Id))
@@ -56,12 +66,12 @@ func (this *APINode) Start() {
_ = utils.SetRLimit(1024 * 1024) _ = utils.SetRLimit(1024 * 1024)
// 监听RPC服务 // 监听RPC服务
logs.Println("[API]starting rpc ...") logs.Println("[API_NODE]starting rpc ...")
// HTTP // HTTP
httpConfig, err := apiNode.DecodeHTTP() httpConfig, err := apiNode.DecodeHTTP()
if err != nil { if err != nil {
logs.Println("[API]decode http config: " + err.Error()) logs.Println("[API_NODE]decode http config: " + err.Error())
return return
} }
isListening := false isListening := false
@@ -70,13 +80,13 @@ func (this *APINode) Start() {
for _, addr := range listen.Addresses() { for _, addr := range listen.Addresses() {
listener, err := net.Listen("tcp", addr) listener, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
logs.Println("[API]listening '" + addr + "' failed: " + err.Error()) logs.Println("[API_NODE]listening '" + addr + "' failed: " + err.Error())
continue continue
} }
go func() { go func() {
err := this.listenRPC(listener, nil) err := this.listenRPC(listener, nil)
if err != nil { if err != nil {
logs.Println("[API]listening '" + addr + "' rpc: " + err.Error()) logs.Println("[API_NODE]listening '" + addr + "' rpc: " + err.Error())
return return
} }
}() }()
@@ -88,7 +98,7 @@ func (this *APINode) Start() {
// HTTPS // HTTPS
httpsConfig, err := apiNode.DecodeHTTPS() httpsConfig, err := apiNode.DecodeHTTPS()
if err != nil { if err != nil {
logs.Println("[API]decode https config: " + err.Error()) logs.Println("[API_NODE]decode https config: " + err.Error())
return return
} }
if httpsConfig != nil && if httpsConfig != nil &&
@@ -106,7 +116,7 @@ func (this *APINode) Start() {
for _, addr := range listen.Addresses() { for _, addr := range listen.Addresses() {
listener, err := net.Listen("tcp", addr) listener, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
logs.Println("[API]listening '" + addr + "' failed: " + err.Error()) logs.Println("[API_NODE]listening '" + addr + "' failed: " + err.Error())
continue continue
} }
go func() { go func() {
@@ -114,7 +124,7 @@ func (this *APINode) Start() {
Certificates: certs, Certificates: certs,
}) })
if err != nil { if err != nil {
logs.Println("[API]listening '" + addr + "' rpc: " + err.Error()) logs.Println("[API_NODE]listening '" + addr + "' rpc: " + err.Error())
return return
} }
}() }()
@@ -124,7 +134,7 @@ func (this *APINode) Start() {
} }
if !isListening { if !isListening {
logs.Println("[API]the api node does have a listening address") logs.Println("[API_NODE]the api node does have a listening address")
return return
} }
@@ -136,10 +146,10 @@ func (this *APINode) Start() {
func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error { func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error {
var rpcServer *grpc.Server var rpcServer *grpc.Server
if tlsConfig == nil { if tlsConfig == nil {
logs.Println("[API]listening http://" + listener.Addr().String() + " ...") logs.Println("[API_NODE]listening http://" + listener.Addr().String() + " ...")
rpcServer = grpc.NewServer() rpcServer = grpc.NewServer()
} else { } else {
logs.Println("[API]listening https://" + listener.Addr().String() + " ...") logs.Println("[API_NODE]listening https://" + listener.Addr().String() + " ...")
rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
} }
pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{}) pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
@@ -186,8 +196,39 @@ func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) err
pb.RegisterDNSServiceServer(rpcServer, &services.DNSService{}) pb.RegisterDNSServiceServer(rpcServer, &services.DNSService{})
err := rpcServer.Serve(listener) err := rpcServer.Serve(listener)
if err != nil { if err != nil {
return errors.New("[API]start rpc failed: " + err.Error()) return errors.New("[API_NODE]start rpc failed: " + err.Error())
} }
return nil return nil
} }
// 自动升级
func (this *APINode) autoUpgrade() error {
db, err := dbs.Default()
if err != nil {
return errors.New("load database config failed: " + err.Error())
}
one, err := db.FindOne("SELECT version FROM edgeVersions LIMIT 1")
if err != nil {
return errors.New("query version failed: " + err.Error())
}
if one != nil {
// 如果是同样的版本,则直接认为是最新版本
version := one.GetString("version")
if stringutil.VersionCompare(version, teaconst.Version) >= 0 {
return nil
}
}
dbConfig, err := db.Config()
if err != nil {
return errors.New("read db config failed: " + err.Error())
}
logs.Println("[API_NODE]upgrade database starting ...")
err = setup.NewSQLExecutor(dbConfig).Run()
if err != nil {
return errors.New("execute sql failed: " + err.Error())
}
logs.Println("[API_NODE]upgrade database done")
return nil
}

File diff suppressed because one or more lines are too long

View File

@@ -208,9 +208,7 @@ func (this *SQLDump) Apply(db *dbs.DB, newResult *SQLDumpResult) (ops []string,
args := []string{} args := []string{}
values := []interface{}{} values := []interface{}{}
for k, v := range record.Values { for k, v := range record.Values {
if k == "id" { // ID需要保留因为各个表格之间需要有对应关系
continue
}
params = append(params, "`"+k+"`") params = append(params, "`"+k+"`")
args = append(args, "?") args = append(args, "?")
values = append(values, v) values = append(values, v)
@@ -231,7 +229,7 @@ func (this *SQLDump) Apply(db *dbs.DB, newResult *SQLDumpResult) (ops []string,
values = append(values, v) values = append(values, v)
} }
values = append(values, one.GetInt("id")) values = append(values, one.GetInt("id"))
_, err = db.Exec("UPDATE " + newTable.Name + " SET " + strings.Join(args, ", ") + " WHERE id=?", values...) _, err = db.Exec("UPDATE "+newTable.Name+" SET "+strings.Join(args, ", ")+" WHERE id=?", values...)
if err != nil { if err != nil {
return nil, err return nil, err
} }