mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-10 20:50:25 +08:00
实现基本的反向代理
This commit is contained in:
40
internal/nodes/http_client.go
Normal file
40
internal/nodes/http_client.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package nodes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTP客户端
|
||||||
|
type HTTPClient struct {
|
||||||
|
rawClient *http.Client
|
||||||
|
accessAt int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取新客户端对象
|
||||||
|
func NewHTTPClient(rawClient *http.Client) *HTTPClient {
|
||||||
|
return &HTTPClient{
|
||||||
|
rawClient: rawClient,
|
||||||
|
accessAt: utils.UnixTime(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取原始客户端对象
|
||||||
|
func (this *HTTPClient) RawClient() *http.Client {
|
||||||
|
return this.rawClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新访问时间
|
||||||
|
func (this *HTTPClient) UpdateAccessTime() {
|
||||||
|
this.accessAt = utils.UnixTime()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取访问时间
|
||||||
|
func (this *HTTPClient) AccessTime() int64 {
|
||||||
|
return this.accessAt
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭
|
||||||
|
func (this *HTTPClient) Close() {
|
||||||
|
this.rawClient.CloseIdleConnections()
|
||||||
|
}
|
||||||
149
internal/nodes/http_client_pool.go
Normal file
149
internal/nodes/http_client_pool.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package nodes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTP客户端池单例
|
||||||
|
var SharedHTTPClientPool = NewHTTPClientPool()
|
||||||
|
|
||||||
|
// 客户端池
|
||||||
|
type HTTPClientPool struct {
|
||||||
|
clientExpiredDuration time.Duration
|
||||||
|
clientsMap map[string]*HTTPClient // backend key => client
|
||||||
|
locker sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取新对象
|
||||||
|
func NewHTTPClientPool() *HTTPClientPool {
|
||||||
|
pool := &HTTPClientPool{
|
||||||
|
clientExpiredDuration: 3600 * time.Second,
|
||||||
|
clientsMap: map[string]*HTTPClient{},
|
||||||
|
}
|
||||||
|
|
||||||
|
go pool.cleanClients()
|
||||||
|
|
||||||
|
return pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据地址获取客户端
|
||||||
|
func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.OriginConfig) (rawClient *http.Client, realAddr string, err error) {
|
||||||
|
if origin.Addr == nil {
|
||||||
|
return nil, "", errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
key := origin.UniqueKey()
|
||||||
|
originAddr := origin.Addr.PickAddress()
|
||||||
|
if origin.Addr.HostHasVariables() {
|
||||||
|
originAddr = req.Format(originAddr)
|
||||||
|
}
|
||||||
|
key += "@" + originAddr
|
||||||
|
|
||||||
|
this.locker.Lock()
|
||||||
|
defer this.locker.Unlock()
|
||||||
|
|
||||||
|
client, found := this.clientsMap[key]
|
||||||
|
if found {
|
||||||
|
client.UpdateAccessTime()
|
||||||
|
return client.RawClient(), originAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
maxConnections := origin.MaxConns
|
||||||
|
connectionTimeout := origin.ConnTimeoutDuration()
|
||||||
|
readTimeout := origin.ReadTimeoutDuration()
|
||||||
|
idleTimeout := origin.IdleTimeoutDuration()
|
||||||
|
idleConns := origin.MaxIdleConns
|
||||||
|
|
||||||
|
// 超时时间
|
||||||
|
if connectionTimeout <= 0 {
|
||||||
|
connectionTimeout = 15 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
if idleTimeout <= 0 {
|
||||||
|
idleTimeout = 2 * time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
numberCPU := runtime.NumCPU()
|
||||||
|
if numberCPU < 8 {
|
||||||
|
numberCPU = 8
|
||||||
|
}
|
||||||
|
if maxConnections <= 0 {
|
||||||
|
maxConnections = numberCPU * 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if idleConns <= 0 {
|
||||||
|
idleConns = numberCPU
|
||||||
|
}
|
||||||
|
//logs.Println("[ORIGIN]max connections:", maxConnections)
|
||||||
|
|
||||||
|
// TLS通讯
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
}
|
||||||
|
if origin.Cert != nil {
|
||||||
|
obj := origin.Cert.CertObject()
|
||||||
|
if obj != nil {
|
||||||
|
tlsConfig.InsecureSkipVerify = false
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{*obj}
|
||||||
|
if len(origin.Cert.ServerName) > 0 {
|
||||||
|
tlsConfig.ServerName = origin.Cert.ServerName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := &http.Transport{
|
||||||
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
// 握手配置
|
||||||
|
return (&net.Dialer{
|
||||||
|
Timeout: connectionTimeout,
|
||||||
|
KeepAlive: 1 * time.Minute,
|
||||||
|
}).DialContext(ctx, network, originAddr)
|
||||||
|
},
|
||||||
|
MaxIdleConns: 0,
|
||||||
|
MaxIdleConnsPerHost: idleConns,
|
||||||
|
MaxConnsPerHost: maxConnections,
|
||||||
|
IdleConnTimeout: idleTimeout,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 0, // 不限
|
||||||
|
TLSClientConfig: tlsConfig,
|
||||||
|
Proxy: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
rawClient = &http.Client{
|
||||||
|
Timeout: readTimeout,
|
||||||
|
Transport: transport,
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
this.clientsMap[key] = NewHTTPClient(rawClient)
|
||||||
|
|
||||||
|
return rawClient, originAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理不使用的Client
|
||||||
|
func (this *HTTPClientPool) cleanClients() {
|
||||||
|
ticker := time.NewTicker(this.clientExpiredDuration)
|
||||||
|
for range ticker.C {
|
||||||
|
currentAt := time.Now().Unix()
|
||||||
|
|
||||||
|
this.locker.Lock()
|
||||||
|
for k, client := range this.clientsMap {
|
||||||
|
if client.AccessTime() < currentAt+86400 { // 超过 N 秒没有调用就关闭
|
||||||
|
delete(this.clientsMap, k)
|
||||||
|
client.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.locker.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
78
internal/nodes/http_client_pool_test.go
Normal file
78
internal/nodes/http_client_pool_test.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package nodes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHTTPClientPool_Client(t *testing.T) {
|
||||||
|
pool := NewHTTPClientPool()
|
||||||
|
|
||||||
|
{
|
||||||
|
origin := &serverconfigs.OriginConfig{
|
||||||
|
Id: 1,
|
||||||
|
Version: 2,
|
||||||
|
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
|
||||||
|
}
|
||||||
|
err := origin.Init()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
client, addr, err := pool.Client(nil, origin)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("addr:", addr, "client:", client)
|
||||||
|
}
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
client, addr, err := pool.Client(nil, origin)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("addr:", addr, "client:", client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPClientPool_cleanClients(t *testing.T) {
|
||||||
|
origin := &serverconfigs.OriginConfig{
|
||||||
|
Id: 1,
|
||||||
|
Version: 2,
|
||||||
|
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
|
||||||
|
}
|
||||||
|
err := origin.Init()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := NewHTTPClientPool()
|
||||||
|
pool.clientExpiredDuration = 2 * time.Second
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
t.Log("get", i)
|
||||||
|
_, _, _ = pool.Client(nil, origin)
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkHTTPClientPool_Client(b *testing.B) {
|
||||||
|
runtime.GOMAXPROCS(1)
|
||||||
|
|
||||||
|
origin := &serverconfigs.OriginConfig{
|
||||||
|
Id: 1,
|
||||||
|
Version: 2,
|
||||||
|
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
|
||||||
|
}
|
||||||
|
err := origin.Init()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := NewHTTPClientPool()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _ = pool.Client(nil, origin)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
"github.com/iwind/TeaGo/types"
|
"github.com/iwind/TeaGo/types"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -51,6 +52,7 @@ type HTTPRequest struct {
|
|||||||
requestCost float64 // 请求耗时
|
requestCost float64 // 请求耗时
|
||||||
filePath string // 请求的文件名,仅在读取Root目录下的内容时不为空
|
filePath string // 请求的文件名,仅在读取Root目录下的内容时不为空
|
||||||
origin *serverconfigs.OriginConfig // 源站
|
origin *serverconfigs.OriginConfig // 源站
|
||||||
|
originAddr string // 源站实际地址
|
||||||
errors []string // 错误信息
|
errors []string // 错误信息
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,7 +80,7 @@ func (this *HTTPRequest) Do() {
|
|||||||
// Web配置
|
// Web配置
|
||||||
err := this.configureWeb(this.Server.Web, true, 0)
|
err := this.configureWeb(this.Server.Web, true, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
this.write500()
|
this.write500(err)
|
||||||
this.doEnd()
|
this.doEnd()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -232,6 +234,7 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
|
|||||||
if !location.IsOn {
|
if !location.IsOn {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
logs.Println("rawPath:", rawPath, "location:", location.Pattern) // TODO
|
||||||
if varMapping, isMatched := location.Match(rawPath, this.Format); isMatched {
|
if varMapping, isMatched := location.Match(rawPath, this.Format); isMatched {
|
||||||
if len(varMapping) > 0 {
|
if len(varMapping) > 0 {
|
||||||
this.addVarMapping(varMapping)
|
this.addVarMapping(varMapping)
|
||||||
@@ -398,9 +401,9 @@ func (this *HTTPRequest) Format(source string) string {
|
|||||||
if this.origin != nil {
|
if this.origin != nil {
|
||||||
switch suffix {
|
switch suffix {
|
||||||
case "address", "addr":
|
case "address", "addr":
|
||||||
return this.origin.RealAddr()
|
return this.originAddr
|
||||||
case "host":
|
case "host":
|
||||||
addr := this.origin.RealAddr()
|
addr := this.originAddr
|
||||||
index := strings.Index(addr, ":")
|
index := strings.Index(addr, ":")
|
||||||
if index > -1 {
|
if index > -1 {
|
||||||
return addr[:index]
|
return addr[:index]
|
||||||
@@ -674,7 +677,9 @@ func (this *HTTPRequest) requestServerPort() int {
|
|||||||
// 设置代理相关头部信息
|
// 设置代理相关头部信息
|
||||||
// 参考:https://tools.ietf.org/html/rfc7239
|
// 参考:https://tools.ietf.org/html/rfc7239
|
||||||
func (this *HTTPRequest) setForwardHeaders(header http.Header) {
|
func (this *HTTPRequest) setForwardHeaders(header http.Header) {
|
||||||
delete(header, "Connection")
|
if this.RawReq.Header.Get("Connection") == "close" {
|
||||||
|
this.RawReq.Header.Set("Connection", "keep-alive")
|
||||||
|
}
|
||||||
|
|
||||||
remoteAddr := this.RawReq.RemoteAddr
|
remoteAddr := this.RawReq.RemoteAddr
|
||||||
host, _, err := net.SplitHostPort(remoteAddr)
|
host, _, err := net.SplitHostPort(remoteAddr)
|
||||||
@@ -728,6 +733,8 @@ func (this *HTTPRequest) setForwardHeaders(header http.Header) {
|
|||||||
|
|
||||||
// 处理自定义Request Header
|
// 处理自定义Request Header
|
||||||
func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
|
func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
|
||||||
|
this.fixRequestHeader(reqHeader)
|
||||||
|
|
||||||
if this.web.RequestHeaderPolicy != nil && this.web.RequestHeaderPolicy.IsOn {
|
if this.web.RequestHeaderPolicy != nil && this.web.RequestHeaderPolicy.IsOn {
|
||||||
// 删除某些Header
|
// 删除某些Header
|
||||||
for name := range reqHeader {
|
for name := range reqHeader {
|
||||||
@@ -742,12 +749,17 @@ func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
oldValues, _ := this.RawReq.Header[header.Name]
|
oldValues, _ := this.RawReq.Header[header.Name]
|
||||||
|
newHeaderValue := header.Value // 因为我们不能修改header,所以在这里使用新变量
|
||||||
if header.HasVariables() {
|
if header.HasVariables() {
|
||||||
oldValues = append(oldValues, this.Format(header.Value))
|
newHeaderValue = this.Format(header.Value)
|
||||||
} else {
|
|
||||||
oldValues = append(oldValues, header.Value)
|
|
||||||
}
|
}
|
||||||
|
oldValues = append(oldValues, newHeaderValue)
|
||||||
reqHeader[header.Name] = oldValues
|
reqHeader[header.Name] = oldValues
|
||||||
|
|
||||||
|
// 支持修改Host
|
||||||
|
if header.Name == "Host" && len(header.Value) > 0 {
|
||||||
|
this.RawReq.Host = newHeaderValue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set
|
// Set
|
||||||
@@ -755,10 +767,15 @@ func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
|
|||||||
if !header.IsOn {
|
if !header.IsOn {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
newHeaderValue := header.Value // 因为我们不能修改header,所以在这里使用新变量
|
||||||
if header.HasVariables() {
|
if header.HasVariables() {
|
||||||
reqHeader[header.Name] = []string{this.Format(header.Value)}
|
newHeaderValue = this.Format(header.Value)
|
||||||
} else {
|
}
|
||||||
reqHeader[header.Name] = []string{header.Value}
|
reqHeader[header.Name] = []string{newHeaderValue}
|
||||||
|
|
||||||
|
// 支持修改Host
|
||||||
|
if header.Name == "Host" && len(header.Value) > 0 {
|
||||||
|
this.RawReq.Host = newHeaderValue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ func (this *HTTPRequest) write404() {
|
|||||||
_, _ = this.writer.Write([]byte(msg))
|
_, _ = this.writer.Write([]byte(msg))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (this *HTTPRequest) write500() {
|
func (this *HTTPRequest) write500(err error) {
|
||||||
|
if err != nil {
|
||||||
|
this.addError(err)
|
||||||
|
}
|
||||||
|
|
||||||
statusCode := http.StatusInternalServerError
|
statusCode := http.StatusInternalServerError
|
||||||
if this.doPage(statusCode) {
|
if this.doPage(statusCode) {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,7 +1,14 @@
|
|||||||
package nodes
|
package nodes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
|
"io"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -11,9 +18,46 @@ func (this *HTTPRequest) doReverseProxy() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// StripPrefix
|
// 对URL的处理
|
||||||
if len(this.reverseProxy.StripPrefix) > 0 {
|
|
||||||
stripPrefix := this.reverseProxy.StripPrefix
|
stripPrefix := this.reverseProxy.StripPrefix
|
||||||
|
requestURI := this.reverseProxy.RequestURI
|
||||||
|
requestURIHasVariables := this.reverseProxy.RequestURIHasVariables()
|
||||||
|
requestHost := this.reverseProxy.RequestHost
|
||||||
|
requestHostHasVariables := this.reverseProxy.RequestHostHasVariables()
|
||||||
|
|
||||||
|
// 源站
|
||||||
|
requestCall := shared.NewRequestCall()
|
||||||
|
origin := this.reverseProxy.NextOrigin(requestCall)
|
||||||
|
if origin == nil {
|
||||||
|
err := errors.New(this.requestPath() + ": no available backends for reverse proxy")
|
||||||
|
logs.Error(err)
|
||||||
|
this.write500(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
this.origin = origin // 设置全局变量是为了日志等处理
|
||||||
|
if len(origin.StripPrefix) > 0 {
|
||||||
|
stripPrefix = origin.StripPrefix
|
||||||
|
}
|
||||||
|
if len(origin.RequestURI) > 0 {
|
||||||
|
requestURI = origin.RequestURI
|
||||||
|
requestURIHasVariables = origin.RequestURIHasVariables()
|
||||||
|
}
|
||||||
|
if len(origin.RequestHost) > 0 {
|
||||||
|
requestHost = origin.RequestHost
|
||||||
|
requestHostHasVariables = origin.RequestHostHasVariables()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理Scheme
|
||||||
|
if origin.Addr == nil {
|
||||||
|
err := errors.New(this.requestPath() + ": origin '" + strconv.FormatInt(origin.Id, 10) + "' does not has a address")
|
||||||
|
logs.Error(err)
|
||||||
|
this.write500(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
this.RawReq.URL.Scheme = origin.Addr.Protocol.Primary().Scheme()
|
||||||
|
|
||||||
|
// StripPrefix
|
||||||
|
if len(stripPrefix) > 0 {
|
||||||
if stripPrefix[0] != '/' {
|
if stripPrefix[0] != '/' {
|
||||||
stripPrefix = "/" + stripPrefix
|
stripPrefix = "/" + stripPrefix
|
||||||
}
|
}
|
||||||
@@ -24,11 +68,11 @@ func (this *HTTPRequest) doReverseProxy() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RequestURI
|
// RequestURI
|
||||||
if len(this.reverseProxy.RequestURI) > 0 {
|
if len(requestURI) > 0 {
|
||||||
if this.reverseProxy.RequestURIHasVariables() {
|
if requestURIHasVariables {
|
||||||
this.uri = this.Format(this.reverseProxy.RequestURI)
|
this.uri = this.Format(requestURI)
|
||||||
} else {
|
} else {
|
||||||
this.uri = this.reverseProxy.RequestURI
|
this.uri = requestURI
|
||||||
}
|
}
|
||||||
if len(this.uri) == 0 || this.uri[0] != '/' {
|
if len(this.uri) == 0 || this.uri[0] != '/' {
|
||||||
this.uri = "/" + this.uri
|
this.uri = "/" + this.uri
|
||||||
@@ -47,6 +91,18 @@ func (this *HTTPRequest) doReverseProxy() {
|
|||||||
this.uri = utils.CleanPath(this.uri)
|
this.uri = utils.CleanPath(this.uri)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RequestHost
|
||||||
|
if len(requestHost) > 0 {
|
||||||
|
if requestHostHasVariables {
|
||||||
|
this.RawReq.Host = this.Format(requestHost)
|
||||||
|
} else {
|
||||||
|
this.RawReq.Host = this.reverseProxy.RequestHost
|
||||||
|
}
|
||||||
|
this.RawReq.URL.Host = this.RawReq.Host
|
||||||
|
} else {
|
||||||
|
this.RawReq.URL.Host = this.Host
|
||||||
|
}
|
||||||
|
|
||||||
// 重组请求URL
|
// 重组请求URL
|
||||||
questionMark := strings.Index(this.uri, "?")
|
questionMark := strings.Index(this.uri, "?")
|
||||||
if questionMark > -1 {
|
if questionMark > -1 {
|
||||||
@@ -56,16 +112,11 @@ func (this *HTTPRequest) doReverseProxy() {
|
|||||||
this.RawReq.URL.Path = this.uri
|
this.RawReq.URL.Path = this.uri
|
||||||
this.RawReq.URL.RawQuery = ""
|
this.RawReq.URL.RawQuery = ""
|
||||||
}
|
}
|
||||||
|
this.RawReq.RequestURI = ""
|
||||||
|
|
||||||
// RequestHost
|
// 处理Header
|
||||||
if len(this.reverseProxy.RequestHost) > 0 {
|
this.setForwardHeaders(this.RawReq.Header)
|
||||||
if this.reverseProxy.RequestHostHasVariables() {
|
this.processRequestHeaders(this.RawReq.Header)
|
||||||
this.RawReq.Host = this.Format(this.reverseProxy.RequestHost)
|
|
||||||
} else {
|
|
||||||
this.RawReq.Host = this.reverseProxy.RequestHost
|
|
||||||
}
|
|
||||||
this.RawReq.URL.Host = this.RawReq.Host
|
|
||||||
}
|
|
||||||
|
|
||||||
// 判断是否为Websocket请求
|
// 判断是否为Websocket请求
|
||||||
if this.RawReq.Header.Get("Upgrade") == "websocket" {
|
if this.RawReq.Header.Get("Upgrade") == "websocket" {
|
||||||
@@ -73,6 +124,110 @@ func (this *HTTPRequest) doReverseProxy() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 普通HTTP请求
|
// 获取请求客户端
|
||||||
// TODO
|
client, addr, err := SharedHTTPClientPool.Client(this, origin)
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
this.write500(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
this.originAddr = addr
|
||||||
|
|
||||||
|
// 开始请求
|
||||||
|
resp, err := client.Do(this.RawReq)
|
||||||
|
if err != nil {
|
||||||
|
// 客户端取消请求,则不提示
|
||||||
|
httpErr, ok := err.(*url.Error)
|
||||||
|
if !ok || httpErr.Err != context.Canceled {
|
||||||
|
// TODO 如果超过最大失败次数,则下线
|
||||||
|
|
||||||
|
this.write500(err)
|
||||||
|
logs.Println("[proxy]'" + this.RawReq.URL.String() + "': " + err.Error())
|
||||||
|
} else {
|
||||||
|
// 是否为客户端方面的错误
|
||||||
|
isClientError := false
|
||||||
|
if ok {
|
||||||
|
if httpErr.Err == context.Canceled {
|
||||||
|
isClientError = true
|
||||||
|
this.addError(errors.New(httpErr.Op + " " + httpErr.URL + ": client closed the connection"))
|
||||||
|
this.writer.WriteHeader(499) // 仿照nginx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isClientError {
|
||||||
|
this.write500(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resp != nil && resp.Body != nil {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// WAF对出站进行检查
|
||||||
|
// TODO
|
||||||
|
|
||||||
|
// TODO 清除源站错误次数
|
||||||
|
|
||||||
|
// 特殊页面
|
||||||
|
// TODO
|
||||||
|
|
||||||
|
// 设置Charset
|
||||||
|
// TODO 这里应该可以设置文本类型的列表,以及是否强制覆盖所有文本类型的字符集
|
||||||
|
if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
|
||||||
|
contentTypes, ok := resp.Header["Content-Type"]
|
||||||
|
if ok && len(contentTypes) > 0 {
|
||||||
|
contentType := contentTypes[0]
|
||||||
|
if _, found := textMimeMap[contentType]; found {
|
||||||
|
resp.Header["Content-Type"][0] = contentType + "; charset=" + this.web.Charset.Charset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 响应Header
|
||||||
|
this.writer.AddHeaders(resp.Header)
|
||||||
|
this.processResponseHeaders(resp.StatusCode)
|
||||||
|
|
||||||
|
// 是否需要刷新
|
||||||
|
shouldFlush := this.RawReq.Header.Get("Accept") == "text/event-stream"
|
||||||
|
|
||||||
|
// 准备
|
||||||
|
this.writer.Prepare(resp.ContentLength)
|
||||||
|
|
||||||
|
// 设置响应代码
|
||||||
|
this.writer.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
|
// 输出到客户端
|
||||||
|
pool := this.bytePool(resp.ContentLength)
|
||||||
|
buf := pool.Get()
|
||||||
|
if shouldFlush {
|
||||||
|
for {
|
||||||
|
n, readErr := resp.Body.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
_, err = this.writer.Write(buf[:n])
|
||||||
|
this.writer.Flush()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if readErr != nil {
|
||||||
|
err = readErr
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_, err = io.CopyBuffer(this.writer, resp.Body, buf)
|
||||||
|
}
|
||||||
|
pool.Put(buf)
|
||||||
|
|
||||||
|
err1 := resp.Body.Close()
|
||||||
|
if err1 != nil {
|
||||||
|
logs.Error(err1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
this.addError(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,9 +107,8 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
this.write500()
|
this.write500(err)
|
||||||
logs.Error(err)
|
logs.Error(err)
|
||||||
this.addError(err)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -137,9 +136,8 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
this.write500()
|
this.write500(err)
|
||||||
logs.Error(err)
|
logs.Error(err)
|
||||||
this.addError(err)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -220,9 +218,8 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
|||||||
|
|
||||||
reader, err := os.OpenFile(filePath, os.O_RDONLY, 0444)
|
reader, err := os.OpenFile(filePath, os.O_RDONLY, 0444)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
this.write500()
|
this.write500(err)
|
||||||
logs.Error(err)
|
logs.Error(err)
|
||||||
this.addError(err)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func TestHTTPRequest_RedirectToHTTPS(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
req.Run()
|
req.Do()
|
||||||
a.IsBool(req.web.RedirectToHttps.IsOn == false)
|
a.IsBool(req.web.RedirectToHttps.IsOn == false)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
@@ -29,7 +29,7 @@ func TestHTTPRequest_RedirectToHTTPS(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
req.Run()
|
req.Do()
|
||||||
a.IsBool(req.web.RedirectToHttps.IsOn == true)
|
a.IsBool(req.web.RedirectToHttps.IsOn == true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,8 +35,7 @@ func (this *HTTPRequest) doURL(method string, url string, host string, statusCod
|
|||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logs.Error(errors.New(req.URL.String() + ": " + err.Error()))
|
logs.Error(errors.New(req.URL.String() + ": " + err.Error()))
|
||||||
this.addError(err)
|
this.write500(err)
|
||||||
this.write500()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package nodes
|
package nodes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
|
|
||||||
"github.com/iwind/TeaGo/logs"
|
"github.com/iwind/TeaGo/logs"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -30,20 +28,6 @@ func (this *HTTPRequest) doWebsocket() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
requestCall := shared.NewRequestCall()
|
|
||||||
origin := this.reverseProxy.NextOrigin(requestCall)
|
|
||||||
if origin == nil {
|
|
||||||
err := errors.New(this.requestPath() + ": no available backends for websocket")
|
|
||||||
logs.Error(err)
|
|
||||||
this.addError(err)
|
|
||||||
this.write500()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理Header
|
|
||||||
this.processRequestHeaders(this.RawReq.Header)
|
|
||||||
this.fixRequestHeader(this.RawReq.Header) // 处理 Websocket -> WebSocket
|
|
||||||
|
|
||||||
// 设置指定的来源域
|
// 设置指定的来源域
|
||||||
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
|
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
|
||||||
newRequestOrigin := this.web.Websocket.RequestOrigin
|
newRequestOrigin := this.web.Websocket.RequestOrigin
|
||||||
@@ -54,11 +38,10 @@ func (this *HTTPRequest) doWebsocket() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO 增加N次错误重试,重试的时候需要尝试不同的源站
|
// TODO 增加N次错误重试,重试的时候需要尝试不同的源站
|
||||||
originConn, err := OriginConnect(origin)
|
originConn, err := OriginConnect(this.origin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logs.Error(err)
|
logs.Error(err)
|
||||||
this.addError(err)
|
this.write500(err)
|
||||||
this.write500()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -68,16 +51,14 @@ func (this *HTTPRequest) doWebsocket() {
|
|||||||
err = this.RawReq.Write(originConn)
|
err = this.RawReq.Write(originConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logs.Error(err)
|
logs.Error(err)
|
||||||
this.addError(err)
|
this.write500(err)
|
||||||
this.write500()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
clientConn, _, err := this.writer.Hijack()
|
clientConn, _, err := this.writer.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logs.Error(err)
|
logs.Error(err)
|
||||||
this.addError(err)
|
this.write500(err)
|
||||||
this.write500()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|||||||
@@ -118,6 +118,9 @@ func (this *HTTPWriter) AddHeaders(header http.Header) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
for key, value := range header {
|
for key, value := range header {
|
||||||
|
if key == "Connection" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
for _, v := range value {
|
for _, v := range value {
|
||||||
this.writer.Header().Add(key, v)
|
this.writer.Header().Add(key, v)
|
||||||
}
|
}
|
||||||
|
|||||||
26
internal/utils/time.go
Normal file
26
internal/utils/time.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var unixTime = time.Now().Unix()
|
||||||
|
var unixTimerIsReady = false
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
ticker := time.NewTicker(500 * time.Millisecond)
|
||||||
|
go func() {
|
||||||
|
for range ticker.C {
|
||||||
|
unixTimerIsReady = true
|
||||||
|
unixTime = time.Now().Unix()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 最快获取时间戳的方式,通常用在不需要特别精确时间戳的场景
|
||||||
|
func UnixTime() int64 {
|
||||||
|
if unixTimerIsReady {
|
||||||
|
return unixTime
|
||||||
|
}
|
||||||
|
return time.Now().Unix()
|
||||||
|
}
|
||||||
13
internal/utils/time_test.go
Normal file
13
internal/utils/time_test.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUnixTime(t *testing.T) {
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
t.Log(UnixTime(), "real:", time.Now().Unix())
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user