// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn . package apinodeutils import ( "bytes" "compress/gzip" "crypto/md5" "errors" "fmt" "github.com/TeaOSLab/EdgeAdmin/internal/configs" "github.com/TeaOSLab/EdgeAdmin/internal/rpc" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/types" stringutil "github.com/iwind/TeaGo/utils/string" "io" "os" "os/exec" "path/filepath" "regexp" "runtime" ) type Progress struct { Percent float64 } type Upgrader struct { progress *Progress apiExe string } func NewUpgrader() *Upgrader { return &Upgrader{ apiExe: Tea.Root + "/edge-api/bin/edge-api", progress: &Progress{Percent: 0}, } } func (this *Upgrader) CanUpgrade(apiVersion string) (canUpgrade bool, reason string) { stat, err := os.Stat(this.apiExe) if err != nil { return false, "stat error: " + err.Error() } if stat.IsDir() { return false, "is directory" } localVersion, err := this.localVersion() if err != nil { return false, "lookup version failed: " + err.Error() } if stringutil.VersionCompare(localVersion, apiVersion) <= 0 { return false, "need not upgrade, local '" + localVersion + "' vs remote '" + apiVersion + "'" } return true, "" } func (this *Upgrader) Upgrade(apiNodeId int64) error { sharedClient, err := rpc.SharedRPC() if err != nil { return err } apiNodeResp, err := sharedClient.APINodeRPC().FindEnabledAPINode(sharedClient.Context(0), &pb.FindEnabledAPINodeRequest{ApiNodeId: apiNodeId}) if err != nil { return err } var apiNode = apiNodeResp.ApiNode if apiNode == nil { return errors.New("could not find api node with id '" + types.String(apiNodeId) + "'") } apiConfig, err := configs.LoadAPIConfig() if err != nil { return err } var newAPIConfig = apiConfig.Clone() newAPIConfig.RPC.Endpoints = apiNode.AccessAddrs rpcClient, err := rpc.NewRPCClient(newAPIConfig, false) if err != nil { return err } versionResp, err := rpcClient.APINodeRPC().FindCurrentAPINodeVersion(sharedClient.Context(0), &pb.FindCurrentAPINodeVersionRequest{}) if err != nil { return err } if !Tea.IsTesting() /** 开发环境下允许突破此限制方便测试 **/ && (stringutil.VersionCompare(versionResp.Version, "0.6.4" /** 从0.6.4开始支持 **/) <= 0 || versionResp.Os != runtime.GOOS || versionResp.Arch != runtime.GOARCH) { return errors.New("could not upgrade api node v" + versionResp.Version + "/" + versionResp.Os + "/" + versionResp.Arch) } // 检查本地文件版本 canUpgrade, reason := this.CanUpgrade(versionResp.Version) if !canUpgrade { return errors.New(reason) } localVersion, err := this.localVersion() if err != nil { return errors.New("lookup version failed: " + err.Error()) } // 检查要升级的文件 var gzFile = this.apiExe + "." + localVersion + ".gz" gzReader, err := os.Open(gzFile) if err != nil { if !os.IsNotExist(err) { return err } err = func() error { // 压缩文件 exeReader, err := os.Open(this.apiExe) if err != nil { return err } defer func() { _ = exeReader.Close() }() var tmpGzFile = gzFile + ".tmp" gzFileWriter, err := os.OpenFile(tmpGzFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) if err != nil { return err } var gzWriter = gzip.NewWriter(gzFileWriter) defer func() { _ = gzWriter.Close() _ = gzFileWriter.Close() _ = os.Rename(tmpGzFile, gzFile) }() _, err = io.Copy(gzWriter, exeReader) if err != nil { return err } return nil }() if err != nil { return err } gzReader, err = os.Open(gzFile) if err != nil { return err } } defer func() { _ = gzReader.Close() }() // 开始上传 var hash = md5.New() var buf = make([]byte, 128*4096) var isFirst = true stat, err := gzReader.Stat() if err != nil { return err } var totalSize = stat.Size() if totalSize == 0 { _ = gzReader.Close() _ = os.Remove(gzFile) return errors.New("invalid gz file") } var uploadedSize int64 = 0 for { n, err := gzReader.Read(buf) if n > 0 { // 计算Hash hash.Write(buf[:n]) // 上传 _, uploadErr := rpcClient.APINodeRPC().UploadAPINodeFile(rpcClient.Context(0), &pb.UploadAPINodeFileRequest{ Filename: filepath.Base(this.apiExe), Sum: "", ChunkData: buf[:n], IsFirstChunk: isFirst, IsLastChunk: false, }) if uploadErr != nil { return uploadErr } // 进度 uploadedSize += int64(n) this.progress = &Progress{Percent: float64(uploadedSize*100) / float64(totalSize)} } if isFirst { isFirst = false } if err != nil { if err != io.EOF { return err } if err == io.EOF { _, uploadErr := rpcClient.APINodeRPC().UploadAPINodeFile(rpcClient.Context(0), &pb.UploadAPINodeFileRequest{ Filename: filepath.Base(this.apiExe), Sum: fmt.Sprintf("%x", hash.Sum(nil)), ChunkData: buf[:n], IsFirstChunk: isFirst, IsLastChunk: true, }) if uploadErr != nil { return uploadErr } break } } } return nil } func (this *Upgrader) Progress() *Progress { return this.progress } func (this *Upgrader) localVersion() (string, error) { var cmd = exec.Command(this.apiExe, "-V") var output = &bytes.Buffer{} cmd.Stdout = output err := cmd.Run() if err != nil { return "", err } var localVersion = output.String() // 检查版本号 var reg = regexp.MustCompile(`^[\d.]+$`) if !reg.MatchString(localVersion) { return "", errors.New("lookup version failed: " + localVersion) } return localVersion, nil }