diff --git a/internal/db/models/reverse_proxy_dao.go b/internal/db/models/reverse_proxy_dao.go index b3928728..fc0dce48 100644 --- a/internal/db/models/reverse_proxy_dao.go +++ b/internal/db/models/reverse_proxy_dao.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" @@ -149,6 +150,37 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(tx *dbs.Tx, reverseProxyI config.AddHeaders = addHeaders } + // 源站相关默认设置 + config.MaxConns = int(reverseProxy.MaxConns) + config.MaxIdleConns = int(reverseProxy.MaxIdleConns) + + if IsNotNull(reverseProxy.ConnTimeout) { + connTimeout := &shared.TimeDuration{} + err = json.Unmarshal([]byte(reverseProxy.ConnTimeout), &connTimeout) + if err != nil { + return nil, err + } + config.ConnTimeout = connTimeout + } + + if IsNotNull(reverseProxy.ReadTimeout) { + readTimeout := &shared.TimeDuration{} + err = json.Unmarshal([]byte(reverseProxy.ReadTimeout), &readTimeout) + if err != nil { + return nil, err + } + config.ReadTimeout = readTimeout + } + + if IsNotNull(reverseProxy.IdleTimeout) { + idleTimeout := &shared.TimeDuration{} + err = json.Unmarshal([]byte(reverseProxy.IdleTimeout), &idleTimeout) + if err != nil { + return nil, err + } + config.IdleTimeout = idleTimeout + } + return config, nil } @@ -242,7 +274,7 @@ func (this *ReverseProxyDAO) UpdateReverseProxyBackupOrigins(tx *dbs.Tx, reverse } // 修改是否启用 -func (this *ReverseProxyDAO) UpdateReverseProxy(tx *dbs.Tx, reverseProxyId int64, requestHostType int8, requestHost string, requestURI string, stripPrefix string, autoFlush bool, addHeaders []string) error { +func (this *ReverseProxyDAO) UpdateReverseProxy(tx *dbs.Tx, reverseProxyId int64, requestHostType int8, requestHost string, requestURI string, stripPrefix string, autoFlush bool, addHeaders []string, connTimeout *shared.TimeDuration, readTimeout *shared.TimeDuration, idleTimeout *shared.TimeDuration, maxConns int32, maxIdleConns int32) error { if reverseProxyId <= 0 { return errors.New("invalid reverseProxyId") } @@ -269,6 +301,38 @@ func (this *ReverseProxyDAO) UpdateReverseProxy(tx *dbs.Tx, reverseProxyId int64 } op.AddHeaders = addHeadersJSON + if connTimeout != nil { + connTimeoutJSON, err := connTimeout.AsJSON() + if err != nil { + return err + } + op.ConnTimeout = connTimeoutJSON + } + if readTimeout != nil { + readTimeoutJSON, err := readTimeout.AsJSON() + if err != nil { + return err + } + op.ReadTimeout = readTimeoutJSON + } + if idleTimeout != nil { + idleTimeoutJSON, err := idleTimeout.AsJSON() + if err != nil { + return err + } + op.IdleTimeout = idleTimeoutJSON + } + if maxConns >= 0 { + op.MaxConns = maxConns + } else { + op.MaxConns = 0 + } + if maxIdleConns >= 0 { + op.MaxIdleConns = maxIdleConns + } else { + op.MaxIdleConns = 0 + } + err = this.Save(tx, op) if err != nil { return err diff --git a/internal/db/models/reverse_proxy_model.go b/internal/db/models/reverse_proxy_model.go index 533e44ba..573ceed9 100644 --- a/internal/db/models/reverse_proxy_model.go +++ b/internal/db/models/reverse_proxy_model.go @@ -18,6 +18,11 @@ type ReverseProxy struct { AddHeaders string `field:"addHeaders"` // 自动添加的Header列表 State uint8 `field:"state"` // 状态 CreatedAt uint64 `field:"createdAt"` // 创建时间 + ConnTimeout string `field:"connTimeout"` // 连接超时时间 + ReadTimeout string `field:"readTimeout"` // 读取超时时间 + IdleTimeout string `field:"idleTimeout"` // 空闲超时时间 + MaxConns uint32 `field:"maxConns"` // 最大并发连接数 + MaxIdleConns uint32 `field:"maxIdleConns"` // 最大空闲连接数 } type ReverseProxyOperator struct { @@ -37,6 +42,11 @@ type ReverseProxyOperator struct { AddHeaders interface{} // 自动添加的Header列表 State interface{} // 状态 CreatedAt interface{} // 创建时间 + ConnTimeout interface{} // 连接超时时间 + ReadTimeout interface{} // 读取超时时间 + IdleTimeout interface{} // 空闲超时时间 + MaxConns interface{} // 最大并发连接数 + MaxIdleConns interface{} // 最大空闲连接数 } func NewReverseProxyOperator() *ReverseProxyOperator { diff --git a/internal/rpc/services/service_reverse_proxy.go b/internal/rpc/services/service_reverse_proxy.go index 18e026a2..f1df994f 100644 --- a/internal/rpc/services/service_reverse_proxy.go +++ b/internal/rpc/services/service_reverse_proxy.go @@ -5,6 +5,7 @@ import ( "encoding/json" "github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" "github.com/iwind/TeaGo/types" ) @@ -214,7 +215,32 @@ func (this *ReverseProxyService) UpdateReverseProxy(ctx context.Context, req *pb tx := this.NullTx() - err = models.SharedReverseProxyDAO.UpdateReverseProxy(tx, req.ReverseProxyId, types.Int8(req.RequestHostType), req.RequestHost, req.RequestURI, req.StripPrefix, req.AutoFlush, req.AddHeaders) + // 校验参数 + var connTimeout = &shared.TimeDuration{} + if len(req.ConnTimeoutJSON) > 0 { + err = json.Unmarshal(req.ConnTimeoutJSON, connTimeout) + if err != nil { + return nil, err + } + } + + var readTimeout = &shared.TimeDuration{} + if len(req.ReadTimeoutJSON) > 0 { + err = json.Unmarshal(req.ReadTimeoutJSON, readTimeout) + if err != nil { + return nil, err + } + } + + var idleTimeout = &shared.TimeDuration{} + if len(req.IdleTimeoutJSON) > 0 { + err = json.Unmarshal(req.IdleTimeoutJSON, idleTimeout) + if err != nil { + return nil, err + } + } + + err = models.SharedReverseProxyDAO.UpdateReverseProxy(tx, req.ReverseProxyId, types.Int8(req.RequestHostType), req.RequestHost, req.RequestURI, req.StripPrefix, req.AutoFlush, req.AddHeaders, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns) if err != nil { return nil, err }