mirror of
https://github.com/TeaOSLab/EdgeAPI.git
synced 2025-11-05 01:20:25 +08:00
支持在API节点启动的时候自动升级数据库
This commit is contained in:
@@ -4,12 +4,15 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/setup"
|
||||
"github.com/TeaOSLab/EdgeAPI/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/iwind/TeaGo/dbs"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"net"
|
||||
@@ -27,7 +30,14 @@ func NewAPINode() *APINode {
|
||||
}
|
||||
|
||||
func (this *APINode) Start() {
|
||||
logs.Println("[API]start api node, pid: " + strconv.Itoa(os.Getpid()))
|
||||
logs.Println("[API_NODE]start api node, pid: " + strconv.Itoa(os.Getpid()))
|
||||
|
||||
// 自动升级
|
||||
err := this.autoUpgrade()
|
||||
if err != nil {
|
||||
logs.Println("[API_NODE]auto upgrade failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 数据库通知启动
|
||||
dbs.NotifyReady()
|
||||
@@ -35,7 +45,7 @@ func (this *APINode) Start() {
|
||||
// 读取配置
|
||||
config, err := configs.SharedAPIConfig()
|
||||
if err != nil {
|
||||
logs.Println("[API]start failed: " + err.Error())
|
||||
logs.Println("[API_NODE]start failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
sharedAPIConfig = config
|
||||
@@ -43,11 +53,11 @@ func (this *APINode) Start() {
|
||||
// 校验
|
||||
apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINodeWithUniqueIdAndSecret(config.NodeId, config.Secret)
|
||||
if err != nil {
|
||||
logs.Println("[API]start failed: read api node from database failed: " + err.Error())
|
||||
logs.Println("[API_NODE]start failed: read api node from database failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
if apiNode == nil {
|
||||
logs.Println("[API]can not start node, wrong 'nodeId' or 'secret'")
|
||||
logs.Println("[API_NODE]can not start node, wrong 'nodeId' or 'secret'")
|
||||
return
|
||||
}
|
||||
config.SetNumberId(int64(apiNode.Id))
|
||||
@@ -56,12 +66,12 @@ func (this *APINode) Start() {
|
||||
_ = utils.SetRLimit(1024 * 1024)
|
||||
|
||||
// 监听RPC服务
|
||||
logs.Println("[API]starting rpc ...")
|
||||
logs.Println("[API_NODE]starting rpc ...")
|
||||
|
||||
// HTTP
|
||||
httpConfig, err := apiNode.DecodeHTTP()
|
||||
if err != nil {
|
||||
logs.Println("[API]decode http config: " + err.Error())
|
||||
logs.Println("[API_NODE]decode http config: " + err.Error())
|
||||
return
|
||||
}
|
||||
isListening := false
|
||||
@@ -70,13 +80,13 @@ func (this *APINode) Start() {
|
||||
for _, addr := range listen.Addresses() {
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
logs.Println("[API]listening '" + addr + "' failed: " + err.Error())
|
||||
logs.Println("[API_NODE]listening '" + addr + "' failed: " + err.Error())
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
err := this.listenRPC(listener, nil)
|
||||
if err != nil {
|
||||
logs.Println("[API]listening '" + addr + "' rpc: " + err.Error())
|
||||
logs.Println("[API_NODE]listening '" + addr + "' rpc: " + err.Error())
|
||||
return
|
||||
}
|
||||
}()
|
||||
@@ -88,7 +98,7 @@ func (this *APINode) Start() {
|
||||
// HTTPS
|
||||
httpsConfig, err := apiNode.DecodeHTTPS()
|
||||
if err != nil {
|
||||
logs.Println("[API]decode https config: " + err.Error())
|
||||
logs.Println("[API_NODE]decode https config: " + err.Error())
|
||||
return
|
||||
}
|
||||
if httpsConfig != nil &&
|
||||
@@ -106,7 +116,7 @@ func (this *APINode) Start() {
|
||||
for _, addr := range listen.Addresses() {
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
logs.Println("[API]listening '" + addr + "' failed: " + err.Error())
|
||||
logs.Println("[API_NODE]listening '" + addr + "' failed: " + err.Error())
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
@@ -114,7 +124,7 @@ func (this *APINode) Start() {
|
||||
Certificates: certs,
|
||||
})
|
||||
if err != nil {
|
||||
logs.Println("[API]listening '" + addr + "' rpc: " + err.Error())
|
||||
logs.Println("[API_NODE]listening '" + addr + "' rpc: " + err.Error())
|
||||
return
|
||||
}
|
||||
}()
|
||||
@@ -124,7 +134,7 @@ func (this *APINode) Start() {
|
||||
}
|
||||
|
||||
if !isListening {
|
||||
logs.Println("[API]the api node does have a listening address")
|
||||
logs.Println("[API_NODE]the api node does have a listening address")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -136,10 +146,10 @@ func (this *APINode) Start() {
|
||||
func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error {
|
||||
var rpcServer *grpc.Server
|
||||
if tlsConfig == nil {
|
||||
logs.Println("[API]listening http://" + listener.Addr().String() + " ...")
|
||||
logs.Println("[API_NODE]listening http://" + listener.Addr().String() + " ...")
|
||||
rpcServer = grpc.NewServer()
|
||||
} else {
|
||||
logs.Println("[API]listening https://" + listener.Addr().String() + " ...")
|
||||
logs.Println("[API_NODE]listening https://" + listener.Addr().String() + " ...")
|
||||
rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||
}
|
||||
pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
|
||||
@@ -186,8 +196,39 @@ func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) err
|
||||
pb.RegisterDNSServiceServer(rpcServer, &services.DNSService{})
|
||||
err := rpcServer.Serve(listener)
|
||||
if err != nil {
|
||||
return errors.New("[API]start rpc failed: " + err.Error())
|
||||
return errors.New("[API_NODE]start rpc failed: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 自动升级
|
||||
func (this *APINode) autoUpgrade() error {
|
||||
db, err := dbs.Default()
|
||||
if err != nil {
|
||||
return errors.New("load database config failed: " + err.Error())
|
||||
}
|
||||
one, err := db.FindOne("SELECT version FROM edgeVersions LIMIT 1")
|
||||
if err != nil {
|
||||
return errors.New("query version failed: " + err.Error())
|
||||
}
|
||||
if one != nil {
|
||||
// 如果是同样的版本,则直接认为是最新版本
|
||||
version := one.GetString("version")
|
||||
if stringutil.VersionCompare(version, teaconst.Version) >= 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
dbConfig, err := db.Config()
|
||||
if err != nil {
|
||||
return errors.New("read db config failed: " + err.Error())
|
||||
}
|
||||
logs.Println("[API_NODE]upgrade database starting ...")
|
||||
err = setup.NewSQLExecutor(dbConfig).Run()
|
||||
if err != nil {
|
||||
return errors.New("execute sql failed: " + err.Error())
|
||||
}
|
||||
logs.Println("[API_NODE]upgrade database done")
|
||||
return nil
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -208,9 +208,7 @@ func (this *SQLDump) Apply(db *dbs.DB, newResult *SQLDumpResult) (ops []string,
|
||||
args := []string{}
|
||||
values := []interface{}{}
|
||||
for k, v := range record.Values {
|
||||
if k == "id" {
|
||||
continue
|
||||
}
|
||||
// ID需要保留,因为各个表格之间需要有对应关系
|
||||
params = append(params, "`"+k+"`")
|
||||
args = append(args, "?")
|
||||
values = append(values, v)
|
||||
@@ -231,7 +229,7 @@ func (this *SQLDump) Apply(db *dbs.DB, newResult *SQLDumpResult) (ops []string,
|
||||
values = append(values, v)
|
||||
}
|
||||
values = append(values, one.GetInt("id"))
|
||||
_, err = db.Exec("UPDATE " + newTable.Name + " SET " + strings.Join(args, ", ") + " WHERE id=?", values...)
|
||||
_, err = db.Exec("UPDATE "+newTable.Name+" SET "+strings.Join(args, ", ")+" WHERE id=?", values...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user