mirror of
				https://github.com/TeaOSLab/EdgeAPI.git
				synced 2025-11-04 16:00:24 +08:00 
			
		
		
		
	支持在API节点启动的时候自动升级数据库
This commit is contained in:
		@@ -4,12 +4,15 @@ import (
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/configs"
 | 
			
		||||
	teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/db/models"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/setup"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeAPI/internal/utils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
	"github.com/iwind/TeaGo/dbs"
 | 
			
		||||
	"github.com/iwind/TeaGo/logs"
 | 
			
		||||
	stringutil "github.com/iwind/TeaGo/utils/string"
 | 
			
		||||
	"google.golang.org/grpc"
 | 
			
		||||
	"google.golang.org/grpc/credentials"
 | 
			
		||||
	"net"
 | 
			
		||||
@@ -27,7 +30,14 @@ func NewAPINode() *APINode {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *APINode) Start() {
 | 
			
		||||
	logs.Println("[API]start api node, pid: " + strconv.Itoa(os.Getpid()))
 | 
			
		||||
	logs.Println("[API_NODE]start api node, pid: " + strconv.Itoa(os.Getpid()))
 | 
			
		||||
 | 
			
		||||
	// 自动升级
 | 
			
		||||
	err := this.autoUpgrade()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Println("[API_NODE]auto upgrade failed: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 数据库通知启动
 | 
			
		||||
	dbs.NotifyReady()
 | 
			
		||||
@@ -35,7 +45,7 @@ func (this *APINode) Start() {
 | 
			
		||||
	// 读取配置
 | 
			
		||||
	config, err := configs.SharedAPIConfig()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Println("[API]start failed: " + err.Error())
 | 
			
		||||
		logs.Println("[API_NODE]start failed: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	sharedAPIConfig = config
 | 
			
		||||
@@ -43,11 +53,11 @@ func (this *APINode) Start() {
 | 
			
		||||
	// 校验
 | 
			
		||||
	apiNode, err := models.SharedAPINodeDAO.FindEnabledAPINodeWithUniqueIdAndSecret(config.NodeId, config.Secret)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Println("[API]start failed: read api node from database failed: " + err.Error())
 | 
			
		||||
		logs.Println("[API_NODE]start failed: read api node from database failed: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if apiNode == nil {
 | 
			
		||||
		logs.Println("[API]can not start node, wrong 'nodeId' or 'secret'")
 | 
			
		||||
		logs.Println("[API_NODE]can not start node, wrong 'nodeId' or 'secret'")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	config.SetNumberId(int64(apiNode.Id))
 | 
			
		||||
@@ -56,12 +66,12 @@ func (this *APINode) Start() {
 | 
			
		||||
	_ = utils.SetRLimit(1024 * 1024)
 | 
			
		||||
 | 
			
		||||
	// 监听RPC服务
 | 
			
		||||
	logs.Println("[API]starting rpc ...")
 | 
			
		||||
	logs.Println("[API_NODE]starting rpc ...")
 | 
			
		||||
 | 
			
		||||
	// HTTP
 | 
			
		||||
	httpConfig, err := apiNode.DecodeHTTP()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Println("[API]decode http config: " + err.Error())
 | 
			
		||||
		logs.Println("[API_NODE]decode http config: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	isListening := false
 | 
			
		||||
@@ -70,13 +80,13 @@ func (this *APINode) Start() {
 | 
			
		||||
			for _, addr := range listen.Addresses() {
 | 
			
		||||
				listener, err := net.Listen("tcp", addr)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logs.Println("[API]listening '" + addr + "' failed: " + err.Error())
 | 
			
		||||
					logs.Println("[API_NODE]listening '" + addr + "' failed: " + err.Error())
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				go func() {
 | 
			
		||||
					err := this.listenRPC(listener, nil)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logs.Println("[API]listening '" + addr + "' rpc: " + err.Error())
 | 
			
		||||
						logs.Println("[API_NODE]listening '" + addr + "' rpc: " + err.Error())
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
				}()
 | 
			
		||||
@@ -88,7 +98,7 @@ func (this *APINode) Start() {
 | 
			
		||||
	// HTTPS
 | 
			
		||||
	httpsConfig, err := apiNode.DecodeHTTPS()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Println("[API]decode https config: " + err.Error())
 | 
			
		||||
		logs.Println("[API_NODE]decode https config: " + err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if httpsConfig != nil &&
 | 
			
		||||
@@ -106,7 +116,7 @@ func (this *APINode) Start() {
 | 
			
		||||
			for _, addr := range listen.Addresses() {
 | 
			
		||||
				listener, err := net.Listen("tcp", addr)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logs.Println("[API]listening '" + addr + "' failed: " + err.Error())
 | 
			
		||||
					logs.Println("[API_NODE]listening '" + addr + "' failed: " + err.Error())
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				go func() {
 | 
			
		||||
@@ -114,7 +124,7 @@ func (this *APINode) Start() {
 | 
			
		||||
						Certificates: certs,
 | 
			
		||||
					})
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						logs.Println("[API]listening '" + addr + "' rpc: " + err.Error())
 | 
			
		||||
						logs.Println("[API_NODE]listening '" + addr + "' rpc: " + err.Error())
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
				}()
 | 
			
		||||
@@ -124,7 +134,7 @@ func (this *APINode) Start() {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !isListening {
 | 
			
		||||
		logs.Println("[API]the api node does have a listening address")
 | 
			
		||||
		logs.Println("[API_NODE]the api node does have a listening address")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -136,10 +146,10 @@ func (this *APINode) Start() {
 | 
			
		||||
func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error {
 | 
			
		||||
	var rpcServer *grpc.Server
 | 
			
		||||
	if tlsConfig == nil {
 | 
			
		||||
		logs.Println("[API]listening http://" + listener.Addr().String() + " ...")
 | 
			
		||||
		logs.Println("[API_NODE]listening http://" + listener.Addr().String() + " ...")
 | 
			
		||||
		rpcServer = grpc.NewServer()
 | 
			
		||||
	} else {
 | 
			
		||||
		logs.Println("[API]listening https://" + listener.Addr().String() + " ...")
 | 
			
		||||
		logs.Println("[API_NODE]listening https://" + listener.Addr().String() + " ...")
 | 
			
		||||
		rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
 | 
			
		||||
	}
 | 
			
		||||
	pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
 | 
			
		||||
@@ -186,8 +196,39 @@ func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) err
 | 
			
		||||
	pb.RegisterDNSServiceServer(rpcServer, &services.DNSService{})
 | 
			
		||||
	err := rpcServer.Serve(listener)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("[API]start rpc failed: " + err.Error())
 | 
			
		||||
		return errors.New("[API_NODE]start rpc failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 自动升级
 | 
			
		||||
func (this *APINode) autoUpgrade() error {
 | 
			
		||||
	db, err := dbs.Default()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("load database config failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	one, err := db.FindOne("SELECT version FROM edgeVersions LIMIT 1")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("query version failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	if one != nil {
 | 
			
		||||
		// 如果是同样的版本,则直接认为是最新版本
 | 
			
		||||
		version := one.GetString("version")
 | 
			
		||||
		if stringutil.VersionCompare(version, teaconst.Version) >= 0 {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dbConfig, err := db.Config()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("read db config failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	logs.Println("[API_NODE]upgrade database starting ...")
 | 
			
		||||
	err = setup.NewSQLExecutor(dbConfig).Run()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("execute sql failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	logs.Println("[API_NODE]upgrade database done")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user