增加服务之间拷贝配置的API(开源版本只有定义,没有完全实现)

This commit is contained in:
GoEdgeLab
2023-04-09 16:01:23 +08:00
parent a13320810d
commit 25c17aed2c
16 changed files with 434 additions and 238 deletions

View File

@@ -96,6 +96,27 @@ func (this *HTTPAuthPolicyDAO) UpdateHTTPAuthPolicy(tx *dbs.Tx, policyId int64,
return this.NotifyUpdate(tx, policyId) return this.NotifyUpdate(tx, policyId)
} }
// CloneAuthPolicy 复制策略
func (this *HTTPAuthPolicyDAO) CloneAuthPolicy(tx *dbs.Tx, fromPolicyId int64) (int64, error) {
policyOne, err := this.Query(tx).
Pk(fromPolicyId).
Find()
if err != nil || policyOne == nil {
return 0, err
}
var policy = policyOne.(*HTTPAuthPolicy)
var op = NewHTTPAuthPolicyOperator()
op.IsOn = policy.IsOn
op.Name = policy.Name
op.Type = policy.Type
if len(policy.Params) > 0 {
op.Params = policy.Params
}
op.State = policy.State
return this.SaveInt64(tx, op)
}
// ComposePolicyConfig 组合配置 // ComposePolicyConfig 组合配置
func (this *HTTPAuthPolicyDAO) ComposePolicyConfig(tx *dbs.Tx, policyId int64, cacheMap *utils.CacheMap) (*serverconfigs.HTTPAuthPolicy, error) { func (this *HTTPAuthPolicyDAO) ComposePolicyConfig(tx *dbs.Tx, policyId int64, cacheMap *utils.CacheMap) (*serverconfigs.HTTPAuthPolicy, error) {
if cacheMap == nil { if cacheMap == nil {

View File

@@ -133,6 +133,32 @@ func (this *HTTPPageDAO) UpdatePage(tx *dbs.Tx, pageId int64, statusList []strin
return this.NotifyUpdate(tx, pageId) return this.NotifyUpdate(tx, pageId)
} }
// ClonePage 克隆页面
func (this *HTTPPageDAO) ClonePage(tx *dbs.Tx, fromPageId int64) (newPageId int64, err error) {
if fromPageId <= 0 {
return
}
pageOne, err := this.Query(tx).
Pk(fromPageId).
Find()
if err != nil || pageOne == nil {
return 0, err
}
var page = pageOne.(*HTTPPage)
var op = NewHTTPPageOperator()
op.IsOn = page.IsOn
if len(page.StatusList) > 0 {
op.StatusList = page.StatusList
}
op.Url = page.Url
op.NewStatus = page.NewStatus
op.Body = page.Body
op.BodyType = page.BodyType
op.State = page.State
return this.SaveInt64(tx, op)
}
// ComposePageConfig 组合配置 // ComposePageConfig 组合配置
func (this *HTTPPageDAO) ComposePageConfig(tx *dbs.Tx, pageId int64, cacheMap *utils.CacheMap) (*serverconfigs.HTTPPageConfig, error) { func (this *HTTPPageDAO) ComposePageConfig(tx *dbs.Tx, pageId int64, cacheMap *utils.CacheMap) (*serverconfigs.HTTPPageConfig, error) {
if cacheMap == nil { if cacheMap == nil {

View File

@@ -732,32 +732,6 @@ func (this *HTTPWebDAO) UpdateWebStat(tx *dbs.Tx, webId int64, statJSON []byte)
return this.NotifyUpdate(tx, webId) return this.NotifyUpdate(tx, webId)
} }
// CopyWebStats 拷贝统计配置
func (this *HTTPWebDAO) CopyWebStats(tx *dbs.Tx, fromWebId int64, toWebIds []int64) error {
if fromWebId <= 0 || len(toWebIds) == 0 {
return nil
}
statJSON, err := this.Query(tx).
Pk(fromWebId).
Result("stat").
FindJSONCol()
if err != nil {
return err
}
// 暂时不处理
if len(statJSON) == 0 {
return nil
}
return this.Query(tx).
Pk(toWebIds).
Reuse(false).
Set("stat", statJSON).
UpdateQuickly()
}
// UpdateWebCache 更改缓存配置 // UpdateWebCache 更改缓存配置
func (this *HTTPWebDAO) UpdateWebCache(tx *dbs.Tx, webId int64, cacheJSON []byte) error { func (this *HTTPWebDAO) UpdateWebCache(tx *dbs.Tx, webId int64, cacheJSON []byte) error {
if webId <= 0 { if webId <= 0 {
@@ -1182,8 +1156,6 @@ func (this *HTTPWebDAO) UpdateWebHostRedirects(tx *dbs.Tx, webId int64, hostRedi
return this.NotifyUpdate(tx, webId) return this.NotifyUpdate(tx, webId)
} }
// 通用设置
// FindWebHostRedirects 查找主机跳转 // FindWebHostRedirects 查找主机跳转
func (this *HTTPWebDAO) FindWebHostRedirects(tx *dbs.Tx, webId int64) ([]byte, error) { func (this *HTTPWebDAO) FindWebHostRedirects(tx *dbs.Tx, webId int64) ([]byte, error) {
col, err := this.Query(tx). col, err := this.Query(tx).

View File

@@ -159,6 +159,31 @@ func (this *HTTPWebsocketDAO) UpdateWebsocket(tx *dbs.Tx, websocketId int64, han
return this.NotifyUpdate(tx, websocketId) return this.NotifyUpdate(tx, websocketId)
} }
// CloneWebsocket 复制配置
func (this *HTTPWebsocketDAO) CloneWebsocket(tx *dbs.Tx, fromWebsocketId int64) (newWebsocketId int64, err error) {
websocketOne, err := this.Query(tx).
Pk(fromWebsocketId).
Find()
if err != nil || websocketOne == nil {
return 0, err
}
var websocket = websocketOne.(*HTTPWebsocket)
var op = NewHTTPWebsocketOperator()
op.State = websocket.State
op.IsOn = websocket.IsOn
if len(websocket.HandshakeTimeout) > 0 {
op.HandshakeTimeout = websocket.HandshakeTimeout
}
op.AllowAllOrigins = websocket.AllowAllOrigins
if len(websocket.AllowedOrigins) > 0 {
op.AllowedOrigins = websocket.AllowedOrigins
}
op.RequestSameOrigin = websocket.RequestSameOrigin
op.RequestOrigin = websocket.RequestOrigin
return this.SaveInt64(tx, op)
}
// NotifyUpdate 通知更新 // NotifyUpdate 通知更新
func (this *HTTPWebsocketDAO) NotifyUpdate(tx *dbs.Tx, websocketId int64) error { func (this *HTTPWebsocketDAO) NotifyUpdate(tx *dbs.Tx, websocketId int64) error {
webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithWebsocketId(tx, websocketId) webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithWebsocketId(tx, websocketId)

View File

@@ -2,7 +2,7 @@ package models
import "github.com/iwind/TeaGo/dbs" import "github.com/iwind/TeaGo/dbs"
// Websocket设置 // HTTPWebsocket Websocket设置
type HTTPWebsocket struct { type HTTPWebsocket struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
AdminId uint32 `field:"adminId"` // 管理员ID AdminId uint32 `field:"adminId"` // 管理员ID
@@ -15,20 +15,22 @@ type HTTPWebsocket struct {
AllowedOrigins dbs.JSON `field:"allowedOrigins"` // 支持的源域名列表 AllowedOrigins dbs.JSON `field:"allowedOrigins"` // 支持的源域名列表
RequestSameOrigin uint8 `field:"requestSameOrigin"` // 是否请求一样的Origin RequestSameOrigin uint8 `field:"requestSameOrigin"` // 是否请求一样的Origin
RequestOrigin string `field:"requestOrigin"` // 请求Origin RequestOrigin string `field:"requestOrigin"` // 请求Origin
WebId uint64 `field:"webId"` // Web
} }
type HTTPWebsocketOperator struct { type HTTPWebsocketOperator struct {
Id interface{} // ID Id any // ID
AdminId interface{} // 管理员ID AdminId any // 管理员ID
UserId interface{} // 用户ID UserId any // 用户ID
CreatedAt interface{} // 创建时间 CreatedAt any // 创建时间
State interface{} // 状态 State any // 状态
IsOn interface{} // 是否启用 IsOn any // 是否启用
HandshakeTimeout interface{} // 握手超时时间 HandshakeTimeout any // 握手超时时间
AllowAllOrigins interface{} // 是否支持所有源 AllowAllOrigins any // 是否支持所有源
AllowedOrigins interface{} // 支持的源域名列表 AllowedOrigins any // 支持的源域名列表
RequestSameOrigin interface{} // 是否请求一样的Origin RequestSameOrigin any // 是否请求一样的Origin
RequestOrigin interface{} // 请求Origin RequestOrigin any // 请求Origin
WebId any // Web
} }
func NewHTTPWebsocketOperator() *HTTPWebsocketOperator { func NewHTTPWebsocketOperator() *HTTPWebsocketOperator {

View File

@@ -275,6 +275,64 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx,
return this.NotifyUpdate(tx, originId) return this.NotifyUpdate(tx, originId)
} }
// CloneOrigin 复制源站
func (this *OriginDAO) CloneOrigin(tx *dbs.Tx, fromOriginId int64) (newOriginId int64, err error) {
if fromOriginId <= 0 {
return
}
originOne, err := this.Find(tx, fromOriginId)
if err != nil || originOne == nil {
return
}
var origin = originOne.(*Origin)
var op = NewOriginOperator()
op.IsOn = origin.IsOn
op.Name = origin.Name
op.Version = origin.Version
if IsNotNull(origin.Addr) {
op.Addr = origin.Addr
}
op.Description = origin.Description
op.Code = origin.Code
op.Weight = origin.Weight
if IsNotNull(origin.ConnTimeout) {
op.ConnTimeout = origin.ConnTimeout
}
if IsNotNull(origin.ReadTimeout) {
op.ReadTimeout = origin.ReadTimeout
}
if IsNotNull(origin.IdleTimeout) {
op.IdleTimeout = origin.IdleTimeout
}
op.MaxFails = origin.MaxFails
op.MaxConns = origin.MaxConns
op.MaxIdleConns = origin.MaxIdleConns
op.HttpRequestURI = origin.HttpRequestURI
if IsNotNull(origin.HttpRequestHeader) {
op.HttpRequestHeader = origin.HttpRequestHeader
}
if IsNotNull(origin.HttpResponseHeader) {
op.HttpResponseHeader = origin.HttpResponseHeader
}
op.Host = origin.Host
if IsNotNull(origin.HealthCheck) {
op.HealthCheck = origin.HealthCheck
}
if IsNotNull(origin.Cert) {
// TODO 需要Clone证书
op.Cert = origin.Cert
}
if IsNotNull(origin.Ftp) {
op.Ftp = origin.Ftp
}
if IsNotNull(origin.Domains) {
op.Domains = origin.Domains
}
op.FollowPort = origin.FollowPort
op.State = origin.State
return this.SaveInt64(tx, op)
}
// ComposeOriginConfig 将源站信息转换为配置 // ComposeOriginConfig 将源站信息转换为配置
func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, dataMap *shared.DataMap, cacheMap *utils.CacheMap) (*serverconfigs.OriginConfig, error) { func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, dataMap *shared.DataMap, cacheMap *utils.CacheMap) (*serverconfigs.OriginConfig, error) {
if cacheMap == nil { if cacheMap == nil {

View File

@@ -243,6 +243,115 @@ func (this *ReverseProxyDAO) CreateReverseProxy(tx *dbs.Tx, adminId int64, userI
return types.Int64(op.Id), nil return types.Int64(op.Id), nil
} }
// CloneReverseProxy 复制反向代理
func (this *ReverseProxyDAO) CloneReverseProxy(tx *dbs.Tx, fromReverseProxyId int64) (newReverseProxyId int64, err error) {
if fromReverseProxyId <= 0 {
return
}
reverseProxyOne, err := this.Query(tx).
Pk(fromReverseProxyId).
State(ReverseProxyStateEnabled).
Find()
if err != nil || reverseProxyOne == nil {
return 0, err
}
var reverseProxy = reverseProxyOne.(*ReverseProxy)
var op = NewReverseProxyOperator()
op.TemplateId = reverseProxy.TemplateId
op.IsOn = reverseProxy.IsOn
if IsNotNull(reverseProxy.Scheduling) {
op.Scheduling = reverseProxy.Scheduling
}
if IsNotNull(reverseProxy.PrimaryOrigins) {
var originRefs = []*serverconfigs.OriginRef{}
err = json.Unmarshal(reverseProxy.PrimaryOrigins, &originRefs)
if err != nil {
return 0, err
}
var newRefs = []*serverconfigs.OriginRef{}
for _, originRef := range originRefs {
if originRef.OriginId > 0 {
newOriginId, err := SharedOriginDAO.CloneOrigin(tx, originRef.OriginId)
if err != nil {
return 0, err
}
if newOriginId > 0 {
newRef, err := utils.JSONClone[*serverconfigs.OriginRef](originRef)
if err != nil {
return 0, err
}
newRef.OriginId = newOriginId
newRefs = append(newRefs, newRef)
}
}
}
newRefsJSON, err := json.Marshal(newRefs)
if err != nil {
return 0, err
}
op.PrimaryOrigins = newRefsJSON
}
if IsNotNull(reverseProxy.BackupOrigins) {
var originRefs = []*serverconfigs.OriginRef{}
err = json.Unmarshal(reverseProxy.BackupOrigins, &originRefs)
if err != nil {
return 0, err
}
var newRefs = []*serverconfigs.OriginRef{}
for _, originRef := range originRefs {
if originRef.OriginId > 0 {
newOriginId, err := SharedOriginDAO.CloneOrigin(tx, originRef.OriginId)
if err != nil {
return 0, err
}
if newOriginId > 0 {
newRef, err := utils.JSONClone[*serverconfigs.OriginRef](originRef)
if err != nil {
return 0, err
}
newRef.OriginId = newOriginId
newRefs = append(newRefs, newRef)
}
}
}
newRefsJSON, err := json.Marshal(newRefs)
if err != nil {
return 0, err
}
op.BackupOrigins = newRefsJSON
}
op.StripPrefix = reverseProxy.StripPrefix
op.RequestHostType = reverseProxy.RequestHostType
op.RequestHost = reverseProxy.RequestHost
op.RequestHostExcludingPort = reverseProxy.RequestHostExcludingPort
op.RequestURI = reverseProxy.RequestURI
op.AutoFlush = reverseProxy.AutoFlush
if IsNotNull(reverseProxy.AddHeaders) {
// TODO 复制Header
op.AddHeaders = reverseProxy.AddHeaders
}
op.State = reverseProxy.State
if IsNotNull(reverseProxy.ConnTimeout) {
op.ConnTimeout = reverseProxy.ConnTimeout
}
if IsNotNull(reverseProxy.ReadTimeout) {
op.ReadTimeout = reverseProxy.ReadTimeout
}
if IsNotNull(reverseProxy.IdleTimeout) {
op.IdleTimeout = reverseProxy.IdleTimeout
}
op.MaxConns = reverseProxy.MaxConns
op.MaxIdleConns = reverseProxy.MaxIdleConns
if IsNotNull(reverseProxy.ProxyProtocol) {
op.ProxyProtocol = reverseProxy.ProxyProtocol
}
op.FollowRedirects = reverseProxy.FollowRedirects
return this.SaveInt64(tx, op)
}
// UpdateReverseProxyScheduling 修改反向代理调度算法 // UpdateReverseProxyScheduling 修改反向代理调度算法
func (this *ReverseProxyDAO) UpdateReverseProxyScheduling(tx *dbs.Tx, reverseProxyId int64, schedulingJSON []byte) error { func (this *ReverseProxyDAO) UpdateReverseProxyScheduling(tx *dbs.Tx, reverseProxyId int64, schedulingJSON []byte) error {
if reverseProxyId <= 0 { if reverseProxyId <= 0 {

View File

@@ -126,7 +126,7 @@ func (this *ServerDAO) FindEnabledServerBasic(tx *dbs.Tx, serverId int64) (*Serv
result, err := this.Query(tx). result, err := this.Query(tx).
Pk(serverId). Pk(serverId).
State(ServerStateEnabled). State(ServerStateEnabled).
Result("id", "name", "description", "isOn", "type", "clusterId", "userId"). Result("id", "name", "description", "isOn", "type", "clusterId", "userId", "groupIds").
Find() Find()
if result == nil { if result == nil {
return nil, err return nil, err

View File

@@ -1,173 +0,0 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package models
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/dbs"
)
// 服务基本信息
type clusterServerList struct {
clusterId int64
serverIds []int64
}
// CopyServerConfigToServers 拷贝服务配置到一组服务
func (this *ServerDAO) CopyServerConfigToServers(tx *dbs.Tx, fromServerId int64, toServerIds []int64, configCode serverconfigs.ConfigCode) error {
if fromServerId <= 0 {
return nil
}
if len(toServerIds) == 0 {
return nil
}
webId, err := SharedServerDAO.FindServerWebId(tx, fromServerId)
if err != nil {
return err
}
clusterServers, toWebIds, err := this.findServerClusterIdsAndWebIds(tx, toServerIds)
if err != nil {
return err
}
if len(clusterServers) == 0 {
return nil
}
switch configCode {
case serverconfigs.ConfigCodeStat: // 统计
if webId <= 0 {
return nil
}
err = SharedHTTPWebDAO.CopyWebStats(tx, webId, toWebIds)
if err != nil {
return err
}
}
// 通知更新
for _, serverList := range clusterServers {
err = SharedUpdatingServerListDAO.CreateList(tx, serverList.clusterId, serverList.serverIds)
if err != nil {
return err
}
err = SharedNodeTaskDAO.CreateClusterTask(tx, nodeconfigs.NodeRoleNode, serverList.clusterId, 0, 0, NodeTaskTypeUpdatingServers)
if err != nil {
return err
}
}
return nil
}
// 查找一组服务的集群和WebId信息
func (this *ServerDAO) findServerClusterIdsAndWebIds(tx *dbs.Tx, serverIds []int64) (clusterServers []*clusterServerList, webIds []int64, err error) {
if len(serverIds) == 0 {
return
}
ones, err := this.Query(tx).
Result("id", "webId", "clusterId").
Pk(serverIds).
Reuse(false).
FindAll()
if err != nil {
return nil, nil, err
}
var clusterMap = map[int64]*clusterServerList{} // clusterId => servers
for _, one := range ones {
var server = one.(*Server)
var clusterId = int64(server.ClusterId)
if clusterId <= 0 {
continue
}
serverList, ok := clusterMap[clusterId]
if ok {
serverList.serverIds = append(serverList.serverIds, int64(server.Id))
} else {
clusterMap[clusterId] = &clusterServerList{
clusterId: clusterId,
serverIds: []int64{int64(server.Id)},
}
}
var webId = int64(server.WebId)
if webId > 0 {
webIds = append(webIds, webId)
}
}
for _, serverList := range clusterMap {
clusterServers = append(clusterServers, serverList)
}
return
}
// CopyServerConfigToGroups 拷贝服务配置到分组
func (this *ServerDAO) CopyServerConfigToGroups(tx *dbs.Tx, fromServerId int64, groupIds []int64, configCode string) error {
if len(groupIds) == 0 {
return nil
}
var serverIds = []int64{}
for _, groupId := range groupIds {
ones, err := this.Query(tx).
ResultPk().
State(ServerStateEnabled).
Where("JSON_CONTAINS(groupIds, :groupId)").
Param("groupId", groupId).
FindAll()
if err != nil {
return err
}
for _, one := range ones {
serverIds = append(serverIds, int64(one.(*Server).Id))
}
}
return this.CopyServerConfigToServers(tx, fromServerId, serverIds, configCode)
}
// CopyServerConfigToCluster 拷贝服务配置到集群
func (this *ServerDAO) CopyServerConfigToCluster(tx *dbs.Tx, fromServerId int64, clusterId int64, configCode string) error {
ones, err := this.Query(tx).
ResultPk().
State(ServerStateEnabled).
Attr("clusterId", clusterId).
UseIndex("clusterId").
FindAll()
if err != nil {
return err
}
var serverIds = []int64{}
for _, one := range ones {
serverIds = append(serverIds, int64(one.(*Server).Id))
}
return this.CopyServerConfigToServers(tx, fromServerId, serverIds, configCode)
}
// CopyServerConfigToUser 拷贝服务配置到用户
func (this *ServerDAO) CopyServerConfigToUser(tx *dbs.Tx, fromServerId int64, userId int64, configCode string) error {
ones, err := this.Query(tx).
ResultPk().
State(ServerStateEnabled).
Attr("userId", userId).
UseIndex("userId").
FindAll()
if err != nil {
return err
}
var serverIds = []int64{}
for _, one := range ones {
serverIds = append(serverIds, int64(one.(*Server).Id))
}
return this.CopyServerConfigToServers(tx, fromServerId, serverIds, configCode)
}

View File

@@ -0,0 +1,35 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package models
import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/dbs"
)
// CopyServerConfigToServers 拷贝服务配置到一组服务
func (this *ServerDAO) CopyServerConfigToServers(tx *dbs.Tx, fromServerId int64, toServerIds []int64, configCode serverconfigs.ConfigCode) error {
return errors.New("not implemented")
}
// CopyServerConfigToGroups 拷贝服务配置到分组
func (this *ServerDAO) CopyServerConfigToGroups(tx *dbs.Tx, fromServerId int64, groupIds []int64, configCode string) error {
return errors.New("not implemented")
}
// CopyServerConfigToCluster 拷贝服务配置到集群
func (this *ServerDAO) CopyServerConfigToCluster(tx *dbs.Tx, fromServerId int64, clusterId int64, configCode string) error {
return errors.New("not implemented")
}
// CopyServerConfigToUser 拷贝服务配置到用户
func (this *ServerDAO) CopyServerConfigToUser(tx *dbs.Tx, fromServerId int64, userId int64, configCode string) error {
return errors.New("not implemented")
}
// CopyServerUAMConfigs 复制UAM设置
func (this *ServerDAO) CopyServerUAMConfigs(tx *dbs.Tx, fromServerId int64, toServerIds []int64) error {
return errors.New("not implemented")
}

View File

@@ -1,22 +0,0 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package models_test
import (
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/dbs"
"testing"
)
func TestServerDAO_CopyServerConfigToServers(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
var dao = models.NewServerDAO()
err := dao.CopyServerConfigToServers(tx, 10170, []int64{23, 10171}, serverconfigs.ConfigCodeStat)
if err != nil {
t.Fatal(err)
}
}

View File

@@ -424,6 +424,18 @@ func (this *ServerGroupDAO) ExistsGroup(tx *dbs.Tx, groupId int64) (bool, error)
Exist() Exist()
} }
// FindGroupUserId 读取分组所属用户
func (this *ServerGroupDAO) FindGroupUserId(tx *dbs.Tx, groupId int64) (userId int64, err error) {
if groupId <= 0 {
return
}
return this.Query(tx).
Pk(groupId).
State(ServerGroupStateEnabled).
Result("userId").
FindInt64Col(0)
}
// NotifyUpdate 通知更新 // NotifyUpdate 通知更新
func (this *ServerGroupDAO) NotifyUpdate(tx *dbs.Tx, groupId int64) error { func (this *ServerGroupDAO) NotifyUpdate(tx *dbs.Tx, groupId int64) error {
serverIds, err := SharedServerDAO.FindAllEnabledServerIdsWithGroupId(tx, groupId) serverIds, err := SharedServerDAO.FindAllEnabledServerIdsWithGroupId(tx, groupId)

View File

@@ -1299,6 +1299,16 @@ func (this *ServerService) CountAllEnabledServersWithServerGroupId(ctx context.C
} }
var tx = this.NullTx() var tx = this.NullTx()
if userId <= 0 {
// 指定用户ID可以加快查询速度
groupUserId, err := models.SharedServerGroupDAO.FindGroupUserId(tx, req.ServerGroupId)
if err != nil {
return nil, err
}
if groupUserId > 0 {
userId = groupUserId
}
}
count, err := models.SharedServerDAO.CountAllEnabledServersWithGroupId(tx, req.ServerGroupId, userId) count, err := models.SharedServerDAO.CountAllEnabledServersWithGroupId(tx, req.ServerGroupId, userId)
if err != nil { if err != nil {
@@ -1562,11 +1572,29 @@ func (this *ServerService) FindEnabledUserServerBasic(ctx context.Context, req *
return &pb.FindEnabledUserServerBasicResponse{Server: nil}, nil return &pb.FindEnabledUserServerBasicResponse{Server: nil}, nil
} }
// 集群
clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(tx, int64(server.ClusterId)) clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(tx, int64(server.ClusterId))
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 分组
var pbGroups = []*pb.ServerGroup{}
for _, groupId := range server.DecodeGroupIds() {
group, err := models.SharedServerGroupDAO.FindEnabledServerGroup(tx, groupId)
if err != nil {
return nil, err
}
if group == nil {
continue
}
pbGroups = append(pbGroups, &pb.ServerGroup{
Id: groupId,
Name: group.Name,
UserId: int64(group.UserId),
})
}
return &pb.FindEnabledUserServerBasicResponse{Server: &pb.Server{ return &pb.FindEnabledUserServerBasicResponse{Server: &pb.Server{
Id: int64(server.Id), Id: int64(server.Id),
Name: server.Name, Name: server.Name,
@@ -1578,6 +1606,7 @@ func (this *ServerService) FindEnabledUserServerBasic(ctx context.Context, req *
Id: int64(server.ClusterId), Id: int64(server.ClusterId),
Name: clusterName, Name: clusterName,
}, },
ServerGroups: pbGroups,
}}, nil }}, nil
} }

View File

@@ -85,21 +85,26 @@ func (this *ServerGroupService) DeleteServerGroup(ctx context.Context, req *pb.D
// FindAllEnabledServerGroups 查询所有分组 // FindAllEnabledServerGroups 查询所有分组
func (this *ServerGroupService) FindAllEnabledServerGroups(ctx context.Context, req *pb.FindAllEnabledServerGroupsRequest) (*pb.FindAllEnabledServerGroupsResponse, error) { func (this *ServerGroupService) FindAllEnabledServerGroups(ctx context.Context, req *pb.FindAllEnabledServerGroupsRequest) (*pb.FindAllEnabledServerGroupsResponse, error) {
// 校验请求 // 校验请求
_, userId, err := this.ValidateAdminAndUser(ctx, true) adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if adminId > 0 {
userId = req.UserId
}
var tx = this.NullTx() var tx = this.NullTx()
groups, err := models.SharedServerGroupDAO.FindAllEnabledGroups(tx, userId) groups, err := models.SharedServerGroupDAO.FindAllEnabledGroups(tx, userId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
result := []*pb.ServerGroup{} var result = []*pb.ServerGroup{}
for _, group := range groups { for _, group := range groups {
result = append(result, &pb.ServerGroup{ result = append(result, &pb.ServerGroup{
Id: int64(group.Id), Id: int64(group.Id),
IsOn: group.IsOn,
Name: group.Name, Name: group.Name,
}) })
} }
@@ -153,6 +158,7 @@ func (this *ServerGroupService) FindEnabledServerGroup(ctx context.Context, req
return &pb.FindEnabledServerGroupResponse{ return &pb.FindEnabledServerGroupResponse{
ServerGroup: &pb.ServerGroup{ ServerGroup: &pb.ServerGroup{
Id: int64(group.Id), Id: int64(group.Id),
IsOn: group.IsOn,
Name: group.Name, Name: group.Name,
}, },
}, nil }, nil

45
internal/utils/json.go Normal file
View File

@@ -0,0 +1,45 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package utils
import (
"encoding/json"
"errors"
"reflect"
)
// JSONClone 使用JSON协议克隆对象
func JSONClone[T any](ptr T) (newPtr T, err error) {
var ptrType = reflect.TypeOf(ptr)
var kind = ptrType.Kind()
if kind != reflect.Ptr && kind != reflect.Slice {
err = errors.New("JSONClone: input must be a ptr or slice")
return
}
var jsonData []byte
jsonData, err = json.Marshal(ptr)
if err != nil {
return ptr, errors.New("JSONClone: marshal failed: " + err.Error())
}
var newValue any
switch kind {
case reflect.Ptr:
newValue = reflect.New(ptrType.Elem()).Interface()
case reflect.Slice:
newValue = reflect.New(reflect.SliceOf(ptrType.Elem())).Interface()
default:
return ptr, errors.New("JSONClone: unknown data type")
}
err = json.Unmarshal(jsonData, newValue)
if err != nil {
err = errors.New("JSONClone: unmarshal failed: " + err.Error())
return
}
if kind == reflect.Slice {
newValue = reflect.Indirect(reflect.ValueOf(newValue)).Interface()
}
return newValue.(T), nil
}

View File

@@ -0,0 +1,51 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package utils_test
import (
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestJSONClone(t *testing.T) {
type user struct {
Name string
Age int
}
var u = &user{
Name: "Jack",
Age: 20,
}
newU, err := utils.JSONClone[*user](u)
if err != nil {
t.Fatal(err)
}
t.Logf("%#v", newU)
}
func TestJSONClone_Slice(t *testing.T) {
type user struct {
Name string
Age int
}
var u = []*user{
{
Name: "Jack",
Age: 20,
},
{
Name: "Lily",
Age: 18,
},
}
newU, err := utils.JSONClone[[]*user](u)
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(newU, t)
}