mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-03 23:20:25 +08:00
206 lines
4.8 KiB
Go
206 lines
4.8 KiB
Go
package nodes
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"errors"
|
|
"github.com/TeaOSLab/EdgeNode/internal/utils/bytepool"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
)
|
|
|
|
// WebsocketResponseReader Websocket响应Reader
|
|
type WebsocketResponseReader struct {
|
|
rawReader io.Reader
|
|
buf []byte
|
|
}
|
|
|
|
func NewWebsocketResponseReader(rawReader io.Reader) *WebsocketResponseReader {
|
|
return &WebsocketResponseReader{
|
|
rawReader: rawReader,
|
|
}
|
|
}
|
|
|
|
func (this *WebsocketResponseReader) Read(p []byte) (n int, err error) {
|
|
n, err = this.rawReader.Read(p)
|
|
if n > 0 {
|
|
if len(this.buf) == 0 {
|
|
this.buf = make([]byte, n)
|
|
copy(this.buf, p[:n])
|
|
} else {
|
|
this.buf = append(this.buf, p[:n]...)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// 处理Websocket请求
|
|
func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shouldRetry bool) {
|
|
// 设置不缓存
|
|
this.web.Cache = nil
|
|
|
|
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
|
|
this.writer.WriteHeader(http.StatusForbidden)
|
|
this.addError(errors.New("websocket have not been enabled yet"))
|
|
return
|
|
}
|
|
|
|
// TODO 实现handshakeTimeout
|
|
|
|
// 校验来源
|
|
var requestOrigin = this.RawReq.Header.Get("Origin")
|
|
if len(requestOrigin) > 0 {
|
|
u, err := url.Parse(requestOrigin)
|
|
if err == nil {
|
|
if !this.web.Websocket.MatchOrigin(u.Host) {
|
|
this.writer.WriteHeader(http.StatusForbidden)
|
|
this.addError(errors.New("websocket origin '" + requestOrigin + "' not been allowed"))
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// 标记
|
|
this.isWebsocketResponse = true
|
|
|
|
// 设置指定的来源域
|
|
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
|
|
var newRequestOrigin = this.web.Websocket.RequestOrigin
|
|
if this.web.Websocket.RequestOriginHasVariables() {
|
|
newRequestOrigin = this.Format(newRequestOrigin)
|
|
}
|
|
this.RawReq.Header.Set("Origin", newRequestOrigin)
|
|
}
|
|
|
|
// 获取当前连接
|
|
var requestConn = this.RawReq.Context().Value(HTTPConnContextKey)
|
|
if requestConn == nil {
|
|
return
|
|
}
|
|
|
|
// 连接源站
|
|
originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
|
|
if err != nil {
|
|
if isLastRetry {
|
|
this.write50x(err, http.StatusBadGateway, "Failed to connect origin site", "源站连接失败", false)
|
|
}
|
|
|
|
// 增加失败次数
|
|
SharedOriginStateManager.Fail(this.origin, requestHost, this.reverseProxy, func() {
|
|
this.reverseProxy.ResetScheduling()
|
|
})
|
|
|
|
shouldRetry = true
|
|
return
|
|
}
|
|
|
|
if !this.origin.IsOk {
|
|
SharedOriginStateManager.Success(this.origin, func() {
|
|
this.reverseProxy.ResetScheduling()
|
|
})
|
|
}
|
|
|
|
defer func() {
|
|
_ = originConn.Close()
|
|
}()
|
|
|
|
err = this.RawReq.Write(originConn)
|
|
if err != nil {
|
|
this.write50x(err, http.StatusBadGateway, "Failed to write request to origin site", "源站请求初始化失败", false)
|
|
return
|
|
}
|
|
|
|
requestClientConn, ok := requestConn.(ClientConnInterface)
|
|
if ok {
|
|
requestClientConn.SetIsPersistent(true)
|
|
}
|
|
|
|
clientConn, _, err := this.writer.Hijack()
|
|
if err != nil || clientConn == nil {
|
|
this.write50x(err, http.StatusInternalServerError, "Failed to get origin site connection", "获取源站连接失败", false)
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = clientConn.Close()
|
|
}()
|
|
|
|
go func() {
|
|
// 读取第一个响应
|
|
var respReader = NewWebsocketResponseReader(originConn)
|
|
resp, respErr := http.ReadResponse(bufio.NewReader(respReader), this.RawReq)
|
|
if respErr != nil || resp == nil {
|
|
if resp != nil && resp.Body != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
|
|
_ = clientConn.Close()
|
|
_ = originConn.Close()
|
|
return
|
|
}
|
|
|
|
this.ProcessResponseHeaders(resp.Header, resp.StatusCode)
|
|
this.writer.statusCode = resp.StatusCode
|
|
|
|
// 将响应写回客户端
|
|
err = resp.Write(clientConn)
|
|
if err != nil {
|
|
if resp.Body != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
|
|
_ = clientConn.Close()
|
|
_ = originConn.Close()
|
|
return
|
|
}
|
|
|
|
// 剩余已经从源站读取的内容
|
|
var headerBytes = respReader.buf
|
|
var headerIndex = bytes.Index(headerBytes, []byte{'\r', '\n', '\r', '\n'}) // CRLF
|
|
if headerIndex > 0 {
|
|
var leftBytes = headerBytes[headerIndex+4:]
|
|
if len(leftBytes) > 0 {
|
|
_, writeErr := clientConn.Write(leftBytes)
|
|
if writeErr != nil {
|
|
if resp.Body != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
|
|
_ = clientConn.Close()
|
|
_ = originConn.Close()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
if resp.Body != nil {
|
|
_ = resp.Body.Close()
|
|
}
|
|
|
|
// 复制剩余的数据
|
|
var buf = bytepool.Pool4k.Get()
|
|
defer bytepool.Pool4k.Put(buf)
|
|
for {
|
|
n, readErr := originConn.Read(buf.Bytes)
|
|
if n > 0 {
|
|
this.writer.sentBodyBytes += int64(n)
|
|
_, writeErr := clientConn.Write(buf.Bytes[:n])
|
|
if writeErr != nil {
|
|
break
|
|
}
|
|
}
|
|
if readErr != nil {
|
|
break
|
|
}
|
|
}
|
|
_ = clientConn.Close()
|
|
_ = originConn.Close()
|
|
}()
|
|
|
|
var buf = bytepool.Pool4k.Get()
|
|
_, _ = io.CopyBuffer(originConn, clientConn, buf.Bytes)
|
|
bytepool.Pool4k.Put(buf)
|
|
|
|
return
|
|
}
|