Files
EdgeNode/internal/nodes/http_request_websocket.go

78 lines
1.8 KiB
Go
Raw Normal View History

2020-09-26 19:54:26 +08:00
package nodes
import (
"errors"
2020-09-26 19:54:26 +08:00
"github.com/iwind/TeaGo/logs"
"io"
"net/http"
"net/url"
)
// 处理Websocket请求
func (this *HTTPRequest) doWebsocket() {
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
this.writer.WriteHeader(http.StatusForbidden)
this.addError(errors.New("websocket have not been enabled yet"))
2020-09-26 19:54:26 +08:00
return
}
// TODO 实现handshakeTimeout
2020-09-26 19:54:26 +08:00
// 校验来源
requestOrigin := this.RawReq.Header.Get("Origin")
if len(requestOrigin) > 0 {
u, err := url.Parse(requestOrigin)
if err == nil {
if !this.web.Websocket.MatchOrigin(u.Host) {
this.writer.WriteHeader(http.StatusForbidden)
this.addError(errors.New("websocket origin '" + requestOrigin + "' not been allowed"))
2020-09-26 19:54:26 +08:00
return
}
}
}
// 设置指定的来源域
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
newRequestOrigin := this.web.Websocket.RequestOrigin
if this.web.Websocket.RequestOriginHasVariables() {
newRequestOrigin = this.Format(newRequestOrigin)
}
this.RawReq.Header.Set("Origin", newRequestOrigin)
}
// TODO 增加N次错误重试重试的时候需要尝试不同的源站
2020-12-03 10:17:28 +08:00
originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr)
2020-09-26 19:54:26 +08:00
if err != nil {
logs.Error(err)
2020-09-27 15:26:06 +08:00
this.write500(err)
2020-09-26 19:54:26 +08:00
return
}
defer func() {
_ = originConn.Close()
}()
err = this.RawReq.Write(originConn)
if err != nil {
logs.Error(err)
2020-09-27 15:26:06 +08:00
this.write500(err)
2020-09-26 19:54:26 +08:00
return
}
clientConn, _, err := this.writer.Hijack()
if err != nil {
logs.Error(err)
2020-09-27 15:26:06 +08:00
this.write500(err)
2020-09-26 19:54:26 +08:00
return
}
defer func() {
_ = clientConn.Close()
}()
go func() {
_, _ = io.Copy(clientConn, originConn)
_ = clientConn.Close()
_ = originConn.Close()
}()
_, _ = io.Copy(originConn, clientConn)
}