From a068ce9a4f4152fcaeb7fc2f19d1246a13dda62c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=A5=A5=E8=B6=85?= Date: Thu, 2 Mar 2023 10:28:15 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AF=B9=E8=BE=B9=E7=BC=98=E8=8A=82=E7=82=B9?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E7=BC=93=E5=AD=98=E8=BF=9B=E8=A1=8C=E5=8A=A0?= =?UTF-8?q?=E5=AF=86=EF=BC=8C=E6=8F=90=E5=8D=87=E5=AE=89=E5=85=A8=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/nodeconfigs/node_config.go | 54 ++++++++++++++++++++++++++++++--- pkg/nodeutils/aes_utils.go | 47 +++++++++++++++++++++++++--- pkg/nodeutils/aes_utils_test.go | 22 +++++++++++--- 3 files changed, 110 insertions(+), 13 deletions(-) diff --git a/pkg/nodeconfigs/node_config.go b/pkg/nodeconfigs/node_config.go index 79181f4..abd6cfb 100644 --- a/pkg/nodeconfigs/node_config.go +++ b/pkg/nodeconfigs/node_config.go @@ -1,10 +1,13 @@ package nodeconfigs import ( + "bytes" "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" + "github.com/TeaOSLab/EdgeCommon/pkg/nodeutils" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" @@ -14,6 +17,7 @@ import ( "os" "reflect" "strconv" + "strings" ) var sharedNodeConfig *NodeConfig = nil @@ -122,17 +126,50 @@ func SharedNodeConfig() (*NodeConfig, error) { return sharedNodeConfig, nil } - data, err := os.ReadFile(Tea.ConfigFile("node.json")) + // 从本地缓存读取 + var configFile = Tea.ConfigFile("node.json") + var readCacheOk = false + defer func() { + if !readCacheOk { + _ = os.Remove(configFile) + } + }() + + data, err := os.ReadFile(configFile) + if err != nil { + return &NodeConfig{}, err + } + + encodedNodeInfo, encodedJSONData, found := bytes.Cut(data, []byte("\n")) + if !found { + // 删除缓存文件 + return &NodeConfig{}, errors.New("node.json: invalid data format") + } + + encodedNodeInfoData, err := base64.StdEncoding.DecodeString(string(encodedNodeInfo)) + if err != nil { + // 删除缓存文件 + return &NodeConfig{}, err + } + + nodeUniqueId, nodeSecret, found := strings.Cut(string(encodedNodeInfoData), "|") + if !found { + // 删除缓存文件 + return &NodeConfig{}, errors.New("node.json: node info: invalid data format") + } + + jsonData, err := nodeutils.DecryptData(nodeUniqueId, nodeSecret, string(encodedJSONData)) if err != nil { return &NodeConfig{}, err } var config = &NodeConfig{} - err = json.Unmarshal(data, &config) + err = json.Unmarshal(jsonData, &config) if err != nil { return &NodeConfig{}, err } + readCacheOk = true sharedNodeConfig = config return config, nil } @@ -397,7 +434,7 @@ func (this *NodeConfig) RemoveServer(serverId int64) { // AvailableGroups 根据网络地址和协议分组 func (this *NodeConfig) AvailableGroups() []*serverconfigs.ServerAddressGroup { - groupMapping := map[string]*serverconfigs.ServerAddressGroup{} // protocol://addr => Server Group + var groupMapping = map[string]*serverconfigs.ServerAddressGroup{} // protocol://addr => Server Group for _, server := range this.Servers { if !server.IsOk() || !server.IsOn { continue @@ -413,7 +450,7 @@ func (this *NodeConfig) AvailableGroups() []*serverconfigs.ServerAddressGroup { groupMapping[addr] = group } } - result := []*serverconfigs.ServerAddressGroup{} + var result = []*serverconfigs.ServerAddressGroup{} for _, group := range groupMapping { result = append(result, group) } @@ -435,7 +472,14 @@ func (this *NodeConfig) Save() error { return err } - return os.WriteFile(Tea.ConfigFile("node.json"), data, 0777) + var headerData = []byte(base64.StdEncoding.EncodeToString([]byte(this.NodeId+"|"+this.Secret)) + "\n") + + encodedData, err := nodeutils.EncryptData(this.NodeId, this.Secret, data) + if err != nil { + return err + } + + return os.WriteFile(Tea.ConfigFile("node.json"), append(headerData, encodedData...), 0777) } // PaddedId 获取填充后的ID diff --git a/pkg/nodeutils/aes_utils.go b/pkg/nodeutils/aes_utils.go index 48a2428..40baeaa 100644 --- a/pkg/nodeutils/aes_utils.go +++ b/pkg/nodeutils/aes_utils.go @@ -10,8 +10,8 @@ import ( "time" ) -// EncryptData 加密 -func EncryptData(nodeUniqueId string, nodeSecret string, data maps.Map, timeout int32) (string, error) { +// EncryptMap 加密 +func EncryptMap(nodeUniqueId string, nodeSecret string, data maps.Map, timeout int32) (string, error) { if data == nil { data = maps.Map{} } @@ -42,8 +42,8 @@ func EncryptData(nodeUniqueId string, nodeSecret string, data maps.Map, timeout return base64.StdEncoding.EncodeToString(result), nil } -// DecryptData 解密 -func DecryptData(nodeUniqueId string, nodeSecret string, encodedString string) (maps.Map, error) { +// DecryptMap 解密 +func DecryptMap(nodeUniqueId string, nodeSecret string, encodedString string) (maps.Map, error) { var method = &AES256CFBMethod{} err := method.Init([]byte(nodeUniqueId), []byte(nodeSecret)) if err != nil { @@ -73,3 +73,42 @@ func DecryptData(nodeUniqueId string, nodeSecret string, encodedString string) ( return result.GetMap("data"), nil } + +// EncryptData 加密 +func EncryptData(nodeUniqueId string, nodeSecret string, data []byte) (string, error) { + if len(data) == 0 { + return "", nil + } + + var method = &AES256CFBMethod{} + err := method.Init([]byte(nodeUniqueId), []byte(nodeSecret)) + if err != nil { + return "", err + } + result, err := method.Encrypt(data) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(result), nil +} + +// DecryptData 解密 +func DecryptData(nodeUniqueId string, nodeSecret string, encodedString string) ([]byte, error) { + if len(encodedString) == 0 { + return nil, nil + } + + var method = &AES256CFBMethod{} + err := method.Init([]byte(nodeUniqueId), []byte(nodeSecret)) + if err != nil { + return nil, err + } + + encodedData, err := base64.StdEncoding.DecodeString(encodedString) + if err != nil { + return nil, errors.New("base64 decode failed: " + err.Error()) + } + + return method.Decrypt(encodedData) +} diff --git a/pkg/nodeutils/aes_utils_test.go b/pkg/nodeutils/aes_utils_test.go index ce43471..f15645d 100644 --- a/pkg/nodeutils/aes_utils_test.go +++ b/pkg/nodeutils/aes_utils_test.go @@ -7,8 +7,8 @@ import ( "testing" ) -func TestEncryptData(t *testing.T) { - e, err := EncryptData("a", "b", maps.Map{ +func TestEncryptMap(t *testing.T) { + e, err := EncryptMap("a", "b", maps.Map{ "c": 1, }, 5) if err != nil { @@ -16,16 +16,30 @@ func TestEncryptData(t *testing.T) { } t.Log("e:", e) - s, err := DecryptData("a", "b", e) + s, err := DecryptMap("a", "b", e) if err != nil { t.Fatal(err) } t.Log("s:", s) } +func TestEncryptData(t *testing.T) { + encoded, err := EncryptData("a", "b", []byte("Hello, World")) + if err != nil { + t.Fatal(err) + } + t.Log("encoded:", encoded) + + source, err := DecryptData("a", "b", encoded) + if err != nil { + t.Fatal(err) + } + t.Log("source:", string(source)) +} + func BenchmarkEncryptData(b *testing.B) { for i := 0; i < b.N; i++ { - _, _ = EncryptData("a", "b", maps.Map{ + _, _ = EncryptMap("a", "b", maps.Map{ "c": 1, }, 5) }