diff --git a/internal/events/utils.go b/internal/events/utils.go index 209f5ba..46426ff 100644 --- a/internal/events/utils.go +++ b/internal/events/utils.go @@ -5,7 +5,7 @@ import "sync" var eventsMap = map[string][]func(){} // event => []callbacks var locker = sync.Mutex{} -// 增加事件回调 +// On 增加事件回调 func On(event string, callback func()) { locker.Lock() defer locker.Unlock() @@ -15,7 +15,7 @@ func On(event string, callback func()) { eventsMap[event] = callbacks } -// 通知事件 +// Notify 通知事件 func Notify(event string) { locker.Lock() callbacks, _ := eventsMap[event] diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index 72c594c..4a52952 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -12,6 +12,7 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/stats" "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/iwind/TeaGo/lists" + "github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" "io" @@ -1160,6 +1161,7 @@ func (this *HTTPRequest) Method() string { return this.RawReq.Method } +// TransferEncoding 获取传输编码 func (this *HTTPRequest) TransferEncoding() string { if len(this.RawReq.TransferEncoding) > 0 { return this.RawReq.TransferEncoding[0] @@ -1167,6 +1169,15 @@ func (this *HTTPRequest) TransferEncoding() string { return "" } +// Cookie 获取Cookie +func (this *HTTPRequest) Cookie(name string) string { + c, err := this.RawReq.Cookie(name) + if err != nil { + return "" + } + return c.Value +} + // DeleteHeader 删除Header func (this *HTTPRequest) DeleteHeader(name string) { this.RawReq.Header.Del(name) @@ -1182,10 +1193,12 @@ func (this *HTTPRequest) Header() http.Header { return this.RawReq.Header } +// URI 获取当前请求的URI func (this *HTTPRequest) URI() string { return this.uri } +// SetURI 设置当前请求的URI func (this *HTTPRequest) SetURI(uri string) { this.uri = uri } @@ -1213,6 +1226,12 @@ func (this *HTTPRequest) Close() { return } +// Allow 放行 +func (this *HTTPRequest) Allow() { + logs.Println("allow") // TODO + this.web.FirewallRef = nil +} + // 设置代理相关头部信息 // 参考:https://tools.ietf.org/html/rfc7239 func (this *HTTPRequest) setForwardHeaders(header http.Header) { diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index 9546367..dd54a07 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -17,6 +17,10 @@ import ( // 调用WAF func (this *HTTPRequest) doWAFRequest() (blocked bool) { + if this.web.FirewallRef == nil || !this.web.FirewallRef.IsOn { + return + } + var remoteAddr = this.requestRemoteAddr(true) // 检查是否为白名单直连 @@ -219,6 +223,10 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir // call response waf func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) { + if this.web.FirewallRef == nil || !this.web.FirewallRef.IsOn { + return + } + // 当前服务的独立设置 if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn { blocked := this.checkWAFResponse(this.web.FirewallPolicy, resp) diff --git a/internal/nodes/http_writer.go b/internal/nodes/http_writer.go index e14ebb0..60521cb 100644 --- a/internal/nodes/http_writer.go +++ b/internal/nodes/http_writer.go @@ -7,6 +7,7 @@ import ( "bytes" "compress/flate" "compress/gzip" + "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeNode/internal/caches" "github.com/TeaOSLab/EdgeNode/internal/compressions" @@ -27,6 +28,7 @@ import ( "io" "net" "net/http" + "os" "path/filepath" "strings" "sync/atomic" @@ -217,13 +219,46 @@ func (this *HTTPWriter) WriteHeader(statusCode int) { this.statusCode = statusCode } -// Send 发送响应 +// Send 直接发送内容,并终止请求 func (this *HTTPWriter) Send(status int, body string) { this.WriteHeader(status) _, _ = this.WriteString(body) this.isFinished = true } +// SendFile 发送文件内容,并终止请求 +func (this *HTTPWriter) SendFile(status int, path string) (int64, error) { + this.WriteHeader(status) + this.isFinished = true + + fp, err := os.OpenFile(path, os.O_RDONLY, 0444) + if err != nil { + return 0, errors.New("open file '" + path + "' failed: " + err.Error()) + } + defer func() { + _ = fp.Close() + }() + + stat, err := fp.Stat() + if err != nil { + return 0, err + } + if stat.IsDir() { + return 0, errors.New("open file '" + path + "' failed: it is a directory") + } + + var bufPool = this.req.bytePool(stat.Size()) + var buf = bufPool.Get() + defer bufPool.Put(buf) + + written, err := io.CopyBuffer(this, fp, buf) + if err != nil { + return written, err + } + + return written, nil +} + // StatusCode 读取状态码 func (this *HTTPWriter) StatusCode() int { if this.statusCode == 0 {