mirror of
				https://github.com/TeaOSLab/EdgeAdmin.git
				synced 2025-11-04 13:10:26 +08:00 
			
		
		
		
	
		
			
	
	
		
			242 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			242 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| 
								 | 
							
								// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								package apinodeutils
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import (
							 | 
						||
| 
								 | 
							
									"bytes"
							 | 
						||
| 
								 | 
							
									"compress/gzip"
							 | 
						||
| 
								 | 
							
									"crypto/md5"
							 | 
						||
| 
								 | 
							
									"errors"
							 | 
						||
| 
								 | 
							
									"fmt"
							 | 
						||
| 
								 | 
							
									"github.com/TeaOSLab/EdgeAdmin/internal/configs"
							 | 
						||
| 
								 | 
							
									"github.com/TeaOSLab/EdgeAdmin/internal/rpc"
							 | 
						||
| 
								 | 
							
									"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
							 | 
						||
| 
								 | 
							
									"github.com/iwind/TeaGo/Tea"
							 | 
						||
| 
								 | 
							
									"github.com/iwind/TeaGo/types"
							 | 
						||
| 
								 | 
							
									stringutil "github.com/iwind/TeaGo/utils/string"
							 | 
						||
| 
								 | 
							
									"io"
							 | 
						||
| 
								 | 
							
									"os"
							 | 
						||
| 
								 | 
							
									"os/exec"
							 | 
						||
| 
								 | 
							
									"path/filepath"
							 | 
						||
| 
								 | 
							
									"regexp"
							 | 
						||
| 
								 | 
							
									"runtime"
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								type Progress struct {
							 | 
						||
| 
								 | 
							
									Percent float64
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								type Upgrader struct {
							 | 
						||
| 
								 | 
							
									progress *Progress
							 | 
						||
| 
								 | 
							
									apiExe   string
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func NewUpgrader() *Upgrader {
							 | 
						||
| 
								 | 
							
									return &Upgrader{
							 | 
						||
| 
								 | 
							
										apiExe:   Tea.Root + "/edge-api/bin/edge-api",
							 | 
						||
| 
								 | 
							
										progress: &Progress{Percent: 0},
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (this *Upgrader) CanUpgrade(apiVersion string) (canUpgrade bool, reason string) {
							 | 
						||
| 
								 | 
							
									stat, err := os.Stat(this.apiExe)
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return false, "stat error: " + err.Error()
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if stat.IsDir() {
							 | 
						||
| 
								 | 
							
										return false, "is directory"
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									localVersion, err := this.localVersion()
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return false, "lookup version failed: " + err.Error()
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if stringutil.VersionCompare(localVersion, apiVersion) <= 0 {
							 | 
						||
| 
								 | 
							
										return false, "need not upgrade, local '" + localVersion + "' vs remote '" + apiVersion + "'"
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return true, ""
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (this *Upgrader) Upgrade(apiNodeId int64) error {
							 | 
						||
| 
								 | 
							
									sharedClient, err := rpc.SharedRPC()
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									apiNodeResp, err := sharedClient.APINodeRPC().FindEnabledAPINode(sharedClient.Context(0), &pb.FindEnabledAPINodeRequest{ApiNodeId: apiNodeId})
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									var apiNode = apiNodeResp.ApiNode
							 | 
						||
| 
								 | 
							
									if apiNode == nil {
							 | 
						||
| 
								 | 
							
										return errors.New("could not find api node with id '" + types.String(apiNodeId) + "'")
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									apiConfig, err := configs.LoadAPIConfig()
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									var newAPIConfig = apiConfig.Clone()
							 | 
						||
| 
								 | 
							
									newAPIConfig.RPC.Endpoints = apiNode.AccessAddrs
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									rpcClient, err := rpc.NewRPCClient(newAPIConfig, false)
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									versionResp, err := rpcClient.APINodeRPC().FindCurrentAPINodeVersion(sharedClient.Context(0), &pb.FindCurrentAPINodeVersionRequest{})
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if !Tea.IsTesting() /** 开发环境下允许突破此限制方便测试 **/ &&
							 | 
						||
| 
								 | 
							
										(stringutil.VersionCompare(versionResp.Version, "0.6.4" /** 从0.6.4开始支持 **/) <= 0 || versionResp.Os != runtime.GOOS || versionResp.Arch != runtime.GOARCH) {
							 | 
						||
| 
								 | 
							
										return errors.New("could not upgrade api node v" + versionResp.Version + "/" + versionResp.Os + "/" + versionResp.Arch)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// 检查本地文件版本
							 | 
						||
| 
								 | 
							
									canUpgrade, reason := this.CanUpgrade(versionResp.Version)
							 | 
						||
| 
								 | 
							
									if !canUpgrade {
							 | 
						||
| 
								 | 
							
										return errors.New(reason)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									localVersion, err := this.localVersion()
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return errors.New("lookup version failed: " + err.Error())
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// 检查要升级的文件
							 | 
						||
| 
								 | 
							
									var gzFile = this.apiExe + "." + localVersion + ".gz"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									gzReader, err := os.Open(gzFile)
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										if !os.IsNotExist(err) {
							 | 
						||
| 
								 | 
							
											return err
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										err = func() error {
							 | 
						||
| 
								 | 
							
											// 压缩文件
							 | 
						||
| 
								 | 
							
											exeReader, err := os.Open(this.apiExe)
							 | 
						||
| 
								 | 
							
											if err != nil {
							 | 
						||
| 
								 | 
							
												return err
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											defer func() {
							 | 
						||
| 
								 | 
							
												_ = exeReader.Close()
							 | 
						||
| 
								 | 
							
											}()
							 | 
						||
| 
								 | 
							
											var tmpGzFile = gzFile + ".tmp"
							 | 
						||
| 
								 | 
							
											gzFileWriter, err := os.OpenFile(tmpGzFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666)
							 | 
						||
| 
								 | 
							
											if err != nil {
							 | 
						||
| 
								 | 
							
												return err
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											var gzWriter = gzip.NewWriter(gzFileWriter)
							 | 
						||
| 
								 | 
							
											defer func() {
							 | 
						||
| 
								 | 
							
												_ = gzWriter.Close()
							 | 
						||
| 
								 | 
							
												_ = gzFileWriter.Close()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
												_ = os.Rename(tmpGzFile, gzFile)
							 | 
						||
| 
								 | 
							
											}()
							 | 
						||
| 
								 | 
							
											_, err = io.Copy(gzWriter, exeReader)
							 | 
						||
| 
								 | 
							
											if err != nil {
							 | 
						||
| 
								 | 
							
												return err
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
											return nil
							 | 
						||
| 
								 | 
							
										}()
							 | 
						||
| 
								 | 
							
										if err != nil {
							 | 
						||
| 
								 | 
							
											return err
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										gzReader, err = os.Open(gzFile)
							 | 
						||
| 
								 | 
							
										if err != nil {
							 | 
						||
| 
								 | 
							
											return err
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									defer func() {
							 | 
						||
| 
								 | 
							
										_ = gzReader.Close()
							 | 
						||
| 
								 | 
							
									}()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// 开始上传
							 | 
						||
| 
								 | 
							
									var hash = md5.New()
							 | 
						||
| 
								 | 
							
									var buf = make([]byte, 128*4096)
							 | 
						||
| 
								 | 
							
									var isFirst = true
							 | 
						||
| 
								 | 
							
									stat, err := gzReader.Stat()
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									var totalSize = stat.Size()
							 | 
						||
| 
								 | 
							
									if totalSize == 0 {
							 | 
						||
| 
								 | 
							
										_ = gzReader.Close()
							 | 
						||
| 
								 | 
							
										_ = os.Remove(gzFile)
							 | 
						||
| 
								 | 
							
										return errors.New("invalid gz file")
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									var uploadedSize int64 = 0
							 | 
						||
| 
								 | 
							
									for {
							 | 
						||
| 
								 | 
							
										n, err := gzReader.Read(buf)
							 | 
						||
| 
								 | 
							
										if n > 0 {
							 | 
						||
| 
								 | 
							
											// 计算Hash
							 | 
						||
| 
								 | 
							
											hash.Write(buf[:n])
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
											// 上传
							 | 
						||
| 
								 | 
							
											_, uploadErr := rpcClient.APINodeRPC().UploadAPINodeFile(rpcClient.Context(0), &pb.UploadAPINodeFileRequest{
							 | 
						||
| 
								 | 
							
												Filename:     filepath.Base(this.apiExe),
							 | 
						||
| 
								 | 
							
												Sum:          "",
							 | 
						||
| 
								 | 
							
												ChunkData:    buf[:n],
							 | 
						||
| 
								 | 
							
												IsFirstChunk: isFirst,
							 | 
						||
| 
								 | 
							
												IsLastChunk:  false,
							 | 
						||
| 
								 | 
							
											})
							 | 
						||
| 
								 | 
							
											if uploadErr != nil {
							 | 
						||
| 
								 | 
							
												return uploadErr
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
											// 进度
							 | 
						||
| 
								 | 
							
											uploadedSize += int64(n)
							 | 
						||
| 
								 | 
							
											this.progress = &Progress{Percent: float64(uploadedSize*100) / float64(totalSize)}
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										if isFirst {
							 | 
						||
| 
								 | 
							
											isFirst = false
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										if err != nil {
							 | 
						||
| 
								 | 
							
											if err != io.EOF {
							 | 
						||
| 
								 | 
							
												return err
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											if err == io.EOF {
							 | 
						||
| 
								 | 
							
												_, uploadErr := rpcClient.APINodeRPC().UploadAPINodeFile(rpcClient.Context(0), &pb.UploadAPINodeFileRequest{
							 | 
						||
| 
								 | 
							
													Filename:     filepath.Base(this.apiExe),
							 | 
						||
| 
								 | 
							
													Sum:          fmt.Sprintf("%x", hash.Sum(nil)),
							 | 
						||
| 
								 | 
							
													ChunkData:    buf[:n],
							 | 
						||
| 
								 | 
							
													IsFirstChunk: isFirst,
							 | 
						||
| 
								 | 
							
													IsLastChunk:  true,
							 | 
						||
| 
								 | 
							
												})
							 | 
						||
| 
								 | 
							
												if uploadErr != nil {
							 | 
						||
| 
								 | 
							
													return uploadErr
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
												break
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (this *Upgrader) Progress() *Progress {
							 | 
						||
| 
								 | 
							
									return this.progress
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (this *Upgrader) localVersion() (string, error) {
							 | 
						||
| 
								 | 
							
									var cmd = exec.Command(this.apiExe, "-V")
							 | 
						||
| 
								 | 
							
									var output = &bytes.Buffer{}
							 | 
						||
| 
								 | 
							
									cmd.Stdout = output
							 | 
						||
| 
								 | 
							
									err := cmd.Run()
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return "", err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									var localVersion = output.String()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// 检查版本号
							 | 
						||
| 
								 | 
							
									var reg = regexp.MustCompile(`^[\d.]+$`)
							 | 
						||
| 
								 | 
							
									if !reg.MatchString(localVersion) {
							 | 
						||
| 
								 | 
							
										return "", errors.New("lookup version failed: " + localVersion)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return localVersion, nil
							 | 
						||
| 
								 | 
							
								}
							 |