diff --git a/internal/db/models/nameservers/ns_record_dao.go b/internal/db/models/nameservers/ns_record_dao.go index 6cfdc655..f992bcf5 100644 --- a/internal/db/models/nameservers/ns_record_dao.go +++ b/internal/db/models/nameservers/ns_record_dao.go @@ -9,7 +9,6 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" - "strconv" ) const ( @@ -89,7 +88,7 @@ func (this *NSRecordDAO) FindNSRecordName(tx *dbs.Tx, id int64) (string, error) } // CreateRecord 创建记录 -func (this *NSRecordDAO) CreateRecord(tx *dbs.Tx, domainId int64, description string, name string, dnsType dnsconfigs.RecordType, value string, ttl int32, routeIds []int64) (int64, error) { +func (this *NSRecordDAO) CreateRecord(tx *dbs.Tx, domainId int64, description string, name string, dnsType dnsconfigs.RecordType, value string, ttl int32, routeIds []string) (int64, error) { version, err := this.IncreaseVersion(tx) if err != nil { return 0, err @@ -104,7 +103,7 @@ func (this *NSRecordDAO) CreateRecord(tx *dbs.Tx, domainId int64, description st op.Ttl = ttl if len(routeIds) == 0 { - op.RouteIds = "[]" + op.RouteIds = `["default"]` } else { routeIds, err := json.Marshal(routeIds) if err != nil { @@ -129,7 +128,7 @@ func (this *NSRecordDAO) CreateRecord(tx *dbs.Tx, domainId int64, description st } // UpdateRecord 修改记录 -func (this *NSRecordDAO) UpdateRecord(tx *dbs.Tx, recordId int64, description string, name string, dnsType dnsconfigs.RecordType, value string, ttl int32, routeIds []int64, isOn bool) error { +func (this *NSRecordDAO) UpdateRecord(tx *dbs.Tx, recordId int64, description string, name string, dnsType dnsconfigs.RecordType, value string, ttl int32, routeIds []string, isOn bool) error { if recordId <= 0 { return errors.New("invalid recordId") } @@ -149,7 +148,7 @@ func (this *NSRecordDAO) UpdateRecord(tx *dbs.Tx, recordId int64, description st op.IsOn = isOn if len(routeIds) == 0 { - op.RouteIds = "[]" + op.RouteIds = `["default"]` } else { routeIds, err := json.Marshal(routeIds) if err != nil { @@ -169,7 +168,7 @@ func (this *NSRecordDAO) UpdateRecord(tx *dbs.Tx, recordId int64, description st } // CountAllEnabledDomainRecords 计算域名中记录数量 -func (this *NSRecordDAO) CountAllEnabledDomainRecords(tx *dbs.Tx, domainId int64, dnsType dnsconfigs.RecordType, keyword string, routeId int64) (int64, error) { +func (this *NSRecordDAO) CountAllEnabledDomainRecords(tx *dbs.Tx, domainId int64, dnsType dnsconfigs.RecordType, keyword string, routeCode string) (int64, error) { query := this.Query(tx). Attr("domainId", domainId). State(NSRecordStateEnabled) @@ -180,8 +179,12 @@ func (this *NSRecordDAO) CountAllEnabledDomainRecords(tx *dbs.Tx, domainId int64 query.Where("(name LIKE :keyword OR value LIKE :keyword OR description LIKE :keyword)"). Param("keyword", "%"+keyword+"%") } - if routeId > 0 { - query.JSONContains("routeIds", strconv.FormatInt(routeId, 10)) + if len(routeCode) > 0 { + routeCodeJSON, err := json.Marshal(routeCode) + if err != nil { + return 0, err + } + query.JSONContains("routeIds", string(routeCodeJSON)) } return query.Count() } @@ -195,7 +198,7 @@ func (this *NSRecordDAO) CountAllEnabledRecords(tx *dbs.Tx) (int64, error) { } // ListEnabledRecords 列出单页记录 -func (this *NSRecordDAO) ListEnabledRecords(tx *dbs.Tx, domainId int64, dnsType dnsconfigs.RecordType, keyword string, routeId int64, offset int64, size int64) (result []*NSRecord, err error) { +func (this *NSRecordDAO) ListEnabledRecords(tx *dbs.Tx, domainId int64, dnsType dnsconfigs.RecordType, keyword string, routeCode string, offset int64, size int64) (result []*NSRecord, err error) { query := this.Query(tx). Attr("domainId", domainId). State(NSRecordStateEnabled) @@ -206,8 +209,12 @@ func (this *NSRecordDAO) ListEnabledRecords(tx *dbs.Tx, domainId int64, dnsType query.Where("(name LIKE :keyword OR value LIKE :keyword OR description LIKE :keyword)"). Param("keyword", "%"+keyword+"%") } - if routeId > 0 { - query.JSONContains("routeIds", strconv.FormatInt(routeId, 10)) + if len(routeCode) > 0 { + routeCodeJSON, err := json.Marshal(routeCode) + if err != nil { + return nil, err + } + query.JSONContains("routeIds", string(routeCodeJSON)) } _, err = query. DescPk(). diff --git a/internal/db/models/nameservers/ns_record_dao_test.go b/internal/db/models/nameservers/ns_record_dao_test.go index e29fe962..d4b0a51a 100644 --- a/internal/db/models/nameservers/ns_record_dao_test.go +++ b/internal/db/models/nameservers/ns_record_dao_test.go @@ -3,4 +3,27 @@ package nameservers import ( _ "github.com/go-sql-driver/mysql" _ "github.com/iwind/TeaGo/bootstrap" + "testing" ) + +func TestNSRecord_DecodeRouteIds(t *testing.T) { + { + record := &NSRecord{} + t.Log(record.DecodeRouteIds()) + } + + { + record := &NSRecord{RouteIds: "[]"} + t.Log(record.DecodeRouteIds()) + } + + { + record := &NSRecord{RouteIds: "[1, 2, 3]"} + t.Log(record.DecodeRouteIds()) + } + + { + record := &NSRecord{RouteIds: `["id:1", "id:2", "isp:liantong"]`} + t.Log(record.DecodeRouteIds()) + } +} diff --git a/internal/db/models/nameservers/ns_record_model_ext.go b/internal/db/models/nameservers/ns_record_model_ext.go index ccb05483..b84c59d5 100644 --- a/internal/db/models/nameservers/ns_record_model_ext.go +++ b/internal/db/models/nameservers/ns_record_model_ext.go @@ -1,11 +1,26 @@ package nameservers -import "encoding/json" +import ( + "encoding/json" + "github.com/iwind/TeaGo/types" +) -func (this *NSRecord) DecodeRouteIds() []int64 { - routeIds := []int64{} +func (this *NSRecord) DecodeRouteIds() []string { + var routeIds = []string{} if len(this.RouteIds) > 0 { - _ = json.Unmarshal([]byte(this.RouteIds), &routeIds) + err := json.Unmarshal([]byte(this.RouteIds), &routeIds) + if err != nil { + // 检查是否有旧的数据 + var oldRouteIds = []int64{} + err = json.Unmarshal([]byte(this.RouteIds), &oldRouteIds) + if err != nil { + return []string{} + } + routeIds = []string{} + for _, routeId := range oldRouteIds { + routeIds = append(routeIds, "id:"+types.String(routeId)) + } + } } return routeIds } diff --git a/internal/db/models/nameservers/ns_route_dao.go b/internal/db/models/nameservers/ns_route_dao.go index 7dbf9440..d206102c 100644 --- a/internal/db/models/nameservers/ns_route_dao.go +++ b/internal/db/models/nameservers/ns_route_dao.go @@ -3,10 +3,14 @@ package nameservers import ( "github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/errors" + "github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/types" + "regexp" + "strings" ) const ( @@ -83,6 +87,33 @@ func (this *NSRouteDAO) FindEnabledNSRoute(tx *dbs.Tx, id int64) (*NSRoute, erro return result.(*NSRoute), err } +// FindEnabledRouteWithCode 根据代号获取线路信息 +func (this *NSRouteDAO) FindEnabledRouteWithCode(tx *dbs.Tx, code string) (*NSRoute, error) { + if regexp.MustCompile(`^id:\d+$`).MatchString(code) { + var routeId = types.Int64(code[strings.Index(code, ":")+1:]) + route, err := this.FindEnabledNSRoute(tx, routeId) + if route == nil || err != nil { + return nil, err + } + + route.Code = "id:" + types.String(routeId) + return route, nil + } + + route := dnsconfigs.FindDefaultRoute(code) + if route == nil { + return nil, nil + } + + return &NSRoute{ + Id: 0, + IsOn: 1, + Name: route.Name, + Code: route.Code, + State: NSRouteStateEnabled, + }, nil +} + // FindNSRouteName 根据主键查找名称 func (this *NSRouteDAO) FindNSRouteName(tx *dbs.Tx, id int64) (string, error) { return this.Query(tx). diff --git a/internal/db/models/nameservers/ns_route_model.go b/internal/db/models/nameservers/ns_route_model.go index 27a17e5b..39e7d7aa 100644 --- a/internal/db/models/nameservers/ns_route_model.go +++ b/internal/db/models/nameservers/ns_route_model.go @@ -11,6 +11,7 @@ type NSRoute struct { Ranges string `field:"ranges"` // 范围 Order uint32 `field:"order"` // 排序 Version uint64 `field:"version"` // 版本号 + Code string `field:"code"` // 代号 State uint8 `field:"state"` // 状态 } @@ -24,6 +25,7 @@ type NSRouteOperator struct { Ranges interface{} // 范围 Order interface{} // 排序 Version interface{} // 版本号 + Code interface{} // 代号 State interface{} // 状态 } diff --git a/internal/dnsclients/provider_local_edge_dns.go b/internal/dnsclients/provider_local_edge_dns.go index ee27d72b..b3829463 100644 --- a/internal/dnsclients/provider_local_edge_dns.go +++ b/internal/dnsclients/provider_local_edge_dns.go @@ -7,10 +7,10 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/db/models/nameservers" "github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes" "github.com/TeaOSLab/EdgeAPI/internal/errors" + "github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" - "regexp" "strings" ) @@ -48,7 +48,7 @@ func (this *LocalEdgeDNSProvider) GetRecords(domain string) (records []*dnstypes offset := int64(0) size := int64(1000) for { - result, err := nameservers.SharedNSRecordDAO.ListEnabledRecords(tx, domainId, "", "", 0, offset, size) + result, err := nameservers.SharedNSRecordDAO.ListEnabledRecords(tx, domainId, "", "", "", offset, size) if err != nil { return nil, err } @@ -61,17 +61,15 @@ func (this *LocalEdgeDNSProvider) GetRecords(domain string) (records []*dnstypes } routeIds := record.DecodeRouteIds() - var routeIdString = "" - if len(routeIds) > 0 { - routeIdString = fmt.Sprintf("%d", routeIds[0]) + if len(routeIds) == 0 { + routeIds = []string{dnsconfigs.DefaultRouteCode} } - records = append(records, &dnstypes.Record{ Id: fmt.Sprintf("%d", record.Id), Name: record.Name, Type: record.Type, Value: record.Value, - Route: routeIdString, + Route: routeIds[0], }) } @@ -92,7 +90,15 @@ func (this *LocalEdgeDNSProvider) GetRoutes(domain string) (routes []*dnstypes.R return nil, errors.New("can not find domain '" + domain + "'") } - // TODO 将来支持集群、域名、用户自定义线路 + // 默认线路 + for _, route := range dnsconfigs.AllDefaultRoutes { + routes = append(routes, &dnstypes.Route{ + Name: route.Name, + Code: route.Code, + }) + } + + // 自定义线路 result, err := nameservers.SharedNSRouteDAO.FindAllEnabledRoutes(tx, 0, 0, 0) if err != nil { return nil, err @@ -100,7 +106,31 @@ func (this *LocalEdgeDNSProvider) GetRoutes(domain string) (routes []*dnstypes.R for _, route := range result { routes = append(routes, &dnstypes.Route{ Name: route.Name, - Code: fmt.Sprintf("%d", route.Id), + Code: "id:" + types.String(route.Id), + }) + } + + // 默认ISP + for _, route := range dnsconfigs.AllDefaultISPRoutes { + routes = append(routes, &dnstypes.Route{ + Name: route.Name, + Code: route.Code, + }) + } + + // 默认中国省份 + for _, route := range dnsconfigs.AllDefaultChinaProvinceRoutes { + routes = append(routes, &dnstypes.Route{ + Name: route.Name, + Code: route.Code, + }) + } + + // 默认全球国家/地区 + for _, route := range dnsconfigs.AllDefaultWorldRegionRoutes { + routes = append(routes, &dnstypes.Route{ + Name: route.Name, + Code: route.Code, }) } @@ -129,7 +159,9 @@ func (this *LocalEdgeDNSProvider) QueryRecord(domain string, name string, record routeIds := record.DecodeRouteIds() var routeIdString = "" if len(routeIds) > 0 { - routeIdString = fmt.Sprintf("%d", routeIds[0]) + routeIdString = routeIds[0] + } else { + routeIdString = dnsconfigs.DefaultRouteCode } return &dnstypes.Record{ @@ -152,12 +184,9 @@ func (this *LocalEdgeDNSProvider) AddRecord(domain string, newRecord *dnstypes.R return errors.New("can not find domain '" + domain + "'") } - var routeIds []int64 - if len(newRecord.Route) > 0 && regexp.MustCompile(`^\d+$`).MatchString(newRecord.Route) { - routeId := types.Int64(newRecord.Route) - if routeId > 0 { - routeIds = append(routeIds, routeId) - } + var routeIds = []string{} + if len(newRecord.Route) > 0 { + routeIds = append(routeIds, newRecord.Route) } _, err = nameservers.SharedNSRecordDAO.CreateRecord(tx, domainId, "", newRecord.Name, newRecord.Type, newRecord.Value, this.ttl, routeIds) @@ -179,12 +208,9 @@ func (this *LocalEdgeDNSProvider) UpdateRecord(domain string, record *dnstypes.R return errors.New("can not find domain '" + domain + "'") } - var routeIds []int64 - if len(newRecord.Route) > 0 && regexp.MustCompile(`^\d+$`).MatchString(newRecord.Route) { - routeId := types.Int64(newRecord.Route) - if routeId > 0 { - routeIds = append(routeIds, routeId) - } + var routeIds []string + if len(newRecord.Route) > 0 { + routeIds = append(routeIds, newRecord.Route) } if len(record.Id) > 0 { @@ -242,5 +268,5 @@ func (this *LocalEdgeDNSProvider) DeleteRecord(domain string, record *dnstypes.R // DefaultRoute 默认线路 func (this *LocalEdgeDNSProvider) DefaultRoute() string { - return "" + return "default" } diff --git a/internal/dnsclients/provider_local_edge_dns_test.go b/internal/dnsclients/provider_local_edge_dns_test.go index ea5bda0c..b8c1a292 100644 --- a/internal/dnsclients/provider_local_edge_dns_test.go +++ b/internal/dnsclients/provider_local_edge_dns_test.go @@ -1,8 +1,9 @@ // Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. -package dnsclients +package dnsclients_test import ( + "github.com/TeaOSLab/EdgeAPI/internal/dnsclients" "github.com/TeaOSLab/EdgeAPI/internal/dnsclients/dnstypes" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/logs" @@ -10,12 +11,14 @@ import ( "testing" ) +const testClusterId = 7 + func TestLocalEdgeDNSProvider_GetRecords(t *testing.T) { dbs.NotifyReady() - provider := &LocalEdgeDNSProvider{} + provider := &dnsclients.LocalEdgeDNSProvider{} err := provider.Auth(maps.Map{ - "clusterId": 1, + "clusterId": testClusterId, }) if err != nil { t.Fatal(err) @@ -31,9 +34,9 @@ func TestLocalEdgeDNSProvider_GetRecords(t *testing.T) { func TestLocalEdgeDNSProvider_GetRoutes(t *testing.T) { dbs.NotifyReady() - provider := &LocalEdgeDNSProvider{} + provider := &dnsclients.LocalEdgeDNSProvider{} err := provider.Auth(maps.Map{ - "clusterId": 1, + "clusterId": testClusterId, }) if err != nil { t.Fatal(err) @@ -49,9 +52,9 @@ func TestLocalEdgeDNSProvider_GetRoutes(t *testing.T) { func TestLocalEdgeDNSProvider_QueryRecord(t *testing.T) { dbs.NotifyReady() - provider := &LocalEdgeDNSProvider{} + provider := &dnsclients.LocalEdgeDNSProvider{} err := provider.Auth(maps.Map{ - "clusterId": 1, + "clusterId": testClusterId, }) if err != nil { t.Fatal(err) @@ -66,9 +69,9 @@ func TestLocalEdgeDNSProvider_QueryRecord(t *testing.T) { func TestLocalEdgeDNSProvider_AddRecord(t *testing.T) { dbs.NotifyReady() - provider := &LocalEdgeDNSProvider{} + provider := &dnsclients.LocalEdgeDNSProvider{} err := provider.Auth(maps.Map{ - "clusterId": 1, + "clusterId": testClusterId, }) if err != nil { t.Fatal(err) @@ -79,7 +82,7 @@ func TestLocalEdgeDNSProvider_AddRecord(t *testing.T) { Name: "example", Type: dnstypes.RecordTypeA, Value: "10.0.0.1", - Route: "7", + Route: "id:7", }) if err != nil { t.Fatal(err) @@ -90,9 +93,9 @@ func TestLocalEdgeDNSProvider_AddRecord(t *testing.T) { func TestLocalEdgeDNSProvider_UpdateRecord(t *testing.T) { dbs.NotifyReady() - provider := &LocalEdgeDNSProvider{} + provider := &dnsclients.LocalEdgeDNSProvider{} err := provider.Auth(maps.Map{ - "clusterId": 1, + "clusterId": testClusterId, }) if err != nil { t.Fatal(err) @@ -124,9 +127,9 @@ func TestLocalEdgeDNSProvider_UpdateRecord(t *testing.T) { func TestLocalEdgeDNSProvider_DeleteRecord(t *testing.T) { dbs.NotifyReady() - provider := &LocalEdgeDNSProvider{} + provider := &dnsclients.LocalEdgeDNSProvider{} err := provider.Auth(maps.Map{ - "clusterId": 1, + "clusterId": testClusterId, }) if err != nil { t.Fatal(err) @@ -148,9 +151,9 @@ func TestLocalEdgeDNSProvider_DeleteRecord(t *testing.T) { func TestLocalEdgeDNSProvider_DefaultRoute(t *testing.T) { dbs.NotifyReady() - provider := &LocalEdgeDNSProvider{} + provider := &dnsclients.LocalEdgeDNSProvider{} err := provider.Auth(maps.Map{ - "clusterId": 1, + "clusterId": testClusterId, }) if err != nil { t.Fatal(err) diff --git a/internal/rpc/services/nameservers/service_ns_record.go b/internal/rpc/services/nameservers/service_ns_record.go index 570e5dcb..e247ba40 100644 --- a/internal/rpc/services/nameservers/service_ns_record.go +++ b/internal/rpc/services/nameservers/service_ns_record.go @@ -9,6 +9,8 @@ import ( rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/iwind/TeaGo/types" + "regexp" + "strings" ) // NSRecordService 域名记录相关服务 @@ -24,7 +26,7 @@ func (this *NSRecordService) CreateNSRecord(ctx context.Context, req *pb.CreateN } var tx = this.NullTx() - recordId, err := nameservers.SharedNSRecordDAO.CreateRecord(tx, req.NsDomainId, req.Description, req.Name, req.Type, req.Value, req.Ttl, req.NsRouteIds) + recordId, err := nameservers.SharedNSRecordDAO.CreateRecord(tx, req.NsDomainId, req.Description, req.Name, req.Type, req.Value, req.Ttl, req.NsRouteCodes) if err != nil { return nil, err } @@ -39,7 +41,7 @@ func (this *NSRecordService) UpdateNSRecord(ctx context.Context, req *pb.UpdateN } var tx = this.NullTx() - err = nameservers.SharedNSRecordDAO.UpdateRecord(tx, req.NsRecordId, req.Description, req.Name, req.Type, req.Value, req.Ttl, req.NsRouteIds, req.IsOn) + err = nameservers.SharedNSRecordDAO.UpdateRecord(tx, req.NsRecordId, req.Description, req.Name, req.Type, req.Value, req.Ttl, req.NsRouteCodes, req.IsOn) if err != nil { return nil, err } @@ -69,7 +71,7 @@ func (this *NSRecordService) CountAllEnabledNSRecords(ctx context.Context, req * } var tx = this.NullTx() - count, err := nameservers.SharedNSRecordDAO.CountAllEnabledDomainRecords(tx, req.NsDomainId, req.Type, req.Keyword, req.NsRouteId) + count, err := nameservers.SharedNSRecordDAO.CountAllEnabledDomainRecords(tx, req.NsDomainId, req.Type, req.Keyword, req.NsRouteCode) if err != nil { return nil, err } @@ -84,7 +86,7 @@ func (this *NSRecordService) ListEnabledNSRecords(ctx context.Context, req *pb.L } var tx = this.NullTx() - records, err := nameservers.SharedNSRecordDAO.ListEnabledRecords(tx, req.NsDomainId, req.Type, req.Keyword, req.NsRouteId, req.Offset, req.Size) + records, err := nameservers.SharedNSRecordDAO.ListEnabledRecords(tx, req.NsDomainId, req.Type, req.Keyword, req.NsRouteCode, req.Offset, req.Size) if err != nil { return nil, err } @@ -92,8 +94,8 @@ func (this *NSRecordService) ListEnabledNSRecords(ctx context.Context, req *pb.L for _, record := range records { // 线路 var pbRoutes = []*pb.NSRoute{} - for _, recordId := range record.DecodeRouteIds() { - route, err := nameservers.SharedNSRouteDAO.FindEnabledNSRoute(tx, recordId) + for _, routeCode := range record.DecodeRouteIds() { + route, err := nameservers.SharedNSRouteDAO.FindEnabledRouteWithCode(tx, routeCode) if err != nil { return nil, err } @@ -103,7 +105,10 @@ func (this *NSRecordService) ListEnabledNSRecords(ctx context.Context, req *pb.L pbRoutes = append(pbRoutes, &pb.NSRoute{ Id: int64(route.Id), Name: route.Name, + Code: route.Code, }) + + // TODO 读取其他线路 } pbRecords = append(pbRecords, &pb.NSRecord{ @@ -155,8 +160,8 @@ func (this *NSRecordService) FindEnabledNSRecord(ctx context.Context, req *pb.Fi // 线路 var pbRoutes = []*pb.NSRoute{} - for _, recordId := range record.DecodeRouteIds() { - route, err := nameservers.SharedNSRouteDAO.FindEnabledNSRoute(tx, recordId) + for _, routeCode := range record.DecodeRouteIds() { + route, err := nameservers.SharedNSRouteDAO.FindEnabledRouteWithCode(tx, routeCode) if err != nil { return nil, err } @@ -166,9 +171,12 @@ func (this *NSRecordService) FindEnabledNSRecord(ctx context.Context, req *pb.Fi pbRoutes = append(pbRoutes, &pb.NSRoute{ Id: int64(route.Id), Name: route.Name, + Code: route.Code, }) } + // TODO 读取其他线路 + return &pb.FindEnabledNSRecordResponse{NsRecord: &pb.NSRecord{ Id: int64(record.Id), Description: record.Description, @@ -207,9 +215,19 @@ func (this *NSRecordService) ListNSRecordsAfterVersion(ctx context.Context, req pbRoutes := []*pb.NSRoute{} routeIds := record.DecodeRouteIds() for _, routeId := range routeIds { - pbRoutes = append(pbRoutes, &pb.NSRoute{Id: routeId}) + var routeIdInt int64 = 0 + if regexp.MustCompile(`^id:\d+$`).MatchString(routeId) { + routeIdInt = types.Int64(routeId[strings.Index(routeId, ":")+1:]) + } + + pbRoutes = append(pbRoutes, &pb.NSRoute{ + Id: routeIdInt, + Code: routeId, + }) } + // TODO 读取其他线路 + pbRecords = append(pbRecords, &pb.NSRecord{ Id: int64(record.Id), Description: "", diff --git a/internal/setup/sql_upgrade.go b/internal/setup/sql_upgrade.go index c9efc3a3..16abacd0 100644 --- a/internal/setup/sql_upgrade.go +++ b/internal/setup/sql_upgrade.go @@ -15,6 +15,7 @@ import ( "github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/types" stringutil "github.com/iwind/TeaGo/utils/string" + "regexp" ) type upgradeVersion struct { @@ -42,7 +43,7 @@ var upgradeFuncs = []*upgradeVersion{ "0.2.5", upgradeV0_2_5, }, { - "0.2.9", upgradeV0_2_9, + "0.2.8.1", upgradeV0_2_8_1, }, } @@ -255,9 +256,9 @@ func upgradeV0_2_5(db *dbs.DB) error { return nil } -// v0.2.9 -func upgradeV0_2_9(db *dbs.DB) error { - // 访问日志 +// v0.2.8.1 +func upgradeV0_2_8_1(db *dbs.DB) error { + // 访问日志设置 { one, err := db.FindOne("SELECT id FROM edgeSysSettings WHERE code=? LIMIT 1", systemconfigs.SettingCodeNSAccessLogSetting) if err != nil { @@ -279,5 +280,47 @@ func upgradeV0_2_9(db *dbs.DB) error { } } } + + // 升级EdgeDNS线路 + ones, _, err := db.FindOnes("SELECT id, dnsRoutes FROM edgeNodes WHERE dnsRoutes IS NOT NULL") + if err != nil { + return err + } + for _, one := range ones { + var nodeId = one.GetInt64("id") + var dnsRoutes = one.GetString("dnsRoutes") + if len(dnsRoutes) == 0 { + continue + } + var m = map[string][]string{} + err = json.Unmarshal([]byte(dnsRoutes), &m) + if err != nil { + continue + } + var isChanged = false + var reg = regexp.MustCompile(`^\d+$`) + for k, routes := range m { + for index, route := range routes { + if reg.MatchString(route) { + route = "id:" + route + isChanged = true + } + routes[index] = route + } + m[k] = routes + } + + if isChanged { + mJSON, err := json.Marshal(m) + if err != nil { + return err + } + _, err = db.Exec("UPDATE edgeNodes SET dnsRoutes=? WHERE id=? LIMIT 1", string(mJSON), nodeId) + if err != nil { + return err + } + } + } + return nil }