优化代码

This commit is contained in:
刘祥超
2022-09-23 15:25:51 +08:00
parent 5655f89ba6
commit c94b3c26c1
18 changed files with 358 additions and 1457 deletions

View File

@@ -2,19 +2,14 @@ package models
import (
"fmt"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/goman"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/rands"
timeutil "github.com/iwind/TeaGo/utils/time"
"hash/crc32"
"regexp"
"strconv"
"strings"
"sync"
"time"
)
@@ -32,22 +27,12 @@ type httpAccessLogDefinition struct {
// HTTP服务访问
var httpAccessLogDAOMapping = map[int64]*HTTPAccessLogDAOWrapper{} // dbNodeId => DAO
// DNS服务访问
var nsAccessLogDAOMapping = map[int64]*NSAccessLogDAOWrapper{} // dbNodeId => DAO
var nsAccessLogTableMapping = map[string]bool{} // tableName_crc(dsn) => true
// HTTPAccessLogDAOWrapper HTTP访问日志DAO
type HTTPAccessLogDAOWrapper struct {
DAO *HTTPAccessLogDAO
NodeId int64
}
// NSAccessLogDAOWrapper NS访问日志DAO
type NSAccessLogDAOWrapper struct {
DAO *NSAccessLogDAO
NodeId int64
}
func init() {
initializer := NewDBNodeInitializer()
dbs.OnReadyDone(func() {
@@ -103,112 +88,6 @@ func randomHTTPAccessLogDAO() (dao *HTTPAccessLogDAOWrapper) {
return daoList[rands.Int(0, l-1)]
}
func randomNSAccessLogDAO() (dao *NSAccessLogDAOWrapper) {
accessLogLocker.RLock()
defer accessLogLocker.RUnlock()
if len(nsAccessLogDAOMapping) == 0 {
dao = nil
return
}
var daoList = []*NSAccessLogDAOWrapper{}
for _, d := range nsAccessLogDAOMapping {
daoList = append(daoList, d)
}
var l = len(daoList)
if l == 0 {
return
}
if l == 1 {
return daoList[0]
}
return daoList[rands.Int(0, l-1)]
}
func findNSAccessLogTableName(db *dbs.DB, day string) (tableName string, ok bool, err error) {
if !regexp.MustCompile(`^\d{8}$`).MatchString(day) {
err = errors.New("invalid day '" + day + "', should be YYYYMMDD")
return
}
config, err := db.Config()
if err != nil {
return "", false, err
}
tableName = "edgeNSAccessLogs_" + day
cacheKey := tableName + "_" + fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(config.Dsn)))
accessLogLocker.RLock()
_, ok = nsAccessLogTableMapping[cacheKey]
accessLogLocker.RUnlock()
if ok {
return tableName, true, nil
}
tableNames, err := db.TableNames()
if err != nil {
return tableName, false, err
}
return tableName, utils.ContainsStringInsensitive(tableNames, tableName), nil
}
func findNSAccessLogTable(db *dbs.DB, day string, force bool) (string, error) {
config, err := db.Config()
if err != nil {
return "", err
}
var tableName = "edgeNSAccessLogs_" + day
var cacheKey = tableName + "_" + fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(config.Dsn)))
if !force {
accessLogLocker.RLock()
_, ok := nsAccessLogTableMapping[cacheKey]
accessLogLocker.RUnlock()
if ok {
return tableName, nil
}
}
tableNames, err := db.TableNames()
if err != nil {
return tableName, err
}
if utils.ContainsStringInsensitive(tableNames, tableName) {
accessLogLocker.Lock()
nsAccessLogTableMapping[cacheKey] = true
accessLogLocker.Unlock()
return tableName, nil
}
// 创建表格
_, err = db.Exec("CREATE TABLE `" + tableName + "` (\n `id` bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 'ID',\n `nodeId` int(11) unsigned DEFAULT '0' COMMENT '节点ID',\n `domainId` int(11) unsigned DEFAULT '0' COMMENT '域名ID',\n `recordId` int(11) unsigned DEFAULT '0' COMMENT '记录ID',\n `content` json DEFAULT NULL COMMENT '访问数据',\n `requestId` varchar(128) DEFAULT NULL COMMENT '请求ID',\n `createdAt` bigint(11) unsigned DEFAULT '0' COMMENT '创建时间',\n `remoteAddr` varchar(128) DEFAULT NULL COMMENT 'IP',\n PRIMARY KEY (`id`),\n KEY `nodeId` (`nodeId`),\n KEY `domainId` (`domainId`),\n KEY `recordId` (`recordId`),\n KEY `requestId` (`requestId`),\n KEY `remoteAddr` (`remoteAddr`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='域名服务访问日志';")
if err != nil {
if CheckSQLErrCode(err, 1050) { // Error 1050: Table 'xxx' already exists
accessLogLocker.Lock()
nsAccessLogTableMapping[cacheKey] = true
accessLogLocker.Unlock()
return tableName, nil
}
return tableName, err
}
accessLogLocker.Lock()
nsAccessLogTableMapping[cacheKey] = true
accessLogLocker.Unlock()
return tableName, nil
}
// DBNodeInitializer 初始化数据库连接
type DBNodeInitializer struct {
}
@@ -242,14 +121,14 @@ func (this *DBNodeInitializer) loop() error {
return err
}
nodeIds := []int64{}
var nodeIds = []int64{}
for _, node := range dbNodes {
nodeIds = append(nodeIds, int64(node.Id))
}
// 关掉老的
accessLogLocker.Lock()
closingDbs := []*dbs.DB{}
var closingDbs = []*dbs.DB{}
for nodeId, db := range accessLogDBMapping {
if !lists.ContainsInt64(nodeIds, nodeId) {
closingDbs = append(closingDbs, db)
@@ -266,12 +145,12 @@ func (this *DBNodeInitializer) loop() error {
// 启动新的
for _, node := range dbNodes {
nodeId := int64(node.Id)
var nodeId = int64(node.Id)
accessLogLocker.Lock()
db, ok := accessLogDBMapping[nodeId]
accessLogLocker.Unlock()
dsn := node.Username + ":" + node.Password + "@tcp(" + node.Host + ":" + fmt.Sprintf("%d", node.Port) + ")/" + node.Database + "?charset=utf8mb4&timeout=10s"
var dsn = node.Username + ":" + node.Password + "@tcp(" + node.Host + ":" + fmt.Sprintf("%d", node.Port) + ")/" + node.Database + "?charset=utf8mb4&timeout=10s"
if ok {
// 检查配置是否有变化
@@ -341,49 +220,8 @@ func (this *DBNodeInitializer) loop() error {
accessLogLocker.Unlock()
}
// nsAccessLog
{
tableName, err := findNSAccessLogTable(db, timeutil.Format("Ymd"), false)
if err != nil {
if !strings.Contains(err.Error(), "1050") { // 非表格已存在错误
remotelogs.Error("DB_NODE", "create first table in database node failed: "+err.Error())
// 创建节点日志
createLogErr := SharedNodeLogDAO.CreateLog(nil, nodeconfigs.NodeRoleDatabase, nodeId, 0, 0, "error", "ACCESS_LOG", "can not create access log table: "+err.Error(), time.Now().Unix(), "", nil)
if createLogErr != nil {
remotelogs.Error("NODE_LOG", createLogErr.Error())
}
continue
} else {
err = nil
}
}
daoObject := dbs.DAOObject{
Instance: db,
DB: node.Name + "(id:" + strconv.Itoa(int(node.Id)) + ")",
Table: tableName,
PkName: "id",
Model: new(NSAccessLog),
}
err = daoObject.Init()
if err != nil {
remotelogs.Error("DB_NODE", "initialize dao failed: "+err.Error())
continue
}
accessLogLocker.Lock()
accessLogDBMapping[nodeId] = db
dao := &NSAccessLogDAO{
DAOObject: daoObject,
}
nsAccessLogDAOMapping[nodeId] = &NSAccessLogDAOWrapper{
DAO: dao,
NodeId: nodeId,
}
accessLogLocker.Unlock()
}
// 扩展
initAccessLogDAO(db, node)
}
}

View File

@@ -0,0 +1,11 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package models
import "github.com/iwind/TeaGo/dbs"
var nsAccessLogDAOMapping = map[int64]any{} // dbNodeId => DAO
func initAccessLogDAO(db *dbs.DB, node *DBNode) {
}

View File

@@ -391,6 +391,13 @@ func (this *NodeLogDAO) DeleteNodeLogsWithCluster(tx *dbs.Tx, role nodeconfigs.N
if clusterId <= 0 {
return nil
}
// 执行钩子
err := this.deleteNodeLogsWithCluster(tx, role, clusterId)
if err != nil {
return err
}
var query = this.Query(tx).
Attr("role", role)
@@ -398,13 +405,10 @@ func (this *NodeLogDAO) DeleteNodeLogsWithCluster(tx *dbs.Tx, role nodeconfigs.N
case nodeconfigs.NodeRoleNode:
query.Where("nodeId IN (SELECT id FROM " + SharedNodeDAO.Table + " WHERE clusterId=:clusterId)")
query.Param("clusterId", clusterId)
case nodeconfigs.NodeRoleDNS:
query.Where("nodeId IN (SELECT id FROM " + SharedNSNodeDAO.Table + " WHERE clusterId=:clusterId)")
query.Param("clusterId", clusterId)
default:
return nil
}
_, err := query.Delete()
_, err = query.Delete()
return err
}

View File

@@ -0,0 +1,13 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/iwind/TeaGo/dbs"
)
func (this *NodeLogDAO) deleteNodeLogsWithCluster(tx *dbs.Tx, role nodeconfigs.NodeRole, clusterId int64) error {
return nil
}

View File

@@ -155,32 +155,3 @@ func (this *NodeLoginDAO) FindFrequentPorts(tx *dbs.Tx) ([]int32, error) {
}
return ports, nil
}
func (this *NodeLoginDAO) FindFrequentGrantIds(tx *dbs.Tx, nodeClusterId int64, nsClusterId int64) ([]int64, error) {
var query = this.Query(tx).
Attr("state", NodeLoginStateEnabled).
Result("JSON_EXTRACT(params, '$.grantId') as `grantId`", "COUNT(*) AS c").
Having("grantId>0").
Desc("c").
Limit(3).
Group("grantId")
if nodeClusterId > 0 {
query.Attr("role", nodeconfigs.NodeRoleNode)
query.Where("(nodeId IN (SELECT id FROM "+SharedNodeDAO.Table+" WHERE state=1 AND clusterId=:clusterId))").
Param("clusterId", nodeClusterId)
} else if nsClusterId > 0 {
query.Attr("role", nodeconfigs.NodeRoleDNS)
query.Where("(nodeId IN (SELECT id FROM "+SharedNSNodeDAO.Table+" WHERE state=1 AND clusterId=:clusterId))").
Param("clusterId", nsClusterId)
}
ones, _, err := query.
FindOnes()
if err != nil {
return nil, err
}
var grantIds = []int64{}
for _, one := range ones {
grantIds = append(grantIds, one.GetInt64("grantId"))
}
return grantIds, nil
}

View File

@@ -0,0 +1,36 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/iwind/TeaGo/dbs"
)
func (this *NodeLoginDAO) FindFrequentGrantIds(tx *dbs.Tx, nodeClusterId int64, nsClusterId int64) ([]int64, error) {
var query = this.Query(tx).
Attr("state", NodeLoginStateEnabled).
Result("JSON_EXTRACT(params, '$.grantId') as `grantId`", "COUNT(*) AS c").
Having("grantId>0").
Desc("c").
Limit(3).
Group("grantId")
if nodeClusterId > 0 {
query.Attr("role", nodeconfigs.NodeRoleNode)
query.Where("(nodeId IN (SELECT id FROM "+SharedNodeDAO.Table+" WHERE state=1 AND clusterId=:clusterId))").
Param("clusterId", nodeClusterId)
} else if nsClusterId > 0 {
return nil, nil
}
ones, _, err := query.
FindOnes()
if err != nil {
return nil, err
}
var grantIds = []int64{}
for _, one := range ones {
grantIds = append(grantIds, one.GetInt64("grantId"))
}
return grantIds, nil
}

