Files
EdgeNode/internal/nodes/http_request_websocket.go

189 lines
4.5 KiB
Go
Raw Normal View History

2020-09-26 19:54:26 +08:00
package nodes
import (
2022-09-23 14:21:53 +08:00
"bufio"
"bytes"
"errors"
2021-12-19 11:32:26 +08:00
"github.com/TeaOSLab/EdgeNode/internal/utils"
2020-09-26 19:54:26 +08:00
"io"
"net/http"
"net/url"
)
// WebsocketResponseReader Websocket响应Reader
type WebsocketResponseReader struct {
rawReader io.Reader
buf []byte
}
func NewWebsocketResponseReader(rawReader io.Reader) *WebsocketResponseReader {
return &WebsocketResponseReader{
rawReader: rawReader,
}
}
func (this *WebsocketResponseReader) Read(p []byte) (n int, err error) {
n, err = this.rawReader.Read(p)
if n > 0 {
if len(this.buf) == 0 {
this.buf = make([]byte, n)
copy(this.buf, p[:n])
} else {
this.buf = append(this.buf, p[:n]...)
}
}
return
}
2020-09-26 19:54:26 +08:00
// 处理Websocket请求
2022-09-16 09:37:49 +08:00
func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shouldRetry bool) {
// 设置不缓存
this.web.Cache = nil
2020-09-26 19:54:26 +08:00
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
this.writer.WriteHeader(http.StatusForbidden)
this.addError(errors.New("websocket have not been enabled yet"))
2020-09-26 19:54:26 +08:00
return
}
// TODO 实现handshakeTimeout
2020-09-26 19:54:26 +08:00
// 校验来源
var requestOrigin = this.RawReq.Header.Get("Origin")
2020-09-26 19:54:26 +08:00
if len(requestOrigin) > 0 {
u, err := url.Parse(requestOrigin)
if err == nil {
if !this.web.Websocket.MatchOrigin(u.Host) {
this.writer.WriteHeader(http.StatusForbidden)
this.addError(errors.New("websocket origin '" + requestOrigin + "' not been allowed"))
2020-09-26 19:54:26 +08:00
return
}
}
}
// 设置指定的来源域
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
var newRequestOrigin = this.web.Websocket.RequestOrigin
2020-09-26 19:54:26 +08:00
if this.web.Websocket.RequestOriginHasVariables() {
newRequestOrigin = this.Format(newRequestOrigin)
}
this.RawReq.Header.Set("Origin", newRequestOrigin)
}
// TODO 增加N次错误重试重试的时候需要尝试不同的源站
2022-06-29 21:58:41 +08:00
originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost)
2020-09-26 19:54:26 +08:00
if err != nil {
2022-09-16 09:37:49 +08:00
if isLastRetry {
this.write50x(err, http.StatusBadGateway, "Failed to connect origin site", "源站连接失败", false)
}
// 增加失败次数
2022-06-27 12:01:33 +08:00
SharedOriginStateManager.Fail(this.origin, requestHost, this.reverseProxy, func() {
this.reverseProxy.ResetScheduling()
})
2022-09-16 09:37:49 +08:00
shouldRetry = true
2020-09-26 19:54:26 +08:00
return
}
if !this.origin.IsOk {
SharedOriginStateManager.Success(this.origin, func() {
this.reverseProxy.ResetScheduling()
})
}
2020-09-26 19:54:26 +08:00
defer func() {
_ = originConn.Close()
}()
err = this.RawReq.Write(originConn)
if err != nil {
this.write50x(err, http.StatusBadGateway, "Failed to write request to origin site", "源站请求初始化失败", false)
2020-09-26 19:54:26 +08:00
return
}
clientConn, _, err := this.writer.Hijack()
2021-10-25 19:42:12 +08:00
if err != nil || clientConn == nil {
this.write50x(err, http.StatusInternalServerError, "Failed to get origin site connection", "获取源站连接失败", false)
2020-09-26 19:54:26 +08:00
return
}
defer func() {
_ = clientConn.Close()
}()
go func() {
2022-09-23 14:21:53 +08:00
// 读取第一个响应
var respReader = NewWebsocketResponseReader(originConn)
resp, err := http.ReadResponse(bufio.NewReader(respReader), this.RawReq)
if err != nil || resp == nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
2022-09-23 14:21:53 +08:00
_ = clientConn.Close()
_ = originConn.Close()
return
}
this.processResponseHeaders(resp.Header, resp.StatusCode)
this.writer.statusCode = resp.StatusCode
2022-09-23 14:21:53 +08:00
// 将响应写回客户端
err = resp.Write(clientConn)
if err != nil {
if resp.Body != nil {
_ = resp.Body.Close()
}
_ = clientConn.Close()
_ = originConn.Close()
return
}
// 剩余已经从源站读取的内容
var headerBytes = respReader.buf
var headerIndex = bytes.Index(headerBytes, []byte{'\r', '\n', '\r', '\n'}) // CRLF
if headerIndex > 0 {
var leftBytes = headerBytes[headerIndex+4:]
if len(leftBytes) > 0 {
_, err = clientConn.Write(leftBytes)
if err != nil {
if resp.Body != nil {
_ = resp.Body.Close()
}
_ = clientConn.Close()
_ = originConn.Close()
return
}
}
}
2022-09-23 14:21:53 +08:00
if resp.Body != nil {
_ = resp.Body.Close()
}
// 复制剩余的数据
2021-12-19 11:32:26 +08:00
var buf = utils.BytePool4k.Get()
defer utils.BytePool4k.Put(buf)
for {
n, err := originConn.Read(buf)
if n > 0 {
this.writer.sentBodyBytes += int64(n)
_, err = clientConn.Write(buf[:n])
if err != nil {
break
}
}
if err != nil {
break
}
}
2020-09-26 19:54:26 +08:00
_ = clientConn.Close()
_ = originConn.Close()
}()
2020-09-26 19:54:26 +08:00
_, _ = io.Copy(originConn, clientConn)
2022-09-16 09:37:49 +08:00
return
2020-09-26 19:54:26 +08:00
}