diff --git a/internal/nodes/http_client_pool.go b/internal/nodes/http_client_pool.go index 3e43cdc..6f5cb26 100644 --- a/internal/nodes/http_client_pool.go +++ b/internal/nodes/http_client_pool.go @@ -29,7 +29,7 @@ type HTTPClientPool struct { // NewHTTPClientPool 获取新对象 func NewHTTPClientPool() *HTTPClientPool { - pool := &HTTPClientPool{ + var pool = &HTTPClientPool{ clientExpiredDuration: 3600 * time.Second, clientsMap: map[string]*HTTPClient{}, } @@ -42,12 +42,16 @@ func NewHTTPClientPool() *HTTPClientPool { } // Client 根据地址获取客户端 -func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.OriginConfig, originAddr string, proxyProtocol *serverconfigs.ProxyProtocolConfig, followRedirects bool) (rawClient *http.Client, err error) { +func (this *HTTPClientPool) Client(req *HTTPRequest, + origin *serverconfigs.OriginConfig, + originAddr string, + proxyProtocol *serverconfigs.ProxyProtocolConfig, + followRedirects bool) (rawClient *http.Client, err error) { if origin.Addr == nil { return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")") } - key := origin.UniqueKey() + "@" + originAddr + var key = origin.UniqueKey() + "@" + originAddr this.locker.Lock() defer this.locker.Unlock() @@ -58,11 +62,11 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi return client.RawClient(), nil } - maxConnections := origin.MaxConns - connectionTimeout := origin.ConnTimeoutDuration() - readTimeout := origin.ReadTimeoutDuration() - idleTimeout := origin.IdleTimeoutDuration() - idleConns := origin.MaxIdleConns + var maxConnections = origin.MaxConns + var connectionTimeout = origin.ConnTimeoutDuration() + var readTimeout = origin.ReadTimeoutDuration() + var idleTimeout = origin.IdleTimeoutDuration() + var idleConns = origin.MaxIdleConns // 超时时间 if connectionTimeout <= 0 { @@ -73,7 +77,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi idleTimeout = 2 * time.Minute } - numberCPU := runtime.NumCPU() + var numberCPU = runtime.NumCPU() if numberCPU < 8 { numberCPU = 8 } @@ -163,7 +167,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi // 清理不使用的Client func (this *HTTPClientPool) cleanClients() { - ticker := time.NewTicker(this.clientExpiredDuration) + var ticker = time.NewTicker(this.clientExpiredDuration) for range ticker.C { currentAt := time.Now().Unix() @@ -181,11 +185,11 @@ func (this *HTTPClientPool) cleanClients() { // 支持TOA func (this *HTTPClientPool) handleTOA(req *HTTPRequest, ctx context.Context, network string, originAddr string, connectionTimeout time.Duration) (net.Conn, error) { // TODO 每个服务读取自身所属集群的TOA设置 - toaConfig := sharedTOAManager.Config() + var toaConfig = sharedTOAManager.Config() if toaConfig != nil && toaConfig.IsOn { - retries := 3 + var retries = 3 for i := 1; i <= retries; i++ { - port := int(toaConfig.RandLocalPort()) + var port = int(toaConfig.RandLocalPort()) // TODO 思考是否支持X-Real-IP/X-Forwarded-IP err := sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + req.requestRemoteAddr(true)) if err != nil { @@ -223,7 +227,7 @@ func (this *HTTPClientPool) handlePROXYProtocol(conn net.Conn, req *HTTPRequest, if reqConn != nil { destAddr = reqConn.(net.Conn).LocalAddr() } - header := proxyproto.Header{ + var header = proxyproto.Header{ Version: byte(proxyProtocol.Version), Command: proxyproto.PROXY, TransportProtocol: transportProtocol, diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index 9478a9f..b856a1b 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "github.com/TeaOSLab/EdgeCommon/pkg/configutils" + "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/iplibrary" @@ -47,6 +48,13 @@ type HTTPRequest struct { IsHTTP bool IsHTTPS bool + // 共享参数 + nodeConfig *nodeconfigs.NodeConfig + + // ln request + isLnRequest bool + lnRemoteAddr string + // 内部参数 isSubRequest bool writer *HTTPWriter @@ -145,6 +153,9 @@ func (this *HTTPRequest) Do() { return } + // 是否为低级别节点 + this.isLnRequest = this.checkLnRequest() + // 回调事件 this.onInit() if this.writer.isFinished { @@ -152,58 +163,60 @@ func (this *HTTPRequest) Do() { return } - // 特殊URL处理 - if len(this.rawURI) > 1 && this.rawURI[1] == '.' { - // ACME - // TODO 需要配置是否启用ACME检测 - if strings.HasPrefix(this.rawURI, "/.well-known/acme-challenge/") { - this.doACME() + if !this.isLnRequest { + // 特殊URL处理 + if len(this.rawURI) > 1 && this.rawURI[1] == '.' { + // ACME + // TODO 需要配置是否启用ACME检测 + if strings.HasPrefix(this.rawURI, "/.well-known/acme-challenge/") { + this.doACME() + this.doEnd() + return + } + } + + // 套餐 + if this.ReqServer.UserPlan != nil && !this.ReqServer.UserPlan.IsAvailable() { + this.doPlanExpires() this.doEnd() return } - } - // 套餐 - if this.ReqServer.UserPlan != nil && !this.ReqServer.UserPlan.IsAvailable() { - this.doPlanExpires() - this.doEnd() - return - } - - // 流量限制 - if this.ReqServer.TrafficLimit != nil && this.ReqServer.TrafficLimit.IsOn && !this.ReqServer.TrafficLimit.IsEmpty() && this.ReqServer.TrafficLimitStatus != nil && this.ReqServer.TrafficLimitStatus.IsValid() { - this.doTrafficLimit() - this.doEnd() - return - } - - // WAF - if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn { - if this.doWAFRequest() { + // 流量限制 + if this.ReqServer.TrafficLimit != nil && this.ReqServer.TrafficLimit.IsOn && !this.ReqServer.TrafficLimit.IsEmpty() && this.ReqServer.TrafficLimitStatus != nil && this.ReqServer.TrafficLimitStatus.IsValid() { + this.doTrafficLimit() this.doEnd() return } - } - // 访问控制 - if !this.isSubRequest && this.web.Auth != nil && this.web.Auth.IsOn { - if this.doAuth() { - this.doEnd() - return + // WAF + if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn { + if this.doWAFRequest() { + this.doEnd() + return + } } - } - // 自动跳转到HTTPS - if this.IsHTTP && this.web.RedirectToHttps != nil && this.web.RedirectToHttps.IsOn { - if this.doRedirectToHTTPS(this.web.RedirectToHttps) { - this.doEnd() - return + // 访问控制 + if !this.isSubRequest && this.web.Auth != nil && this.web.Auth.IsOn { + if this.doAuth() { + this.doEnd() + return + } } - } - // Compression - if this.web.Compression != nil && this.web.Compression.IsOn && this.web.Compression.Level > 0 { - this.writer.SetCompression(this.web.Compression) + // 自动跳转到HTTPS + if this.IsHTTP && this.web.RedirectToHttps != nil && this.web.RedirectToHttps.IsOn { + if this.doRedirectToHTTPS(this.web.RedirectToHttps) { + this.doEnd() + return + } + } + + // Compression + if this.web.Compression != nil && this.web.Compression.IsOn && this.web.Compression.Level > 0 { + this.writer.SetCompression(this.web.Compression) + } } // 开始调用 @@ -218,58 +231,60 @@ func (this *HTTPRequest) Do() { // 开始调用 func (this *HTTPRequest) doBegin() { - // 处理request limit - if this.web.RequestLimit != nil && - this.web.RequestLimit.IsOn { - if this.doRequestLimit() { + if !this.isLnRequest { + // 处理request limit + if this.web.RequestLimit != nil && + this.web.RequestLimit.IsOn { + if this.doRequestLimit() { + return + } + } + + // 处理requestBody + if this.RawReq.ContentLength > 0 && + this.web.AccessLogRef != nil && + this.web.AccessLogRef.IsOn && + this.web.AccessLogRef.ContainsField(serverconfigs.HTTPAccessLogFieldRequestBody) { + var err error + this.requestBodyData, err = ioutil.ReadAll(io.LimitReader(this.RawReq.Body, AccessLogMaxRequestBodySize)) + if err != nil { + this.write50x(err, http.StatusBadGateway, false) + return + } + this.RawReq.Body = ioutil.NopCloser(io.MultiReader(bytes.NewBuffer(this.requestBodyData), this.RawReq.Body)) + } + + // 处理健康检查 + var isHealthCheck = false + var healthCheckKey = this.RawReq.Header.Get(serverconfigs.HealthCheckHeaderName) + if len(healthCheckKey) > 0 { + if this.doHealthCheck(healthCheckKey, &isHealthCheck) { + return + } + } + + // UAM + if !isHealthCheck && this.ReqServer.UAM != nil && this.ReqServer.UAM.IsOn { + if this.doUAM() { + this.doEnd() + return + } + } + + // 跳转 + if len(this.web.HostRedirects) > 0 { + if this.doHostRedirect() { + return + } + } + + // 临时关闭页面 + if this.web.Shutdown != nil && this.web.Shutdown.IsOn { + this.doShutdown() return } } - // 处理requestBody - if this.RawReq.ContentLength > 0 && - this.web.AccessLogRef != nil && - this.web.AccessLogRef.IsOn && - this.web.AccessLogRef.ContainsField(serverconfigs.HTTPAccessLogFieldRequestBody) { - var err error - this.requestBodyData, err = ioutil.ReadAll(io.LimitReader(this.RawReq.Body, AccessLogMaxRequestBodySize)) - if err != nil { - this.write50x(err, http.StatusBadGateway, false) - return - } - this.RawReq.Body = ioutil.NopCloser(io.MultiReader(bytes.NewBuffer(this.requestBodyData), this.RawReq.Body)) - } - - // 处理健康检查 - var isHealthCheck = false - var healthCheckKey = this.RawReq.Header.Get(serverconfigs.HealthCheckHeaderName) - if len(healthCheckKey) > 0 { - if this.doHealthCheck(healthCheckKey, &isHealthCheck) { - return - } - } - - // UAM - if !isHealthCheck && this.ReqServer.UAM != nil && this.ReqServer.UAM.IsOn { - if this.doUAM() { - this.doEnd() - return - } - } - - // 跳转 - if len(this.web.HostRedirects) > 0 { - if this.doHostRedirect() { - return - } - } - - // 临时关闭页面 - if this.web.Shutdown != nil && this.web.Shutdown.IsOn { - this.doShutdown() - return - } - // 缓存 if this.web.Cache != nil && this.web.Cache.IsOn { if this.doCacheRead(false) { @@ -277,30 +292,32 @@ func (this *HTTPRequest) doBegin() { } } - // 重写规则 - if this.rewriteRule != nil { - if this.doRewrite() { - return - } - } - - // Fastcgi - if this.web.FastcgiRef != nil && this.web.FastcgiRef.IsOn && len(this.web.FastcgiList) > 0 { - if this.doFastcgi() { - return - } - } - - // root - if this.web.Root != nil && this.web.Root.IsOn { - // 如果处理成功,则终止请求的处理 - if this.doRoot() { - return + if !this.isLnRequest { + // 重写规则 + if this.rewriteRule != nil { + if this.doRewrite() { + return + } } - // 如果明确设置了终止,则也会自动终止 - if this.web.Root.IsBreak { - return + // Fastcgi + if this.web.FastcgiRef != nil && this.web.FastcgiRef.IsOn && len(this.web.FastcgiList) > 0 { + if this.doFastcgi() { + return + } + } + + // root + if this.web.Root != nil && this.web.Root.IsOn { + // 如果处理成功,则终止请求的处理 + if this.doRoot() { + return + } + + // 如果明确设置了终止,则也会自动终止 + if this.web.Root.IsBreak { + return + } } } @@ -809,9 +826,9 @@ func (this *HTTPRequest) Format(source string) string { if prefix == "node" { switch suffix { case "id": - return strconv.FormatInt(sharedNodeConfig.Id, 10) + return strconv.FormatInt(this.nodeConfig.Id, 10) case "name": - return sharedNodeConfig.Name + return this.nodeConfig.Name case "role": return teaconst.Role } @@ -970,13 +987,13 @@ func (this *HTTPRequest) Format(source string) string { if prefix == "product" { switch suffix { case "name": - if sharedNodeConfig.ProductConfig != nil && len(sharedNodeConfig.ProductConfig.Name) > 0 { - return sharedNodeConfig.ProductConfig.Name + if this.nodeConfig.ProductConfig != nil && len(this.nodeConfig.ProductConfig.Name) > 0 { + return this.nodeConfig.ProductConfig.Name } return teaconst.GlobalProductName case "version": - if sharedNodeConfig.ProductConfig != nil && len(sharedNodeConfig.ProductConfig.Version) > 0 { - return sharedNodeConfig.ProductConfig.Version + if this.nodeConfig.ProductConfig != nil && len(this.nodeConfig.ProductConfig.Version) > 0 { + return this.nodeConfig.ProductConfig.Version } return teaconst.Version } @@ -995,6 +1012,10 @@ func (this *HTTPRequest) addVarMapping(varMapping map[string]string) { // 获取请求的客户端地址 func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string { + if len(this.lnRemoteAddr) > 0 { + return this.lnRemoteAddr + } + if supportVar && len(this.remoteAddr) > 0 { return this.remoteAddr } diff --git a/internal/nodes/http_request_ln.go b/internal/nodes/http_request_ln.go new file mode 100644 index 0000000..c1d5dfc --- /dev/null +++ b/internal/nodes/http_request_ln.go @@ -0,0 +1,17 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build !plus +// +build !plus + +package nodes + +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" +) + +func (this *HTTPRequest) checkLnRequest() bool { + return false +} + +func (this *HTTPRequest) getLnOrigin() *serverconfigs.OriginConfig { + return nil +} diff --git a/internal/nodes/http_request_log.go b/internal/nodes/http_request_log.go index c3b81d1..a3cadd4 100644 --- a/internal/nodes/http_request_log.go +++ b/internal/nodes/http_request_log.go @@ -92,7 +92,7 @@ func (this *HTTPRequest) log() { accessLog := &pb.HTTPAccessLog{ RequestId: this.requestId, - NodeId: sharedNodeConfig.Id, + NodeId: this.nodeConfig.Id, ServerId: this.ReqServer.Id, RemoteAddr: this.requestRemoteAddr(true), RawRemoteAddr: addr, diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index c141db4..a79234a 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -36,22 +36,40 @@ func (this *HTTPRequest) doReverseProxy() { requestCall.Request = this.RawReq requestCall.Formatter = this.Format requestCall.Domain = this.ReqHost - var origin = this.reverseProxy.NextOrigin(requestCall) - requestCall.CallResponseCallbacks(this.writer) + + var origin *serverconfigs.OriginConfig + + // 二级节点 + if this.cacheRef != nil { + origin = this.getLnOrigin() + if origin != nil { + // 强制变更原来访问的域名 + requestHost = this.ReqHost + } + } + + // 自定义源站 if origin == nil { - err := errors.New(this.URL() + ": no available origin sites for reverse proxy") - remotelogs.ServerError(this.ReqServer.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error(), "", nil) - this.write50x(err, http.StatusBadGateway, true) - return + origin = this.reverseProxy.NextOrigin(requestCall) + requestCall.CallResponseCallbacks(this.writer) + if origin == nil { + err := errors.New(this.URL() + ": no available origin sites for reverse proxy") + remotelogs.ServerError(this.ReqServer.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error(), "", nil) + this.write50x(err, http.StatusBadGateway, true) + return + } + + if len(origin.StripPrefix) > 0 { + stripPrefix = origin.StripPrefix + } + if len(origin.RequestURI) > 0 { + requestURI = origin.RequestURI + requestURIHasVariables = origin.RequestURIHasVariables() + } } + this.origin = origin // 设置全局变量是为了日志等处理 - if len(origin.StripPrefix) > 0 { - stripPrefix = origin.StripPrefix - } - if len(origin.RequestURI) > 0 { - requestURI = origin.RequestURI - requestURIHasVariables = origin.RequestURIHasVariables() - } + if len(origin.RequestHost) > 0 { requestHost = origin.RequestHost requestHostHasVariables = origin.RequestHostHasVariables() diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index 7965100..6837b31 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -24,7 +24,7 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) { var remoteAddr = this.requestRemoteAddr(true) // 检查是否为白名单直连 - if !Tea.IsTesting() && sharedNodeConfig.IPIsAutoAllowed(remoteAddr) { + if !Tea.IsTesting() && this.nodeConfig.IPIsAutoAllowed(remoteAddr) { return } diff --git a/internal/nodes/http_writer.go b/internal/nodes/http_writer.go index cd6a7fe..84b15fa 100644 --- a/internal/nodes/http_writer.go +++ b/internal/nodes/http_writer.go @@ -445,7 +445,7 @@ func (this *HTTPWriter) PrepareWebP(resp *http.Response, size int64) { } // 集群配置 - var policy = sharedNodeConfig.FindWebPImagePolicyWithClusterId(this.req.ReqServer.ClusterId) + var policy = this.req.nodeConfig.FindWebPImagePolicyWithClusterId(this.req.ReqServer.ClusterId) if policy == nil { policy = nodeconfigs.DefaultWebPImagePolicy } diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index dcf283d..0904fa7 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -134,7 +134,7 @@ func (this *HTTPListener) Reload(group *serverconfigs.ServerAddressGroup) { // ServerHTTP 处理HTTP请求 func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.Request) { // 域名 - reqHost := rawReq.Host + var reqHost = rawReq.Host // TLS域名 if this.isIP(reqHost) { @@ -214,6 +214,8 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http. ServerAddr: this.addr, IsHTTP: this.isHTTP, IsHTTPS: this.isHTTPS, + + nodeConfig: sharedNodeConfig, } req.Do() } diff --git a/internal/nodes/node.go b/internal/nodes/node.go index ad7922d..1a815d9 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -339,6 +339,32 @@ func (this *Node) loop() error { return errors.New("reload common scripts failed: " + err.Error()) } + // 修改为已同步 + _, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{ + NodeTaskId: task.Id, + IsOk: true, + Error: "", + }) + if err != nil { + return err + } + case "nodeLevelChanged": + levelInfoResp, err := rpcClient.NodeRPC().FindNodeLevelInfo(nodeCtx, &pb.FindNodeLevelInfoRequest{}) + if err != nil { + return err + } + + sharedNodeConfig.Level = levelInfoResp.Level + + var parentNodes = map[int64][]*nodeconfigs.ParentNodeConfig{} + if len(levelInfoResp.ParentNodesMapJSON) > 0 { + err = json.Unmarshal(levelInfoResp.ParentNodesMapJSON, &parentNodes) + if err != nil { + return errors.New("decode level info failed: " + err.Error()) + } + } + sharedNodeConfig.ParentNodes = parentNodes + // 修改为已同步 _, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{ NodeTaskId: task.Id, diff --git a/internal/nodes/origin_state_manager.go b/internal/nodes/origin_state_manager.go index 4489096..b4d7fd8 100644 --- a/internal/nodes/origin_state_manager.go +++ b/internal/nodes/origin_state_manager.go @@ -124,9 +124,10 @@ func (this *OriginStateManager) Loop() error { // Fail 添加失败的源站 func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverseProxy *serverconfigs.ReverseProxyConfig, callback func()) { - if origin == nil { + if origin == nil || origin.Id <= 0 { return } + this.locker.Lock() state, ok := this.stateMap[origin.Id] var timestamp = time.Now().Unix() @@ -164,7 +165,7 @@ func (this *OriginStateManager) Fail(origin *serverconfigs.OriginConfig, reverse // Success 添加成功的源站 func (this *OriginStateManager) Success(origin *serverconfigs.OriginConfig, callback func()) { - if origin == nil { + if origin == nil || origin.Id <= 0 { return } @@ -182,6 +183,10 @@ func (this *OriginStateManager) Success(origin *serverconfigs.OriginConfig, call // IsAvailable 检查是否正常 func (this *OriginStateManager) IsAvailable(originId int64) bool { + if originId <= 0 { + return true + } + this.locker.RLock() _, ok := this.stateMap[originId] this.locker.RUnlock()