优化代码

This commit is contained in:
刘祥超
2022-01-01 20:15:39 +08:00
parent a1212804bb
commit 336db828ad
16 changed files with 97 additions and 86 deletions

View File

@@ -12,6 +12,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/stats" "github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
"io" "io"
"io/ioutil" "io/ioutil"
@@ -38,8 +39,8 @@ type HTTPRequest struct {
// 外部参数 // 外部参数
RawReq *http.Request RawReq *http.Request
RawWriter http.ResponseWriter RawWriter http.ResponseWriter
Server *serverconfigs.ServerConfig ReqServer *serverconfigs.ServerConfig
host string // 请求的Host ReqHost string // 请求的Host
ServerName string // 实际匹配到的Host ServerName string // 实际匹配到的Host
ServerAddr string // 实际启动的服务器监听地址 ServerAddr string // 实际启动的服务器监听地址
IsHTTP bool IsHTTP bool
@@ -98,7 +99,7 @@ func (this *HTTPRequest) init() {
// this.uri = this.RawReq.URL.RequestURI() // this.uri = this.RawReq.URL.RequestURI()
// 之所以不使用RequestURI()是不想让URL中的Path被Encode // 之所以不使用RequestURI()是不想让URL中的Path被Encode
var urlPath = this.RawReq.URL.Path var urlPath = this.RawReq.URL.Path
if this.Server.Web != nil && this.Server.Web.MergeSlashes { if this.ReqServer.Web != nil && this.ReqServer.Web.MergeSlashes {
urlPath = utils.CleanPath(urlPath) urlPath = utils.CleanPath(urlPath)
this.web.MergeSlashes = true this.web.MergeSlashes = true
} }
@@ -129,13 +130,13 @@ func (this *HTTPRequest) Do() {
this.init() this.init()
// 当前服务的反向代理配置 // 当前服务的反向代理配置
if this.Server.ReverseProxyRef != nil && this.Server.ReverseProxy != nil { if this.ReqServer.ReverseProxyRef != nil && this.ReqServer.ReverseProxy != nil {
this.reverseProxyRef = this.Server.ReverseProxyRef this.reverseProxyRef = this.ReqServer.ReverseProxyRef
this.reverseProxy = this.Server.ReverseProxy this.reverseProxy = this.ReqServer.ReverseProxy
} }
// Web配置 // Web配置
err := this.configureWeb(this.Server.Web, true, 0) err := this.configureWeb(this.ReqServer.Web, true, 0)
if err != nil { if err != nil {
this.write50x(err, http.StatusInternalServerError, false) this.write50x(err, http.StatusInternalServerError, false)
this.doEnd() this.doEnd()
@@ -161,14 +162,14 @@ func (this *HTTPRequest) Do() {
} }
// 套餐 // 套餐
if this.Server.UserPlan != nil && !this.Server.UserPlan.IsAvailable() { if this.ReqServer.UserPlan != nil && !this.ReqServer.UserPlan.IsAvailable() {
this.doPlanExpires() this.doPlanExpires()
this.doEnd() this.doEnd()
return return
} }
// 流量限制 // 流量限制
if this.Server.TrafficLimit != nil && this.Server.TrafficLimit.IsOn && !this.Server.TrafficLimit.IsEmpty() && this.Server.TrafficLimitStatus != nil && this.Server.TrafficLimitStatus.IsValid() { if this.ReqServer.TrafficLimit != nil && this.ReqServer.TrafficLimit.IsOn && !this.ReqServer.TrafficLimit.IsEmpty() && this.ReqServer.TrafficLimitStatus != nil && this.ReqServer.TrafficLimitStatus.IsValid() {
this.doTrafficLimit() this.doTrafficLimit()
this.doEnd() this.doEnd()
return return
@@ -310,14 +311,14 @@ func (this *HTTPRequest) doEnd() {
// 流量统计 // 流量统计
// TODO 增加是否开启开关 // TODO 增加是否开启开关
// TODO 增加Header统计考虑从Conn中读取 // TODO 增加Header统计考虑从Conn中读取
if this.Server != nil { if this.ReqServer != nil {
if this.isCached { if this.isCached {
stats.SharedTrafficStatManager.Add(this.Server.Id, this.host, this.writer.sentBodyBytes, this.writer.sentBodyBytes, 1, 1, 0, 0, this.Server.ShouldCheckTrafficLimit(), this.Server.PlanId()) stats.SharedTrafficStatManager.Add(this.ReqServer.Id, this.ReqHost, this.writer.sentBodyBytes, this.writer.sentBodyBytes, 1, 1, 0, 0, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
} else { } else {
if this.isAttack { if this.isAttack {
stats.SharedTrafficStatManager.Add(this.Server.Id, this.host, this.writer.sentBodyBytes, 0, 1, 0, 1, this.writer.sentBodyBytes, this.Server.ShouldCheckTrafficLimit(), this.Server.PlanId()) stats.SharedTrafficStatManager.Add(this.ReqServer.Id, this.ReqHost, this.writer.sentBodyBytes, 0, 1, 0, 1, this.writer.sentBodyBytes, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
} else { } else {
stats.SharedTrafficStatManager.Add(this.Server.Id, this.host, this.writer.sentBodyBytes, 0, 1, 0, 0, 0, this.Server.ShouldCheckTrafficLimit(), this.Server.PlanId()) stats.SharedTrafficStatManager.Add(this.ReqServer.Id, this.ReqHost, this.writer.sentBodyBytes, 0, 1, 0, 0, 0, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId())
} }
} }
} }
@@ -545,7 +546,7 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
} }
if varMapping, isMatched := location.Match(rawPath, this.Format); isMatched { if varMapping, isMatched := location.Match(rawPath, this.Format); isMatched {
// 检查专属域名 // 检查专属域名
if len(location.Domains) > 0 && !configutils.MatchDomains(location.Domains, this.host) { if len(location.Domains) > 0 && !configutils.MatchDomains(location.Domains, this.ReqHost) {
continue continue
} }
@@ -627,7 +628,7 @@ func (this *HTTPRequest) Format(source string) string {
if this.IsHTTPS { if this.IsHTTPS {
scheme = "https" scheme = "https"
} }
return scheme + "://" + this.host + this.rawURI return scheme + "://" + this.ReqHost + this.rawURI
case "requestPath": case "requestPath":
return this.Path() return this.Path()
case "requestPathExtension": case "requestPathExtension":
@@ -674,7 +675,7 @@ func (this *HTTPRequest) Format(source string) string {
case "timestamp": case "timestamp":
return strconv.FormatInt(this.requestFromTime.Unix(), 10) return strconv.FormatInt(this.requestFromTime.Unix(), 10)
case "host": case "host":
return this.host return this.ReqHost
case "referer": case "referer":
return this.RawReq.Referer() return this.RawReq.Referer()
case "referer.host": case "referer.host":
@@ -792,7 +793,7 @@ func (this *HTTPRequest) Format(source string) string {
// host // host
if prefix == "host" { if prefix == "host" {
pieces := strings.Split(this.host, ".") pieces := strings.Split(this.ReqHost, ".")
switch suffix { switch suffix {
case "first": case "first":
if len(pieces) > 0 { if len(pieces) > 0 {
@@ -1089,14 +1090,22 @@ func (this *HTTPRequest) Id() string {
return this.requestId return this.requestId
} }
func (this *HTTPRequest) Server() maps.Map {
return maps.Map{"id": this.ReqServer.Id}
}
func (this *HTTPRequest) Node() maps.Map {
return maps.Map{"id": teaconst.NodeId}
}
// URL 获取完整的URL // URL 获取完整的URL
func (this *HTTPRequest) URL() string { func (this *HTTPRequest) URL() string {
return this.requestScheme() + "://" + this.host + this.uri return this.requestScheme() + "://" + this.ReqHost + this.uri
} }
// Host 获取Host // Host 获取Host
func (this *HTTPRequest) Host() string { func (this *HTTPRequest) Host() string {
return this.host return this.ReqHost
} }
func (this *HTTPRequest) Proto() string { func (this *HTTPRequest) Proto() string {
@@ -1186,6 +1195,24 @@ func (this *HTTPRequest) Done() {
this.isDone = true this.isDone = true
} }
// Close 关闭连接
func (this *HTTPRequest) Close() {
this.Done()
requestConn := this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn == nil {
return
}
conn, ok := requestConn.(net.Conn)
if ok {
_ = conn.Close()
return
}
return
}
// 设置代理相关头部信息 // 设置代理相关头部信息
// 参考https://tools.ietf.org/html/rfc7239 // 参考https://tools.ietf.org/html/rfc7239
func (this *HTTPRequest) setForwardHeaders(header http.Header) { func (this *HTTPRequest) setForwardHeaders(header http.Header) {
@@ -1226,9 +1253,9 @@ func (this *HTTPRequest) setForwardHeaders(header http.Header) {
/**{ /**{
forwarded, ok := header["Forwarded"] forwarded, ok := header["Forwarded"]
if ok { if ok {
header["Forwarded"] = []string{strings.Join(forwarded, ", ") + ", by=" + this.serverAddr + "; for=" + remoteAddr + "; host=" + this.host + "; proto=" + this.rawScheme} header["Forwarded"] = []string{strings.Join(forwarded, ", ") + ", by=" + this.serverAddr + "; for=" + remoteAddr + "; host=" + this.ReqHost + "; proto=" + this.rawScheme}
} else { } else {
header["Forwarded"] = []string{"by=" + this.serverAddr + "; for=" + remoteAddr + "; host=" + this.host + "; proto=" + this.rawScheme} header["Forwarded"] = []string{"by=" + this.serverAddr + "; for=" + remoteAddr + "; host=" + this.ReqHost + "; proto=" + this.rawScheme}
} }
}**/ }**/
@@ -1239,7 +1266,7 @@ func (this *HTTPRequest) setForwardHeaders(header http.Header) {
if this.reverseProxy != nil && this.reverseProxy.ShouldAddXForwardedHostHeader() { if this.reverseProxy != nil && this.reverseProxy.ShouldAddXForwardedHostHeader() {
if _, ok := header["X-Forwarded-Host"]; !ok { if _, ok := header["X-Forwarded-Host"]; !ok {
this.RawReq.Header.Set("X-Forwarded-Host", this.host) this.RawReq.Header.Set("X-Forwarded-Host", this.ReqHost)
} }
} }
@@ -1279,7 +1306,7 @@ func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
} }
// 域名 // 域名
if len(header.Domains) > 0 && !configutils.MatchDomains(header.Domains, this.host) { if len(header.Domains) > 0 && !configutils.MatchDomains(header.Domains, this.ReqHost) {
continue continue
} }
@@ -1363,7 +1390,7 @@ func (this *HTTPRequest) processResponseHeaders(statusCode int) {
} }
// 域名 // 域名
if len(header.Domains) > 0 && !configutils.MatchDomains(header.Domains, this.host) { if len(header.Domains) > 0 && !configutils.MatchDomains(header.Domains, this.ReqHost) {
continue continue
} }
@@ -1397,13 +1424,13 @@ func (this *HTTPRequest) processResponseHeaders(statusCode int) {
// HSTS // HSTS
if this.IsHTTPS && if this.IsHTTPS &&
this.Server.HTTPS != nil && this.ReqServer.HTTPS != nil &&
this.Server.HTTPS.SSLPolicy != nil && this.ReqServer.HTTPS.SSLPolicy != nil &&
this.Server.HTTPS.SSLPolicy.IsOn && this.ReqServer.HTTPS.SSLPolicy.IsOn &&
this.Server.HTTPS.SSLPolicy.HSTS != nil && this.ReqServer.HTTPS.SSLPolicy.HSTS != nil &&
this.Server.HTTPS.SSLPolicy.HSTS.IsOn && this.ReqServer.HTTPS.SSLPolicy.HSTS.IsOn &&
this.Server.HTTPS.SSLPolicy.HSTS.Match(this.host) { this.ReqServer.HTTPS.SSLPolicy.HSTS.Match(this.ReqHost) {
responseHeader.Set(this.Server.HTTPS.SSLPolicy.HSTS.HeaderKey(), this.Server.HTTPS.SSLPolicy.HSTS.HeaderValue()) responseHeader.Set(this.ReqServer.HTTPS.SSLPolicy.HSTS.HeaderKey(), this.ReqServer.HTTPS.SSLPolicy.HSTS.HeaderValue())
} }
} }
@@ -1464,22 +1491,6 @@ func (this *HTTPRequest) canIgnore(err error) bool {
return false return false
} }
// 关闭当前连接
func (this *HTTPRequest) closeConn() {
requestConn := this.RawReq.Context().Value(HTTPConnContextKey)
if requestConn == nil {
return
}
conn, ok := requestConn.(net.Conn)
if ok {
_ = conn.Close()
return
}
return
}
// 检查连接是否已关闭 // 检查连接是否已关闭
func (this *HTTPRequest) isConnClosed() bool { func (this *HTTPRequest) isConnClosed() bool {
requestConn := this.RawReq.Context().Value(HTTPConnContextKey) requestConn := this.RawReq.Context().Value(HTTPConnContextKey)

View File

@@ -45,7 +45,7 @@ func (this *HTTPRequest) doAuth() (shouldStop bool) {
if len(method.Realm) > 0 { if len(method.Realm) > 0 {
headerValue += method.Realm headerValue += method.Realm
} else { } else {
headerValue += this.host headerValue += this.ReqHost
} }
headerValue += "\"" headerValue += "\""
if len(method.Charset) > 0 { if len(method.Charset) > 0 {

View File

@@ -22,7 +22,7 @@ import (
func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
this.cacheCanTryStale = false this.cacheCanTryStale = false
cachePolicy := this.Server.HTTPCachePolicy cachePolicy := this.ReqServer.HTTPCachePolicy
if cachePolicy == nil || !cachePolicy.IsOn { if cachePolicy == nil || !cachePolicy.IsOn {
return return
} }
@@ -138,7 +138,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if err == nil { if err == nil {
for _, rpcServerService := range rpcClient.ServerRPCList() { for _, rpcServerService := range rpcClient.ServerRPCList() {
_, err = rpcServerService.PurgeServerCache(rpcClient.Context(), &pb.PurgeServerCacheRequest{ _, err = rpcServerService.PurgeServerCache(rpcClient.Context(), &pb.PurgeServerCacheRequest{
Domains: []string{this.host}, Domains: []string{this.ReqHost},
Keys: []string{key}, Keys: []string{key},
Prefixes: nil, Prefixes: nil,
}) })

View File

@@ -52,13 +52,13 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
} }
} }
if !env.Has("SERVER_NAME") { if !env.Has("SERVER_NAME") {
env["SERVER_NAME"] = this.host env["SERVER_NAME"] = this.ReqHost
} }
if !env.Has("REQUEST_URI") { if !env.Has("REQUEST_URI") {
env["REQUEST_URI"] = this.uri env["REQUEST_URI"] = this.uri
} }
if !env.Has("HOST") { if !env.Has("HOST") {
env["HOST"] = this.host env["HOST"] = this.ReqHost
} }
if len(this.ServerAddr) > 0 { if len(this.ServerAddr) > 0 {
@@ -149,7 +149,7 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
host, found := params["HTTP_HOST"] host, found := params["HTTP_HOST"]
if !found || len(host) == 0 { if !found || len(host) == 0 {
params["HTTP_HOST"] = this.host params["HTTP_HOST"] = this.ReqHost
} }
fcgiReq := fcgi.NewRequest() fcgiReq := fcgi.NewRequest()

View File

@@ -13,7 +13,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
if this.web.MergeSlashes { if this.web.MergeSlashes {
urlPath = utils.CleanPath(urlPath) urlPath = utils.CleanPath(urlPath)
} }
fullURL := this.requestScheme() + "://" + this.host + urlPath fullURL := this.requestScheme() + "://" + this.ReqHost + urlPath
for _, u := range this.web.HostRedirects { for _, u := range this.web.HostRedirects {
if !u.IsOn { if !u.IsOn {
continue continue

View File

@@ -19,9 +19,9 @@ func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
if requestConn != nil { if requestConn != nil {
clientConn, ok := requestConn.(ClientConnInterface) clientConn, ok := requestConn.(ClientConnInterface)
if ok && !clientConn.IsBound() { if ok && !clientConn.IsBound() {
if !clientConn.Bind(this.Server.Id, this.requestRemoteAddr(true), this.web.RequestLimit.MaxConns, this.web.RequestLimit.MaxConnsPerIP) { if !clientConn.Bind(this.ReqServer.Id, this.requestRemoteAddr(true), this.web.RequestLimit.MaxConns, this.web.RequestLimit.MaxConnsPerIP) {
this.writeCode(http.StatusTooManyRequests) this.writeCode(http.StatusTooManyRequests)
this.closeConn() this.Close()
return true return true
} }
} }

View File

@@ -93,7 +93,7 @@ func (this *HTTPRequest) log() {
accessLog := &pb.HTTPAccessLog{ accessLog := &pb.HTTPAccessLog{
RequestId: this.requestId, RequestId: this.requestId,
NodeId: sharedNodeConfig.Id, NodeId: sharedNodeConfig.Id,
ServerId: this.Server.Id, ServerId: this.ReqServer.Id,
RemoteAddr: this.requestRemoteAddr(true), RemoteAddr: this.requestRemoteAddr(true),
RawRemoteAddr: addr, RawRemoteAddr: addr,
RemotePort: int32(this.requestRemotePort()), RemotePort: int32(this.requestRemotePort()),
@@ -114,7 +114,7 @@ func (this *HTTPRequest) log() {
TimeLocal: this.requestFromTime.Format("2/Jan/2006:15:04:05 -0700"), TimeLocal: this.requestFromTime.Format("2/Jan/2006:15:04:05 -0700"),
Msec: float64(this.requestFromTime.Unix()) + float64(this.requestFromTime.Nanosecond())/1000000000, Msec: float64(this.requestFromTime.Unix()) + float64(this.requestFromTime.Nanosecond())/1000000000,
Timestamp: this.requestFromTime.Unix(), Timestamp: this.requestFromTime.Unix(),
Host: this.host, Host: this.ReqHost,
Referer: referer, Referer: referer,
UserAgent: userAgent, UserAgent: userAgent,
Request: this.requestString(), Request: this.requestString(),

View File

@@ -50,7 +50,7 @@ func (this *HTTPRequest) MetricValue(value string) (result int64, ok bool) {
} }
func (this *HTTPRequest) MetricServerId() int64 { func (this *HTTPRequest) MetricServerId() int64 {
return this.Server.Id return this.ReqServer.Id
} }
func (this *HTTPRequest) MetricCategory() string { func (this *HTTPRequest) MetricCategory() string {

View File

@@ -36,12 +36,12 @@ func (this *HTTPRequest) doReverseProxy() {
requestCall := shared.NewRequestCall() requestCall := shared.NewRequestCall()
requestCall.Request = this.RawReq requestCall.Request = this.RawReq
requestCall.Formatter = this.Format requestCall.Formatter = this.Format
requestCall.Domain = this.host requestCall.Domain = this.ReqHost
origin := this.reverseProxy.NextOrigin(requestCall) origin := this.reverseProxy.NextOrigin(requestCall)
requestCall.CallResponseCallbacks(this.writer) requestCall.CallResponseCallbacks(this.writer)
if origin == nil { if origin == nil {
err := errors.New(this.URL() + ": no available origin sites for reverse proxy") err := errors.New(this.URL() + ": no available origin sites for reverse proxy")
remotelogs.ServerError(this.Server.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error(), "", nil) remotelogs.ServerError(this.ReqServer.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error(), "", nil)
this.write50x(err, http.StatusBadGateway, true) this.write50x(err, http.StatusBadGateway, true)
return return
} }
@@ -129,7 +129,7 @@ func (this *HTTPRequest) doReverseProxy() {
this.RawReq.Host = hostname this.RawReq.Host = hostname
this.RawReq.URL.Host = this.RawReq.Host this.RawReq.URL.Host = this.RawReq.Host
} else { } else {
this.RawReq.URL.Host = this.host this.RawReq.URL.Host = this.ReqHost
} }
// 重组请求URL // 重组请求URL

View File

@@ -15,7 +15,7 @@ func (this *HTTPRequest) doRewrite() (shouldShop bool) {
if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeProxy { if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeProxy {
// 外部URL // 外部URL
if this.rewriteIsExternalURL { if this.rewriteIsExternalURL {
host := this.host host := this.ReqHost
if len(this.rewriteRule.ProxyHost) > 0 { if len(this.rewriteRule.ProxyHost) > 0 {
host = this.rewriteRule.ProxyHost host = this.rewriteRule.ProxyHost
} }

View File

@@ -6,11 +6,11 @@ import (
// 统计 // 统计
func (this *HTTPRequest) doStat() { func (this *HTTPRequest) doStat() {
if this.Server == nil { if this.ReqServer == nil {
return return
} }
// 内置的统计 // 内置的统计
stats.SharedHTTPRequestStatManager.AddRemoteAddr(this.Server.Id, this.requestRemoteAddr(true), this.writer.SentBodyBytes(), this.isAttack) stats.SharedHTTPRequestStatManager.AddRemoteAddr(this.ReqServer.Id, this.requestRemoteAddr(true), this.writer.SentBodyBytes(), this.isAttack)
stats.SharedHTTPRequestStatManager.AddUserAgent(this.Server.Id, this.requestHeader("User-Agent")) stats.SharedHTTPRequestStatManager.AddUserAgent(this.ReqServer.Id, this.requestHeader("User-Agent"))
} }

View File

@@ -10,8 +10,8 @@ func (this *HTTPRequest) doSubRequest(writer http.ResponseWriter, rawReq *http.R
req := &HTTPRequest{ req := &HTTPRequest{
RawReq: rawReq, RawReq: rawReq,
RawWriter: writer, RawWriter: writer,
Server: this.Server, ReqServer: this.ReqServer,
host: this.host, ReqHost: this.ReqHost,
ServerName: this.ServerName, ServerName: this.ServerName,
ServerAddr: this.ServerAddr, ServerAddr: this.ServerAddr,
IsHTTP: this.IsHTTP, IsHTTP: this.IsHTTP,

View File

@@ -8,7 +8,7 @@ import (
// 流量限制 // 流量限制
func (this *HTTPRequest) doTrafficLimit() { func (this *HTTPRequest) doTrafficLimit() {
var config = this.Server.TrafficLimit var config = this.ReqServer.TrafficLimit
this.tags = append(this.tags, "bandwidth") this.tags = append(this.tags, "bandwidth")

View File

@@ -31,16 +31,16 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
} }
// 是否在全局名单中 // 是否在全局名单中
if !iplibrary.AllowIP(remoteAddr, this.Server.Id) { if !iplibrary.AllowIP(remoteAddr, this.ReqServer.Id) {
this.disableLog = true this.disableLog = true
this.closeConn() this.Close()
return true return true
} }
// 检查是否在临时黑名单中 // 检查是否在临时黑名单中
if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeService, this.Server.Id, remoteAddr) || waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteAddr) { if waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeService, this.ReqServer.Id, remoteAddr) || waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, remoteAddr) {
this.disableLog = true this.disableLog = true
this.closeConn() this.Close()
return true return true
} }
@@ -57,8 +57,8 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
} }
// 公用的防火墙设置 // 公用的防火墙设置
if this.Server.HTTPFirewallPolicy != nil && this.Server.HTTPFirewallPolicy.IsOn { if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blocked, breakChecking := this.checkWAFRequest(this.Server.HTTPFirewallPolicy) blocked, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy)
if blocked { if blocked {
return true return true
} }
@@ -208,7 +208,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
} }
// 添加统计 // 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions) stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions)
} }
this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode) this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode)
@@ -228,8 +228,8 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
} }
// 公用的防火墙设置 // 公用的防火墙设置
if this.Server.HTTPFirewallPolicy != nil && this.Server.HTTPFirewallPolicy.IsOn { if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blocked := this.checkWAFResponse(this.Server.HTTPFirewallPolicy, resp) blocked := this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp)
if blocked { if blocked {
return true return true
} }
@@ -266,7 +266,7 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
} }
// 添加统计 // 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.Server.Id, this.firewallRuleGroupId, ruleSet.Actions) stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions)
} }
this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode) this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode)
@@ -313,12 +313,12 @@ func (this *HTTPRequest) WAFRestoreBody(data []byte) {
// WAFServerId 服务ID // WAFServerId 服务ID
func (this *HTTPRequest) WAFServerId() int64 { func (this *HTTPRequest) WAFServerId() int64 {
return this.Server.Id return this.ReqServer.Id
} }
// WAFClose 关闭连接 // WAFClose 关闭连接
func (this *HTTPRequest) WAFClose() { func (this *HTTPRequest) WAFClose() {
this.closeConn() this.Close()
} }
func (this *HTTPRequest) WAFOnAction(action interface{}) (goNext bool) { func (this *HTTPRequest) WAFOnAction(action interface{}) (goNext bool) {

View File

@@ -441,8 +441,8 @@ func (this *HTTPWriter) Close() {
StaleAt: expiredAt + int64(this.calculateStaleLife()), StaleAt: expiredAt + int64(this.calculateStaleLife()),
HeaderSize: this.cacheWriter.HeaderSize(), HeaderSize: this.cacheWriter.HeaderSize(),
BodySize: this.cacheWriter.BodySize(), BodySize: this.cacheWriter.BodySize(),
Host: this.req.host, Host: this.req.ReqHost,
ServerId: this.req.Server.Id, ServerId: this.req.ReqServer.Id,
}) })
} }
} }
@@ -566,7 +566,7 @@ func (this *HTTPWriter) prepareCache(size int64) {
return return
} }
cachePolicy := this.req.Server.HTTPCachePolicy cachePolicy := this.req.ReqServer.HTTPCachePolicy
if cachePolicy == nil || !cachePolicy.IsOn { if cachePolicy == nil || !cachePolicy.IsOn {
return return
} }

View File

@@ -208,8 +208,8 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
req := &HTTPRequest{ req := &HTTPRequest{
RawReq: rawReq, RawReq: rawReq,
RawWriter: rawWriter, RawWriter: rawWriter,
Server: server, ReqServer: server,
host: reqHost, ReqHost: reqHost,
ServerName: serverName, ServerName: serverName,
ServerAddr: this.addr, ServerAddr: this.addr,
IsHTTP: this.isHTTP, IsHTTP: this.isHTTP,