Files
EdgeAPI/internal/db/models/reverse_proxy_dao.go

641 lines
17 KiB
Go
Raw Normal View History

2020-09-15 14:44:11 +08:00
package models
import (
"encoding/json"
"errors"
"fmt"
2021-11-11 14:16:42 +08:00
"github.com/TeaOSLab/EdgeAPI/internal/utils"
2020-09-15 14:44:11 +08:00
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
2020-09-15 14:44:11 +08:00
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
2020-09-15 14:44:11 +08:00
"github.com/iwind/TeaGo/types"
)
const (
ReverseProxyStateEnabled = 1 // 已启用
ReverseProxyStateDisabled = 0 // 已禁用
)
type ReverseProxyDAO dbs.DAO
func NewReverseProxyDAO() *ReverseProxyDAO {
return dbs.NewDAO(&ReverseProxyDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgeReverseProxies",
Model: new(ReverseProxy),
PkName: "id",
},
}).(*ReverseProxyDAO)
}
2020-10-13 20:05:13 +08:00
var SharedReverseProxyDAO *ReverseProxyDAO
func init() {
dbs.OnReady(func() {
SharedReverseProxyDAO = NewReverseProxyDAO()
})
}
2020-09-15 14:44:11 +08:00
// Init 初始化
2020-09-26 08:06:40 +08:00
func (this *ReverseProxyDAO) Init() {
_ = this.DAOObject.Init()
2020-09-26 08:06:40 +08:00
}
// EnableReverseProxy 启用条目
func (this *ReverseProxyDAO) EnableReverseProxy(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx).
2020-09-15 14:44:11 +08:00
Pk(id).
Set("state", ReverseProxyStateEnabled).
Update()
2020-09-26 08:06:40 +08:00
if err != nil {
return err
}
return this.NotifyUpdate(tx, id)
2020-09-15 14:44:11 +08:00
}
// DisableReverseProxy 禁用条目
func (this *ReverseProxyDAO) DisableReverseProxy(tx *dbs.Tx, id int64) error {
_, err := this.Query(tx).
2020-09-15 14:44:11 +08:00
Pk(id).
Set("state", ReverseProxyStateDisabled).
Update()
2020-09-26 08:06:40 +08:00
if err != nil {
return err
}
return this.NotifyUpdate(tx, id)
2020-09-15 14:44:11 +08:00
}
// FindEnabledReverseProxy 查找启用中的条目
func (this *ReverseProxyDAO) FindEnabledReverseProxy(tx *dbs.Tx, id int64) (*ReverseProxy, error) {
result, err := this.Query(tx).
2020-09-15 14:44:11 +08:00
Pk(id).
Attr("state", ReverseProxyStateEnabled).
Find()
if result == nil {
return nil, err
}
return result.(*ReverseProxy), err
}
// ComposeReverseProxyConfig 根据ID组合配置
func (this *ReverseProxyDAO) ComposeReverseProxyConfig(tx *dbs.Tx, reverseProxyId int64, dataMap *shared.DataMap, cacheMap *utils.CacheMap) (*serverconfigs.ReverseProxyConfig, error) {
2021-08-22 11:35:33 +08:00
if cacheMap == nil {
2021-11-11 14:16:42 +08:00
cacheMap = utils.NewCacheMap()
2021-08-22 11:35:33 +08:00
}
var cacheKey = this.Table + ":config:" + types.String(reverseProxyId)
2021-11-11 14:16:42 +08:00
var cache, _ = cacheMap.Get(cacheKey)
2021-08-22 11:35:33 +08:00
if cache != nil {
return cache.(*serverconfigs.ReverseProxyConfig), nil
}
reverseProxy, err := this.FindEnabledReverseProxy(tx, reverseProxyId)
2020-09-15 14:44:11 +08:00
if err != nil {
return nil, err
}
if reverseProxy == nil {
return nil, nil
}
var config = serverconfigs.NewReverseProxyConfig()
2020-09-21 11:37:17 +08:00
config.Id = int64(reverseProxy.Id)
2022-03-22 21:45:07 +08:00
config.IsOn = reverseProxy.IsOn
config.RequestHostType = types.Int8(reverseProxy.RequestHostType)
config.RequestHost = reverseProxy.RequestHost
config.RequestHostExcludingPort = reverseProxy.RequestHostExcludingPort
config.RequestURI = reverseProxy.RequestURI
config.StripPrefix = reverseProxy.StripPrefix
2020-09-27 18:41:15 +08:00
config.AutoFlush = reverseProxy.AutoFlush == 1
2022-03-14 15:42:45 +08:00
config.FollowRedirects = reverseProxy.FollowRedirects == 1
config.Retry50X = reverseProxy.Retry50X
config.Retry40X = reverseProxy.Retry40X
2020-09-15 14:44:11 +08:00
var schedulingConfig = &serverconfigs.SchedulingConfig{}
if IsNotNull(reverseProxy.Scheduling) {
err = json.Unmarshal(reverseProxy.Scheduling, schedulingConfig)
2020-09-15 14:44:11 +08:00
if err != nil {
return nil, err
}
config.Scheduling = schedulingConfig
}
if IsNotNull(reverseProxy.PrimaryOrigins) {
var originRefs = []*serverconfigs.OriginRef{}
err = json.Unmarshal(reverseProxy.PrimaryOrigins, &originRefs)
2020-09-15 14:44:11 +08:00
if err != nil {
return nil, err
}
2020-09-21 20:21:26 +08:00
for _, ref := range originRefs {
originConfig, err := SharedOriginDAO.ComposeOriginConfig(tx, ref.OriginId, dataMap, cacheMap)
2020-09-15 14:44:11 +08:00
if err != nil {
return nil, err
}
2020-09-21 20:21:26 +08:00
if originConfig != nil {
config.AddPrimaryOrigin(originConfig)
2020-09-15 14:44:11 +08:00
}
}
}
if IsNotNull(reverseProxy.BackupOrigins) {
var originRefs = []*serverconfigs.OriginRef{}
err = json.Unmarshal(reverseProxy.BackupOrigins, &originRefs)
2020-09-15 14:44:11 +08:00
if err != nil {
return nil, err
}
for _, ref := range originRefs {
originConfig, err := SharedOriginDAO.ComposeOriginConfig(tx, ref.OriginId, dataMap, cacheMap)
2020-09-15 14:44:11 +08:00
if err != nil {
return nil, err
}
2020-09-21 20:21:26 +08:00
if originConfig != nil {
config.AddBackupOrigin(originConfig)
2020-09-15 14:44:11 +08:00
}
}
}
// add headers
if IsNotNull(reverseProxy.AddHeaders) {
var addHeaders = []string{}
err = json.Unmarshal(reverseProxy.AddHeaders, &addHeaders)
if err != nil {
return nil, err
}
config.AddHeaders = addHeaders
}
// 源站相关默认设置
config.MaxConns = int(reverseProxy.MaxConns)
config.MaxIdleConns = int(reverseProxy.MaxIdleConns)
if IsNotNull(reverseProxy.ConnTimeout) {
var connTimeout = &shared.TimeDuration{}
err = json.Unmarshal(reverseProxy.ConnTimeout, &connTimeout)
if err != nil {
return nil, err
}
config.ConnTimeout = connTimeout
}
if IsNotNull(reverseProxy.ReadTimeout) {
var readTimeout = &shared.TimeDuration{}
err = json.Unmarshal(reverseProxy.ReadTimeout, &readTimeout)
if err != nil {
return nil, err
}
config.ReadTimeout = readTimeout
}
if IsNotNull(reverseProxy.IdleTimeout) {
var idleTimeout = &shared.TimeDuration{}
err = json.Unmarshal(reverseProxy.IdleTimeout, &idleTimeout)
if err != nil {
return nil, err
}
config.IdleTimeout = idleTimeout
}
2021-10-12 20:18:35 +08:00
// PROXY Protocol
if IsNotNull(reverseProxy.ProxyProtocol) {
var proxyProtocolConfig = &serverconfigs.ProxyProtocolConfig{}
err = json.Unmarshal(reverseProxy.ProxyProtocol, proxyProtocolConfig)
2021-10-12 20:18:35 +08:00
if err != nil {
return nil, err
}
config.ProxyProtocol = proxyProtocolConfig
}
2021-11-11 14:16:42 +08:00
if cacheMap != nil {
cacheMap.Put(cacheKey, config)
}
2021-08-22 11:35:33 +08:00
2020-09-15 14:44:11 +08:00
return config, nil
}
// CreateReverseProxy 创建反向代理
func (this *ReverseProxyDAO) CreateReverseProxy(tx *dbs.Tx, adminId int64, userId int64, schedulingJSON []byte, primaryOriginRefsJSON []byte, backupOriginRefsJSON []byte) (int64, error) {
// decode origins
var primaryOriginRefs []*serverconfigs.OriginRef
if len(primaryOriginRefsJSON) > 0 {
err := json.Unmarshal(primaryOriginRefsJSON, &primaryOriginRefs)
if err != nil {
return 0, fmt.Errorf("decode 'primaryOriginRefs' failed: " + err.Error())
}
}
var backupOriginRefs []*serverconfigs.OriginRef
if len(backupOriginRefsJSON) > 0 {
err := json.Unmarshal(backupOriginRefsJSON, &backupOriginRefs)
if err != nil {
return 0, fmt.Errorf("decode 'backupOriginRefs' failed: " + err.Error())
}
}
var op = NewReverseProxyOperator()
2020-09-26 19:54:15 +08:00
op.IsOn = true
2020-09-15 14:44:11 +08:00
op.State = ReverseProxyStateEnabled
2020-12-18 21:18:53 +08:00
op.AdminId = adminId
op.UserId = userId
op.RequestHostType = serverconfigs.RequestHostTypeProxyServer
op.Retry50X = false
op.Retry40X = false
defaultHeaders := []string{"X-Real-IP", "X-Forwarded-For", "X-Forwarded-By", "X-Forwarded-Host", "X-Forwarded-Proto"}
defaultHeadersJSON, err := json.Marshal(defaultHeaders)
if err != nil {
return 0, err
}
op.AddHeaders = defaultHeadersJSON
2020-12-18 21:18:53 +08:00
if IsNotNull(schedulingJSON) {
2020-09-15 14:44:11 +08:00
op.Scheduling = string(schedulingJSON)
}
if IsNotNull(primaryOriginRefsJSON) {
op.PrimaryOrigins = string(primaryOriginRefsJSON)
2020-09-15 14:44:11 +08:00
}
if IsNotNull(backupOriginRefsJSON) {
op.BackupOrigins = string(backupOriginRefsJSON)
2020-09-15 14:44:11 +08:00
}
err = this.Save(tx, op)
2020-09-15 14:44:11 +08:00
if err != nil {
return 0, err
}
var reverseProxyId = types.Int64(op.Id)
// set 'reverseProxyId' of origins
for _, originRef := range primaryOriginRefs {
err = SharedOriginDAO.UpdateOriginReverseProxyId(tx, originRef.OriginId, reverseProxyId)
if err != nil {
return 0, err
}
}
for _, originRef := range backupOriginRefs {
err = SharedOriginDAO.UpdateOriginReverseProxyId(tx, originRef.OriginId, reverseProxyId)
if err != nil {
return 0, err
}
}
return reverseProxyId, nil
2020-09-15 14:44:11 +08:00
}
// CloneReverseProxy 复制反向代理
func (this *ReverseProxyDAO) CloneReverseProxy(tx *dbs.Tx, fromReverseProxyId int64) (newReverseProxyId int64, err error) {
if fromReverseProxyId <= 0 {
return
}
reverseProxyOne, err := this.Query(tx).
Pk(fromReverseProxyId).
State(ReverseProxyStateEnabled).
Find()
if err != nil || reverseProxyOne == nil {
return 0, err
}
var reverseProxy = reverseProxyOne.(*ReverseProxy)
var op = NewReverseProxyOperator()
op.TemplateId = reverseProxy.TemplateId
op.IsOn = reverseProxy.IsOn
if IsNotNull(reverseProxy.Scheduling) {
op.Scheduling = reverseProxy.Scheduling
}
if IsNotNull(reverseProxy.PrimaryOrigins) {
var originRefs = []*serverconfigs.OriginRef{}
err = json.Unmarshal(reverseProxy.PrimaryOrigins, &originRefs)
if err != nil {
return 0, err
}
var newRefs = []*serverconfigs.OriginRef{}
for _, originRef := range originRefs {
if originRef.OriginId > 0 {
newOriginId, err := SharedOriginDAO.CloneOrigin(tx, originRef.OriginId)
if err != nil {
return 0, err
}
if newOriginId > 0 {
newRef, err := utils.JSONClone[*serverconfigs.OriginRef](originRef)
if err != nil {
return 0, err
}
newRef.OriginId = newOriginId
newRefs = append(newRefs, newRef)
}
}
}
newRefsJSON, err := json.Marshal(newRefs)
if err != nil {
return 0, err
}
op.PrimaryOrigins = newRefsJSON
}
if IsNotNull(reverseProxy.BackupOrigins) {
var originRefs = []*serverconfigs.OriginRef{}
err = json.Unmarshal(reverseProxy.BackupOrigins, &originRefs)
if err != nil {
return 0, err
}
var newRefs = []*serverconfigs.OriginRef{}
for _, originRef := range originRefs {
if originRef.OriginId > 0 {
newOriginId, err := SharedOriginDAO.CloneOrigin(tx, originRef.OriginId)
if err != nil {
return 0, err
}
if newOriginId > 0 {
newRef, err := utils.JSONClone[*serverconfigs.OriginRef](originRef)
if err != nil {
return 0, err
}
newRef.OriginId = newOriginId
newRefs = append(newRefs, newRef)
}
}
}
newRefsJSON, err := json.Marshal(newRefs)
if err != nil {
return 0, err
}
op.BackupOrigins = newRefsJSON
}
op.StripPrefix = reverseProxy.StripPrefix
op.RequestHostType = reverseProxy.RequestHostType
op.RequestHost = reverseProxy.RequestHost
op.RequestHostExcludingPort = reverseProxy.RequestHostExcludingPort
op.RequestURI = reverseProxy.RequestURI
op.AutoFlush = reverseProxy.AutoFlush
if IsNotNull(reverseProxy.AddHeaders) {
// TODO 复制Header
op.AddHeaders = reverseProxy.AddHeaders
}
op.State = reverseProxy.State
if IsNotNull(reverseProxy.ConnTimeout) {
op.ConnTimeout = reverseProxy.ConnTimeout
}
if IsNotNull(reverseProxy.ReadTimeout) {
op.ReadTimeout = reverseProxy.ReadTimeout
}
if IsNotNull(reverseProxy.IdleTimeout) {
op.IdleTimeout = reverseProxy.IdleTimeout
}
op.MaxConns = reverseProxy.MaxConns
op.MaxIdleConns = reverseProxy.MaxIdleConns
if IsNotNull(reverseProxy.ProxyProtocol) {
op.ProxyProtocol = reverseProxy.ProxyProtocol
}
op.FollowRedirects = reverseProxy.FollowRedirects
return this.SaveInt64(tx, op)
}
// UpdateReverseProxyScheduling 修改反向代理调度算法
func (this *ReverseProxyDAO) UpdateReverseProxyScheduling(tx *dbs.Tx, reverseProxyId int64, schedulingJSON []byte) error {
2020-09-15 14:44:11 +08:00
if reverseProxyId <= 0 {
return errors.New("invalid reverseProxyId")
}
var op = NewReverseProxyOperator()
2020-09-15 14:44:11 +08:00
op.Id = reverseProxyId
if len(schedulingJSON) > 0 {
op.Scheduling = string(schedulingJSON)
} else {
op.Scheduling = "null"
}
err := this.Save(tx, op)
if err != nil {
return err
}
return this.NotifyUpdate(tx, reverseProxyId)
2020-09-15 14:44:11 +08:00
}
// UpdateReverseProxyPrimaryOrigins 修改主要源站
func (this *ReverseProxyDAO) UpdateReverseProxyPrimaryOrigins(tx *dbs.Tx, reverseProxyId int64, originRefsJSON []byte) error {
2020-09-15 14:44:11 +08:00
if reverseProxyId <= 0 {
return errors.New("invalid reverseProxyId")
}
// set 'reverseProxyId' of origins
if len(originRefsJSON) > 0 {
var originRefs []*serverconfigs.OriginRef
err := json.Unmarshal(originRefsJSON, &originRefs)
if err != nil {
return fmt.Errorf("decode 'originRefs' failed: " + err.Error())
}
for _, originRef := range originRefs {
err = SharedOriginDAO.UpdateOriginReverseProxyId(tx, originRef.OriginId, reverseProxyId)
if err != nil {
return err
}
}
}
var op = NewReverseProxyOperator()
2020-09-15 14:44:11 +08:00
op.Id = reverseProxyId
if len(originRefsJSON) > 0 {
op.PrimaryOrigins = originRefsJSON
2020-09-15 14:44:11 +08:00
} else {
op.PrimaryOrigins = "[]"
}
err := this.Save(tx, op)
if err != nil {
return err
}
return this.NotifyUpdate(tx, reverseProxyId)
2020-09-15 14:44:11 +08:00
}
// UpdateReverseProxyBackupOrigins 修改备用源站
func (this *ReverseProxyDAO) UpdateReverseProxyBackupOrigins(tx *dbs.Tx, reverseProxyId int64, originRefsJSON []byte) error {
2020-09-15 14:44:11 +08:00
if reverseProxyId <= 0 {
return errors.New("invalid reverseProxyId")
}
// set 'reverseProxyId' of origins
if len(originRefsJSON) > 0 {
var originRefs []*serverconfigs.OriginRef
err := json.Unmarshal(originRefsJSON, &originRefs)
if err != nil {
return fmt.Errorf("decode 'originRefs' failed: " + err.Error())
}
for _, originRef := range originRefs {
err = SharedOriginDAO.UpdateOriginReverseProxyId(tx, originRef.OriginId, reverseProxyId)
if err != nil {
return err
}
}
}
var op = NewReverseProxyOperator()
2020-09-15 14:44:11 +08:00
op.Id = reverseProxyId
if len(originRefsJSON) > 0 {
op.BackupOrigins = originRefsJSON
2020-09-15 14:44:11 +08:00
} else {
op.BackupOrigins = "[]"
}
err := this.Save(tx, op)
if err != nil {
return err
}
return this.NotifyUpdate(tx, reverseProxyId)
2020-09-15 14:44:11 +08:00
}
2020-09-16 09:09:21 +08:00
// UpdateReverseProxy 修改是否启用
2021-10-12 20:18:35 +08:00
func (this *ReverseProxyDAO) UpdateReverseProxy(tx *dbs.Tx,
reverseProxyId int64,
requestHostType int8,
requestHost string,
requestHostExcludingPort bool,
2021-10-12 20:18:35 +08:00
requestURI string,
stripPrefix string,
autoFlush bool,
addHeaders []string,
connTimeout *shared.TimeDuration,
readTimeout *shared.TimeDuration,
idleTimeout *shared.TimeDuration,
maxConns int32,
maxIdleConns int32,
2022-03-14 15:42:45 +08:00
proxyProtocolJSON []byte,
followRedirects bool,
retry50X bool,
retry40X bool) error {
if reverseProxyId <= 0 {
return errors.New("invalid reverseProxyId")
}
2022-03-14 15:42:45 +08:00
var op = NewReverseProxyOperator()
op.Id = reverseProxyId
if requestHostType < 0 {
requestHostType = 0
}
op.RequestHostType = requestHostType
op.RequestHost = requestHost
op.RequestHostExcludingPort = requestHostExcludingPort
op.RequestURI = requestURI
op.StripPrefix = stripPrefix
2020-09-27 18:41:15 +08:00
op.AutoFlush = autoFlush
2022-03-14 15:42:45 +08:00
op.FollowRedirects = followRedirects
if len(addHeaders) == 0 {
addHeaders = []string{}
}
addHeadersJSON, err := json.Marshal(addHeaders)
if err != nil {
return err
}
op.AddHeaders = addHeadersJSON
if connTimeout != nil {
connTimeoutJSON, err := connTimeout.AsJSON()
if err != nil {
return err
}
op.ConnTimeout = connTimeoutJSON
}
if readTimeout != nil {
readTimeoutJSON, err := readTimeout.AsJSON()
if err != nil {
return err
}
op.ReadTimeout = readTimeoutJSON
}
if idleTimeout != nil {
idleTimeoutJSON, err := idleTimeout.AsJSON()
if err != nil {
return err
}
op.IdleTimeout = idleTimeoutJSON
}
if maxConns >= 0 {
op.MaxConns = maxConns
} else {
op.MaxConns = 0
}
if maxIdleConns >= 0 {
op.MaxIdleConns = maxIdleConns
} else {
op.MaxIdleConns = 0
}
2021-10-12 20:18:35 +08:00
if len(proxyProtocolJSON) > 0 {
op.ProxyProtocol = proxyProtocolJSON
}
op.Retry50X = retry50X
op.Retry40X = retry40X
err = this.Save(tx, op)
if err != nil {
return err
}
return this.NotifyUpdate(tx, reverseProxyId)
}
// FindReverseProxyContainsOriginId 查找包含某个源站的反向代理ID
func (this *ReverseProxyDAO) FindReverseProxyContainsOriginId(tx *dbs.Tx, originId int64) (int64, error) {
return this.Query(tx).
ResultPk().
Where("(JSON_CONTAINS(primaryOrigins, :jsonQuery) OR JSON_CONTAINS(backupOrigins, :jsonQuery))").
Param("jsonQuery", maps.Map{
"originId": originId,
}.AsJSON()).
FindInt64Col(0)
2020-09-16 09:09:21 +08:00
}
2020-09-26 08:06:40 +08:00
// CheckUserReverseProxy 检查用户权限
func (this *ReverseProxyDAO) CheckUserReverseProxy(tx *dbs.Tx, userId int64, reverseProxyId int64) error {
exists, err := this.Query(tx).
Pk(reverseProxyId).
Attr("userId", userId).
Exist()
if err != nil {
return err
}
if exists {
return nil
}
// 检查server是否为用户的
serverId, err := SharedServerDAO.FindEnabledServerIdWithReverseProxyId(tx, reverseProxyId)
if err != nil {
return err
}
if serverId == 0 {
return ErrNotFound
}
return SharedServerDAO.CheckUserServer(tx, userId, serverId)
}
// NotifyUpdate 通知更新
func (this *ReverseProxyDAO) NotifyUpdate(tx *dbs.Tx, reverseProxyId int64) error {
serverId, err := SharedServerDAO.FindEnabledServerIdWithReverseProxyId(tx, reverseProxyId)
if err != nil {
return err
}
if serverId > 0 {
return SharedServerDAO.NotifyUpdate(tx, serverId)
}
2021-08-01 21:54:44 +08:00
// locations
locationId, err := SharedHTTPLocationDAO.FindEnabledLocationIdWithReverseProxyId(tx, reverseProxyId)
if err != nil {
return err
}
if locationId > 0 {
return SharedHTTPLocationDAO.NotifyUpdate(tx, locationId)
}
// group
groupId, err := SharedServerGroupDAO.FindEnabledGroupIdWithReverseProxyId(tx, reverseProxyId)
if err != nil {
return err
}
if groupId > 0 {
return SharedServerGroupDAO.NotifyUpdate(tx, groupId)
}
return nil
2020-09-26 08:06:40 +08:00
}