mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 15:51:54 +08:00 
			
		
		
		
	反向代理支持RequestPath、RequestURI等
This commit is contained in:
		@@ -727,7 +727,6 @@ func (this *HTTPRequest) setForwardHeaders(header http.Header) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 处理自定义Request Header
 | 
					// 处理自定义Request Header
 | 
				
			||||||
// TODO 处理一些被Golang转换了的Header,比如Websocket
 | 
					 | 
				
			||||||
func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
 | 
					func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
 | 
				
			||||||
	if this.web.RequestHeaderPolicy != nil && this.web.RequestHeaderPolicy.IsOn {
 | 
						if this.web.RequestHeaderPolicy != nil && this.web.RequestHeaderPolicy.IsOn {
 | 
				
			||||||
		// 删除某些Header
 | 
							// 删除某些Header
 | 
				
			||||||
@@ -768,6 +767,18 @@ func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 处理一些被Golang转换了的Header
 | 
				
			||||||
 | 
					// TODO 可以自定义要转换的Header
 | 
				
			||||||
 | 
					func (this *HTTPRequest) fixRequestHeader(header http.Header) {
 | 
				
			||||||
 | 
						for k, v := range header {
 | 
				
			||||||
 | 
							if strings.Contains(k, "-Websocket-") {
 | 
				
			||||||
 | 
								header.Del(k)
 | 
				
			||||||
 | 
								k = strings.ReplaceAll(k, "-Websocket-", "-WebSocket-")
 | 
				
			||||||
 | 
								header[k] = v
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 处理自定义Response Header
 | 
					// 处理自定义Response Header
 | 
				
			||||||
func (this *HTTPRequest) processResponseHeaders(statusCode int) {
 | 
					func (this *HTTPRequest) processResponseHeaders(statusCode int) {
 | 
				
			||||||
	responseHeader := this.writer.Header()
 | 
						responseHeader := this.writer.Header()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,72 @@
 | 
				
			|||||||
package nodes
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 处理反向代理
 | 
					// 处理反向代理
 | 
				
			||||||
func (this *HTTPRequest) doReverseProxy() {
 | 
					func (this *HTTPRequest) doReverseProxy() {
 | 
				
			||||||
 | 
						if this.reverseProxy == nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// StripPrefix
 | 
				
			||||||
 | 
						if len(this.reverseProxy.StripPrefix) > 0 {
 | 
				
			||||||
 | 
							stripPrefix := this.reverseProxy.StripPrefix
 | 
				
			||||||
 | 
							if stripPrefix[0] != '/' {
 | 
				
			||||||
 | 
								stripPrefix = "/" + stripPrefix
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							this.uri = strings.TrimPrefix(this.uri, stripPrefix)
 | 
				
			||||||
 | 
							if len(this.uri) == 0 || this.uri[0] != '/' {
 | 
				
			||||||
 | 
								this.uri = "/" + this.uri
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// RequestURI
 | 
				
			||||||
 | 
						if len(this.reverseProxy.RequestURI) > 0 {
 | 
				
			||||||
 | 
							if this.reverseProxy.RequestURIHasVariables() {
 | 
				
			||||||
 | 
								this.uri = this.Format(this.reverseProxy.RequestURI)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								this.uri = this.reverseProxy.RequestURI
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if len(this.uri) == 0 || this.uri[0] != '/' {
 | 
				
			||||||
 | 
								this.uri = "/" + this.uri
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// 处理RequestURI中的问号
 | 
				
			||||||
 | 
							questionMark := strings.LastIndex(this.uri, "?")
 | 
				
			||||||
 | 
							if questionMark > 0 {
 | 
				
			||||||
 | 
								path := this.uri[:questionMark]
 | 
				
			||||||
 | 
								if strings.Contains(path, "?") {
 | 
				
			||||||
 | 
									this.uri = path + "&" + this.uri[questionMark+1:]
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// 去除多个/
 | 
				
			||||||
 | 
							this.uri = utils.CleanPath(this.uri)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 重组请求URL
 | 
				
			||||||
 | 
						questionMark := strings.Index(this.uri, "?")
 | 
				
			||||||
 | 
						if questionMark > -1 {
 | 
				
			||||||
 | 
							this.RawReq.URL.Path = this.uri[:questionMark]
 | 
				
			||||||
 | 
							this.RawReq.URL.RawQuery = this.uri[questionMark+1:]
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							this.RawReq.URL.Path = this.uri
 | 
				
			||||||
 | 
							this.RawReq.URL.RawQuery = ""
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// RequestHost
 | 
				
			||||||
 | 
						if len(this.reverseProxy.RequestHost) > 0 {
 | 
				
			||||||
 | 
							if this.reverseProxy.RequestHostHasVariables() {
 | 
				
			||||||
 | 
								this.RawReq.Host = this.Format(this.reverseProxy.RequestHost)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								this.RawReq.Host = this.reverseProxy.RequestHost
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							this.RawReq.URL.Host = this.RawReq.Host
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 判断是否为Websocket请求
 | 
						// 判断是否为Websocket请求
 | 
				
			||||||
	if this.RawReq.Header.Get("Upgrade") == "websocket" {
 | 
						if this.RawReq.Header.Get("Upgrade") == "websocket" {
 | 
				
			||||||
		this.doWebsocket()
 | 
							this.doWebsocket()
 | 
				
			||||||
@@ -11,4 +76,3 @@ func (this *HTTPRequest) doReverseProxy() {
 | 
				
			|||||||
	// 普通HTTP请求
 | 
						// 普通HTTP请求
 | 
				
			||||||
	// TODO
 | 
						// TODO
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -16,6 +16,8 @@ func (this *HTTPRequest) doWebsocket() {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO 实现handshakeTimeout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 校验来源
 | 
						// 校验来源
 | 
				
			||||||
	requestOrigin := this.RawReq.Header.Get("Origin")
 | 
						requestOrigin := this.RawReq.Header.Get("Origin")
 | 
				
			||||||
	if len(requestOrigin) > 0 {
 | 
						if len(requestOrigin) > 0 {
 | 
				
			||||||
@@ -38,7 +40,9 @@ func (this *HTTPRequest) doWebsocket() {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 处理Header
 | 
				
			||||||
	this.processRequestHeaders(this.RawReq.Header)
 | 
						this.processRequestHeaders(this.RawReq.Header)
 | 
				
			||||||
 | 
						this.fixRequestHeader(this.RawReq.Header) // 处理 Websocket -> WebSocket
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 设置指定的来源域
 | 
						// 设置指定的来源域
 | 
				
			||||||
	if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
 | 
						if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 {
 | 
				
			||||||
@@ -49,11 +53,7 @@ func (this *HTTPRequest) doWebsocket() {
 | 
				
			|||||||
		this.RawReq.Header.Set("Origin", newRequestOrigin)
 | 
							this.RawReq.Header.Set("Origin", newRequestOrigin)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TODO 修改RequestURI
 | 
						// TODO 增加N次错误重试,重试的时候需要尝试不同的源站
 | 
				
			||||||
	// TODO 实现handshakeTimeout
 | 
					 | 
				
			||||||
	// TODO 修改 Websocket- 为 WebSocket-
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// TODO 增加N次错误重试
 | 
					 | 
				
			||||||
	originConn, err := OriginConnect(origin)
 | 
						originConn, err := OriginConnect(origin)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logs.Error(err)
 | 
							logs.Error(err)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										24
									
								
								internal/utils/path.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								internal/utils/path.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,24 @@
 | 
				
			|||||||
 | 
					package utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 清理Path中的多余的字符
 | 
				
			||||||
 | 
					func CleanPath(path string) string {
 | 
				
			||||||
 | 
						l := len(path)
 | 
				
			||||||
 | 
						if l == 0 {
 | 
				
			||||||
 | 
							return "/"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						result := []byte{'/'}
 | 
				
			||||||
 | 
						isSlash := true
 | 
				
			||||||
 | 
						for i := 0; i < l; i++ {
 | 
				
			||||||
 | 
							if path[i] == '\\' || path[i] == '/' {
 | 
				
			||||||
 | 
								if !isSlash {
 | 
				
			||||||
 | 
									isSlash = true
 | 
				
			||||||
 | 
									result = append(result, '/')
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								isSlash = false
 | 
				
			||||||
 | 
								result = append(result, path[i])
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return string(result)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										23
									
								
								internal/utils/path_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								internal/utils/path_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,23 @@
 | 
				
			|||||||
 | 
					package utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/assert"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCleanPath(t *testing.T) {
 | 
				
			||||||
 | 
						a := assert.NewAssertion(t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						a.IsTrue(CleanPath("") == "/")
 | 
				
			||||||
 | 
						a.IsTrue(CleanPath("/hello/world") == "/hello/world")
 | 
				
			||||||
 | 
						a.IsTrue(CleanPath("\\hello\\world") == "/hello/world")
 | 
				
			||||||
 | 
						a.IsTrue(CleanPath("/\\hello\\//world") == "/hello/world")
 | 
				
			||||||
 | 
						a.IsTrue(CleanPath("hello/world") == "/hello/world")
 | 
				
			||||||
 | 
						a.IsTrue(CleanPath("/hello////world") == "/hello/world")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkCleanPath(b *testing.B) {
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
 | 
							_ = CleanPath("/hello///world/very/long/very//long")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Reference in New Issue
	
	Block a user