View File

@@ -159,45 +159,6 @@ func (this *NodeTaskDAO) ExtractNodeClusterTask(tx *dbs.Tx, clusterId int64, ser
return nil
}
// ExtractNSClusterTask 分解NS节点集群任务
func (this *NodeTaskDAO) ExtractNSClusterTask(tx *dbs.Tx, clusterId int64, taskType NodeTaskType) error {
nodeIds, err := SharedNSNodeDAO.FindAllNodeIdsMatch(tx, clusterId, true, configutils.BoolStateYes)
if err != nil {
return err
}
_, err = this.Query(tx).
Attr("role", nodeconfigs.NodeRoleDNS).
Attr("clusterId", clusterId).
Param("clusterIdString", types.String(clusterId)).
Where("nodeId > 0").
Attr("type", taskType).
Delete()
if err != nil {
return err
}
var version = time.Now().UnixNano()
for _, nodeId := range nodeIds {
err = this.CreateNodeTask(tx, nodeconfigs.NodeRoleDNS, clusterId, nodeId, 0, taskType, version)
if err != nil {
return err
}
}
_, err = this.Query(tx).
Attr("role", nodeconfigs.NodeRoleDNS).
Attr("clusterId", clusterId).
Attr("nodeId", 0).
Attr("type", taskType).
Delete()
if err != nil {
return err
}
return nil
}
// ExtractAllClusterTasks 分解所有集群任务
func (this *NodeTaskDAO) ExtractAllClusterTasks(tx *dbs.Tx, role string) error {
ones, err := this.Query(tx).

View File

@@ -0,0 +1,11 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package models
import "github.com/iwind/TeaGo/dbs"
// ExtractNSClusterTask 分解NS节点集群任务
func (this *NodeTaskDAO) ExtractNSClusterTask(tx *dbs.Tx, clusterId int64, taskType NodeTaskType) error {
return nil
}

View File

@@ -1,336 +0,0 @@
package models
import (
"encoding/json"
dbutils "github.com/TeaOSLab/EdgeAPI/internal/db/utils"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"regexp"
"sort"
"strings"
"sync"
"time"
)
type NSAccessLogDAO dbs.DAO
func NewNSAccessLogDAO() *NSAccessLogDAO {
return dbs.NewDAO(&NSAccessLogDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgeNSAccessLogs",
Model: new(NSAccessLog),
PkName: "id",
},
}).(*NSAccessLogDAO)
}
var SharedNSAccessLogDAO *NSAccessLogDAO
func init() {
dbs.OnReady(func() {
SharedNSAccessLogDAO = NewNSAccessLogDAO()
})
}
// CreateNSAccessLogs 创建访问日志
func (this *NSAccessLogDAO) CreateNSAccessLogs(tx *dbs.Tx, accessLogs []*pb.NSAccessLog) error {
dao := randomNSAccessLogDAO()
if dao == nil {
dao = &NSAccessLogDAOWrapper{
DAO: SharedNSAccessLogDAO,
NodeId: 0,
}
}
return this.CreateNSAccessLogsWithDAO(tx, dao, accessLogs)
}
// CreateNSAccessLogsWithDAO 使用特定的DAO创建访问日志
func (this *NSAccessLogDAO) CreateNSAccessLogsWithDAO(tx *dbs.Tx, daoWrapper *NSAccessLogDAOWrapper, accessLogs []*pb.NSAccessLog) error {
if daoWrapper == nil {
return errors.New("dao should not be nil")
}
if len(accessLogs) == 0 {
return nil
}
dao := daoWrapper.DAO
// TODO 改成事务批量提交,以加快速度
for _, accessLog := range accessLogs {
day := timeutil.Format("Ymd", time.Unix(accessLog.Timestamp, 0))
table, err := findNSAccessLogTable(dao.Instance, day, false)
if err != nil {
return err
}
fields := map[string]interface{}{}
fields["nodeId"] = accessLog.NsNodeId
fields["domainId"] = accessLog.NsDomainId
fields["recordId"] = accessLog.NsRecordId
fields["createdAt"] = accessLog.Timestamp
fields["requestId"] = accessLog.RequestId
content, err := json.Marshal(accessLog)
if err != nil {
return err
}
fields["content"] = content
_, err = dao.Query(tx).
Table(table).
Sets(fields).
Insert()
if err != nil {
// 是否为 Error 1146: Table 'xxx.xxx' doesn't exist 如果是,则创建表之后重试
if strings.Contains(err.Error(), "1146") {
table, err = findNSAccessLogTable(dao.Instance, day, true)
if err != nil {
return err
}
_, err = dao.Query(tx).
Table(table).
Sets(fields).
Insert()
if err != nil {
return err
}
}
}
}
return nil
}
// ListAccessLogs 读取往前的 单页访问日志
func (this *NSAccessLogDAO) ListAccessLogs(tx *dbs.Tx, lastRequestId string, size int64, day string, clusterId int64, nodeId int64, domainId int64, recordId int64, recordType string, keyword string, reverse bool) (result []*NSAccessLog, nextLastRequestId string, hasMore bool, err error) {
if len(day) != 8 {
return
}
// 限制能查询的最大条数,防止占用内存过多
if size > 1000 {
size = 1000
}
result, nextLastRequestId, err = this.listAccessLogs(tx, lastRequestId, size, day, clusterId, nodeId, domainId, recordId, recordType, keyword, reverse)
if err != nil || int64(len(result)) < size {
return
}
moreResult, _, _ := this.listAccessLogs(tx, nextLastRequestId, 1, day, clusterId, nodeId, domainId, recordId, recordType, keyword, reverse)
hasMore = len(moreResult) > 0
return
}
// 读取往前的单页访问日志
func (this *NSAccessLogDAO) listAccessLogs(tx *dbs.Tx, lastRequestId string, size int64, day string, clusterId int64, nodeId int64, domainId int64, recordId int64, recordType string, keyword string, reverse bool) (result []*NSAccessLog, nextLastRequestId string, err error) {
if size <= 0 {
return nil, lastRequestId, nil
}
accessLogLocker.RLock()
var daoList = []*NSAccessLogDAOWrapper{}
for _, daoWrapper := range nsAccessLogDAOMapping {
daoList = append(daoList, daoWrapper)
}
accessLogLocker.RUnlock()
if len(daoList) == 0 {
daoList = []*NSAccessLogDAOWrapper{{
DAO: SharedNSAccessLogDAO,
NodeId: 0,
}}
}
// 检查是否有集群筛选条件
var nodeIds []int64
if clusterId > 0 && nodeId <= 0 {
nodeIds, err = SharedNSNodeDAO.FindEnabledNodeIdsWithClusterId(tx, clusterId)
if err != nil {
return
}
if len(nodeIds) == 0 {
// 没有任何节点则直接返回空
return nil, "", nil
}
}
var locker = sync.Mutex{}
var count = len(daoList)
var wg = &sync.WaitGroup{}
wg.Add(count)
for _, daoWrapper := range daoList {
go func(daoWrapper *NSAccessLogDAOWrapper) {
defer wg.Done()
dao := daoWrapper.DAO
tableName, exists, err := findNSAccessLogTableName(dao.Instance, day)
if !exists {
// 表格不存在则跳过
return
}
if err != nil {
logs.Println("[DB_NODE]" + err.Error())
return
}
var query = dao.Query(tx)
// 条件
if nodeId > 0 {
query.Attr("nodeId", nodeId)
} else if clusterId > 0 {
query.Attr("nodeId", nodeIds)
query.Reuse(false)
}
if domainId > 0 {
query.Attr("domainId", domainId)
}
if recordId > 0 {
query.Attr("recordId", recordId)
}
// offset
if len(lastRequestId) > 0 {
if !reverse {
query.Where("requestId<:requestId").
Param("requestId", lastRequestId)
} else {
query.Where("requestId>:requestId").
Param("requestId", lastRequestId)
}
}
// keyword
if len(keyword) > 0 {
query.Where("(JSON_EXTRACT(content, '$.remoteAddr') LIKE :keyword OR JSON_EXTRACT(content, '$.questionName') LIKE :keyword OR JSON_EXTRACT(content, '$.recordValue') LIKE :keyword)").
Param("keyword", dbutils.QuoteLike(keyword))
}
// record type
if len(recordType) > 0 {
query.Where("JSON_EXTRACT(content, '$.questionType')=:recordType")
query.Param("recordType", recordType)
}
if !reverse {
query.Desc("requestId")
} else {
query.Asc("requestId")
}
// 开始查询
ones, err := query.
Table(tableName).
Limit(size).
FindAll()
if err != nil {
logs.Println("[DB_NODE]" + err.Error())
return
}
locker.Lock()
for _, one := range ones {
accessLog := one.(*NSAccessLog)
result = append(result, accessLog)
}
locker.Unlock()
}(daoWrapper)
}
wg.Wait()
if len(result) == 0 {
return nil, lastRequestId, nil
}
// 按照requestId排序
sort.Slice(result, func(i, j int) bool {
if !reverse {
return result[i].RequestId > result[j].RequestId
} else {
return result[i].RequestId < result[j].RequestId
}
})
if int64(len(result)) > size {
result = result[:size]
}
var requestId = result[len(result)-1].RequestId
if reverse {
lists.Reverse(result)
}
if !reverse {
return result, requestId, nil
} else {
return result, requestId, nil
}
}
// FindAccessLogWithRequestId 根据请求ID获取访问日志
func (this *NSAccessLogDAO) FindAccessLogWithRequestId(tx *dbs.Tx, requestId string) (*NSAccessLog, error) {
if !regexp.MustCompile(`^\d{11,}`).MatchString(requestId) {
return nil, errors.New("invalid requestId")
}
accessLogLocker.RLock()
daoList := []*NSAccessLogDAOWrapper{}
for _, daoWrapper := range nsAccessLogDAOMapping {
daoList = append(daoList, daoWrapper)
}
accessLogLocker.RUnlock()
if len(daoList) == 0 {
daoList = []*NSAccessLogDAOWrapper{{
DAO: SharedNSAccessLogDAO,
NodeId: 0,
}}
}
count := len(daoList)
wg := &sync.WaitGroup{}
wg.Add(count)
var result *NSAccessLog = nil
day := timeutil.FormatTime("Ymd", types.Int64(requestId[:10]))
for _, daoWrapper := range daoList {
go func(daoWrapper *NSAccessLogDAOWrapper) {
defer wg.Done()
dao := daoWrapper.DAO
tableName, exists, err := findNSAccessLogTableName(dao.Instance, day)
if err != nil {
logs.Println("[DB_NODE]" + err.Error())
return
}
if !exists {
return
}
one, err := dao.Query(tx).
Table(tableName).
Attr("requestId", requestId).
Find()
if err != nil {
logs.Println("[DB_NODE]" + err.Error())
return
}
if one != nil {
result = one.(*NSAccessLog)
}
}(daoWrapper)
}
wg.Wait()
return result, nil
}

