mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-10 20:50:25 +08:00
实现基本的反向代理
This commit is contained in:
40
internal/nodes/http_client.go
Normal file
40
internal/nodes/http_client.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// HTTP客户端
|
||||
type HTTPClient struct {
|
||||
rawClient *http.Client
|
||||
accessAt int64
|
||||
}
|
||||
|
||||
// 获取新客户端对象
|
||||
func NewHTTPClient(rawClient *http.Client) *HTTPClient {
|
||||
return &HTTPClient{
|
||||
rawClient: rawClient,
|
||||
accessAt: utils.UnixTime(),
|
||||
}
|
||||
}
|
||||
|
||||
// 获取原始客户端对象
|
||||
func (this *HTTPClient) RawClient() *http.Client {
|
||||
return this.rawClient
|
||||
}
|
||||
|
||||
// 更新访问时间
|
||||
func (this *HTTPClient) UpdateAccessTime() {
|
||||
this.accessAt = utils.UnixTime()
|
||||
}
|
||||
|
||||
// 获取访问时间
|
||||
func (this *HTTPClient) AccessTime() int64 {
|
||||
return this.accessAt
|
||||
}
|
||||
|
||||
// 关闭
|
||||
func (this *HTTPClient) Close() {
|
||||
this.rawClient.CloseIdleConnections()
|
||||
}
|
||||
149
internal/nodes/http_client_pool.go
Normal file
149
internal/nodes/http_client_pool.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTP客户端池单例
|
||||
var SharedHTTPClientPool = NewHTTPClientPool()
|
||||
|
||||
// 客户端池
|
||||
type HTTPClientPool struct {
|
||||
clientExpiredDuration time.Duration
|
||||
clientsMap map[string]*HTTPClient // backend key => client
|
||||
locker sync.Mutex
|
||||
}
|
||||
|
||||
// 获取新对象
|
||||
func NewHTTPClientPool() *HTTPClientPool {
|
||||
pool := &HTTPClientPool{
|
||||
clientExpiredDuration: 3600 * time.Second,
|
||||
clientsMap: map[string]*HTTPClient{},
|
||||
}
|
||||
|
||||
go pool.cleanClients()
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
// 根据地址获取客户端
|
||||
func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.OriginConfig) (rawClient *http.Client, realAddr string, err error) {
|
||||
if origin.Addr == nil {
|
||||
return nil, "", errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")")
|
||||
}
|
||||
|
||||
key := origin.UniqueKey()
|
||||
originAddr := origin.Addr.PickAddress()
|
||||
if origin.Addr.HostHasVariables() {
|
||||
originAddr = req.Format(originAddr)
|
||||
}
|
||||
key += "@" + originAddr
|
||||
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
client, found := this.clientsMap[key]
|
||||
if found {
|
||||
client.UpdateAccessTime()
|
||||
return client.RawClient(), originAddr, nil
|
||||
}
|
||||
|
||||
maxConnections := origin.MaxConns
|
||||
connectionTimeout := origin.ConnTimeoutDuration()
|
||||
readTimeout := origin.ReadTimeoutDuration()
|
||||
idleTimeout := origin.IdleTimeoutDuration()
|
||||
idleConns := origin.MaxIdleConns
|
||||
|
||||
// 超时时间
|
||||
if connectionTimeout <= 0 {
|
||||
connectionTimeout = 15 * time.Second
|
||||
}
|
||||
|
||||
if idleTimeout <= 0 {
|
||||
idleTimeout = 2 * time.Minute
|
||||
}
|
||||
|
||||
numberCPU := runtime.NumCPU()
|
||||
if numberCPU < 8 {
|
||||
numberCPU = 8
|
||||
}
|
||||
if maxConnections <= 0 {
|
||||
maxConnections = numberCPU * 2
|
||||
}
|
||||
|
||||
if idleConns <= 0 {
|
||||
idleConns = numberCPU
|
||||
}
|
||||
//logs.Println("[ORIGIN]max connections:", maxConnections)
|
||||
|
||||
// TLS通讯
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
if origin.Cert != nil {
|
||||
obj := origin.Cert.CertObject()
|
||||
if obj != nil {
|
||||
tlsConfig.InsecureSkipVerify = false
|
||||
tlsConfig.Certificates = []tls.Certificate{*obj}
|
||||
if len(origin.Cert.ServerName) > 0 {
|
||||
tlsConfig.ServerName = origin.Cert.ServerName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// 握手配置
|
||||
return (&net.Dialer{
|
||||
Timeout: connectionTimeout,
|
||||
KeepAlive: 1 * time.Minute,
|
||||
}).DialContext(ctx, network, originAddr)
|
||||
},
|
||||
MaxIdleConns: 0,
|
||||
MaxIdleConnsPerHost: idleConns,
|
||||
MaxConnsPerHost: maxConnections,
|
||||
IdleConnTimeout: idleTimeout,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSHandshakeTimeout: 0, // 不限
|
||||
TLSClientConfig: tlsConfig,
|
||||
Proxy: nil,
|
||||
}
|
||||
|
||||
rawClient = &http.Client{
|
||||
Timeout: readTimeout,
|
||||
Transport: transport,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
this.clientsMap[key] = NewHTTPClient(rawClient)
|
||||
|
||||
return rawClient, originAddr, nil
|
||||
}
|
||||
|
||||
// 清理不使用的Client
|
||||
func (this *HTTPClientPool) cleanClients() {
|
||||
ticker := time.NewTicker(this.clientExpiredDuration)
|
||||
for range ticker.C {
|
||||
currentAt := time.Now().Unix()
|
||||
|
||||
this.locker.Lock()
|
||||
for k, client := range this.clientsMap {
|
||||
if client.AccessTime() < currentAt+86400 { // 超过 N 秒没有调用就关闭
|
||||
delete(this.clientsMap, k)
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
this.locker.Unlock()
|
||||
}
|
||||
}
|
||||
78
internal/nodes/http_client_pool_test.go
Normal file
78
internal/nodes/http_client_pool_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestHTTPClientPool_Client(t *testing.T) {
|
||||
pool := NewHTTPClientPool()
|
||||
|
||||
{
|
||||
origin := &serverconfigs.OriginConfig{
|
||||
Id: 1,
|
||||
Version: 2,
|
||||
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
|
||||
}
|
||||
err := origin.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
{
|
||||
client, addr, err := pool.Client(nil, origin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("addr:", addr, "client:", client)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
client, addr, err := pool.Client(nil, origin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("addr:", addr, "client:", client)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClientPool_cleanClients(t *testing.T) {
|
||||
origin := &serverconfigs.OriginConfig{
|
||||
Id: 1,
|
||||
Version: 2,
|
||||
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
|
||||
}
|
||||
err := origin.Init()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pool := NewHTTPClientPool()
|
||||
pool.clientExpiredDuration = 2 * time.Second
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
t.Log("get", i)
|
||||
_, _, _ = pool.Client(nil, origin)
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHTTPClientPool_Client(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
origin := &serverconfigs.OriginConfig{
|
||||
Id: 1,
|
||||
Version: 2,
|
||||
Addr: &serverconfigs.NetworkAddressConfig{Host: "127.0.0.1", PortRange: "1234"},
|
||||
}
|
||||
err := origin.Init()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
pool := NewHTTPClientPool()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = pool.Client(nil, origin)
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -51,6 +52,7 @@ type HTTPRequest struct {
|
||||
requestCost float64 // 请求耗时
|
||||
filePath string // 请求的文件名,仅在读取Root目录下的内容时不为空
|
||||
origin *serverconfigs.OriginConfig // 源站
|
||||
originAddr string // 源站实际地址
|
||||
errors []string // 错误信息
|
||||
}
|
||||
|
||||
@@ -78,7 +80,7 @@ func (this *HTTPRequest) Do() {
|
||||
// Web配置
|
||||
err := this.configureWeb(this.Server.Web, true, 0)
|
||||
if err != nil {
|
||||
this.write500()
|
||||
this.write500(err)
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
@@ -232,6 +234,7 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
|
||||
if !location.IsOn {
|
||||
continue
|
||||
}
|
||||
logs.Println("rawPath:", rawPath, "location:", location.Pattern) // TODO
|
||||
if varMapping, isMatched := location.Match(rawPath, this.Format); isMatched {
|
||||
if len(varMapping) > 0 {
|
||||
this.addVarMapping(varMapping)
|
||||
@@ -398,9 +401,9 @@ func (this *HTTPRequest) Format(source string) string {
|
||||
if this.origin != nil {
|
||||
switch suffix {
|
||||
case "address", "addr":
|
||||
return this.origin.RealAddr()
|
||||
return this.originAddr
|
||||
case "host":
|
||||
addr := this.origin.RealAddr()
|
||||
addr := this.originAddr
|
||||
index := strings.Index(addr, ":")
|
||||
if index > -1 {
|
||||
return addr[:index]
|
||||
@@ -674,7 +677,9 @@ func (this *HTTPRequest) requestServerPort() int {
|
||||
// 设置代理相关头部信息
|
||||
// 参考:https://tools.ietf.org/html/rfc7239
|
||||
func (this *HTTPRequest) setForwardHeaders(header http.Header) {
|
||||
delete(header, "Connection")
|
||||
if this.RawReq.Header.Get("Connection") == "close" {
|
||||
this.RawReq.Header.Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
remoteAddr := this.RawReq.RemoteAddr
|
||||
host, _, err := net.SplitHostPort(remoteAddr)
|
||||
@@ -728,6 +733,8 @@ func (this *HTTPRequest) setForwardHeaders(header http.Header) {
|
||||
|
||||
// 处理自定义Request Header
|
||||
func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
|
||||
this.fixRequestHeader(reqHeader)
|
||||
|
||||
if this.web.RequestHeaderPolicy != nil && this.web.RequestHeaderPolicy.IsOn {
|
||||
// 删除某些Header
|
||||
for name := range reqHeader {
|
||||
@@ -742,12 +749,17 @@ func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
|
||||
continue
|
||||
}
|
||||
oldValues, _ := this.RawReq.Header[header.Name]
|
||||
newHeaderValue := header.Value // 因为我们不能修改header,所以在这里使用新变量
|
||||
if header.HasVariables() {
|
||||
oldValues = append(oldValues, this.Format(header.Value))
|
||||
} else {
|
||||
oldValues = append(oldValues, header.Value)
|
||||
newHeaderValue = this.Format(header.Value)
|
||||
}
|
||||
oldValues = append(oldValues, newHeaderValue)
|
||||
reqHeader[header.Name] = oldValues
|
||||
|
||||
// 支持修改Host
|
||||
if header.Name == "Host" && len(header.Value) > 0 {
|
||||
this.RawReq.Host = newHeaderValue
|
||||
}
|
||||
}
|
||||
|
||||
// Set
|
||||
@@ -755,10 +767,15 @@ func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
|
||||
if !header.IsOn {
|
||||
continue
|
||||
}
|
||||
newHeaderValue := header.Value // 因为我们不能修改header,所以在这里使用新变量
|
||||
if header.HasVariables() {
|
||||
reqHeader[header.Name] = []string{this.Format(header.Value)}
|
||||
} else {
|
||||
reqHeader[header.Name] = []string{header.Value}
|
||||
newHeaderValue = this.Format(header.Value)
|
||||
}
|
||||
reqHeader[header.Name] = []string{newHeaderValue}
|
||||
|
||||
// 支持修改Host
|
||||
if header.Name == "Host" && len(header.Value) > 0 {
|
||||
this.RawReq.Host = newHeaderValue
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,11 @@ func (this *HTTPRequest) write404() {
|
||||
_, _ = this.writer.Write([]byte(msg))
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) write500() {
|
||||
func (this *HTTPRequest) write500(err error) {
|
||||
if err != nil {
|
||||
this.addError(err)
|
||||
}
|
||||
|
||||
statusCode := http.StatusInternalServerError
|
||||
if this.doPage(statusCode) {
|
||||
return
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"io"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -11,9 +18,46 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
return
|
||||
}
|
||||
|
||||
// StripPrefix
|
||||
if len(this.reverseProxy.StripPrefix) > 0 {
|
||||
// 对URL的处理
|
||||
stripPrefix := this.reverseProxy.StripPrefix
|
||||
requestURI := this.reverseProxy.RequestURI
|
||||
requestURIHasVariables := this.reverseProxy.RequestURIHasVariables()
|
||||
requestHost := this.reverseProxy.RequestHost
|
||||
requestHostHasVariables := this.reverseProxy.RequestHostHasVariables()
|
||||
|
||||
// 源站
|
||||
requestCall := shared.NewRequestCall()
|
||||
origin := this.reverseProxy.NextOrigin(requestCall)
|
||||
if origin == nil {
|
||||
err := errors.New(this.requestPath() + ": no available backends for reverse proxy")
|
||||
logs.Error(err)
|
||||
this.write500(err)
|
||||
return
|
||||
}
|
||||
this.origin = origin // 设置全局变量是为了日志等处理
|
||||
if len(origin.StripPrefix) > 0 {
|
||||
stripPrefix = origin.StripPrefix
|
||||
}
|
||||
if len(origin.RequestURI) > 0 {
|
||||
requestURI = origin.RequestURI
|
||||
requestURIHasVariables = origin.RequestURIHasVariables()
|
||||
}
|
||||
if len(origin.RequestHost) > 0 {
|
||||
requestHost = origin.RequestHost
|
||||
requestHostHasVariables = origin.RequestHostHasVariables()
|
||||
}
|
||||
|
||||
// 处理Scheme
|
||||
if origin.Addr == nil {
|
||||
err := errors.New(this.requestPath() + ": origin '" + strconv.FormatInt(origin.Id, 10) + "' does not has a address")
|
||||
logs.Error(err)
|
||||
this.write500(err)
|
||||
return
|
||||
}
|
||||
this.RawReq.URL.Scheme = origin.Addr.Protocol.Primary().Scheme()
|
||||
|
||||
// StripPrefix
|
||||
if len(stripPrefix) > 0 {
|
||||
if stripPrefix[0] != '/' {
|
||||
stripPrefix = "/" + stripPrefix
|
||||
}
|
||||
@@ -24,11 +68,11 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
}
|
||||
|
||||
// RequestURI
|
||||
if len(this.reverseProxy.RequestURI) > 0 {
|
||||
if this.reverseProxy.RequestURIHasVariables() {
|
||||
this.uri = this.Format(this.reverseProxy.RequestURI)
|
||||
if len(requestURI) > 0 {
|
||||
if requestURIHasVariables {
|
||||
this.uri = this.Format(requestURI)
|
||||
} else {
|
||||
this.uri = this.reverseProxy.RequestURI
|
||||
this.uri = requestURI
|
||||
}
|
||||
if len(this.uri) == 0 || this.uri[0] != '/' {
|
||||
this.uri = "/" + this.uri
|
||||
@@ -47,6 +91,18 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
this.uri = utils.CleanPath(this.uri)
|
||||
}
|
||||
|
||||
// RequestHost
|
||||
if len(requestHost) > 0 {
|
||||
if requestHostHasVariables {
|
||||
this.RawReq.Host = this.Format(requestHost)
|
||||
} else {
|
||||
this.RawReq.Host = this.reverseProxy.RequestHost
|
||||
}
|
||||
this.RawReq.URL.Host = this.RawReq.Host
|
||||
} else {
|
||||
this.RawReq.URL.Host = this.Host
|
||||
}
|
||||
|
||||
// 重组请求URL
|
||||
questionMark := strings.Index(this.uri, "?")
|
||||
if questionMark > -1 {
|
||||
@@ -56,16 +112,11 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
this.RawReq.URL.Path = this.uri
|
||||
this.RawReq.URL.RawQuery = ""
|
||||
}
|
||||
this.RawReq.RequestURI = ""
|
||||
|
||||
// RequestHost
|
||||
if len(this.reverseProxy.RequestHost) > 0 {
|
||||
if this.reverseProxy.RequestHostHasVariables() {
|
||||
this.RawReq.Host = this.Format(this.reverseProxy.RequestHost)
|
||||
} else {
|
||||
this.RawReq.Host = this.reverseProxy.RequestHost
|
||||
}
|
||||
this.RawReq.URL.Host = this.RawReq.Host
|
||||
}
|
||||
// 处理Header
|
||||
this.setForwardHeaders(this.RawReq.Header)
|
||||
this.processRequestHeaders(this.RawReq.Header)
|
||||
|
||||
// 判断是否为Websocket请求
|
||||
if this.RawReq.Header.Get("Upgrade") == "websocket" {
|
||||
@@ -73,6 +124,110 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
return
|
||||
}
|
||||
|
||||
// 普通HTTP请求
|
||||
// 获取请求客户端
|
||||
client, addr, err := SharedHTTPClientPool.Client(this, origin)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
this.write500(err)
|
||||
return
|
||||
}
|
||||
|
||||
this.originAddr = addr
|
||||
|
||||
// 开始请求
|
||||
resp, err := client.Do(this.RawReq)
|
||||
if err != nil {
|
||||
// 客户端取消请求,则不提示
|
||||
httpErr, ok := err.(*url.Error)
|
||||
if !ok || httpErr.Err != context.Canceled {
|
||||
// TODO 如果超过最大失败次数,则下线
|
||||
|
||||
this.write500(err)
|
||||
logs.Println("[proxy]'" + this.RawReq.URL.String() + "': " + err.Error())
|
||||
} else {
|
||||
// 是否为客户端方面的错误
|
||||
isClientError := false
|
||||
if ok {
|
||||
if httpErr.Err == context.Canceled {
|
||||
isClientError = true
|
||||
this.addError(errors.New(httpErr.Op + " " + httpErr.URL + ": client closed the connection"))
|
||||
this.writer.WriteHeader(499) // 仿照nginx
|
||||
}
|
||||
}
|
||||
|
||||
if !isClientError {
|
||||
this.write500(err)
|
||||
}
|
||||
}
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// WAF对出站进行检查
|
||||
// TODO
|
||||
|
||||
// TODO 清除源站错误次数
|
||||
|
||||
// 特殊页面
|
||||
// TODO
|
||||
|
||||
// 设置Charset
|
||||
// TODO 这里应该可以设置文本类型的列表,以及是否强制覆盖所有文本类型的字符集
|
||||
if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
|
||||
contentTypes, ok := resp.Header["Content-Type"]
|
||||
if ok && len(contentTypes) > 0 {
|
||||
contentType := contentTypes[0]
|
||||
if _, found := textMimeMap[contentType]; found {
|
||||
resp.Header["Content-Type"][0] = contentType + "; charset=" + this.web.Charset.Charset
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 响应Header
|
||||
this.writer.AddHeaders(resp.Header)
|
||||
this.processResponseHeaders(resp.StatusCode)
|
||||
|
||||
// 是否需要刷新
|
||||
shouldFlush := this.RawReq.Header.Get("Accept") == "text/event-stream"
|
||||
|
||||
// 准备
|
||||
this.writer.Prepare(resp.ContentLength)
|
||||
|
||||
// 设置响应代码
|
||||
this.writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
// 输出到客户端
|
||||
pool := this.bytePool(resp.ContentLength)
|
||||
buf := pool.Get()
|
||||
if shouldFlush {
|
||||
for {
|
||||
n, readErr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
_, err = this.writer.Write(buf[:n])
|
||||
this.writer.Flush()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
err = readErr
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
_, err = io.CopyBuffer(this.writer, resp.Body, buf)
|
||||
}
|
||||
pool.Put(buf)
|
||||
|
||||
err1 := resp.Body.Close()
|
||||
if err1 != nil {
|
||||
logs.Error(err1)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
this.addError(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,9 +107,8 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
}
|
||||
return
|
||||
} else {
|
||||
this.write500()
|
||||
this.write500(err)
|
||||
logs.Error(err)
|
||||
this.addError(err)
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -137,9 +136,8 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
}
|
||||
return
|
||||
} else {
|
||||
this.write500()
|
||||
this.write500(err)
|
||||
logs.Error(err)
|
||||
this.addError(err)
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -220,9 +218,8 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
|
||||
reader, err := os.OpenFile(filePath, os.O_RDONLY, 0444)
|
||||
if err != nil {
|
||||
this.write500()
|
||||
this.write500(err)
|
||||
logs.Error(err)
|
||||
this.addError(err)
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ func TestHTTPRequest_RedirectToHTTPS(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
req.Run()
|
||||
req.Do()
|
||||
a.IsBool(req.web.RedirectToHttps.IsOn == false)
|
||||
}
|
||||
{
|
||||
@@ -29,7 +29,7 @@ func TestHTTPRequest_RedirectToHTTPS(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
req.Run()
|
||||
req.Do()
|
||||
a.IsBool(req.web.RedirectToHttps.IsOn == true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,8 +35,7 @@ func (this *HTTPRequest) doURL(method string, url string, host string, statusCod
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logs.Error(errors.New(req.URL.String() + ": " + err.Error()))
|
||||
this.addError(err)
|
||||
this.write500()
|
||||
this.write500(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -30,20 +28,6 @@ func (this *HTTPRequest) doWebsocket() {
|
||||
}
|
||||
}
|
||||
|
||||
requestCall := shared.NewRequestCall()
|
||||
origin := this.reverseProxy.NextOrigin(requestCall)
|
||||
if origin == nil {
|
||||
err := errors.New(this.requestPath() + ": no available backends for websocket")
|
||||
logs.Error(err)
|
||||
this.addError(err)
|
||||
this.write500()
|
||||
return
|
||||
}
|
||||
|
||||
// 处理Header
|
||||
this.processRequestHeaders(this.RawReq.Header)
|
||||
this.fixRequestHeader(this.RawReq.Header) // 处理 Websocket -> WebSocket
|
||||
|
||||
// 设置指定的来源域
|
||||
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
|
||||
newRequestOrigin := this.web.Websocket.RequestOrigin
|
||||
@@ -54,11 +38,10 @@ func (this *HTTPRequest) doWebsocket() {
|
||||
}
|
||||
|
||||
// TODO 增加N次错误重试,重试的时候需要尝试不同的源站
|
||||
originConn, err := OriginConnect(origin)
|
||||
originConn, err := OriginConnect(this.origin)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
this.addError(err)
|
||||
this.write500()
|
||||
this.write500(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -68,16 +51,14 @@ func (this *HTTPRequest) doWebsocket() {
|
||||
err = this.RawReq.Write(originConn)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
this.addError(err)
|
||||
this.write500()
|
||||
this.write500(err)
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, _, err := this.writer.Hijack()
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
this.addError(err)
|
||||
this.write500()
|
||||
this.write500(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
|
||||
@@ -118,6 +118,9 @@ func (this *HTTPWriter) AddHeaders(header http.Header) {
|
||||
return
|
||||
}
|
||||
for key, value := range header {
|
||||
if key == "Connection" {
|
||||
continue
|
||||
}
|
||||
for _, v := range value {
|
||||
this.writer.Header().Add(key, v)
|
||||
}
|
||||
|
||||
26
internal/utils/time.go
Normal file
26
internal/utils/time.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var unixTime = time.Now().Unix()
|
||||
var unixTimerIsReady = false
|
||||
|
||||
func init() {
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
unixTimerIsReady = true
|
||||
unixTime = time.Now().Unix()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 最快获取时间戳的方式,通常用在不需要特别精确时间戳的场景
|
||||
func UnixTime() int64 {
|
||||
if unixTimerIsReady {
|
||||
return unixTime
|
||||
}
|
||||
return time.Now().Unix()
|
||||
}
|
||||
13
internal/utils/time_test.go
Normal file
13
internal/utils/time_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUnixTime(t *testing.T) {
|
||||
for i := 0; i < 5; i++ {
|
||||
t.Log(UnixTime(), "real:", time.Now().Unix())
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user