mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-03 23:20:25 +08:00
实现websocket基本功能
This commit is contained in:
@@ -7,7 +7,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -42,15 +41,17 @@ type HTTPRequest struct {
|
||||
|
||||
// 内部参数
|
||||
writer *HTTPWriter
|
||||
web *serverconfigs.HTTPWebConfig
|
||||
rawURI string // 原始的URI
|
||||
uri string // 经过rewrite等运算之后的URI
|
||||
varMapping map[string]string // 变量集合
|
||||
requestFromTime time.Time // 请求开始时间
|
||||
requestCost float64 // 请求耗时
|
||||
filePath string // 请求的文件名,仅在读取Root目录下的内容时不为空
|
||||
origin *serverconfigs.OriginConfig // 源站
|
||||
errors []string // 错误信息
|
||||
web *serverconfigs.HTTPWebConfig // Web配置,重要提示:由于引用了别的共享的配置,所以操作中只能读取不要修改
|
||||
reverseProxyRef *serverconfigs.ReverseProxyRef // 反向代理引用
|
||||
reverseProxy *serverconfigs.ReverseProxyConfig // 反向代理配置,重要提示:由于引用了别的共享的配置,所以操作中只能读取不要修改
|
||||
rawURI string // 原始的URI
|
||||
uri string // 经过rewrite等运算之后的URI
|
||||
varMapping map[string]string // 变量集合
|
||||
requestFromTime time.Time // 请求开始时间
|
||||
requestCost float64 // 请求耗时
|
||||
filePath string // 请求的文件名,仅在读取Root目录下的内容时不为空
|
||||
origin *serverconfigs.OriginConfig // 源站
|
||||
errors []string // 错误信息
|
||||
}
|
||||
|
||||
// 初始化
|
||||
@@ -68,7 +69,13 @@ func (this *HTTPRequest) Do() {
|
||||
// 初始化
|
||||
this.init()
|
||||
|
||||
// 配置
|
||||
// 当前服务的反向代理配置
|
||||
if this.Server.ReverseProxyRef != nil && this.Server.ReverseProxy != nil {
|
||||
this.reverseProxyRef = this.Server.ReverseProxyRef
|
||||
this.reverseProxy = this.Server.ReverseProxy
|
||||
}
|
||||
|
||||
// Web配置
|
||||
err := this.configureWeb(this.Server.Web, true, 0)
|
||||
if err != nil {
|
||||
this.write500()
|
||||
@@ -99,7 +106,7 @@ func (this *HTTPRequest) Do() {
|
||||
// 开始调用
|
||||
func (this *HTTPRequest) doBegin() {
|
||||
// 重写规则
|
||||
// TODO
|
||||
// TODO 需要实现
|
||||
|
||||
// 临时关闭页面
|
||||
if this.web.Shutdown != nil && this.web.Shutdown.IsOn {
|
||||
@@ -107,11 +114,10 @@ func (this *HTTPRequest) doBegin() {
|
||||
return
|
||||
}
|
||||
|
||||
// 缓存
|
||||
// TODO
|
||||
|
||||
// root
|
||||
// TODO 从本地文件中读取
|
||||
// TODO 增加stripPrefix
|
||||
// TODO 增加URLEncode的处理方式
|
||||
// TODO ROOT支持变量
|
||||
if this.web.Root != nil && this.web.Root.IsOn {
|
||||
// 如果处理成功,则终止请求的处理
|
||||
if this.doRoot() {
|
||||
@@ -124,12 +130,11 @@ func (this *HTTPRequest) doBegin() {
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse
|
||||
// TODO
|
||||
logs.Println("reverse proxy")
|
||||
|
||||
// WebSocket
|
||||
// TODO
|
||||
// Reverse Proxy
|
||||
if this.reverseProxyRef != nil && this.reverseProxyRef.IsOn && this.reverseProxy != nil && this.reverseProxy.IsOn {
|
||||
this.doReverseProxy()
|
||||
return
|
||||
}
|
||||
|
||||
// Fastcgi
|
||||
// TODO
|
||||
@@ -210,6 +215,12 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
|
||||
this.web.Charset = web.Charset
|
||||
}
|
||||
|
||||
// websocket
|
||||
if web.WebsocketRef != nil && (web.WebsocketRef.IsPrior || isTop) {
|
||||
this.web.WebsocketRef = web.WebsocketRef
|
||||
this.web.Websocket = web.Websocket
|
||||
}
|
||||
|
||||
// locations
|
||||
if len(web.LocationRefs) > 0 {
|
||||
var resultLocation *serverconfigs.HTTPLocationConfig
|
||||
@@ -232,10 +243,19 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
|
||||
}
|
||||
}
|
||||
}
|
||||
if resultLocation != nil && resultLocation.Web != nil {
|
||||
err := this.configureWeb(resultLocation.Web, false, redirects)
|
||||
if err != nil {
|
||||
return err
|
||||
if resultLocation != nil {
|
||||
// Reverse Proxy
|
||||
if resultLocation.ReverseProxyRef != nil && resultLocation.ReverseProxyRef.IsPrior {
|
||||
this.reverseProxyRef = resultLocation.ReverseProxyRef
|
||||
this.reverseProxy = resultLocation.ReverseProxy
|
||||
}
|
||||
|
||||
// Web
|
||||
if resultLocation.Web != nil {
|
||||
err := this.configureWeb(resultLocation.Web, false, redirects)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -707,6 +727,7 @@ func (this *HTTPRequest) setForwardHeaders(header http.Header) {
|
||||
}
|
||||
|
||||
// 处理自定义Request Header
|
||||
// TODO 处理一些被Golang转换了的Header,比如Websocket
|
||||
func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
|
||||
if this.web.RequestHeaderPolicy != nil && this.web.RequestHeaderPolicy.IsOn {
|
||||
// 删除某些Header
|
||||
|
||||
14
internal/nodes/http_request_reverse_proxy.go
Normal file
14
internal/nodes/http_request_reverse_proxy.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package nodes
|
||||
|
||||
// 处理反向代理
|
||||
func (this *HTTPRequest) doReverseProxy() {
|
||||
// 判断是否为Websocket请求
|
||||
if this.RawReq.Header.Get("Upgrade") == "websocket" {
|
||||
this.doWebsocket()
|
||||
return
|
||||
}
|
||||
|
||||
// 普通HTTP请求
|
||||
// TODO
|
||||
}
|
||||
|
||||
@@ -48,6 +48,9 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
}
|
||||
|
||||
rootDir := this.web.Root.Dir
|
||||
if this.web.Root.HasVariables() {
|
||||
rootDir = this.Format(rootDir)
|
||||
}
|
||||
if !filepath.IsAbs(rootDir) {
|
||||
rootDir = Tea.Root + Tea.DS + rootDir
|
||||
}
|
||||
@@ -149,22 +152,24 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
|
||||
respHeader := this.writer.Header()
|
||||
|
||||
// mime type
|
||||
if !(this.web.ResponseHeaderPolicy != nil && this.web.ResponseHeaderPolicy.IsOn && this.web.ResponseHeaderPolicy.ContainsHeader("CONTENT-TYPE")) {
|
||||
if this.web.ResponseHeaderPolicy == nil || !this.web.ResponseHeaderPolicy.IsOn || !this.web.ResponseHeaderPolicy.ContainsHeader("CONTENT-TYPE") {
|
||||
ext := filepath.Ext(requestPath)
|
||||
if len(ext) > 0 {
|
||||
mimeType := mime.TypeByExtension(ext)
|
||||
if len(mimeType) > 0 {
|
||||
if _, found := textMimeMap[mimeType]; found {
|
||||
semicolonIndex := strings.Index(mimeType, ";")
|
||||
mimeTypeKey := mimeType
|
||||
if semicolonIndex > 0 {
|
||||
mimeTypeKey = mimeType[:semicolonIndex]
|
||||
}
|
||||
|
||||
if _, found := textMimeMap[mimeTypeKey]; found {
|
||||
if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
|
||||
charset := this.web.Charset.Charset
|
||||
|
||||
// 去掉里面的charset设置
|
||||
index := strings.Index(mimeType, "charset=")
|
||||
if index > 0 {
|
||||
respHeader.Set("Content-Type", mimeType[:index+len("charset=")]+charset)
|
||||
} else {
|
||||
respHeader.Set("Content-Type", mimeType+"; charset="+charset)
|
||||
if this.web.Charset.IsUpper {
|
||||
charset = strings.ToUpper(charset)
|
||||
}
|
||||
respHeader.Set("Content-Type", mimeTypeKey+"; charset="+charset)
|
||||
} else {
|
||||
respHeader.Set("Content-Type", mimeType)
|
||||
}
|
||||
|
||||
93
internal/nodes/http_request_websocket.go
Normal file
93
internal/nodes/http_request_websocket.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// 处理Websocket请求
|
||||
func (this *HTTPRequest) doWebsocket() {
|
||||
if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// 校验来源
|
||||
requestOrigin := this.RawReq.Header.Get("Origin")
|
||||
if len(requestOrigin) > 0 {
|
||||
u, err := url.Parse(requestOrigin)
|
||||
if err == nil {
|
||||
if !this.web.Websocket.MatchOrigin(u.Host) {
|
||||
this.writer.WriteHeader(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
requestCall := shared.NewRequestCall()
|
||||
origin := this.reverseProxy.NextOrigin(requestCall)
|
||||
if origin == nil {
|
||||
err := errors.New(this.requestPath() + ": no available backends for websocket")
|
||||
logs.Error(err)
|
||||
this.addError(err)
|
||||
this.write500()
|
||||
return
|
||||
}
|
||||
|
||||
this.processRequestHeaders(this.RawReq.Header)
|
||||
|
||||
// 设置指定的来源域
|
||||
if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
|
||||
newRequestOrigin := this.web.Websocket.RequestOrigin
|
||||
if this.web.Websocket.RequestOriginHasVariables() {
|
||||
newRequestOrigin = this.Format(newRequestOrigin)
|
||||
}
|
||||
this.RawReq.Header.Set("Origin", newRequestOrigin)
|
||||
}
|
||||
|
||||
// TODO 修改RequestURI
|
||||
// TODO 实现handshakeTimeout
|
||||
// TODO 修改 Websocket- 为 WebSocket-
|
||||
|
||||
// TODO 增加N次错误重试
|
||||
originConn, err := OriginConnect(origin)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
this.addError(err)
|
||||
this.write500()
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = originConn.Close()
|
||||
}()
|
||||
|
||||
err = this.RawReq.Write(originConn)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
this.addError(err)
|
||||
this.write500()
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, _, err := this.writer.Hijack()
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
this.addError(err)
|
||||
this.write500()
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = clientConn.Close()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(clientConn, originConn)
|
||||
_ = clientConn.Close()
|
||||
_ = originConn.Close()
|
||||
}()
|
||||
_, _ = io.Copy(originConn, clientConn)
|
||||
}
|
||||
@@ -100,7 +100,7 @@ func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC
|
||||
if origin == nil {
|
||||
continue
|
||||
}
|
||||
conn, err = origin.Connect()
|
||||
conn, err = OriginConnect(origin)
|
||||
if err != nil {
|
||||
logs.Println("[TCP_LISTENER]unable to connect origin: " + origin.Addr.Host + ":" + origin.Addr.PortRange + ": " + err.Error())
|
||||
continue
|
||||
|
||||
33
internal/nodes/origin_utils.go
Normal file
33
internal/nodes/origin_utils.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"net"
|
||||
)
|
||||
|
||||
// 连接源站
|
||||
func OriginConnect(origin *serverconfigs.OriginConfig) (net.Conn, error) {
|
||||
if origin.Addr == nil {
|
||||
return nil, errors.New("origin server address should not be empty")
|
||||
}
|
||||
|
||||
switch origin.Addr.Protocol {
|
||||
case "", serverconfigs.ProtocolTCP, serverconfigs.ProtocolHTTP:
|
||||
// TODO 支持TCP4/TCP6
|
||||
// TODO 支持指定特定网卡
|
||||
// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
|
||||
return net.DialTimeout("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, origin.ConnTimeoutDuration())
|
||||
case serverconfigs.ProtocolTLS, serverconfigs.ProtocolHTTPS:
|
||||
// TODO 支持TCP4/TCP6
|
||||
// TODO 支持指定特定网卡
|
||||
// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
|
||||
// TODO 支持使用证书
|
||||
return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, &tls.Config{})
|
||||
}
|
||||
|
||||
// TODO 支持从Unix、Pipe、HTTP、HTTPS中读取数据
|
||||
|
||||
return nil, errors.New("invalid scheme '" + origin.Addr.Protocol.String() + "'")
|
||||
}
|
||||
Reference in New Issue
Block a user