View File

@@ -1,6 +0,0 @@
package models
import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap"
)

View File

@@ -1,22 +1,15 @@
//go:build !plus
package models
import (
"encoding/json"
dbutils "github.com/TeaOSLab/EdgeAPI/internal/db/utils"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"time"
)
const (
@@ -95,110 +88,6 @@ func (this *NSNodeDAO) FindEnabledNSNodeName(tx *dbs.Tx, nodeId int64) (string,
FindStringCol("")
}
// FindAllEnabledNodesWithClusterId 查找一个集群下的所有节点
func (this *NSNodeDAO) FindAllEnabledNodesWithClusterId(tx *dbs.Tx, clusterId int64) (result []*NSNode, err error) {
_, err = this.Query(tx).
Attr("clusterId", clusterId).
State(NSNodeStateEnabled).
DescPk().
Slice(&result).
FindAll()
return
}
// CountAllEnabledNodes 所有集群的可用的节点数量
func (this *NSNodeDAO) CountAllEnabledNodes(tx *dbs.Tx) (int64, error) {
return this.Query(tx).
State(NSNodeStateEnabled).
Where("clusterId IN (SELECT id FROM " + SharedNSClusterDAO.Table + " WHERE state=1)").
Count()
}
// CountAllOfflineNodes 计算离线节点数量
func (this *NSNodeDAO) CountAllOfflineNodes(tx *dbs.Tx) (int64, error) {
return this.Query(tx).
State(NSNodeStateEnabled).
Where("(status IS NULL OR JSON_EXTRACT(status, '$.updatedAt')<UNIX_TIMESTAMP()-120)").
Where("clusterId IN (SELECT id FROM " + SharedNSClusterDAO.Table + " WHERE state=1)").
Count()
}
// CountAllEnabledNodesMatch 计算满足条件的节点数量
func (this *NSNodeDAO) CountAllEnabledNodesMatch(tx *dbs.Tx, clusterId int64, installState configutils.BoolState, activeState configutils.BoolState, keyword string) (int64, error) {
query := this.Query(tx)
if clusterId > 0 {
query.Attr("clusterId", clusterId)
}
// 安装状态
switch installState {
case configutils.BoolStateAll:
// 所有
case configutils.BoolStateYes:
query.Attr("isInstalled", 1)
case configutils.BoolStateNo:
query.Attr("isInstalled", 0)
}
// 在线状态
switch activeState {
case configutils.BoolStateAll:
// 所有
case configutils.BoolStateYes:
query.Where("(isActive=1 AND JSON_EXTRACT(status, '$.isActive') AND UNIX_TIMESTAMP()-JSON_EXTRACT(status, '$.updatedAt')<=60)")
case configutils.BoolStateNo:
query.Where("(isActive=0 OR status IS NULL OR NOT JSON_EXTRACT(status, '$.isActive') OR UNIX_TIMESTAMP()-JSON_EXTRACT(status, '$.updatedAt')>60)")
}
if len(keyword) > 0 {
query.Where("(name LIKE :keyword)").
Param("keyword", dbutils.QuoteLike(keyword))
}
return query.
State(NSNodeStateEnabled).
Count()
}
// ListAllEnabledNodesMatch 列出单页匹配的节点
func (this *NSNodeDAO) ListAllEnabledNodesMatch(tx *dbs.Tx, clusterId int64, installState configutils.BoolState, activeState configutils.BoolState, keyword string, offset int64, size int64) (result []*NSNode, err error) {
query := this.Query(tx)
// 安装状态
switch installState {
case configutils.BoolStateAll:
// 所有
case configutils.BoolStateYes:
query.Attr("isInstalled", 1)
case configutils.BoolStateNo:
query.Attr("isInstalled", 0)
}
// 在线状态
switch activeState {
case configutils.BoolStateAll:
// 所有
case configutils.BoolStateYes:
query.Where("(isActive=1 AND JSON_EXTRACT(status, '$.isActive') AND UNIX_TIMESTAMP()-JSON_EXTRACT(status, '$.updatedAt')<=60)")
case configutils.BoolStateNo:
query.Where("(isActive=0 OR status IS NULL OR NOT JSON_EXTRACT(status, '$.isActive') OR UNIX_TIMESTAMP()-JSON_EXTRACT(status, '$.updatedAt')>60)")
}
if clusterId > 0 {
query.Attr("clusterId", clusterId)
}
if len(keyword) > 0 {
query.Where("(name LIKE :keyword)").
Param("keyword", dbutils.QuoteLike(keyword))
}
_, err = query.
State(NSNodeStateEnabled).
Offset(offset).
Limit(size).
Slice(&result).
DescPk().
FindAll()
return
}
// CountAllLowerVersionNodesWithClusterId 计算单个集群中所有低于某个版本的节点数量
func (this *NSNodeDAO) CountAllLowerVersionNodesWithClusterId(tx *dbs.Tx, clusterId int64, os string, arch string, version string) (int64, error) {
return this.Query(tx).
@@ -214,73 +103,6 @@ func (this *NSNodeDAO) CountAllLowerVersionNodesWithClusterId(tx *dbs.Tx, cluste
Count()
}
// CreateNode 创建节点
func (this *NSNodeDAO) CreateNode(tx *dbs.Tx, adminId int64, name string, clusterId int64) (nodeId int64, err error) {
uniqueId, err := this.GenUniqueId(tx)
if err != nil {
return 0, err
}
secret := rands.String(32)
// 保存API Token
err = SharedApiTokenDAO.CreateAPIToken(tx, uniqueId, secret, nodeconfigs.NodeRoleDNS)
if err != nil {
return
}
var op = NewNSNodeOperator()
op.AdminId = adminId
op.Name = name
op.UniqueId = uniqueId
op.Secret = secret
op.ClusterId = clusterId
op.IsOn = 1
op.State = NSNodeStateEnabled
err = this.Save(tx, op)
if err != nil {
return 0, err
}
// 通知节点更新
nodeId = types.Int64(op.Id)
err = this.NotifyUpdate(tx, nodeId)
if err != nil {
return 0, err
}
// 通知DNS更新
err = this.NotifyDNSUpdate(tx, nodeId)
if err != nil {
return 0, err
}
return nodeId, nil
}
// UpdateNode 修改节点
func (this *NSNodeDAO) UpdateNode(tx *dbs.Tx, nodeId int64, name string, clusterId int64, isOn bool) error {
if nodeId <= 0 {
return errors.New("invalid nodeId")
}
var op = NewNSNodeOperator()
op.Id = nodeId
op.Name = name
op.ClusterId = clusterId
op.IsOn = isOn
err := this.Save(tx, op)
if err != nil {
return err
}
err = this.NotifyUpdate(tx, nodeId)
if err != nil {
return err
}
return this.NotifyDNSUpdate(tx, nodeId)
}
// FindEnabledNodeIdWithUniqueId 根据唯一ID获取节点ID
func (this *NSNodeDAO) FindEnabledNodeIdWithUniqueId(tx *dbs.Tx, uniqueId string) (int64, error) {
return this.Query(tx).
@@ -290,37 +112,6 @@ func (this *NSNodeDAO) FindEnabledNodeIdWithUniqueId(tx *dbs.Tx, uniqueId string
FindInt64Col(0)
}
// FindNodeInstallStatus 查询节点的安装状态
func (this *NSNodeDAO) FindNodeInstallStatus(tx *dbs.Tx, nodeId int64) (*NodeInstallStatus, error) {
node, err := this.Query(tx).
Pk(nodeId).
Result("installStatus", "isInstalled").
Find()
if err != nil {
return nil, err
}
if node == nil {
return nil, errors.New("not found")
}
installStatus := node.(*NSNode).InstallStatus
isInstalled := node.(*NSNode).IsInstalled
if len(installStatus) == 0 {
return NewNodeInstallStatus(), nil
}
status := &NodeInstallStatus{}
err = json.Unmarshal(installStatus, status)
if err != nil {
return nil, err
}
if isInstalled {
status.IsFinished = true
status.IsOk = true
}
return status, nil
}
// GenUniqueId 生成唯一ID
func (this *NSNodeDAO) GenUniqueId(tx *dbs.Tx) (string, error) {
for {
@@ -377,132 +168,6 @@ func (this *NSNodeDAO) CountAllLowerVersionNodes(tx *dbs.Tx, version string) (in
Count()
}
// ComposeNodeConfig 组合节点配置
func (this *NSNodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64) (*dnsconfigs.NSNodeConfig, error) {
if nodeId <= 0 {
return nil, nil
}
node, err := this.FindEnabledNSNode(tx, nodeId)
if err != nil {
return nil, err
}
if node == nil {
return nil, nil
}
cluster, err := SharedNSClusterDAO.FindEnabledNSCluster(tx, int64(node.ClusterId))
if err != nil {
return nil, err
}
if cluster == nil {
return nil, nil
}
var config = &dnsconfigs.NSNodeConfig{
Id: int64(node.Id),
NodeId: node.UniqueId,
Secret: node.Secret,
ClusterId: int64(node.ClusterId),
TimeZone: cluster.TimeZone,
}
// 访问日志
// 全局配置
{
globalValue, err := SharedSysSettingDAO.ReadSetting(tx, systemconfigs.SettingCodeNSAccessLogSetting)
if err != nil {
return nil, err
}
if len(globalValue) > 0 {
var ref = &dnsconfigs.NSAccessLogRef{}
err = json.Unmarshal(globalValue, ref)
if err != nil {
return nil, err
}
config.AccessLogRef = ref
}
// 集群配置
if len(cluster.AccessLog) > 0 {
ref := &dnsconfigs.NSAccessLogRef{}
err = json.Unmarshal(cluster.AccessLog, ref)
if err != nil {
return nil, err
}
if ref.IsPrior {
config.AccessLogRef = ref
}
}
}
// 递归DNS配置
if IsNotNull(cluster.Recursion) {
var recursionConfig = &dnsconfigs.NSRecursionConfig{}
err = json.Unmarshal(cluster.Recursion, recursionConfig)
if err != nil {
return nil, err
}
config.RecursionConfig = recursionConfig
}
// TCP
if IsNotNull(cluster.Tcp) {
var tcpConfig = &serverconfigs.TCPProtocolConfig{}
err = json.Unmarshal(cluster.Tcp, tcpConfig)
if err != nil {
return nil, err
}
config.TCP = tcpConfig
}
// TLS
if IsNotNull(cluster.Tls) {
var tlsConfig = &serverconfigs.TLSProtocolConfig{}
err = json.Unmarshal(cluster.Tls, tlsConfig)
if err != nil {
return nil, err
}
// SSL
if tlsConfig.SSLPolicyRef != nil {
sslPolicyConfig, err := SharedSSLPolicyDAO.ComposePolicyConfig(tx, tlsConfig.SSLPolicyRef.SSLPolicyId, nil)
if err != nil {
return nil, err
}
if sslPolicyConfig != nil {
tlsConfig.SSLPolicy = sslPolicyConfig
}
}
config.TLS = tlsConfig
}
// UDP
if IsNotNull(cluster.Udp) {
var udpConfig = &serverconfigs.UDPProtocolConfig{}
err = json.Unmarshal(cluster.Udp, udpConfig)
if err != nil {
return nil, err
}
config.UDP = udpConfig
}
// DDoS
config.DDoSProtection = cluster.DecodeDDoSProtection()
// DDoS Protection
var ddosProtection = node.DecodeDDoSProtection()
if ddosProtection != nil {
if config.DDoSProtection == nil {
config.DDoSProtection = ddosProtection
} else {
config.DDoSProtection.Merge(ddosProtection)
}
}
return config, nil
}
// FindNodeClusterId 获取节点的集群ID
func (this *NSNodeDAO) FindNodeClusterId(tx *dbs.Tx, nodeId int64) (int64, error) {
return this.Query(tx).
@@ -511,197 +176,6 @@ func (this *NSNodeDAO) FindNodeClusterId(tx *dbs.Tx, nodeId int64) (int64, error
FindInt64Col(0)
}
// FindNodeActive 检查节点活跃状态
func (this *NSNodeDAO) FindNodeActive(tx *dbs.Tx, nodeId int64) (bool, error) {
isActive, err := this.Query(tx).
Pk(nodeId).
Result("isActive").
FindIntCol(0)
if err != nil {
return false, err
}
return isActive == 1, nil
}
// UpdateNodeActive 修改节点活跃状态
func (this *NSNodeDAO) UpdateNodeActive(tx *dbs.Tx, nodeId int64, isActive bool) error {
if nodeId <= 0 {
return errors.New("invalid nodeId")
}
_, err := this.Query(tx).
Pk(nodeId).
Set("isActive", isActive).
Set("statusIsNotified", false).
Set("inactiveNotifiedAt", 0).
Update()
return err
}
// UpdateNodeConnectedAPINodes 修改当前连接的API节点
func (this *NSNodeDAO) UpdateNodeConnectedAPINodes(tx *dbs.Tx, nodeId int64, apiNodeIds []int64) error {
if nodeId <= 0 {
return errors.New("invalid nodeId")
}
var op = NewNSNodeOperator()
op.Id = nodeId
if len(apiNodeIds) > 0 {
apiNodeIdsJSON, err := json.Marshal(apiNodeIds)
if err != nil {
return errors.Wrap(err)
}
op.ConnectedAPINodes = apiNodeIdsJSON
} else {
op.ConnectedAPINodes = "[]"
}
err := this.Save(tx, op)
return err
}
// FindAllNotifyingInactiveNodesWithClusterId 取得某个集群所有等待通知离线离线的节点
func (this *NSNodeDAO) FindAllNotifyingInactiveNodesWithClusterId(tx *dbs.Tx, clusterId int64) (result []*NSNode, err error) {
_, err = this.Query(tx).
State(NSNodeStateEnabled).
Attr("clusterId", clusterId).
Attr("isOn", true). // 只监控启用的节点
Attr("isInstalled", true). // 只监控已经安装的节点
Attr("isActive", false). // 当前已经离线的
Attr("statusIsNotified", false).
Result("id", "name").
Slice(&result).
FindAll()
return
}
// UpdateNodeStatusIsNotified 设置状态已经通知
func (this *NSNodeDAO) UpdateNodeStatusIsNotified(tx *dbs.Tx, nodeId int64) error {
return this.Query(tx).
Pk(nodeId).
Set("statusIsNotified", true).
Set("inactiveNotifiedAt", time.Now().Unix()).
UpdateQuickly()
}
// FindNodeInactiveNotifiedAt 读取上次的节点离线通知时间
func (this *NSNodeDAO) FindNodeInactiveNotifiedAt(tx *dbs.Tx, nodeId int64) (int64, error) {
return this.Query(tx).
Pk(nodeId).
Result("inactiveNotifiedAt").
FindInt64Col(0)
}
// FindAllNodeIdsMatch 匹配节点并返回节点ID
func (this *NSNodeDAO) FindAllNodeIdsMatch(tx *dbs.Tx, clusterId int64, includeSecondaryNodes bool, isOn configutils.BoolState) (result []int64, err error) {
query := this.Query(tx)
query.State(NSNodeStateEnabled)
if clusterId > 0 {
query.Attr("clusterId", clusterId)
}
if isOn == configutils.BoolStateYes {
query.Attr("isOn", true)
} else if isOn == configutils.BoolStateNo {
query.Attr("isOn", false)
}
query.Result("id")
ones, _, err := query.FindOnes()
if err != nil {
return nil, err
}
for _, one := range ones {
result = append(result, one.GetInt64("id"))
}
return
}
// UpdateNodeInstallStatus 修改节点的安装状态
func (this *NSNodeDAO) UpdateNodeInstallStatus(tx *dbs.Tx, nodeId int64, status *NodeInstallStatus) error {
if status == nil {
_, err := this.Query(tx).
Pk(nodeId).
Set("installStatus", "null").
Update()
return err
}
data, err := json.Marshal(status)
if err != nil {
return err
}
_, err = this.Query(tx).
Pk(nodeId).
Set("installStatus", string(data)).
Update()
return err
}
// FindEnabledNodeIdsWithClusterId 查找集群下的所有节点
func (this *NSNodeDAO) FindEnabledNodeIdsWithClusterId(tx *dbs.Tx, clusterId int64) ([]int64, error) {
if clusterId <= 0 {
return nil, nil
}
ones, err := this.Query(tx).
ResultPk().
Attr("clusterId", clusterId).
State(NSNodeStateEnabled).
FindAll()
if err != nil {
return nil, err
}
var result = []int64{}
for _, one := range ones {
result = append(result, int64(one.(*NSNode).Id))
}
return result, nil
}
// FindNodeDDoSProtection 获取节点的DDOS设置
func (this *NSNodeDAO) FindNodeDDoSProtection(tx *dbs.Tx, nodeId int64) (*ddosconfigs.ProtectionConfig, error) {
one, err := this.Query(tx).
Result("ddosProtection").
Pk(nodeId).
Find()
if one == nil || err != nil {
return nil, err
}
return one.(*NSNode).DecodeDDoSProtection(), nil
}
// UpdateNodeDDoSProtection 设置集群的DDOS设置
func (this *NSNodeDAO) UpdateNodeDDoSProtection(tx *dbs.Tx, nodeId int64, ddosProtection *ddosconfigs.ProtectionConfig) error {
if nodeId <= 0 {
return ErrNotFound
}
var op = NewNSNodeOperator()
op.Id = nodeId
if ddosProtection == nil {
op.DdosProtection = "{}"
} else {
ddosProtectionJSON, err := json.Marshal(ddosProtection)
if err != nil {
return err
}
op.DdosProtection = ddosProtectionJSON
}
err := this.Save(tx, op)
if err != nil {
return err
}
clusterId, err := this.FindNodeClusterId(tx, nodeId)
if err != nil {
return err
}
if clusterId > 0 {
return SharedNodeTaskDAO.CreateNodeTask(tx, nodeconfigs.NodeRoleDNS, clusterId, nodeId, 0, NSNodeTaskTypeDDosProtectionChanged, 0)
}
return nil
}
// NotifyUpdate 通知更新
func (this *NSNodeDAO) NotifyUpdate(tx *dbs.Tx, nodeId int64) error {
// TODO 先什么都不做

View File

@@ -1,6 +0,0 @@
package models
import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap"
)

View File

@@ -1,18 +1,7 @@
package rpcutils
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/authority"
"github.com/TeaOSLab/EdgeAPI/internal/encrypt"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"google.golang.org/grpc/metadata"
)
type UserType = string
@@ -33,181 +22,6 @@ const (
UserTypeReport = "report"
)
// ValidateRequest 校验请求
func ValidateRequest(ctx context.Context, userTypes ...UserType) (userType UserType, resultNodeId int64, userId int64, err error) {
if ctx == nil {
err = errors.New("context should not be nil")
return
}
// 支持直接认证
plainCtx, ok := ctx.(*PlainContext)
if ok {
userType = plainCtx.UserType
userId = plainCtx.UserId
if len(userTypes) > 0 && !lists.ContainsString(userTypes, userType) {
userType = UserTypeNone
userId = 0
}
if userId <= 0 {
err = errors.New("context: can not find user or permission denied")
}
return
}
// 是否是模拟测试
{
mockCtx, isMock := ctx.(*MockNodeContext)
if isMock {
return UserTypeNode, 0, mockCtx.NodeId, nil
}
}
{
mockCtx, isMock := ctx.(*MockAdminNodeContext)
if isMock {
return UserTypeAdmin, 0, mockCtx.AdminId, nil
}
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return UserTypeNone, 0, 0, errors.New("context: need 'nodeId'")
}
nodeIds := md.Get("nodeid")
if len(nodeIds) == 0 || len(nodeIds[0]) == 0 {
return UserTypeNone, 0, 0, errors.New("context: need 'nodeId'")
}
nodeId := nodeIds[0]
// 获取角色Node信息
apiToken, err := models.SharedApiTokenDAO.FindEnabledTokenWithNodeCacheable(nil, nodeId)
if err != nil {
utils.PrintError(err)
return UserTypeNone, 0, 0, err
}
nodeUserId := int64(0)
if apiToken == nil {
return UserTypeNode, 0, 0, errors.New("context: can not find api token for node '" + nodeId + "'")
}
tokens := md.Get("token")
if len(tokens) == 0 || len(tokens[0]) == 0 {
return UserTypeNone, 0, 0, errors.New("context: need 'token'")
}
token := tokens[0]
data, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return UserTypeNone, 0, 0, err
}
method, err := encrypt.NewMethodInstance(teaconst.EncryptMethod, apiToken.Secret, nodeId)
if err != nil {
utils.PrintError(err)
return UserTypeNone, 0, 0, err
}
data, err = method.Decrypt(data)
if err != nil {
return UserTypeNone, 0, 0, err
}
if len(data) == 0 {
return UserTypeNone, 0, 0, errors.New("invalid token")
}
var m = maps.Map{}
err = json.Unmarshal(data, &m)
if err != nil {
return UserTypeNone, 0, 0, errors.New("decode token error: " + err.Error())
}
t := m.GetString("type")
if len(userTypes) > 0 && !lists.ContainsString(userTypes, t) {
return UserTypeNone, 0, 0, errors.New("not supported node type: '" + t + "'")
}
switch apiToken.Role {
case UserTypeNode:
// TODO 需要检查集群是否已经删除
nodeIntId, err := models.SharedNodeDAO.FindEnabledNodeIdWithUniqueIdCacheable(nil, nodeId)
if err != nil {
return UserTypeNode, 0, 0, errors.New("context: " + err.Error())
}
if nodeIntId <= 0 {
return UserTypeNode, 0, 0, errors.New("context: not found node with id '" + nodeId + "'")
}
nodeUserId = nodeIntId
resultNodeId = nodeIntId
case UserTypeCluster:
clusterId, err := models.SharedNodeClusterDAO.FindEnabledClusterIdWithUniqueId(nil, nodeId)
if err != nil {
return UserTypeCluster, 0, 0, errors.New("context: " + err.Error())
}
if clusterId <= 0 {
return UserTypeCluster, 0, 0, errors.New("context: not found cluster with id '" + nodeId + "'")
}
nodeUserId = clusterId
resultNodeId = clusterId
case UserTypeUser:
nodeIntId, err := models.SharedUserNodeDAO.FindEnabledUserNodeIdWithUniqueId(nil, nodeId)
if err != nil {
return UserTypeUser, 0, 0, errors.New("context: " + err.Error())
}
if nodeIntId <= 0 {
return UserTypeUser, 0, 0, errors.New("context: not found node with id '" + nodeId + "'")
}
resultNodeId = nodeIntId
case UserTypeMonitor:
nodeIntId, err := models.SharedMonitorNodeDAO.FindEnabledMonitorNodeIdWithUniqueId(nil, nodeId)
if err != nil {
return UserTypeMonitor, 0, 0, errors.New("context: " + err.Error())
}
if nodeIntId <= 0 {
return UserTypeMonitor, 0, 0, errors.New("context: not found node with id '" + nodeId + "'")
}
resultNodeId = nodeIntId
case UserTypeAuthority:
nodeIntId, err := authority.SharedAuthorityNodeDAO.FindEnabledAuthorityNodeIdWithUniqueId(nil, nodeId)
if err != nil {
return UserTypeAuthority, 0, 0, errors.New("context: " + err.Error())
}
if nodeIntId <= 0 {
return UserTypeAuthority, 0, 0, errors.New("context: not found node with id '" + nodeId + "'")
}
nodeUserId = nodeIntId
resultNodeId = nodeIntId
case UserTypeDNS:
nodeIntId, err := models.SharedNSNodeDAO.FindEnabledNodeIdWithUniqueId(nil, nodeId)
if err != nil {
return UserTypeDNS, nodeIntId, 0, errors.New("context: " + err.Error())
}
if nodeIntId <= 0 {
return UserTypeDNS, nodeIntId, 0, errors.New("context: not found node with id '" + nodeId + "'")
}
nodeUserId = nodeIntId
resultNodeId = nodeIntId
case UserTypeReport:
nodeIntId, err := models.SharedReportNodeDAO.FindEnabledNodeIdWithUniqueId(nil, nodeId)
if err != nil {
return UserTypeReport, nodeIntId, 0, errors.New("context: " + err.Error())
}
if nodeIntId <= 0 {
return UserTypeReport, nodeIntId, 0, errors.New("context: not found node with id '" + nodeId + "'")
}
nodeUserId = nodeIntId
resultNodeId = nodeIntId
}
if nodeUserId > 0 {
return t, resultNodeId, nodeUserId, nil
} else {
return t, resultNodeId, m.GetInt64("userId"), nil
}
}
// Wrap 包装错误
func Wrap(description string, err error) error {
if err == nil {

View File

@@ -0,0 +1,153 @@
//go:build !plus
package rpcutils
import (
"context"
"encoding/base64"
"encoding/json"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/encrypt"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"google.golang.org/grpc/metadata"
)
// ValidateRequest 校验请求
func ValidateRequest(ctx context.Context, userTypes ...UserType) (userType UserType, resultNodeId int64, userId int64, err error) {
if ctx == nil {
err = errors.New("context should not be nil")
return
}
// 支持直接认证
plainCtx, ok := ctx.(*PlainContext)
if ok {
userType = plainCtx.UserType
userId = plainCtx.UserId
if len(userTypes) > 0 && !lists.ContainsString(userTypes, userType) {
userType = UserTypeNone
userId = 0
}
if userId <= 0 {
err = errors.New("context: can not find user or permission denied")
}
return
}
// 是否是模拟测试
{
mockCtx, isMock := ctx.(*MockNodeContext)
if isMock {
return UserTypeNode, 0, mockCtx.NodeId, nil
}
}
{
mockCtx, isMock := ctx.(*MockAdminNodeContext)
if isMock {
return UserTypeAdmin, 0, mockCtx.AdminId, nil
}
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return UserTypeNone, 0, 0, errors.New("context: need 'nodeId'")
}
nodeIds := md.Get("nodeid")
if len(nodeIds) == 0 || len(nodeIds[0]) == 0 {
return UserTypeNone, 0, 0, errors.New("context: need 'nodeId'")
}
nodeId := nodeIds[0]
// 获取角色Node信息
apiToken, err := models.SharedApiTokenDAO.FindEnabledTokenWithNodeCacheable(nil, nodeId)
if err != nil {
utils.PrintError(err)
return UserTypeNone, 0, 0, err
}
nodeUserId := int64(0)
if apiToken == nil {
return UserTypeNode, 0, 0, errors.New("context: can not find api token for node '" + nodeId + "'")
}
tokens := md.Get("token")
if len(tokens) == 0 || len(tokens[0]) == 0 {
return UserTypeNone, 0, 0, errors.New("context: need 'token'")
}
token := tokens[0]
data, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return UserTypeNone, 0, 0, err
}
method, err := encrypt.NewMethodInstance(teaconst.EncryptMethod, apiToken.Secret, nodeId)
if err != nil {
utils.PrintError(err)
return UserTypeNone, 0, 0, err
}
data, err = method.Decrypt(data)
if err != nil {
return UserTypeNone, 0, 0, err
}
if len(data) == 0 {
return UserTypeNone, 0, 0, errors.New("invalid token")
}
var m = maps.Map{}
err = json.Unmarshal(data, &m)
if err != nil {
return UserTypeNone, 0, 0, errors.New("decode token error: " + err.Error())
}
t := m.GetString("type")
if len(userTypes) > 0 && !lists.ContainsString(userTypes, t) {
return UserTypeNone, 0, 0, errors.New("not supported node type: '" + t + "'")
}
switch apiToken.Role {
case UserTypeNode:
// TODO 需要检查集群是否已经删除
nodeIntId, err := models.SharedNodeDAO.FindEnabledNodeIdWithUniqueIdCacheable(nil, nodeId)
if err != nil {
return UserTypeNode, 0, 0, errors.New("context: " + err.Error())
}
if nodeIntId <= 0 {
return UserTypeNode, 0, 0, errors.New("context: not found node with id '" + nodeId + "'")
}
nodeUserId = nodeIntId
resultNodeId = nodeIntId
case UserTypeCluster:
clusterId, err := models.SharedNodeClusterDAO.FindEnabledClusterIdWithUniqueId(nil, nodeId)
if err != nil {
return UserTypeCluster, 0, 0, errors.New("context: " + err.Error())
}
if clusterId <= 0 {
return UserTypeCluster, 0, 0, errors.New("context: not found cluster with id '" + nodeId + "'")
}
nodeUserId = clusterId
resultNodeId = clusterId
case UserTypeUser:
nodeIntId, err := models.SharedUserNodeDAO.FindEnabledUserNodeIdWithUniqueId(nil, nodeId)
if err != nil {
return UserTypeUser, 0, 0, errors.New("context: " + err.Error())
}
if nodeIntId <= 0 {
return UserTypeUser, 0, 0, errors.New("context: not found node with id '" + nodeId + "'")
}
resultNodeId = nodeIntId
}
if nodeUserId > 0 {
return t, resultNodeId, nodeUserId, nil
} else {
return t, resultNodeId, m.GetInt64("userId"), nil
}
}

View File

@@ -4,9 +4,7 @@ import (
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/systemconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
@@ -448,33 +446,6 @@ func (this *SQLExecutor) checkMetricItems(db *dbs.DB) error {
return nil
}
// 检查自建DNS全局设置
func (this *SQLExecutor) checkNS(db *dbs.DB) error {
// 访问日志
{
one, err := db.FindOne("SELECT id FROM edgeSysSettings WHERE code=? LIMIT 1", systemconfigs.SettingCodeNSAccessLogSetting)
if err != nil {
return err
}
if len(one) == 0 {
ref := &dnsconfigs.NSAccessLogRef{
IsPrior: false,
IsOn: true,
LogMissingDomains: false,
}
refJSON, err := json.Marshal(ref)
if err != nil {
return err
}
_, err = db.Exec("INSERT edgeSysSettings (code, value) VALUES (?, ?)", systemconfigs.SettingCodeNSAccessLogSetting, refJSON)
if err != nil {
return err
}
}
}
return nil
}
// 更新版本号
func (this *SQLExecutor) updateVersion(db *dbs.DB, version string) error {
stmt, err := db.Prepare("SELECT COUNT(*) FROM edgeVersions")

View File

@@ -0,0 +1,13 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package setup
import (
"github.com/iwind/TeaGo/dbs"
)
// 检查自建DNS全局设置
func (this *SQLExecutor) checkNS(db *dbs.DB) error {
return nil
}

View File

@@ -7,7 +7,6 @@ import (
"github.com/TeaOSLab/EdgeAPI/internal/db/models/stats"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/dnsconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
@@ -18,7 +17,6 @@ import (
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"regexp"
)
type upgradeVersion struct {
@@ -298,75 +296,6 @@ func upgradeV0_2_5(db *dbs.DB) error {
return nil
}
// v0.2.8.1
func upgradeV0_2_8_1(db *dbs.DB) error {
// 访问日志设置
{
one, err := db.FindOne("SELECT id FROM edgeSysSettings WHERE code=? LIMIT 1", systemconfigs.SettingCodeNSAccessLogSetting)
if err != nil {
return err
}
if len(one) == 0 {
ref := &dnsconfigs.NSAccessLogRef{
IsPrior: false,
IsOn: true,
LogMissingDomains: false,
}
refJSON, err := json.Marshal(ref)
if err != nil {
return err
}
_, err = db.Exec("INSERT edgeSysSettings (code, value) VALUES (?, ?)", systemconfigs.SettingCodeNSAccessLogSetting, refJSON)
if err != nil {
return err
}
}
}
// 升级EdgeDNS线路
ones, _, err := db.FindOnes("SELECT id, dnsRoutes FROM edgeNodes WHERE dnsRoutes IS NOT NULL")
if err != nil {
return err
}
for _, one := range ones {
var nodeId = one.GetInt64("id")
var dnsRoutes = one.GetString("dnsRoutes")
if len(dnsRoutes) == 0 {
continue
}
var m = map[string][]string{}
err = json.Unmarshal([]byte(dnsRoutes), &m)
if err != nil {
continue
}
var isChanged = false
var reg = regexp.MustCompile(`^\d+$`)
for k, routes := range m {
for index, route := range routes {
if reg.MatchString(route) {
route = "id:" + route
isChanged = true
}
routes[index] = route
}
m[k] = routes
}
if isChanged {
mJSON, err := json.Marshal(m)
if err != nil {
return err
}
_, err = db.Exec("UPDATE edgeNodes SET dnsRoutes=? WHERE id=? LIMIT 1", string(mJSON), nodeId)
if err != nil {
return err
}
}
}
return nil
}
// v0.3.0
func upgradeV0_3_0(db *dbs.DB) error {
// 升级健康检查
@@ -846,58 +775,3 @@ func upgradeV0_4_11(db *dbs.DB) error {
return nil
}
// v0.5.3
func upgradeV0_5_3(db *dbs.DB) error {
// 升级ns domains中的status字段
{
_, err := db.Exec("UPDATE edgeNSDomains SET status='" + dnsconfigs.NSDomainStatusVerified + "'")
if err != nil {
return err
}
}
// 升级集群服务配置
{
type oldGlobalConfig struct {
// HTTP & HTTPS相关配置
HTTPAll struct {
MatchDomainStrictly bool `yaml:"matchDomainStrictly" json:"matchDomainStrictly"` // 是否严格匹配域名
AllowMismatchDomains []string `yaml:"allowMismatchDomains" json:"allowMismatchDomains"` // 允许的不匹配的域名
DefaultDomain string `yaml:"defaultDomain" json:"defaultDomain"` // 默认的域名
DomainMismatchAction *serverconfigs.DomainMismatchAction `yaml:"domainMismatchAction" json:"domainMismatchAction"` // 不匹配时采取的动作
} `yaml:"httpAll" json:"httpAll"`
}
value, err := db.FindCol(0, "SELECT value FROM edgeSysSettings WHERE code='serverGlobalConfig'")
if err != nil {
return err
}
if value != nil {
var valueJSON = []byte(types.String(value))
var oldConfig = &oldGlobalConfig{}
err = json.Unmarshal(valueJSON, oldConfig)
if err == nil {
var newConfig = &serverconfigs.GlobalServerConfig{}
newConfig.HTTPAll.MatchDomainStrictly = oldConfig.HTTPAll.MatchDomainStrictly
newConfig.HTTPAll.AllowMismatchDomains = oldConfig.HTTPAll.AllowMismatchDomains
newConfig.HTTPAll.DefaultDomain = oldConfig.HTTPAll.DefaultDomain
if oldConfig.HTTPAll.DomainMismatchAction != nil {
newConfig.HTTPAll.DomainMismatchAction = oldConfig.HTTPAll.DomainMismatchAction
}
newConfig.HTTPAll.AllowNodeIP = true
newConfig.Log.RecordServerError = false
newConfigJSON, err := json.Marshal(newConfig)
if err == nil {
_, err = db.Exec("UPDATE edgeNodeClusters SET globalServerConfig=?", newConfigJSON)
if err != nil {
return err
}
}
}
}
}
return nil
}

View File

@@ -0,0 +1,105 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package setup
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
"regexp"
)
// v0.2.8.1
func upgradeV0_2_8_1(db *dbs.DB) error {
// 升级EdgeDNS线路
ones, _, err := db.FindOnes("SELECT id, dnsRoutes FROM edgeNodes WHERE dnsRoutes IS NOT NULL")
if err != nil {
return err
}
for _, one := range ones {
var nodeId = one.GetInt64("id")
var dnsRoutes = one.GetString("dnsRoutes")
if len(dnsRoutes) == 0 {
continue
}
var m = map[string][]string{}
err = json.Unmarshal([]byte(dnsRoutes), &m)
if err != nil {
continue
}
var isChanged = false
var reg = regexp.MustCompile(`^\d+$`)
for k, routes := range m {
for index, route := range routes {
if reg.MatchString(route) {
route = "id:" + route
isChanged = true
}
routes[index] = route
}
m[k] = routes
}
if isChanged {
mJSON, err := json.Marshal(m)
if err != nil {
return err
}
_, err = db.Exec("UPDATE edgeNodes SET dnsRoutes=? WHERE id=? LIMIT 1", string(mJSON), nodeId)
if err != nil {
return err
}
}
}
return nil
}
// v0.5.3
func upgradeV0_5_3(db *dbs.DB) error {
// 升级集群服务配置
{
type oldGlobalConfig struct {
// HTTP & HTTPS相关配置
HTTPAll struct {
MatchDomainStrictly bool `yaml:"matchDomainStrictly" json:"matchDomainStrictly"` // 是否严格匹配域名
AllowMismatchDomains []string `yaml:"allowMismatchDomains" json:"allowMismatchDomains"` // 允许的不匹配的域名
DefaultDomain string `yaml:"defaultDomain" json:"defaultDomain"` // 默认的域名
DomainMismatchAction *serverconfigs.DomainMismatchAction `yaml:"domainMismatchAction" json:"domainMismatchAction"` // 不匹配时采取的动作
} `yaml:"httpAll" json:"httpAll"`
}
value, err := db.FindCol(0, "SELECT value FROM edgeSysSettings WHERE code='serverGlobalConfig'")
if err != nil {
return err
}
if value != nil {
var valueJSON = []byte(types.String(value))
var oldConfig = &oldGlobalConfig{}
err = json.Unmarshal(valueJSON, oldConfig)
if err == nil {
var newConfig = &serverconfigs.GlobalServerConfig{}
newConfig.HTTPAll.MatchDomainStrictly = oldConfig.HTTPAll.MatchDomainStrictly
newConfig.HTTPAll.AllowMismatchDomains = oldConfig.HTTPAll.AllowMismatchDomains
newConfig.HTTPAll.DefaultDomain = oldConfig.HTTPAll.DefaultDomain
if oldConfig.HTTPAll.DomainMismatchAction != nil {
newConfig.HTTPAll.DomainMismatchAction = oldConfig.HTTPAll.DomainMismatchAction
}
newConfig.HTTPAll.AllowNodeIP = true
newConfig.Log.RecordServerError = false
newConfigJSON, err := json.Marshal(newConfig)
if err == nil {
_, err = db.Exec("UPDATE edgeNodeClusters SET globalServerConfig=?", newConfigJSON)
if err != nil {
return err
}
}
}
}
}
return nil
}