mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 16:00:25 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			78 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			78 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package nodes
 | 
						||
 | 
						||
import (
 | 
						||
	"errors"
 | 
						||
	"github.com/iwind/TeaGo/logs"
 | 
						||
	"io"
 | 
						||
	"net/http"
 | 
						||
	"net/url"
 | 
						||
)
 | 
						||
 | 
						||
// 处理Websocket请求
 | 
						||
func (this *HTTPRequest) doWebsocket() {
 | 
						||
	if this.web.WebsocketRef == nil || !this.web.WebsocketRef.IsOn || this.web.Websocket == nil || !this.web.Websocket.IsOn {
 | 
						||
		this.writer.WriteHeader(http.StatusForbidden)
 | 
						||
		this.addError(errors.New("websocket have not been enabled yet"))
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	// TODO 实现handshakeTimeout
 | 
						||
 | 
						||
	// 校验来源
 | 
						||
	requestOrigin := this.RawReq.Header.Get("Origin")
 | 
						||
	if len(requestOrigin) > 0 {
 | 
						||
		u, err := url.Parse(requestOrigin)
 | 
						||
		if err == nil {
 | 
						||
			if !this.web.Websocket.MatchOrigin(u.Host) {
 | 
						||
				this.writer.WriteHeader(http.StatusForbidden)
 | 
						||
				this.addError(errors.New("websocket origin '" + requestOrigin + "' not been allowed"))
 | 
						||
				return
 | 
						||
			}
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	// 设置指定的来源域
 | 
						||
	if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
 | 
						||
		newRequestOrigin := this.web.Websocket.RequestOrigin
 | 
						||
		if this.web.Websocket.RequestOriginHasVariables() {
 | 
						||
			newRequestOrigin = this.Format(newRequestOrigin)
 | 
						||
		}
 | 
						||
		this.RawReq.Header.Set("Origin", newRequestOrigin)
 | 
						||
	}
 | 
						||
 | 
						||
	// TODO 增加N次错误重试,重试的时候需要尝试不同的源站
 | 
						||
	originConn, err := OriginConnect(this.origin, this.RawReq.RemoteAddr)
 | 
						||
	if err != nil {
 | 
						||
		logs.Error(err)
 | 
						||
		this.write500(err)
 | 
						||
		return
 | 
						||
	}
 | 
						||
	defer func() {
 | 
						||
		_ = originConn.Close()
 | 
						||
	}()
 | 
						||
 | 
						||
	err = this.RawReq.Write(originConn)
 | 
						||
	if err != nil {
 | 
						||
		logs.Error(err)
 | 
						||
		this.write500(err)
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	clientConn, _, err := this.writer.Hijack()
 | 
						||
	if err != nil {
 | 
						||
		logs.Error(err)
 | 
						||
		this.write500(err)
 | 
						||
		return
 | 
						||
	}
 | 
						||
	defer func() {
 | 
						||
		_ = clientConn.Close()
 | 
						||
	}()
 | 
						||
 | 
						||
	go func() {
 | 
						||
		_, _ = io.Copy(clientConn, originConn)
 | 
						||
		_ = clientConn.Close()
 | 
						||
		_ = originConn.Close()
 | 
						||
	}()
 | 
						||
	_, _ = io.Copy(originConn, clientConn)
 | 
						||
}
 |