修复Websocket无法正常交互的问题

This commit is contained in:
刘祥超
2022-09-30 16:34:21 +08:00
parent 5742dfb263
commit 92a20e3c9a
3 changed files with 111 additions and 1 deletions

View File

@@ -2,6 +2,7 @@ package nodes
import (
"bufio"
"bytes"
"errors"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"io"
@@ -9,8 +10,36 @@ import (
"net/url"
)
// WebsocketResponseReader Websocket响应Reader
type WebsocketResponseReader struct {
rawReader io.Reader
buf []byte
}
func NewWebsocketResponseReader(rawReader io.Reader) *WebsocketResponseReader {
return &WebsocketResponseReader{
rawReader: rawReader,
}
}
func (this *WebsocketResponseReader) Read(p []byte) (n int, err error) {
n, err = this.rawReader.Read(p)
if n > 0 {
if len(this.buf) == 0 {
this.buf = make([]byte, n)
copy(this.buf, p[:n])
} else {
this.buf = append(this.buf, p[:n]...)
}
}
return
}
// 处理Websocket请求
func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shouldRetry bool) {
// 设置不缓存
this.web.Cache = nil
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
this.writer.WriteHeader(http.StatusForbidden)
this.addError(errors.New("websocket have not been enabled yet"))
@@ -84,14 +113,20 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
go func() {
// 读取第一个响应
resp, err := http.ReadResponse(bufio.NewReader(originConn), this.RawReq)
var respReader = NewWebsocketResponseReader(originConn)
resp, err := http.ReadResponse(bufio.NewReader(respReader), this.RawReq)
if err != nil {
if resp.Body != nil {
_ = resp.Body.Close()
}
_ = clientConn.Close()
_ = originConn.Close()
return
}
this.processResponseHeaders(resp.Header, resp.StatusCode)
this.writer.statusCode = resp.StatusCode
// 将响应写回客户端
err = resp.Write(clientConn)
@@ -105,6 +140,25 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
return
}
// 剩余已经从源站读取的内容
var headerBytes = respReader.buf
var headerIndex = bytes.Index(headerBytes, []byte{'\r', '\n', '\r', '\n'}) // CRLF
if headerIndex > 0 {
var leftBytes = headerBytes[headerIndex+4:]
if len(leftBytes) > 0 {
_, err = clientConn.Write(leftBytes)
if err != nil {
if resp.Body != nil {
_ = resp.Body.Close()
}
_ = clientConn.Close()
_ = originConn.Close()
return
}
}
}
if resp.Body != nil {
_ = resp.Body.Close()
}