diff --git a/internal/dnsclients/provider_alidns.go b/internal/dnsclients/provider_alidns.go index d13d1836..ebc341cf 100644 --- a/internal/dnsclients/provider_alidns.go +++ b/internal/dnsclients/provider_alidns.go @@ -42,6 +42,14 @@ func (this *AliDNSProvider) Auth(params maps.Map) error { return nil } +// MaskParams 对参数进行掩码 +func (this *AliDNSProvider) MaskParams(params maps.Map) { + if params == nil { + return + } + params["accessKeySecret"] = MaskString(params.GetString("accessKeySecret")) +} + // GetDomains 获取所有域名列表 func (this *AliDNSProvider) GetDomains() (domains []string, err error) { var pageNumber = 1 diff --git a/internal/dnsclients/provider_cloud_flare.go b/internal/dnsclients/provider_cloud_flare.go index 20db0cdd..3f26273f 100644 --- a/internal/dnsclients/provider_cloud_flare.go +++ b/internal/dnsclients/provider_cloud_flare.go @@ -66,6 +66,14 @@ func (this *CloudFlareProvider) Auth(params maps.Map) error { return nil } +// MaskParams 对参数进行掩码 +func (this *CloudFlareProvider) MaskParams(params maps.Map) { + if params == nil { + return + } + params["apiKey"] = MaskString(params.GetString("apiKey")) +} + // GetDomains 获取所有域名列表 func (this *CloudFlareProvider) GetDomains() (domains []string, err error) { for page := 1; page <= 500; page++ { diff --git a/internal/dnsclients/provider_custom_http.go b/internal/dnsclients/provider_custom_http.go index 854cc427..292ef56e 100644 --- a/internal/dnsclients/provider_custom_http.go +++ b/internal/dnsclients/provider_custom_http.go @@ -53,6 +53,11 @@ func (this *CustomHTTPProvider) Auth(params maps.Map) error { return nil } +// MaskParams 对参数进行掩码 +func (this *CustomHTTPProvider) MaskParams(params maps.Map) { + // 这里暂时不要掩码,避免用户忘记 +} + // GetDomains 获取所有域名列表 func (this *CustomHTTPProvider) GetDomains() (domains []string, err error) { resp, err := this.post(maps.Map{ diff --git a/internal/dnsclients/provider_dnspod.go b/internal/dnsclients/provider_dnspod.go index b184151b..5fd4c04b 100644 --- a/internal/dnsclients/provider_dnspod.go +++ b/internal/dnsclients/provider_dnspod.go @@ -69,6 +69,19 @@ func (this *DNSPodProvider) Auth(params maps.Map) error { return nil } +// MaskParams 对参数进行掩码 +func (this *DNSPodProvider) MaskParams(params maps.Map) { + if params == nil { + return + } + + if params.GetString("apiType") == "tencentDNS" { + params["accessKeySecret"] = MaskString(params.GetString("accessKeySecret")) + } else { + params["token"] = MaskString(params.GetString("token")) + } +} + // GetDomains 获取所有域名列表 func (this *DNSPodProvider) GetDomains() (domains []string, err error) { if this.tencentDNSProvider != nil { diff --git a/internal/dnsclients/provider_edge_dns_api.go b/internal/dnsclients/provider_edge_dns_api.go index b8a8f61c..2b7b0ea1 100644 --- a/internal/dnsclients/provider_edge_dns_api.go +++ b/internal/dnsclients/provider_edge_dns_api.go @@ -71,6 +71,14 @@ func (this *EdgeDNSAPIProvider) Auth(params maps.Map) error { return nil } +// MaskParams 对参数进行掩码 +func (this *EdgeDNSAPIProvider) MaskParams(params maps.Map) { + if params == nil { + return + } + params["accessKeySecret"] = MaskString(params.GetString("accessKeySecret")) +} + // GetDomains 获取所有域名列表 func (this *EdgeDNSAPIProvider) GetDomains() (domains []string, err error) { var offset = 0 diff --git a/internal/dnsclients/provider_huawei_dns.go b/internal/dnsclients/provider_huawei_dns.go index 1eb23623..21fb2c0a 100644 --- a/internal/dnsclients/provider_huawei_dns.go +++ b/internal/dnsclients/provider_huawei_dns.go @@ -71,6 +71,14 @@ func (this *HuaweiDNSProvider) Auth(params maps.Map) error { return nil } +// MaskParams 对参数进行掩码 +func (this *HuaweiDNSProvider) MaskParams(params maps.Map) { + if params == nil { + return + } + params["accessKeySecret"] = MaskString(params.GetString("accessKeySecret")) +} + // GetDomains 获取所有域名列表 func (this *HuaweiDNSProvider) GetDomains() (domains []string, err error) { var resp = new(huaweidns.ZonesResponse) diff --git a/internal/dnsclients/provider_interface.go b/internal/dnsclients/provider_interface.go index 2a6d8aa2..d60e6561 100644 --- a/internal/dnsclients/provider_interface.go +++ b/internal/dnsclients/provider_interface.go @@ -10,6 +10,9 @@ type ProviderInterface interface { // Auth 认证 Auth(params maps.Map) error + // MaskParams 对参数进行掩码 + MaskParams(params maps.Map) + // GetDomains 获取所有域名列表 GetDomains() (domains []string, err error) diff --git a/internal/dnsclients/provider_tencent_dns.go b/internal/dnsclients/provider_tencent_dns.go index caaaf2e6..8df478f8 100644 --- a/internal/dnsclients/provider_tencent_dns.go +++ b/internal/dnsclients/provider_tencent_dns.go @@ -47,6 +47,14 @@ func (this *TencentDNSProvider) Auth(params maps.Map) error { return nil } +// MaskParams 对参数进行掩码 +func (this *TencentDNSProvider) MaskParams(params maps.Map) { + if params == nil { + return + } + params["accessKeySecret"] = MaskString(params.GetString("accessKeySecret")) +} + // GetDomains 获取所有域名列表 func (this *TencentDNSProvider) GetDomains() (domains []string, err error) { var offset int64 = 0 diff --git a/internal/dnsclients/utils.go b/internal/dnsclients/utils.go new file mode 100644 index 00000000..69fa4711 --- /dev/null +++ b/internal/dnsclients/utils.go @@ -0,0 +1,67 @@ +// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package dnsclients + +import ( + "encoding/json" + "github.com/iwind/TeaGo/maps" + "strings" +) + +// MaskString 对字符串进行掩码 +func MaskString(s string) string { + var l = len(s) + if l == 0 { + return "" + } + if l < 8 { + return strings.Repeat("*", l) + } + return s[:4] + strings.Repeat("*", l-4) +} + +// IsMasked 判断字符串是否被掩码 +func IsMasked(s string) bool { + if len(s) == 0 { + return false + } + return s == strings.Repeat("*", len(s)) || strings.HasSuffix(s, "**") +} + +// UnmaskAPIParams 恢复API参数 +func UnmaskAPIParams(oldParamsJSON []byte, newParamsJSON []byte) (resultJSON []byte, err error) { + var oldParams maps.Map + var newParams maps.Map + + if len(oldParamsJSON) == 0 || len(newParamsJSON) == 0 { + return newParamsJSON, nil + } + err = json.Unmarshal(oldParamsJSON, &oldParams) + if err != nil { + return nil, err + } + + err = json.Unmarshal(newParamsJSON, &newParams) + if err != nil { + return nil, err + } + + if oldParams == nil || newParams == nil { + return newParamsJSON, nil + } + + for k, v := range newParams { + if v != nil { + s, ok := v.(string) + if ok && IsMasked(s) { + var oldV = oldParams.GetString(k) + if len(oldV) > 0 { + newParams[k] = oldV + } + } + } + } + + resultJSON, err = json.Marshal(newParams) + return +} diff --git a/internal/dnsclients/utils_test.go b/internal/dnsclients/utils_test.go new file mode 100644 index 00000000..da254026 --- /dev/null +++ b/internal/dnsclients/utils_test.go @@ -0,0 +1,35 @@ +// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package dnsclients_test + +import ( + "github.com/TeaOSLab/EdgeAPI/internal/dnsclients" + "github.com/iwind/TeaGo/assert" + "testing" +) + +func TestIsMasked(t *testing.T) { + var a = assert.NewAssertion(t) + a.IsFalse(dnsclients.IsMasked("")) + a.IsFalse(dnsclients.IsMasked("abc")) + a.IsFalse(dnsclients.IsMasked("abc*")) + a.IsTrue(dnsclients.IsMasked("*")) + a.IsTrue(dnsclients.IsMasked("**")) + a.IsTrue(dnsclients.IsMasked("***")) + a.IsTrue(dnsclients.IsMasked("*******")) + a.IsTrue(dnsclients.IsMasked("abc**")) + a.IsTrue(dnsclients.IsMasked("abcd*********")) +} + +func TestUnmaskAPIParams(t *testing.T) { + data, err := dnsclients.UnmaskAPIParams([]byte(`{ + "key": "a", + "secret": "abc12" +}`), []byte(`{ + "secret": "abc**" +}`)) + if err != nil { + t.Fatal(err) + } + t.Log(string(data)) +} diff --git a/internal/rpc/services/service_dns_provider.go b/internal/rpc/services/service_dns_provider.go index 3518a8c9..15a6cef2 100644 --- a/internal/rpc/services/service_dns_provider.go +++ b/internal/rpc/services/service_dns_provider.go @@ -2,9 +2,11 @@ package services import ( "context" + "encoding/json" "github.com/TeaOSLab/EdgeAPI/internal/db/models/dns" "github.com/TeaOSLab/EdgeAPI/internal/dnsclients" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/iwind/TeaGo/maps" ) // DNSProviderService DNS服务商相关服务 @@ -42,6 +44,21 @@ func (this *DNSProviderService) UpdateDNSProvider(ctx context.Context, req *pb.U var tx = this.NullTx() + provider, err := dns.SharedDNSProviderDAO.FindEnabledDNSProvider(tx, req.DnsProviderId) + if err != nil { + return nil, err + } + if provider == nil { + // do nothing here + return this.Success() + } + + // 恢复被掩码的数据 + req.ApiParamsJSON, err = dnsclients.UnmaskAPIParams(provider.ApiParams, req.ApiParamsJSON) + if err != nil { + return nil, err + } + err = dns.SharedDNSProviderDAO.UpdateDNSProvider(tx, req.DnsProviderId, req.Name, req.ApiParamsJSON) if err != nil { return nil, err @@ -175,14 +192,34 @@ func (this *DNSProviderService) FindEnabledDNSProvider(ctx context.Context, req return &pb.FindEnabledDNSProviderResponse{DnsProvider: nil}, nil } - return &pb.FindEnabledDNSProviderResponse{DnsProvider: &pb.DNSProvider{ - Id: int64(provider.Id), - Name: provider.Name, - Type: provider.Type, - TypeName: dnsclients.FindProviderTypeName(provider.Type), - ApiParamsJSON: provider.ApiParams, - DataUpdatedAt: int64(provider.DataUpdatedAt), - }}, nil + if req.MaskParams { + var providerObj = dnsclients.FindProvider(provider.Type, int64(provider.Id)) + if providerObj != nil { + var paramsMap = maps.Map{} + if len(provider.ApiParams) > 0 { + err = json.Unmarshal(provider.ApiParams, ¶msMap) + if err != nil { + return nil, err + } + providerObj.MaskParams(paramsMap) + provider.ApiParams, err = json.Marshal(paramsMap) + if err != nil { + return nil, err + } + } + } + } + + return &pb.FindEnabledDNSProviderResponse{ + DnsProvider: &pb.DNSProvider{ + Id: int64(provider.Id), + Name: provider.Name, + Type: provider.Type, + TypeName: dnsclients.FindProviderTypeName(provider.Type), + ApiParamsJSON: provider.ApiParams, + DataUpdatedAt: int64(provider.DataUpdatedAt), + }, + }, nil } // FindAllDNSProviderTypes 取得所有服务商类型