mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 16:00:25 +08:00 
			
		
		
		
	实现websocket基本功能
This commit is contained in:
		@@ -7,7 +7,6 @@ import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
 | 
			
		||||
	teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
			
		||||
	"github.com/iwind/TeaGo/logs"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -42,15 +41,17 @@ type HTTPRequest struct {
 | 
			
		||||
 | 
			
		||||
	// 内部参数
 | 
			
		||||
	writer          *HTTPWriter
 | 
			
		||||
	web             *serverconfigs.HTTPWebConfig
 | 
			
		||||
	rawURI          string                      // 原始的URI
 | 
			
		||||
	uri             string                      // 经过rewrite等运算之后的URI
 | 
			
		||||
	varMapping      map[string]string           // 变量集合
 | 
			
		||||
	requestFromTime time.Time                   // 请求开始时间
 | 
			
		||||
	requestCost     float64                     // 请求耗时
 | 
			
		||||
	filePath        string                      // 请求的文件名,仅在读取Root目录下的内容时不为空
 | 
			
		||||
	origin          *serverconfigs.OriginConfig // 源站
 | 
			
		||||
	errors          []string                    // 错误信息
 | 
			
		||||
	web             *serverconfigs.HTTPWebConfig      // Web配置,重要提示:由于引用了别的共享的配置,所以操作中只能读取不要修改
 | 
			
		||||
	reverseProxyRef *serverconfigs.ReverseProxyRef    // 反向代理引用
 | 
			
		||||
	reverseProxy    *serverconfigs.ReverseProxyConfig // 反向代理配置,重要提示:由于引用了别的共享的配置,所以操作中只能读取不要修改
 | 
			
		||||
	rawURI          string                            // 原始的URI
 | 
			
		||||
	uri             string                            // 经过rewrite等运算之后的URI
 | 
			
		||||
	varMapping      map[string]string                 // 变量集合
 | 
			
		||||
	requestFromTime time.Time                         // 请求开始时间
 | 
			
		||||
	requestCost     float64                           // 请求耗时
 | 
			
		||||
	filePath        string                            // 请求的文件名,仅在读取Root目录下的内容时不为空
 | 
			
		||||
	origin          *serverconfigs.OriginConfig       // 源站
 | 
			
		||||
	errors          []string                          // 错误信息
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 初始化
 | 
			
		||||
@@ -68,7 +69,13 @@ func (this *HTTPRequest) Do() {
 | 
			
		||||
	// 初始化
 | 
			
		||||
	this.init()
 | 
			
		||||
 | 
			
		||||
	// 配置
 | 
			
		||||
	// 当前服务的反向代理配置
 | 
			
		||||
	if this.Server.ReverseProxyRef != nil && this.Server.ReverseProxy != nil {
 | 
			
		||||
		this.reverseProxyRef = this.Server.ReverseProxyRef
 | 
			
		||||
		this.reverseProxy = this.Server.ReverseProxy
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Web配置
 | 
			
		||||
	err := this.configureWeb(this.Server.Web, true, 0)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		this.write500()
 | 
			
		||||
@@ -99,7 +106,7 @@ func (this *HTTPRequest) Do() {
 | 
			
		||||
// 开始调用
 | 
			
		||||
func (this *HTTPRequest) doBegin() {
 | 
			
		||||
	// 重写规则
 | 
			
		||||
	// TODO
 | 
			
		||||
	// TODO 需要实现
 | 
			
		||||
 | 
			
		||||
	// 临时关闭页面
 | 
			
		||||
	if this.web.Shutdown != nil && this.web.Shutdown.IsOn {
 | 
			
		||||
@@ -107,11 +114,10 @@ func (this *HTTPRequest) doBegin() {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 缓存
 | 
			
		||||
	// TODO
 | 
			
		||||
 | 
			
		||||
	// root
 | 
			
		||||
	// TODO 从本地文件中读取
 | 
			
		||||
	// TODO 增加stripPrefix
 | 
			
		||||
	// TODO 增加URLEncode的处理方式
 | 
			
		||||
	// TODO ROOT支持变量
 | 
			
		||||
	if this.web.Root != nil && this.web.Root.IsOn {
 | 
			
		||||
		// 如果处理成功,则终止请求的处理
 | 
			
		||||
		if this.doRoot() {
 | 
			
		||||
@@ -124,12 +130,11 @@ func (this *HTTPRequest) doBegin() {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Reverse
 | 
			
		||||
	// TODO
 | 
			
		||||
	logs.Println("reverse proxy")
 | 
			
		||||
 | 
			
		||||
	// WebSocket
 | 
			
		||||
	// TODO
 | 
			
		||||
	// Reverse Proxy
 | 
			
		||||
	if this.reverseProxyRef != nil && this.reverseProxyRef.IsOn && this.reverseProxy != nil && this.reverseProxy.IsOn {
 | 
			
		||||
		this.doReverseProxy()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Fastcgi
 | 
			
		||||
	// TODO
 | 
			
		||||
@@ -210,6 +215,12 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
 | 
			
		||||
		this.web.Charset = web.Charset
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// websocket
 | 
			
		||||
	if web.WebsocketRef != nil && (web.WebsocketRef.IsPrior || isTop) {
 | 
			
		||||
		this.web.WebsocketRef = web.WebsocketRef
 | 
			
		||||
		this.web.Websocket = web.Websocket
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// locations
 | 
			
		||||
	if len(web.LocationRefs) > 0 {
 | 
			
		||||
		var resultLocation *serverconfigs.HTTPLocationConfig
 | 
			
		||||
@@ -232,10 +243,19 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if resultLocation != nil && resultLocation.Web != nil {
 | 
			
		||||
			err := this.configureWeb(resultLocation.Web, false, redirects)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
		if resultLocation != nil {
 | 
			
		||||
			// Reverse Proxy
 | 
			
		||||
			if resultLocation.ReverseProxyRef != nil && resultLocation.ReverseProxyRef.IsPrior {
 | 
			
		||||
				this.reverseProxyRef = resultLocation.ReverseProxyRef
 | 
			
		||||
				this.reverseProxy = resultLocation.ReverseProxy
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Web
 | 
			
		||||
			if resultLocation.Web != nil {
 | 
			
		||||
				err := this.configureWeb(resultLocation.Web, false, redirects)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return err
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@@ -707,6 +727,7 @@ func (this *HTTPRequest) setForwardHeaders(header http.Header) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 处理自定义Request Header
 | 
			
		||||
// TODO 处理一些被Golang转换了的Header,比如Websocket
 | 
			
		||||
func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
 | 
			
		||||
	if this.web.RequestHeaderPolicy != nil && this.web.RequestHeaderPolicy.IsOn {
 | 
			
		||||
		// 删除某些Header
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										14
									
								
								internal/nodes/http_request_reverse_proxy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								internal/nodes/http_request_reverse_proxy.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,14 @@
 | 
			
		||||
package nodes
 | 
			
		||||
 | 
			
		||||
// 处理反向代理
 | 
			
		||||
func (this *HTTPRequest) doReverseProxy() {
 | 
			
		||||
	// 判断是否为Websocket请求
 | 
			
		||||
	if this.RawReq.Header.Get("Upgrade") == "websocket" {
 | 
			
		||||
		this.doWebsocket()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 普通HTTP请求
 | 
			
		||||
	// TODO
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -48,6 +48,9 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rootDir := this.web.Root.Dir
 | 
			
		||||
	if this.web.Root.HasVariables() {
 | 
			
		||||
		rootDir = this.Format(rootDir)
 | 
			
		||||
	}
 | 
			
		||||
	if !filepath.IsAbs(rootDir) {
 | 
			
		||||
		rootDir = Tea.Root + Tea.DS + rootDir
 | 
			
		||||
	}
 | 
			
		||||
@@ -149,22 +152,24 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
 | 
			
		||||
	respHeader := this.writer.Header()
 | 
			
		||||
 | 
			
		||||
	// mime type
 | 
			
		||||
	if !(this.web.ResponseHeaderPolicy != nil && this.web.ResponseHeaderPolicy.IsOn && this.web.ResponseHeaderPolicy.ContainsHeader("CONTENT-TYPE")) {
 | 
			
		||||
	if this.web.ResponseHeaderPolicy == nil || !this.web.ResponseHeaderPolicy.IsOn || !this.web.ResponseHeaderPolicy.ContainsHeader("CONTENT-TYPE") {
 | 
			
		||||
		ext := filepath.Ext(requestPath)
 | 
			
		||||
		if len(ext) > 0 {
 | 
			
		||||
			mimeType := mime.TypeByExtension(ext)
 | 
			
		||||
			if len(mimeType) > 0 {
 | 
			
		||||
				if _, found := textMimeMap[mimeType]; found {
 | 
			
		||||
				semicolonIndex := strings.Index(mimeType, ";")
 | 
			
		||||
				mimeTypeKey := mimeType
 | 
			
		||||
				if semicolonIndex > 0 {
 | 
			
		||||
					mimeTypeKey = mimeType[:semicolonIndex]
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if _, found := textMimeMap[mimeTypeKey]; found {
 | 
			
		||||
					if this.web.Charset != nil && this.web.Charset.IsOn && len(this.web.Charset.Charset) > 0 {
 | 
			
		||||
						charset := this.web.Charset.Charset
 | 
			
		||||
 | 
			
		||||
						// 去掉里面的charset设置
 | 
			
		||||
						index := strings.Index(mimeType, "charset=")
 | 
			
		||||
						if index > 0 {
 | 
			
		||||
							respHeader.Set("Content-Type", mimeType[:index+len("charset=")]+charset)
 | 
			
		||||
						} else {
 | 
			
		||||
							respHeader.Set("Content-Type", mimeType+"; charset="+charset)
 | 
			
		||||
						if this.web.Charset.IsUpper {
 | 
			
		||||
							charset = strings.ToUpper(charset)
 | 
			
		||||
						}
 | 
			
		||||
						respHeader.Set("Content-Type", mimeTypeKey+"; charset="+charset)
 | 
			
		||||
					} else {
 | 
			
		||||
						respHeader.Set("Content-Type", mimeType)
 | 
			
		||||
					}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										93
									
								
								internal/nodes/http_request_websocket.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								internal/nodes/http_request_websocket.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,93 @@
 | 
			
		||||
package nodes
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
 | 
			
		||||
	"github.com/iwind/TeaGo/logs"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 处理Websocket请求
 | 
			
		||||
func (this *HTTPRequest) doWebsocket() {
 | 
			
		||||
	if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
 | 
			
		||||
		this.writer.WriteHeader(http.StatusForbidden)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 校验来源
 | 
			
		||||
	requestOrigin := this.RawReq.Header.Get("Origin")
 | 
			
		||||
	if len(requestOrigin) > 0 {
 | 
			
		||||
		u, err := url.Parse(requestOrigin)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			if !this.web.Websocket.MatchOrigin(u.Host) {
 | 
			
		||||
				this.writer.WriteHeader(http.StatusForbidden)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	requestCall := shared.NewRequestCall()
 | 
			
		||||
	origin := this.reverseProxy.NextOrigin(requestCall)
 | 
			
		||||
	if origin == nil {
 | 
			
		||||
		err := errors.New(this.requestPath() + ": no available backends for websocket")
 | 
			
		||||
		logs.Error(err)
 | 
			
		||||
		this.addError(err)
 | 
			
		||||
		this.write500()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.processRequestHeaders(this.RawReq.Header)
 | 
			
		||||
 | 
			
		||||
	// 设置指定的来源域
 | 
			
		||||
	if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
 | 
			
		||||
		newRequestOrigin := this.web.Websocket.RequestOrigin
 | 
			
		||||
		if this.web.Websocket.RequestOriginHasVariables() {
 | 
			
		||||
			newRequestOrigin = this.Format(newRequestOrigin)
 | 
			
		||||
		}
 | 
			
		||||
		this.RawReq.Header.Set("Origin", newRequestOrigin)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO 修改RequestURI
 | 
			
		||||
	// TODO 实现handshakeTimeout
 | 
			
		||||
	// TODO 修改 Websocket- 为 WebSocket-
 | 
			
		||||
 | 
			
		||||
	// TODO 增加N次错误重试
 | 
			
		||||
	originConn, err := OriginConnect(origin)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Error(err)
 | 
			
		||||
		this.addError(err)
 | 
			
		||||
		this.write500()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = originConn.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	err = this.RawReq.Write(originConn)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Error(err)
 | 
			
		||||
		this.addError(err)
 | 
			
		||||
		this.write500()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	clientConn, _, err := this.writer.Hijack()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logs.Error(err)
 | 
			
		||||
		this.addError(err)
 | 
			
		||||
		this.write500()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = clientConn.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		_, _ = io.Copy(clientConn, originConn)
 | 
			
		||||
		_ = clientConn.Close()
 | 
			
		||||
		_ = originConn.Close()
 | 
			
		||||
	}()
 | 
			
		||||
	_, _ = io.Copy(originConn, clientConn)
 | 
			
		||||
}
 | 
			
		||||
@@ -100,7 +100,7 @@ func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC
 | 
			
		||||
		if origin == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		conn, err = origin.Connect()
 | 
			
		||||
		conn, err = OriginConnect(origin)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logs.Println("[TCP_LISTENER]unable to connect origin: " + origin.Addr.Host + ":" + origin.Addr.PortRange + ": " + err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										33
									
								
								internal/nodes/origin_utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								internal/nodes/origin_utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,33 @@
 | 
			
		||||
package nodes
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 连接源站
 | 
			
		||||
func OriginConnect(origin *serverconfigs.OriginConfig) (net.Conn, error) {
 | 
			
		||||
	if origin.Addr == nil {
 | 
			
		||||
		return nil, errors.New("origin server address should not be empty")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch origin.Addr.Protocol {
 | 
			
		||||
	case "", serverconfigs.ProtocolTCP, serverconfigs.ProtocolHTTP:
 | 
			
		||||
		// TODO 支持TCP4/TCP6
 | 
			
		||||
		// TODO 支持指定特定网卡
 | 
			
		||||
		// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
 | 
			
		||||
		return net.DialTimeout("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, origin.ConnTimeoutDuration())
 | 
			
		||||
	case serverconfigs.ProtocolTLS, serverconfigs.ProtocolHTTPS:
 | 
			
		||||
		// TODO 支持TCP4/TCP6
 | 
			
		||||
		// TODO 支持指定特定网卡
 | 
			
		||||
		// TODO Addr支持端口范围,如果有多个端口时,随机一个端口使用
 | 
			
		||||
		// TODO 支持使用证书
 | 
			
		||||
		return tls.Dial("tcp", origin.Addr.Host+":"+origin.Addr.PortRange, &tls.Config{})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO 支持从Unix、Pipe、HTTP、HTTPS中读取数据
 | 
			
		||||
 | 
			
		||||
	return nil, errors.New("invalid scheme '" + origin.Addr.Protocol.String() + "'")
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user