mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-03 23:20:25 +08:00
实现HTTP部分功能
This commit is contained in:
@@ -1,9 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
TARGET=../../EdgeAdmin/internal/serverconfigs
|
||||
if [ -d ${TARGET} ]
|
||||
then
|
||||
rm -rf ../../EdgeAdmin/internal/serverconfigs
|
||||
fi
|
||||
cp -R ../internal/configs/serverconfigs ../../EdgeAdmin/internal/configs/
|
||||
cp -R ../internal/configs/serverconfigs ../../EdgeAPI/internal/configs
|
||||
1
build/configs/node.json
Executable file
1
build/configs/node.json
Executable file
@@ -0,0 +1 @@
|
||||
{"id":"db4ab72647d3d2b9c8d9b57b4f4bf47b","isOn":true,"servers":[{"id":3,"type":"httpProxy","isOn":true,"name":"我的服务","description":"","serverNames":[],"http":{"isOn":true,"listen":[{"protocol":"http","host":"127.0.0.1","portRange":"9991"}]},"https":null,"tcp":null,"tls":null,"unix":null,"udp":null,"web":{"id":35,"isOn":true,"locations":[{"id":7,"isOn":true,"pattern":"/hello","name":"","web":{"id":36,"isOn":true,"locations":null,"locationRefs":null,"gzipRef":null,"gzip":null,"charset":null,"shutdown":{"isPrior":false,"isOn":false,"url":"","status":0},"pages":null,"redirectToHttps":{"isPrior":false,"isOn":false,"status":307,"host":"www","port":11111},"root":"","indexes":null,"maxRequestBodySize":"","accessLog":null,"statRef":null,"cacheRef":null,"firewallRef":null,"websocketRef":null,"websocket":null,"requestHeaderPolicyRef":{"isPrior":false,"isOn":true,"headerPolicyId":23},"requestHeaderPolicy":{"id":23,"name":"","isOn":true,"description":"","addHeaderRefs":null,"addHeaders":null,"addTrailerRefs":null,"addTrailers":null,"setHeaderRefs":null,"setHeaders":null,"replaceHeaderRefs":null,"replaceHeaders":null,"deleteHeaders":["Cache-Control","Pragma"],"expires":null},"responseHeaderPolicyRef":{"isPrior":false,"isOn":true,"headerPolicyId":24},"responseHeaderPolicy":{"id":24,"name":"","isOn":true,"description":"","addHeaderRefs":null,"addHeaders":null,"addTrailerRefs":null,"addTrailers":null,"setHeaderRefs":null,"setHeaders":null,"replaceHeaderRefs":null,"replaceHeaders":null,"deleteHeaders":null,"expires":null},"filterRefs":null,"filterPolicies":null},"urlPrefix":"","description":"","reverseProxyRef":null,"reverseProxy":null,"isBreak":false,"children":null,"condGroups":null}],"locationRefs":[{"isOn":true,"locationId":7,"children":null}],"gzipRef":null,"gzip":null,"charset":null,"shutdown":{"isPrior":false,"isOn":false,"url":"hello.html","status":0},"pages":[{"id":14,"isOn":true,"status":["404"],"url":"pages/404.html","newStatus":0}],"redirectToHttps":{"isPrior":false,"isOn":false,"status":307,"host":"","port":0},"root":"","indexes":null,"maxRequestBodySize":"","accessLog":null,"statRef":null,"cacheRef":null,"firewallRef":null,"websocketRef":null,"websocket":null,"requestHeaderPolicyRef":{"isPrior":false,"isOn":true,"headerPolicyId":21},"requestHeaderPolicy":{"id":21,"name":"","isOn":true,"description":"","addHeaderRefs":null,"addHeaders":null,"addTrailerRefs":null,"addTrailers":null,"setHeaderRefs":[{"isOn":true,"headerId":30}],"setHeaders":[{"id":30,"isOn":true,"name":"From","value":"Edge","status":{"always":true,"codes":null}}],"replaceHeaderRefs":null,"replaceHeaders":null,"deleteHeaders":["Cache-Control","Cookie"],"expires":null},"responseHeaderPolicyRef":{"isPrior":false,"isOn":true,"headerPolicyId":22},"responseHeaderPolicy":{"id":22,"name":"","isOn":true,"description":"","addHeaderRefs":null,"addHeaders":null,"addTrailerRefs":null,"addTrailers":null,"setHeaderRefs":[{"isOn":true,"headerId":28},{"isOn":true,"headerId":29}],"setHeaders":[{"id":28,"isOn":true,"name":"Server","value":"Edge","status":{"always":true,"codes":null}},{"id":29,"isOn":true,"name":"Hello","value":"World","status":{"always":true,"codes":null}}],"replaceHeaderRefs":null,"replaceHeaders":null,"deleteHeaders":["Name"],"expires":null},"filterRefs":null,"filterPolicies":null},"reverseProxyRef":{"isPrior":false,"isOn":true,"reverseProxyId":20},"reverseProxy":{"id":20,"isOn":false,"primaryOrigins":null,"primaryOriginRefs":null,"backupOrigins":null,"backupOriginRefs":null,"scheduling":null}}],"version":114,"name":"认证啊","globalConfig":null}
|
||||
15
build/pages/403.html
Normal file
15
build/pages/403.html
Normal file
@@ -0,0 +1,15 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Error</title>
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<h3>403 Forbidden</h3>
|
||||
<p>Sorry, your access to the page has been denied. Please try again later.</p>
|
||||
|
||||
<footer>Powered by TeaEdge.</footer>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
15
build/pages/404.html
Normal file
15
build/pages/404.html
Normal file
@@ -0,0 +1,15 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Error</title>
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<h3>404 Not Found</h3>
|
||||
<p>Sorry, the page you are looking for is not found. Please try again later.</p>
|
||||
|
||||
<footer>Powered by TeaEdge.</footer>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
15
build/pages/50x.html
Normal file
15
build/pages/50x.html
Normal file
@@ -0,0 +1,15 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Error</title>
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<h3>An error occurred.</h3>
|
||||
<p>Sorry, the page you are looking for is currently unavailable. Please try again later.</p>
|
||||
|
||||
<footer>Powered by TeaEdge.</footer>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
15
build/pages/shutdown_en.html
Normal file
15
build/pages/shutdown_en.html
Normal file
@@ -0,0 +1,15 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Shutdown Notice</title>
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<h3>The website is shutdown.</h3>
|
||||
<p>Sorry, the page you are looking for is currently unavailable. Please try again later.</p>
|
||||
|
||||
<footer>Powered by TeaEdge.</footer>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
15
build/pages/shutdown_upgrade_zh.html
Normal file
15
build/pages/shutdown_upgrade_zh.html
Normal file
@@ -0,0 +1,15 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>升级中</title>
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<h3>网站升级中</h3>
|
||||
<p>为了给您提供更好的服务,我们正在升级网站,请稍后重新访问。</p>
|
||||
|
||||
<footer>Powered by TeaEdge.</footer>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
15
build/pages/shutdown_zh.html
Normal file
15
build/pages/shutdown_zh.html
Normal file
@@ -0,0 +1,15 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>临时关闭提醒</title>
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<h3>网站暂时关闭</h3>
|
||||
<p>网站已被暂时关闭,请耐心等待我们的重新开通通知。</p>
|
||||
|
||||
<footer>Powered by TeaEdge.</footer>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
2
go.mod
2
go.mod
@@ -10,7 +10,7 @@ require (
|
||||
github.com/go-ole/go-ole v1.2.4 // indirect
|
||||
github.com/go-redis/redis v6.15.8+incompatible // indirect
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible
|
||||
github.com/iwind/TeaGo v0.0.0-20200910072805-729cffe36729
|
||||
github.com/iwind/TeaGo v0.0.0-20200923021120-f5d76441fe9e
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.1 // indirect
|
||||
github.com/pquerna/ffjson v0.0.0-20190930134022-aa0246cd15f7 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -57,6 +57,8 @@ github.com/iwind/TeaGo v0.0.0-20200822074248-b1cf7248c98a h1:VaWcMNOzHHT1y8MeTA2
|
||||
github.com/iwind/TeaGo v0.0.0-20200822074248-b1cf7248c98a/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
|
||||
github.com/iwind/TeaGo v0.0.0-20200910072805-729cffe36729 h1:/v0WhSFVeNay/dA5zU9iCBXlgVDfxnztuanlauXE0gM=
|
||||
github.com/iwind/TeaGo v0.0.0-20200910072805-729cffe36729/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
|
||||
github.com/iwind/TeaGo v0.0.0-20200923021120-f5d76441fe9e h1:/xn7wUvlwaoA5IkdBUctv2OQbJSZ0/Dw8qRJmn55sJk=
|
||||
github.com/iwind/TeaGo v0.0.0-20200923021120-f5d76441fe9e/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc=
|
||||
github.com/json-iterator/go v1.1.10 h1:Kz6Cvnvv2wGdaG/V8yMvfkmNiXq9Ya2KUv4rouJJr68=
|
||||
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
package configs
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/go-yaml/yaml"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
var sharedNodeConfig *NodeConfig = nil
|
||||
|
||||
type NodeConfig struct {
|
||||
Id string `yaml:"id" json:"id"`
|
||||
IsOn bool `yaml:"isOn" json:"isOn"`
|
||||
Servers []*serverconfigs.ServerConfig `yaml:"servers" json:"servers"`
|
||||
Version int `yaml:"version" json:"version"`
|
||||
}
|
||||
|
||||
// 取得当前节点配置单例
|
||||
func SharedNodeConfig() (*NodeConfig, error) {
|
||||
sharedLocker.Lock()
|
||||
defer sharedLocker.Unlock()
|
||||
|
||||
if sharedNodeConfig != nil {
|
||||
return sharedNodeConfig, nil
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(Tea.ConfigFile("node.yaml"))
|
||||
if err != nil {
|
||||
return &NodeConfig{}, err
|
||||
}
|
||||
|
||||
config := &NodeConfig{}
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return &NodeConfig{}, err
|
||||
}
|
||||
|
||||
sharedNodeConfig = config
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// 刷新当前节点配置
|
||||
func ReloadNodeConfig() error {
|
||||
sharedLocker.Lock()
|
||||
sharedNodeConfig = nil
|
||||
sharedLocker.Unlock()
|
||||
|
||||
_, err := SharedNodeConfig()
|
||||
return err
|
||||
}
|
||||
|
||||
// 根据网络地址和协议分组
|
||||
func (this *NodeConfig) AvailableGroups() []*serverconfigs.ServerGroup {
|
||||
groupMapping := map[string]*serverconfigs.ServerGroup{} // protocol://addr => Server Group
|
||||
for _, server := range this.Servers {
|
||||
if !server.IsOn {
|
||||
continue
|
||||
}
|
||||
for _, addr := range server.FullAddresses() {
|
||||
group, ok := groupMapping[addr]
|
||||
if ok {
|
||||
group.Add(server)
|
||||
} else {
|
||||
group = serverconfigs.NewServerGroup(addr)
|
||||
group.Add(server)
|
||||
}
|
||||
groupMapping[addr] = group
|
||||
}
|
||||
}
|
||||
result := []*serverconfigs.ServerGroup{}
|
||||
for _, group := range groupMapping {
|
||||
result = append(result, group)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (this *NodeConfig) Init() error {
|
||||
for _, server := range this.Servers {
|
||||
err := server.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 写入到文件
|
||||
func (this *NodeConfig) Save() error {
|
||||
sharedLocker.Lock()
|
||||
defer sharedLocker.Unlock()
|
||||
|
||||
data, err := yaml.Marshal(this)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ioutil.WriteFile(Tea.ConfigFile("node.yaml"), data, 0777)
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package configs
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
_ "github.com/iwind/TeaGo/bootstrap"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSharedNodeConfig(t *testing.T) {
|
||||
{
|
||||
config, err := SharedNodeConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(config)
|
||||
}
|
||||
|
||||
// read from memory cache
|
||||
{
|
||||
config, err := SharedNodeConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(config)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeConfig_Groups(t *testing.T) {
|
||||
config := &NodeConfig{}
|
||||
config.Servers = []*serverconfigs.ServerConfig{
|
||||
{
|
||||
IsOn: true,
|
||||
HTTP: &serverconfigs.HTTPProtocolConfig{
|
||||
BaseProtocol: serverconfigs.BaseProtocol{
|
||||
IsOn: true,
|
||||
Listen: []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: serverconfigs.ProtocolHTTP,
|
||||
Host: "127.0.0.1",
|
||||
PortRange: "1234",
|
||||
},
|
||||
{
|
||||
Protocol: serverconfigs.ProtocolHTTP,
|
||||
PortRange: "8080",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
HTTP: &serverconfigs.HTTPProtocolConfig{
|
||||
BaseProtocol: serverconfigs.BaseProtocol{
|
||||
IsOn: true,
|
||||
Listen: []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: serverconfigs.ProtocolHTTP,
|
||||
PortRange: "8080",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
logs.PrintAsJSON(config.AvailableGroups(), t)
|
||||
}
|
||||
820
internal/nodes/http_request.go
Normal file
820
internal/nodes/http_request.go
Normal file
@@ -0,0 +1,820 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 环境变量
|
||||
var HOSTNAME, _ = os.Hostname()
|
||||
|
||||
// byte pool
|
||||
var bytePool256b = utils.NewBytePool(20480, 256)
|
||||
var bytePool1k = utils.NewBytePool(20480, 1024)
|
||||
var bytePool32k = utils.NewBytePool(20480, 32*1024)
|
||||
var bytePool128k = utils.NewBytePool(20480, 128*1024)
|
||||
|
||||
// HTTP请求
|
||||
type HTTPRequest struct {
|
||||
// 外部参数
|
||||
RawReq *http.Request
|
||||
RawWriter http.ResponseWriter
|
||||
Server *serverconfigs.ServerConfig
|
||||
Host string // 请求的Host
|
||||
ServerName string // 实际匹配到的Host
|
||||
ServerAddr string // 实际启动的服务器监听地址
|
||||
IsHTTP bool
|
||||
IsHTTPS bool
|
||||
|
||||
// 内部参数
|
||||
writer *HTTPWriter
|
||||
web *serverconfigs.HTTPWebConfig
|
||||
rawURI string // 原始的URI
|
||||
uri string // 经过rewrite等运算之后的URI
|
||||
varMapping map[string]string // 变量集合
|
||||
requestFromTime time.Time // 请求开始时间
|
||||
requestCost float64 // 请求耗时
|
||||
filePath string // 请求的文件名,仅在读取Root目录下的内容时不为空
|
||||
origin *serverconfigs.OriginConfig // 源站
|
||||
errors []string // 错误信息
|
||||
}
|
||||
|
||||
// 初始化
|
||||
func (this *HTTPRequest) init() {
|
||||
this.writer = NewHTTPWriter(this.RawWriter)
|
||||
this.web = &serverconfigs.HTTPWebConfig{}
|
||||
this.uri = this.RawReq.URL.RequestURI()
|
||||
this.rawURI = this.uri
|
||||
this.varMapping = map[string]string{}
|
||||
this.requestFromTime = time.Now()
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
func (this *HTTPRequest) Do() {
|
||||
// 初始化
|
||||
this.init()
|
||||
|
||||
// 配置
|
||||
err := this.configureWeb(this.Server.Web, true, 0)
|
||||
if err != nil {
|
||||
this.writeInternalServerError()
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
|
||||
// WAF
|
||||
// TODO 需要实现
|
||||
|
||||
// 访问控制
|
||||
// TODO 需要实现
|
||||
|
||||
// 自动跳转到HTTPS
|
||||
if this.IsHTTP && this.web.RedirectToHttps != nil && this.web.RedirectToHttps.IsOn {
|
||||
this.doRedirectToHTTPS(this.web.RedirectToHttps)
|
||||
this.doEnd()
|
||||
return
|
||||
}
|
||||
|
||||
// Gzip
|
||||
// TODO 需要实现
|
||||
|
||||
// 开始调用
|
||||
this.doBegin()
|
||||
}
|
||||
|
||||
// 开始调用
|
||||
func (this *HTTPRequest) doBegin() {
|
||||
// 重写规则
|
||||
// TODO
|
||||
|
||||
// 临时关闭页面
|
||||
if this.web.Shutdown != nil && this.web.Shutdown.IsOn {
|
||||
this.doShutdown()
|
||||
return
|
||||
}
|
||||
|
||||
// Origin
|
||||
// TODO
|
||||
|
||||
// WebSocket
|
||||
// TODO
|
||||
|
||||
// Fastcgi
|
||||
// TODO
|
||||
|
||||
// Server Event Sent
|
||||
// TODO 实现Location的AutoFlush
|
||||
|
||||
// root
|
||||
// TODO 从本地文件中读取
|
||||
// TODO 增加root优先级:High:优先从Root读取,Low:优先从反向代理等条件中读取
|
||||
// TODO 增加stripPrefix
|
||||
// TODO 增加URLEncode的处理方式
|
||||
|
||||
// 返回404页面
|
||||
this.writeNotFoundError()
|
||||
}
|
||||
|
||||
// 结束调用
|
||||
func (this *HTTPRequest) doEnd() {
|
||||
this.log()
|
||||
}
|
||||
|
||||
// 原始的请求URI
|
||||
func (this *HTTPRequest) RawURI() string {
|
||||
return this.rawURI
|
||||
}
|
||||
|
||||
// 配置
|
||||
func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bool, redirects int) error {
|
||||
if web == nil || !web.IsOn {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 防止跳转次数过多
|
||||
if redirects > 8 {
|
||||
return errors.New("too many redirects")
|
||||
}
|
||||
redirects++
|
||||
|
||||
// uri
|
||||
rawPath := ""
|
||||
rawQuery := ""
|
||||
qIndex := strings.Index(this.uri, "?") // question mark index
|
||||
if qIndex > -1 {
|
||||
rawPath = this.uri[:qIndex]
|
||||
rawQuery = this.uri[qIndex+1:]
|
||||
} else {
|
||||
rawPath = this.uri
|
||||
}
|
||||
_ = rawQuery // TODO 暂时不用到这个变量
|
||||
|
||||
// redirect
|
||||
if web.RedirectToHttps != nil && (web.RedirectToHttps.IsPrior || isTop) {
|
||||
this.web.RedirectToHttps = web.RedirectToHttps
|
||||
}
|
||||
|
||||
// pages
|
||||
if len(web.Pages) > 0 {
|
||||
this.web.Pages = web.Pages
|
||||
}
|
||||
|
||||
// shutdown
|
||||
if web.Shutdown != nil && (web.Shutdown.IsPrior || isTop) {
|
||||
this.web.Shutdown = web.Shutdown
|
||||
}
|
||||
|
||||
// headers
|
||||
if web.RequestHeaderPolicyRef != nil && (web.RequestHeaderPolicyRef.IsPrior || isTop) && web.RequestHeaderPolicy != nil {
|
||||
// TODO 现在是只能选一个有效的设置,未来可以选择是否合并多级别的设置
|
||||
this.web.RequestHeaderPolicy = web.RequestHeaderPolicy
|
||||
}
|
||||
if web.ResponseHeaderPolicyRef != nil && (web.ResponseHeaderPolicyRef.IsPrior || isTop) && web.ResponseHeaderPolicy != nil {
|
||||
// TODO 现在是只能选一个有效的设置,未来可以选择是否合并多级别的设置
|
||||
this.web.ResponseHeaderPolicy = web.ResponseHeaderPolicy
|
||||
}
|
||||
|
||||
// locations
|
||||
if len(web.LocationRefs) > 0 {
|
||||
var resultLocation *serverconfigs.HTTPLocationConfig
|
||||
for index, ref := range web.LocationRefs {
|
||||
if !ref.IsOn {
|
||||
continue
|
||||
}
|
||||
location := web.Locations[index]
|
||||
if !location.IsOn {
|
||||
continue
|
||||
}
|
||||
if varMapping, isMatched := location.Match(rawPath, this.Format); isMatched {
|
||||
if len(varMapping) > 0 {
|
||||
this.addVarMapping(varMapping)
|
||||
}
|
||||
resultLocation = location
|
||||
|
||||
if location.IsBreak {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if resultLocation != nil && resultLocation.Web != nil {
|
||||
err := this.configureWeb(resultLocation.Web, false, redirects)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 利用请求参数格式化字符串
|
||||
func (this *HTTPRequest) Format(source string) string {
|
||||
if len(source) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var hasVarMapping = len(this.varMapping) > 0
|
||||
|
||||
return configutils.ParseVariables(source, func(varName string) string {
|
||||
// 自定义变量
|
||||
if hasVarMapping {
|
||||
value, found := this.varMapping[varName]
|
||||
if found {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// 请求变量
|
||||
switch varName {
|
||||
case "edgeVersion":
|
||||
return teaconst.Version
|
||||
case "remoteAddr":
|
||||
return this.requestRemoteAddr()
|
||||
case "rawRemoteAddr":
|
||||
addr := this.RawReq.RemoteAddr
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err == nil {
|
||||
addr = host
|
||||
}
|
||||
return addr
|
||||
case "remotePort":
|
||||
return strconv.Itoa(this.requestRemotePort())
|
||||
case "remoteUser":
|
||||
return this.requestRemoteUser()
|
||||
case "requestURI", "requestUri":
|
||||
return this.rawURI
|
||||
case "requestPath":
|
||||
return this.requestPath()
|
||||
case "requestLength":
|
||||
return strconv.FormatInt(this.requestLength(), 10)
|
||||
case "requestTime":
|
||||
return fmt.Sprintf("%.6f", this.requestCost)
|
||||
case "requestMethod":
|
||||
return this.RawReq.Method
|
||||
case "requestFilename":
|
||||
filename := this.requestFilename()
|
||||
if len(filename) > 0 {
|
||||
return filename
|
||||
}
|
||||
|
||||
if len(this.web.Root) > 0 {
|
||||
return filepath.Clean(this.web.Root + this.requestPath())
|
||||
}
|
||||
|
||||
return ""
|
||||
case "scheme":
|
||||
if this.IsHTTP {
|
||||
return "http"
|
||||
} else {
|
||||
return "https"
|
||||
}
|
||||
case "serverProtocol", "proto":
|
||||
return this.RawReq.Proto
|
||||
case "bytesSent":
|
||||
return strconv.FormatInt(this.writer.SentBodyBytes(), 10) // TODO 加上Header长度
|
||||
case "bodyBytesSent":
|
||||
return strconv.FormatInt(this.writer.SentBodyBytes(), 10)
|
||||
case "status":
|
||||
return strconv.Itoa(this.writer.StatusCode())
|
||||
case "statusMessage":
|
||||
return http.StatusText(this.writer.StatusCode())
|
||||
case "timeISO8601":
|
||||
return this.requestFromTime.Format("2006-01-02T15:04:05.000Z07:00")
|
||||
case "timeLocal":
|
||||
return this.requestFromTime.Format("2/Jan/2006:15:04:05 -0700")
|
||||
case "msec":
|
||||
return fmt.Sprintf("%.6f", float64(this.requestFromTime.Unix())+float64(this.requestFromTime.Nanosecond())/1000000000)
|
||||
case "timestamp":
|
||||
return strconv.FormatInt(this.requestFromTime.Unix(), 10)
|
||||
case "host":
|
||||
return this.Host
|
||||
case "referer":
|
||||
return this.RawReq.Referer()
|
||||
case "userAgent":
|
||||
return this.RawReq.UserAgent()
|
||||
case "contentType":
|
||||
return this.requestContentType()
|
||||
case "request":
|
||||
return this.requestString()
|
||||
case "cookies":
|
||||
return this.requestCookiesString()
|
||||
case "args", "queryString":
|
||||
return this.requestQueryString()
|
||||
case "headers":
|
||||
return this.requestHeadersString()
|
||||
case "serverName":
|
||||
return this.ServerName
|
||||
case "serverPort":
|
||||
return strconv.Itoa(this.requestServerPort())
|
||||
case "hostname":
|
||||
return HOSTNAME
|
||||
case "documentRoot":
|
||||
return this.web.Root
|
||||
}
|
||||
|
||||
dotIndex := strings.Index(varName, ".")
|
||||
if dotIndex < 0 {
|
||||
return "${" + varName + "}"
|
||||
}
|
||||
prefix := varName[:dotIndex]
|
||||
suffix := varName[dotIndex+1:]
|
||||
|
||||
// cookie.
|
||||
if prefix == "cookie" {
|
||||
return this.requestCookie(suffix)
|
||||
}
|
||||
|
||||
// arg.
|
||||
if prefix == "arg" {
|
||||
return this.requestQueryParam(suffix)
|
||||
}
|
||||
|
||||
// header.
|
||||
if prefix == "header" || prefix == "http" {
|
||||
return this.requestHeader(suffix)
|
||||
}
|
||||
|
||||
// backend.
|
||||
if prefix == "origin" {
|
||||
if this.origin != nil {
|
||||
switch suffix {
|
||||
case "address", "addr":
|
||||
return this.origin.RealAddr()
|
||||
case "host":
|
||||
addr := this.origin.RealAddr()
|
||||
index := strings.Index(addr, ":")
|
||||
if index > -1 {
|
||||
return addr[:index]
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
case "id":
|
||||
return strconv.FormatInt(this.origin.Id, 10)
|
||||
case "scheme", "protocol":
|
||||
return this.origin.Addr.Protocol.String()
|
||||
case "code":
|
||||
return this.origin.Code
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// node
|
||||
if prefix == "node" {
|
||||
switch suffix {
|
||||
case "id":
|
||||
return sharedNodeConfig.Id
|
||||
case "name":
|
||||
return sharedNodeConfig.Name
|
||||
case "role":
|
||||
return teaconst.Role
|
||||
}
|
||||
}
|
||||
|
||||
// host
|
||||
if prefix == "host" {
|
||||
pieces := strings.Split(this.Host, ".")
|
||||
switch suffix {
|
||||
case "first":
|
||||
if len(pieces) > 0 {
|
||||
return pieces[0]
|
||||
}
|
||||
return ""
|
||||
case "last":
|
||||
if len(pieces) > 0 {
|
||||
return pieces[len(pieces)-1]
|
||||
}
|
||||
return ""
|
||||
case "0":
|
||||
if len(pieces) > 0 {
|
||||
return pieces[0]
|
||||
}
|
||||
return ""
|
||||
case "1":
|
||||
if len(pieces) > 1 {
|
||||
return pieces[1]
|
||||
}
|
||||
return ""
|
||||
case "2":
|
||||
if len(pieces) > 2 {
|
||||
return pieces[2]
|
||||
}
|
||||
return ""
|
||||
case "3":
|
||||
if len(pieces) > 3 {
|
||||
return pieces[3]
|
||||
}
|
||||
return ""
|
||||
case "4":
|
||||
if len(pieces) > 4 {
|
||||
return pieces[4]
|
||||
}
|
||||
return ""
|
||||
case "-1":
|
||||
if len(pieces) > 0 {
|
||||
return pieces[len(pieces)-1]
|
||||
}
|
||||
return ""
|
||||
case "-2":
|
||||
if len(pieces) > 1 {
|
||||
return pieces[len(pieces)-2]
|
||||
}
|
||||
return ""
|
||||
case "-3":
|
||||
if len(pieces) > 2 {
|
||||
return pieces[len(pieces)-3]
|
||||
}
|
||||
return ""
|
||||
case "-4":
|
||||
if len(pieces) > 3 {
|
||||
return pieces[len(pieces)-4]
|
||||
}
|
||||
return ""
|
||||
case "-5":
|
||||
if len(pieces) > 4 {
|
||||
return pieces[len(pieces)-5]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
return "${" + varName + "}"
|
||||
})
|
||||
}
|
||||
|
||||
// 添加变量定义
|
||||
func (this *HTTPRequest) addVarMapping(varMapping map[string]string) {
|
||||
for k, v := range varMapping {
|
||||
this.varMapping[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// 获取请求的客户端地址
|
||||
func (this *HTTPRequest) requestRemoteAddr() string {
|
||||
// X-Forwarded-For
|
||||
forwardedFor := this.RawReq.Header.Get("X-Forwarded-For")
|
||||
if len(forwardedFor) > 0 {
|
||||
commaIndex := strings.Index(forwardedFor, ",")
|
||||
if commaIndex > 0 {
|
||||
return forwardedFor[:commaIndex]
|
||||
}
|
||||
return forwardedFor
|
||||
}
|
||||
|
||||
// Real-IP
|
||||
{
|
||||
realIP, ok := this.RawReq.Header["X-Real-IP"]
|
||||
if ok && len(realIP) > 0 {
|
||||
return realIP[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Real-Ip
|
||||
{
|
||||
realIP, ok := this.RawReq.Header["X-Real-Ip"]
|
||||
if ok && len(realIP) > 0 {
|
||||
return realIP[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Remote-Addr
|
||||
remoteAddr := this.RawReq.RemoteAddr
|
||||
host, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err == nil {
|
||||
return host
|
||||
} else {
|
||||
return remoteAddr
|
||||
}
|
||||
}
|
||||
|
||||
// 请求内容长度
|
||||
func (this *HTTPRequest) requestLength() int64 {
|
||||
return this.RawReq.ContentLength
|
||||
}
|
||||
|
||||
// 请求用户
|
||||
func (this *HTTPRequest) requestRemoteUser() string {
|
||||
username, _, ok := this.RawReq.BasicAuth()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return username
|
||||
}
|
||||
|
||||
// 请求的URL中路径部分
|
||||
func (this *HTTPRequest) requestPath() string {
|
||||
uri, err := url.ParseRequestURI(this.rawURI)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return uri.Path
|
||||
}
|
||||
|
||||
// 客户端端口
|
||||
func (this *HTTPRequest) requestRemotePort() int {
|
||||
_, port, err := net.SplitHostPort(this.RawReq.RemoteAddr)
|
||||
if err == nil {
|
||||
return types.Int(port)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// 情趣的URI中的参数部分
|
||||
func (this *HTTPRequest) requestQueryString() string {
|
||||
uri, err := url.ParseRequestURI(this.uri)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return uri.RawQuery
|
||||
}
|
||||
|
||||
// 构造类似于"GET / HTTP/1.1"之类的请求字符串
|
||||
func (this *HTTPRequest) requestString() string {
|
||||
return this.RawReq.Method + " " + this.rawURI + " " + this.RawReq.Proto
|
||||
}
|
||||
|
||||
// 构造请求字符串
|
||||
func (this *HTTPRequest) requestCookiesString() string {
|
||||
var cookies = []string{}
|
||||
for _, cookie := range this.RawReq.Cookies() {
|
||||
cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value))
|
||||
}
|
||||
return strings.Join(cookies, "&")
|
||||
}
|
||||
|
||||
// 查询单个Cookie值
|
||||
func (this *HTTPRequest) requestCookie(name string) string {
|
||||
cookie, err := this.RawReq.Cookie(name)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
// 查询请求参数值
|
||||
func (this *HTTPRequest) requestQueryParam(name string) string {
|
||||
uri, err := url.ParseRequestURI(this.rawURI)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
v, found := uri.Query()[name]
|
||||
if !found {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(v, "&")
|
||||
}
|
||||
|
||||
// 查询单个请求Header值
|
||||
func (this *HTTPRequest) requestHeader(key string) string {
|
||||
v, found := this.RawReq.Header[key]
|
||||
if !found {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(v, ";")
|
||||
}
|
||||
|
||||
// 以字符串的形式返回所有请求Header
|
||||
func (this *HTTPRequest) requestHeadersString() string {
|
||||
var headers = []string{}
|
||||
for k, v := range this.RawReq.Header {
|
||||
for _, subV := range v {
|
||||
headers = append(headers, k+": "+subV)
|
||||
}
|
||||
}
|
||||
return strings.Join(headers, ";")
|
||||
}
|
||||
|
||||
// 获取请求Content-Type值
|
||||
func (this *HTTPRequest) requestContentType() string {
|
||||
return this.RawReq.Header.Get("Content-Type")
|
||||
}
|
||||
|
||||
// 获取请求的文件名,仅在请求是读取本地文件时不为空
|
||||
func (this *HTTPRequest) requestFilename() string {
|
||||
return this.filePath
|
||||
}
|
||||
|
||||
// 请求的scheme
|
||||
func (this *HTTPRequest) requestScheme() string {
|
||||
if this.IsHTTPS {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
// 请求的服务器地址中的端口
|
||||
func (this *HTTPRequest) requestServerPort() int {
|
||||
_, port, err := net.SplitHostPort(this.ServerAddr)
|
||||
if err == nil {
|
||||
return types.Int(port)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// 设置代理相关头部信息
|
||||
// 参考:https://tools.ietf.org/html/rfc7239
|
||||
func (this *HTTPRequest) setForwardHeaders(header http.Header) {
|
||||
delete(header, "Connection")
|
||||
|
||||
remoteAddr := this.RawReq.RemoteAddr
|
||||
host, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err == nil {
|
||||
remoteAddr = host
|
||||
}
|
||||
|
||||
// x-real-ip
|
||||
{
|
||||
_, ok1 := header["X-Real-IP"]
|
||||
_, ok2 := header["X-Real-Ip"]
|
||||
if !ok1 && !ok2 {
|
||||
header["X-Real-IP"] = []string{remoteAddr}
|
||||
}
|
||||
}
|
||||
|
||||
// X-Forwarded-For
|
||||
{
|
||||
forwardedFor, ok := header["X-Forwarded-For"]
|
||||
if ok {
|
||||
_, hasForwardHeader := this.RawReq.Header["X-Forwarded-For"]
|
||||
if hasForwardHeader {
|
||||
header["X-Forwarded-For"] = []string{strings.Join(forwardedFor, ", ") + ", " + remoteAddr}
|
||||
}
|
||||
} else {
|
||||
header["X-Forwarded-For"] = []string{remoteAddr}
|
||||
}
|
||||
}
|
||||
|
||||
// Forwarded
|
||||
/**{
|
||||
forwarded, ok := header["Forwarded"]
|
||||
if ok {
|
||||
header["Forwarded"] = []string{strings.Join(forwarded, ", ") + ", by=" + this.serverAddr + "; for=" + remoteAddr + "; host=" + this.host + "; proto=" + this.rawScheme}
|
||||
} else {
|
||||
header["Forwarded"] = []string{"by=" + this.serverAddr + "; for=" + remoteAddr + "; host=" + this.host + "; proto=" + this.rawScheme}
|
||||
}
|
||||
}**/
|
||||
|
||||
// others
|
||||
this.RawReq.Header.Set("X-Forwarded-By", this.ServerAddr)
|
||||
|
||||
if _, ok := header["X-Forwarded-Host"]; !ok {
|
||||
this.RawReq.Header.Set("X-Forwarded-Host", this.Host)
|
||||
}
|
||||
|
||||
if _, ok := header["X-Forwarded-Proto"]; !ok {
|
||||
this.RawReq.Header.Set("X-Forwarded-Proto", this.requestScheme())
|
||||
}
|
||||
}
|
||||
|
||||
// 处理自定义Request Header
|
||||
func (this *HTTPRequest) processRequestHeaders(reqHeader http.Header) {
|
||||
if this.web.RequestHeaderPolicy != nil && this.web.RequestHeaderPolicy.IsOn {
|
||||
// 删除某些Header
|
||||
for name := range reqHeader {
|
||||
if this.web.RequestHeaderPolicy.ContainsDeletedHeader(name) {
|
||||
reqHeader.Del(name)
|
||||
}
|
||||
}
|
||||
|
||||
// Add
|
||||
for _, header := range this.web.RequestHeaderPolicy.AddHeaders {
|
||||
if !header.IsOn {
|
||||
continue
|
||||
}
|
||||
oldValues, _ := this.RawReq.Header[header.Name]
|
||||
if header.HasVariables() {
|
||||
oldValues = append(oldValues, this.Format(header.Value))
|
||||
} else {
|
||||
oldValues = append(oldValues, header.Value)
|
||||
}
|
||||
reqHeader[header.Name] = oldValues
|
||||
}
|
||||
|
||||
// Set
|
||||
for _, header := range this.web.RequestHeaderPolicy.SetHeaders {
|
||||
if !header.IsOn {
|
||||
continue
|
||||
}
|
||||
if header.HasVariables() {
|
||||
reqHeader[header.Name] = []string{this.Format(header.Value)}
|
||||
} else {
|
||||
reqHeader[header.Name] = []string{header.Value}
|
||||
}
|
||||
}
|
||||
|
||||
// Replace
|
||||
// TODO 需要实现
|
||||
}
|
||||
}
|
||||
|
||||
// 处理自定义Response Header
|
||||
func (this *HTTPRequest) processResponseHeaders(statusCode int) {
|
||||
responseHeader := this.writer.Header()
|
||||
|
||||
// 删除/添加/替换Header
|
||||
// TODO 实现AddTrailers
|
||||
// TODO 实现ReplaceHeaders
|
||||
if this.web.ResponseHeaderPolicy != nil && this.web.ResponseHeaderPolicy.IsOn {
|
||||
// 删除某些Header
|
||||
for name := range responseHeader {
|
||||
if this.web.ResponseHeaderPolicy.ContainsDeletedHeader(name) {
|
||||
responseHeader.Del(name)
|
||||
}
|
||||
}
|
||||
|
||||
// Add
|
||||
for _, header := range this.web.ResponseHeaderPolicy.AddHeaders {
|
||||
if !header.IsOn {
|
||||
continue
|
||||
}
|
||||
if header.Match(statusCode) {
|
||||
if this.web.ResponseHeaderPolicy.ContainsDeletedHeader(header.Name) {
|
||||
continue
|
||||
}
|
||||
oldValues, _ := responseHeader[header.Name]
|
||||
if header.HasVariables() {
|
||||
oldValues = append(oldValues, this.Format(header.Value))
|
||||
} else {
|
||||
oldValues = append(oldValues, header.Value)
|
||||
}
|
||||
responseHeader[header.Name] = oldValues
|
||||
}
|
||||
}
|
||||
|
||||
// Set
|
||||
for _, header := range this.web.ResponseHeaderPolicy.SetHeaders {
|
||||
if !header.IsOn {
|
||||
continue
|
||||
}
|
||||
if header.Match(statusCode) {
|
||||
if this.web.ResponseHeaderPolicy.ContainsDeletedHeader(header.Name) {
|
||||
continue
|
||||
}
|
||||
if header.HasVariables() {
|
||||
responseHeader[header.Name] = []string{this.Format(header.Value)}
|
||||
} else {
|
||||
responseHeader[header.Name] = []string{header.Value}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Replace
|
||||
// TODO
|
||||
}
|
||||
|
||||
// HSTS
|
||||
if this.IsHTTPS &&
|
||||
this.Server.HTTPS != nil &&
|
||||
this.Server.HTTPS.SSL != nil &&
|
||||
this.Server.HTTPS.SSL.IsOn &&
|
||||
this.Server.HTTPS.SSL.HSTS != nil &&
|
||||
this.Server.HTTPS.SSL.HSTS.IsOn &&
|
||||
this.Server.HTTPS.SSL.HSTS.Match(this.Host) {
|
||||
responseHeader.Set(this.Server.HTTPS.SSL.HSTS.HeaderKey(), this.Server.HTTPS.SSL.HSTS.HeaderValue())
|
||||
}
|
||||
}
|
||||
|
||||
// 添加错误信息
|
||||
func (this *HTTPRequest) addError(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
this.errors = append(this.errors, err.Error())
|
||||
}
|
||||
|
||||
// 日志
|
||||
func (this *HTTPRequest) log() {
|
||||
// 计算请求时间
|
||||
this.requestCost = time.Since(this.requestFromTime).Seconds()
|
||||
}
|
||||
|
||||
// 计算合适的buffer size
|
||||
func (this *HTTPRequest) bytePool(contentLength int64) *utils.BytePool {
|
||||
if contentLength <= 0 {
|
||||
return bytePool1k
|
||||
}
|
||||
if contentLength < 1024 { // 1K
|
||||
return bytePool256b
|
||||
}
|
||||
if contentLength < 32768 { // 32K
|
||||
return bytePool1k
|
||||
}
|
||||
if contentLength < 1048576 { // 1M
|
||||
return bytePool32k
|
||||
}
|
||||
return bytePool128k
|
||||
}
|
||||
28
internal/nodes/http_request_error.go
Normal file
28
internal/nodes/http_request_error.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) writeNotFoundError() {
|
||||
if this.doPage(http.StatusNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
this.processResponseHeaders(http.StatusNotFound)
|
||||
|
||||
msg := "404 page not found: '" + this.RawURI() + "'"
|
||||
|
||||
this.writer.WriteHeader(http.StatusNotFound)
|
||||
_, _ = this.writer.Write([]byte(msg))
|
||||
}
|
||||
|
||||
func (this *HTTPRequest) writeInternalServerError() {
|
||||
statusCode := http.StatusInternalServerError
|
||||
if this.doPage(statusCode) {
|
||||
return
|
||||
}
|
||||
this.processResponseHeaders(statusCode)
|
||||
this.writer.WriteHeader(statusCode)
|
||||
_, _ = this.writer.Write([]byte(http.StatusText(statusCode)))
|
||||
}
|
||||
65
internal/nodes/http_request_page.go
Normal file
65
internal/nodes/http_request_page.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var urlPrefixRegexp = regexp.MustCompile("^(?i)(http|https|ftp)://")
|
||||
|
||||
// 请求特殊页面
|
||||
func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
|
||||
if len(this.web.Pages) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, page := range this.web.Pages {
|
||||
if page.Match(status) {
|
||||
if urlPrefixRegexp.MatchString(page.URL) {
|
||||
this.doURL(http.MethodGet, page.URL, "", page.NewStatus)
|
||||
return true
|
||||
} else {
|
||||
file := Tea.Root + Tea.DS + page.URL
|
||||
fp, err := os.Open(file)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
msg := "404 page not found: '" + page.URL + "'"
|
||||
|
||||
this.writer.WriteHeader(http.StatusNotFound)
|
||||
_, err := this.writer.Write([]byte(msg))
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 修改状态码
|
||||
if page.NewStatus > 0 {
|
||||
// 自定义响应Headers
|
||||
this.processResponseHeaders(page.NewStatus)
|
||||
this.writer.WriteHeader(page.NewStatus)
|
||||
} else {
|
||||
this.processResponseHeaders(status)
|
||||
this.writer.WriteHeader(status)
|
||||
}
|
||||
buf := bytePool1k.Get()
|
||||
_, err = io.CopyBuffer(this.writer, fp, buf)
|
||||
bytePool1k.Put(buf)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
}
|
||||
err = fp.Close()
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
42
internal/nodes/http_request_redirect_https.go
Normal file
42
internal/nodes/http_request_redirect_https.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.HTTPRedirectToHTTPSConfig) {
|
||||
host := this.RawReq.Host
|
||||
|
||||
if len(redirectToHTTPSConfig.Host) > 0 {
|
||||
if redirectToHTTPSConfig.Port > 0 && redirectToHTTPSConfig.Port != 443 {
|
||||
host = redirectToHTTPSConfig.Host + ":" + strconv.Itoa(redirectToHTTPSConfig.Port)
|
||||
} else {
|
||||
host = redirectToHTTPSConfig.Host
|
||||
}
|
||||
} else if redirectToHTTPSConfig.Port > 0 {
|
||||
lastIndex := strings.LastIndex(host, ":")
|
||||
if lastIndex > 0 {
|
||||
if redirectToHTTPSConfig.Port != 443 {
|
||||
host = host[:lastIndex] + ":" + strconv.Itoa(redirectToHTTPSConfig.Port)
|
||||
} else {
|
||||
host = host[:lastIndex]
|
||||
}
|
||||
}
|
||||
} else {
|
||||
lastIndex := strings.LastIndex(host, ":")
|
||||
if lastIndex > 0 {
|
||||
host = host[:lastIndex]
|
||||
}
|
||||
}
|
||||
|
||||
statusCode := http.StatusMovedPermanently
|
||||
if redirectToHTTPSConfig.Status > 0 {
|
||||
statusCode = redirectToHTTPSConfig.Status
|
||||
}
|
||||
|
||||
newURL := "https://" + host + this.RawReq.RequestURI
|
||||
http.Redirect(this.writer, this.RawReq, newURL, statusCode)
|
||||
}
|
||||
72
internal/nodes/http_request_shutdown.go
Normal file
72
internal/nodes/http_request_shutdown.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
// 调用临时关闭页面
|
||||
func (this *HTTPRequest) doShutdown() {
|
||||
shutdown := this.web.Shutdown
|
||||
if shutdown == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if urlPrefixRegexp.MatchString(shutdown.URL) { // URL
|
||||
this.doURL(http.MethodGet, shutdown.URL, "", shutdown.Status)
|
||||
return
|
||||
}
|
||||
|
||||
// URL为空,则显示文本 TODO 未来可以自定义文本
|
||||
if len(shutdown.URL) == 0 {
|
||||
// 自定义响应Headers
|
||||
if shutdown.Status > 0 {
|
||||
this.processResponseHeaders(shutdown.Status)
|
||||
this.writer.WriteHeader(shutdown.Status)
|
||||
} else {
|
||||
this.processResponseHeaders(http.StatusOK)
|
||||
this.writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
_, err := this.writer.WriteString("The site have been shutdown.")
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 从本地文件中读取
|
||||
// TODO 支持从数据库中读取文件
|
||||
file := Tea.Root + Tea.DS + shutdown.URL
|
||||
fp, err := os.Open(file)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
msg := "404 page not found: '" + shutdown.URL + "'"
|
||||
|
||||
this.writer.WriteHeader(http.StatusNotFound)
|
||||
_, err = this.writer.Write([]byte(msg))
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 自定义响应Headers
|
||||
if shutdown.Status > 0 {
|
||||
this.processResponseHeaders(shutdown.Status)
|
||||
this.writer.WriteHeader(shutdown.Status)
|
||||
} else {
|
||||
this.processResponseHeaders(http.StatusOK)
|
||||
this.writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
buf := bytePool1k.Get()
|
||||
_, err = io.CopyBuffer(this.writer, fp, buf)
|
||||
bytePool1k.Put(buf)
|
||||
err = fp.Close()
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
}
|
||||
}
|
||||
35
internal/nodes/http_request_test.go
Normal file
35
internal/nodes/http_request_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHTTPRequest_RedirectToHTTPS(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
{
|
||||
req := &HTTPRequest{
|
||||
Server: &serverconfigs.ServerConfig{
|
||||
Web: &serverconfigs.HTTPWebConfig{
|
||||
RedirectToHttps: &serverconfigs.HTTPRedirectToHTTPSConfig{},
|
||||
},
|
||||
},
|
||||
}
|
||||
req.Run()
|
||||
a.IsBool(req.web.RedirectToHttps.IsOn == false)
|
||||
}
|
||||
{
|
||||
req := &HTTPRequest{
|
||||
Server: &serverconfigs.ServerConfig{
|
||||
Web: &serverconfigs.HTTPWebConfig{
|
||||
RedirectToHttps: &serverconfigs.HTTPRedirectToHTTPSConfig{
|
||||
IsOn: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
req.Run()
|
||||
a.IsBool(req.web.RedirectToHttps.IsOn == true)
|
||||
}
|
||||
}
|
||||
68
internal/nodes/http_request_url.go
Normal file
68
internal/nodes/http_request_url.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 请求某个URL
|
||||
func (this *HTTPRequest) doURL(method string, url string, host string, statusCode int) {
|
||||
req, err := http.NewRequest(method, url, this.RawReq.Body)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// 修改Host
|
||||
if len(host) > 0 {
|
||||
req.Host = this.Format(host)
|
||||
}
|
||||
|
||||
// 添加当前Header
|
||||
req.Header = this.RawReq.Header
|
||||
|
||||
// 代理头部
|
||||
this.setForwardHeaders(req.Header)
|
||||
|
||||
// 自定义请求Header
|
||||
this.processRequestHeaders(req.Header)
|
||||
|
||||
var client = utils.SharedHttpClient(60 * time.Second)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logs.Error(errors.New(req.URL.String() + ": " + err.Error()))
|
||||
this.addError(err)
|
||||
this.writeInternalServerError()
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
// Header
|
||||
if statusCode <= 0 {
|
||||
this.processResponseHeaders(resp.StatusCode)
|
||||
} else {
|
||||
this.processResponseHeaders(statusCode)
|
||||
}
|
||||
|
||||
this.writer.AddHeaders(resp.Header)
|
||||
this.writer.Prepare(resp.ContentLength)
|
||||
|
||||
// 设置响应代码
|
||||
if statusCode <= 0 {
|
||||
this.writer.WriteHeader(resp.StatusCode)
|
||||
} else {
|
||||
this.writer.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// 输出内容
|
||||
pool := this.bytePool(resp.ContentLength)
|
||||
buf := pool.Get()
|
||||
_, err = io.CopyBuffer(this.writer, resp.Body, buf)
|
||||
pool.Put(buf)
|
||||
}
|
||||
246
internal/nodes/http_writer.go
Normal file
246
internal/nodes/http_writer.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// 响应Writer
|
||||
type HTTPWriter struct {
|
||||
writer http.ResponseWriter
|
||||
|
||||
gzipConfig *serverconfigs.HTTPGzipConfig
|
||||
gzipWriter *gzip.Writer
|
||||
|
||||
statusCode int
|
||||
sentBodyBytes int64
|
||||
|
||||
bodyCopying bool
|
||||
body []byte
|
||||
gzipBodyBuffer *bytes.Buffer // 当使用gzip压缩时使用
|
||||
gzipBodyWriter *gzip.Writer // 当使用gzip压缩时使用
|
||||
}
|
||||
|
||||
// 包装对象
|
||||
func NewHTTPWriter(httpResponseWriter http.ResponseWriter) *HTTPWriter {
|
||||
return &HTTPWriter{
|
||||
writer: httpResponseWriter,
|
||||
}
|
||||
}
|
||||
|
||||
// 重置
|
||||
func (this *HTTPWriter) Reset(httpResponseWriter http.ResponseWriter) {
|
||||
this.writer = httpResponseWriter
|
||||
|
||||
this.gzipConfig = nil
|
||||
this.gzipWriter = nil
|
||||
|
||||
this.statusCode = 0
|
||||
this.sentBodyBytes = 0
|
||||
|
||||
this.bodyCopying = false
|
||||
this.body = nil
|
||||
this.gzipBodyBuffer = nil
|
||||
this.gzipBodyWriter = nil
|
||||
}
|
||||
|
||||
// 设置Gzip
|
||||
func (this *HTTPWriter) Gzip(config *serverconfigs.HTTPGzipConfig) {
|
||||
this.gzipConfig = config
|
||||
}
|
||||
|
||||
// 准备输出
|
||||
func (this *HTTPWriter) Prepare(size int64) {
|
||||
if this.gzipConfig == nil || this.gzipConfig.Level <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 尺寸和类型
|
||||
if size < this.gzipConfig.MinBytes() {
|
||||
return
|
||||
}
|
||||
|
||||
contentType := this.Header().Get("Content-Type")
|
||||
if !this.gzipConfig.MatchContentType(contentType) {
|
||||
return
|
||||
}
|
||||
|
||||
// 如果已经有编码则不处理
|
||||
if len(this.writer.Header().Get("Content-Encoding")) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// gzip writer
|
||||
var err error = nil
|
||||
this.gzipWriter, err = gzip.NewWriterLevel(this.writer, int(this.gzipConfig.Level))
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// body copy
|
||||
if this.bodyCopying {
|
||||
this.gzipBodyBuffer = bytes.NewBuffer([]byte{})
|
||||
this.gzipBodyWriter, err = gzip.NewWriterLevel(this.gzipBodyBuffer, int(this.gzipConfig.Level))
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
header := this.writer.Header()
|
||||
header.Set("Content-Encoding", "gzip")
|
||||
header.Set("Transfer-Encoding", "chunked")
|
||||
header.Set("Vary", "Accept-Encoding")
|
||||
header.Del("Content-Length")
|
||||
}
|
||||
|
||||
// 包装前的原始的Writer
|
||||
func (this *HTTPWriter) Raw() http.ResponseWriter {
|
||||
return this.writer
|
||||
}
|
||||
|
||||
// 获取Header
|
||||
func (this *HTTPWriter) Header() http.Header {
|
||||
if this.writer == nil {
|
||||
return http.Header{}
|
||||
}
|
||||
return this.writer.Header()
|
||||
}
|
||||
|
||||
// 添加一组Header
|
||||
func (this *HTTPWriter) AddHeaders(header http.Header) {
|
||||
if this.writer == nil {
|
||||
return
|
||||
}
|
||||
for key, value := range header {
|
||||
for _, v := range value {
|
||||
this.writer.Header().Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 写入数据
|
||||
func (this *HTTPWriter) Write(data []byte) (n int, err error) {
|
||||
if this.writer != nil {
|
||||
if this.gzipWriter != nil {
|
||||
n, err = this.gzipWriter.Write(data)
|
||||
} else {
|
||||
n, err = this.writer.Write(data)
|
||||
}
|
||||
if n > 0 {
|
||||
this.sentBodyBytes += int64(n)
|
||||
}
|
||||
} else {
|
||||
if n == 0 {
|
||||
n = len(data) // 防止出现short write错误
|
||||
}
|
||||
}
|
||||
if this.bodyCopying {
|
||||
if this.gzipBodyWriter != nil {
|
||||
_, err := this.gzipBodyWriter.Write(data)
|
||||
if err != nil {
|
||||
logs.Error(err)
|
||||
}
|
||||
} else {
|
||||
this.body = append(this.body, data...)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 写入字符串
|
||||
func (this *HTTPWriter) WriteString(s string) (n int, err error) {
|
||||
return this.Write([]byte(s))
|
||||
}
|
||||
|
||||
// 读取发送的字节数
|
||||
func (this *HTTPWriter) SentBodyBytes() int64 {
|
||||
return this.sentBodyBytes
|
||||
}
|
||||
|
||||
// 写入状态码
|
||||
func (this *HTTPWriter) WriteHeader(statusCode int) {
|
||||
if this.writer != nil {
|
||||
this.writer.WriteHeader(statusCode)
|
||||
}
|
||||
this.statusCode = statusCode
|
||||
}
|
||||
|
||||
// 读取状态码
|
||||
func (this *HTTPWriter) StatusCode() int {
|
||||
if this.statusCode == 0 {
|
||||
return http.StatusOK
|
||||
}
|
||||
return this.statusCode
|
||||
}
|
||||
|
||||
// 设置拷贝Body数据
|
||||
func (this *HTTPWriter) SetBodyCopying(b bool) {
|
||||
this.bodyCopying = b
|
||||
}
|
||||
|
||||
// 判断是否在拷贝Body数据
|
||||
func (this *HTTPWriter) BodyIsCopying() bool {
|
||||
return this.bodyCopying
|
||||
}
|
||||
|
||||
// 读取拷贝的Body数据
|
||||
func (this *HTTPWriter) Body() []byte {
|
||||
return this.body
|
||||
}
|
||||
|
||||
// 读取Header二进制数据
|
||||
func (this *HTTPWriter) HeaderData() []byte {
|
||||
if this.writer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
resp := &http.Response{}
|
||||
resp.Header = this.Header()
|
||||
if this.statusCode == 0 {
|
||||
this.statusCode = http.StatusOK
|
||||
}
|
||||
resp.StatusCode = this.statusCode
|
||||
resp.ProtoMajor = 1
|
||||
resp.ProtoMinor = 1
|
||||
|
||||
resp.ContentLength = 1 // Trick:这样可以屏蔽Content-Length
|
||||
|
||||
writer := bytes.NewBuffer([]byte{})
|
||||
_ = resp.Write(writer)
|
||||
return writer.Bytes()
|
||||
}
|
||||
|
||||
// 关闭
|
||||
func (this *HTTPWriter) Close() {
|
||||
if this.gzipWriter != nil {
|
||||
if this.bodyCopying && this.gzipBodyWriter != nil {
|
||||
_ = this.gzipBodyWriter.Close()
|
||||
this.body = this.gzipBodyBuffer.Bytes()
|
||||
}
|
||||
_ = this.gzipWriter.Close()
|
||||
this.gzipWriter = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack
|
||||
func (this *HTTPWriter) Hijack() (conn net.Conn, buf *bufio.ReadWriter, err error) {
|
||||
hijack, ok := this.writer.(http.Hijacker)
|
||||
if ok {
|
||||
return hijack.Hijack()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Flush
|
||||
func (this *HTTPWriter) Flush() {
|
||||
flusher, ok := this.writer.(http.Flusher)
|
||||
if ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func (this *Listener) Listen() error {
|
||||
Listener: netListener,
|
||||
}
|
||||
default:
|
||||
return errors.New("unknown protocol '" + protocol + "'")
|
||||
return errors.New("unknown protocol '" + protocol.String() + "'")
|
||||
}
|
||||
|
||||
this.listener.Init()
|
||||
|
||||
@@ -80,7 +80,7 @@ func (this *BaseListener) matchSSL(group *serverconfigs.ServerGroup, domain stri
|
||||
// 如果域名为空,则取第一个
|
||||
// 通常域名为空是因为是直接通过IP访问的
|
||||
if len(domain) == 0 {
|
||||
if serverconfigs.SharedGlobalConfig().HTTPAll.MatchDomainStrictly {
|
||||
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
|
||||
return nil, nil, errors.New("no tls server name matched")
|
||||
}
|
||||
|
||||
@@ -148,7 +148,7 @@ func (this *BaseListener) findNamedServer(group *serverconfigs.ServerGroup, name
|
||||
maxNamedServers := 10240
|
||||
|
||||
// 是否严格匹配域名
|
||||
matchDomainStrictly := serverconfigs.SharedGlobalConfig().HTTPAll.MatchDomainStrictly
|
||||
matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
|
||||
|
||||
// 如果只有一个server,则默认为这个
|
||||
if countServers == 1 && !matchDomainStrictly {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"golang.org/x/net/http2"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -15,6 +16,9 @@ type HTTPListener struct {
|
||||
Group *serverconfigs.ServerGroup
|
||||
Listener net.Listener
|
||||
|
||||
addr string
|
||||
isHTTP bool
|
||||
isHTTPS bool
|
||||
httpServer *http.Server
|
||||
}
|
||||
|
||||
@@ -24,15 +28,19 @@ func (this *HTTPListener) Serve() error {
|
||||
this.handleHTTP(writer, request)
|
||||
})
|
||||
|
||||
this.addr = this.Group.Addr()
|
||||
this.isHTTP = this.Group.IsHTTP()
|
||||
this.isHTTPS = this.Group.IsHTTPS()
|
||||
|
||||
this.httpServer = &http.Server{
|
||||
Addr: this.Group.Addr(),
|
||||
Addr: this.addr,
|
||||
Handler: handler,
|
||||
IdleTimeout: 2 * time.Minute,
|
||||
}
|
||||
this.httpServer.SetKeepAlivesEnabled(true)
|
||||
|
||||
// HTTP协议
|
||||
if this.Group.IsHTTP() {
|
||||
if this.isHTTP {
|
||||
err := this.httpServer.Serve(this.Listener)
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
return err
|
||||
@@ -40,7 +48,7 @@ func (this *HTTPListener) Serve() error {
|
||||
}
|
||||
|
||||
// HTTPS协议
|
||||
if this.Group.IsHTTPS() {
|
||||
if this.isHTTPS {
|
||||
this.httpServer.TLSConfig = this.buildTLSConfig(this.Group)
|
||||
|
||||
// support http/2
|
||||
@@ -65,6 +73,86 @@ func (this *HTTPListener) Close() error {
|
||||
return this.Listener.Close()
|
||||
}
|
||||
|
||||
func (this *HTTPListener) handleHTTP(writer http.ResponseWriter, req *http.Request) {
|
||||
writer.Write([]byte("Hello, World"))
|
||||
// 处理HTTP请求
|
||||
func (this *HTTPListener) handleHTTP(rawWriter http.ResponseWriter, rawReq *http.Request) {
|
||||
// 域名
|
||||
reqHost := rawReq.Host
|
||||
|
||||
// TLS域名
|
||||
if this.isIP(reqHost) {
|
||||
if rawReq.TLS != nil {
|
||||
serverName := rawReq.TLS.ServerName
|
||||
if len(serverName) > 0 {
|
||||
// 端口
|
||||
index := strings.LastIndex(reqHost, ":")
|
||||
if index >= 0 {
|
||||
reqHost = serverName + reqHost[index:]
|
||||
} else {
|
||||
reqHost = serverName
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 防止空Host
|
||||
if len(reqHost) == 0 {
|
||||
ctx := rawReq.Context()
|
||||
if ctx != nil {
|
||||
addr := ctx.Value(http.LocalAddrContextKey)
|
||||
if addr != nil {
|
||||
reqHost = addr.(net.Addr).String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
domain, _, err := net.SplitHostPort(reqHost)
|
||||
if err != nil {
|
||||
domain = reqHost
|
||||
}
|
||||
|
||||
server, serverName := this.findNamedServer(this.Group, domain)
|
||||
if server == nil {
|
||||
// 严格匹配域名模式下,我们拒绝用户访问
|
||||
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
|
||||
hijacker, ok := rawWriter.(http.Hijacker)
|
||||
if ok {
|
||||
conn, _, _ := hijacker.Hijack()
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
http.Error(rawWriter, "404 page not found: '"+rawReq.URL.String()+"'", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// 包装新请求对象
|
||||
req := &HTTPRequest{
|
||||
RawReq: rawReq,
|
||||
RawWriter: rawWriter,
|
||||
Server: server,
|
||||
Host: reqHost,
|
||||
ServerName: serverName,
|
||||
ServerAddr: this.addr,
|
||||
IsHTTP: this.isHTTP,
|
||||
IsHTTPS: this.isHTTPS,
|
||||
}
|
||||
req.Do()
|
||||
}
|
||||
|
||||
func (this *HTTPListener) isIP(host string) bool {
|
||||
// IPv6
|
||||
if strings.Index(host, "[") > -1 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, b := range host {
|
||||
if b >= 'a' && b <= 'z' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"net/url"
|
||||
@@ -14,7 +14,7 @@ var sharedListenerManager = NewListenerManager()
|
||||
type ListenerManager struct {
|
||||
listenersMap map[string]*Listener // addr => *Listener
|
||||
locker sync.Mutex
|
||||
lastConfig *configs.NodeConfig
|
||||
lastConfig *nodeconfigs.NodeConfig
|
||||
}
|
||||
|
||||
func NewListenerManager() *ListenerManager {
|
||||
@@ -23,7 +23,7 @@ func NewListenerManager() *ListenerManager {
|
||||
}
|
||||
}
|
||||
|
||||
func (this *ListenerManager) Start(node *configs.NodeConfig) error {
|
||||
func (this *ListenerManager) Start(node *nodeconfigs.NodeConfig) error {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestListenerManager_Listen(t *testing.T) {
|
||||
manager := NewListenerManager()
|
||||
err := manager.Start(&configs.NodeConfig{
|
||||
Servers: []*configs.ServerConfig{
|
||||
err := manager.Start(&nodeconfigs.NodeConfig{
|
||||
Servers: []*serverconfigs.ServerConfig{
|
||||
{
|
||||
IsOn: true,
|
||||
HTTP: &configs.HTTPProtocolConfig{
|
||||
BaseProtocol: configs.BaseProtocol{
|
||||
HTTP: &serverconfigs.HTTPProtocolConfig{
|
||||
BaseProtocol: serverconfigs.BaseProtocol{
|
||||
IsOn: true,
|
||||
Listen: []*configs.NetworkAddressConfig{
|
||||
Listen: []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: configs.ProtocolHTTP,
|
||||
Protocol: serverconfigs.ProtocolHTTP,
|
||||
PortRange: "1234",
|
||||
},
|
||||
},
|
||||
@@ -25,12 +26,12 @@ func TestListenerManager_Listen(t *testing.T) {
|
||||
},
|
||||
{
|
||||
IsOn: true,
|
||||
HTTP: &configs.HTTPProtocolConfig{
|
||||
BaseProtocol: configs.BaseProtocol{
|
||||
HTTP: &serverconfigs.HTTPProtocolConfig{
|
||||
BaseProtocol: serverconfigs.BaseProtocol{
|
||||
IsOn: true,
|
||||
Listen: []*configs.NetworkAddressConfig{
|
||||
Listen: []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: configs.ProtocolHTTP,
|
||||
Protocol: serverconfigs.ProtocolHTTP,
|
||||
PortRange: "1235",
|
||||
},
|
||||
},
|
||||
@@ -43,16 +44,16 @@ func TestListenerManager_Listen(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.Start(&configs.NodeConfig{
|
||||
Servers: []*configs.ServerConfig{
|
||||
err = manager.Start(&nodeconfigs.NodeConfig{
|
||||
Servers: []*serverconfigs.ServerConfig{
|
||||
{
|
||||
IsOn: true,
|
||||
HTTP: &configs.HTTPProtocolConfig{
|
||||
BaseProtocol: configs.BaseProtocol{
|
||||
HTTP: &serverconfigs.HTTPProtocolConfig{
|
||||
BaseProtocol: serverconfigs.BaseProtocol{
|
||||
IsOn: true,
|
||||
Listen: []*configs.NetworkAddressConfig{
|
||||
Listen: []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: configs.ProtocolHTTP,
|
||||
Protocol: serverconfigs.ProtocolHTTP,
|
||||
PortRange: "1234",
|
||||
},
|
||||
},
|
||||
@@ -61,12 +62,12 @@ func TestListenerManager_Listen(t *testing.T) {
|
||||
},
|
||||
{
|
||||
IsOn: true,
|
||||
HTTP: &configs.HTTPProtocolConfig{
|
||||
BaseProtocol: configs.BaseProtocol{
|
||||
HTTP: &serverconfigs.HTTPProtocolConfig{
|
||||
BaseProtocol: serverconfigs.BaseProtocol{
|
||||
IsOn: true,
|
||||
Listen: []*configs.NetworkAddressConfig{
|
||||
Listen: []*serverconfigs.NetworkAddressConfig{
|
||||
{
|
||||
Protocol: configs.ProtocolHTTP,
|
||||
Protocol: serverconfigs.ProtocolHTTP,
|
||||
PortRange: "1236",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestListener_Listen(t *testing.T) {
|
||||
listener := NewListener()
|
||||
|
||||
group := configs.NewServerGroup("http://:1234")
|
||||
group := serverconfigs.NewServerGroup("http://:1234")
|
||||
|
||||
listener.Reload(group)
|
||||
err := listener.Listen()
|
||||
|
||||
@@ -3,8 +3,8 @@ package nodes
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
@@ -12,7 +12,8 @@ import (
|
||||
)
|
||||
|
||||
var stop = make(chan bool)
|
||||
var lastVersion = -1
|
||||
var lastVersion = int64(-1)
|
||||
var sharedNodeConfig *nodeconfigs.NodeConfig
|
||||
|
||||
// 节点
|
||||
type Node struct {
|
||||
@@ -36,11 +37,17 @@ func (this *Node) Start() {
|
||||
go NewNodeStatusExecutor().Listen()
|
||||
|
||||
// 读取配置
|
||||
nodeConfig, err := configs.SharedNodeConfig()
|
||||
nodeConfig, err := nodeconfigs.SharedNodeConfig()
|
||||
if err != nil {
|
||||
logs.Println("[NODE]start failed: read node config failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
err = nodeConfig.Init()
|
||||
if err != nil {
|
||||
logs.Println("[NODE]init node config failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
sharedNodeConfig = nodeConfig
|
||||
|
||||
// 设置rlimit
|
||||
_ = utils.SetRLimit(1024 * 1024)
|
||||
@@ -61,13 +68,14 @@ func (this *Node) syncConfig(isFirstTime bool) error {
|
||||
if err != nil {
|
||||
return errors.New("[NODE]create rpc client failed: " + err.Error())
|
||||
}
|
||||
// TODO 这里考虑只同步版本号有变更的
|
||||
configResp, err := rpcClient.NodeRPC().ComposeNodeConfig(rpcClient.Context(), &pb.ComposeNodeConfigRequest{})
|
||||
if err != nil {
|
||||
return errors.New("[NODE]read config from rpc failed: " + err.Error())
|
||||
}
|
||||
configBytes := configResp.ConfigJSON
|
||||
nodeConfig := &configs.NodeConfig{}
|
||||
err = json.Unmarshal(configBytes, nodeConfig)
|
||||
configJSON := configResp.NodeJSON
|
||||
nodeConfig := &nodeconfigs.NodeConfig{}
|
||||
err = json.Unmarshal(configJSON, nodeConfig)
|
||||
if err != nil {
|
||||
return errors.New("[NODE]decode config failed: " + err.Error())
|
||||
}
|
||||
@@ -84,12 +92,16 @@ func (this *Node) syncConfig(isFirstTime bool) error {
|
||||
}
|
||||
lastVersion = nodeConfig.Version
|
||||
|
||||
// 刷新配置
|
||||
err = configs.ReloadNodeConfig()
|
||||
err = nodeConfig.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 刷新配置
|
||||
logs.Println("[NODE]reload config ...")
|
||||
nodeconfigs.ResetNodeConfig(nodeConfig)
|
||||
sharedNodeConfig = nodeConfig
|
||||
|
||||
if !isFirstTime {
|
||||
return sharedListenerManager.Start(nodeConfig)
|
||||
}
|
||||
@@ -99,7 +111,8 @@ func (this *Node) syncConfig(isFirstTime bool) error {
|
||||
|
||||
// 启动同步计时器
|
||||
func (this *Node) startSyncTimer() {
|
||||
ticker := time.NewTicker(60 * time.Second)
|
||||
// TODO 这个时间间隔可以自行设置
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
err := this.syncConfig(false)
|
||||
|
||||
@@ -2,7 +2,9 @@ package nodes
|
||||
|
||||
// 节点状态
|
||||
type NodeStatus struct {
|
||||
Version string `json:"version"`
|
||||
BuildVersion string `json:"buildVersion"` // 编译版本
|
||||
ConfigVersion int64 `json:"configVersion"` // 节点配置版本
|
||||
|
||||
Hostname string `json:"hostname"`
|
||||
HostIP string `json:"hostIP"`
|
||||
CPUUsage float64 `json:"cpuUsage"`
|
||||
|
||||
@@ -32,7 +32,8 @@ func (this *NodeStatusExecutor) Listen() {
|
||||
this.cpuUpdatedTime = time.Now()
|
||||
this.update()
|
||||
|
||||
ticker := time.NewTicker(60 * time.Second)
|
||||
// TODO 这个时间间隔可以配置
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
for range ticker.C {
|
||||
this.isFirstTime = false
|
||||
this.update()
|
||||
@@ -40,8 +41,13 @@ func (this *NodeStatusExecutor) Listen() {
|
||||
}
|
||||
|
||||
func (this *NodeStatusExecutor) update() {
|
||||
if sharedNodeConfig == nil {
|
||||
return
|
||||
}
|
||||
|
||||
status := &NodeStatus{}
|
||||
status.Version = teaconst.Version
|
||||
status.BuildVersion = teaconst.Version
|
||||
status.ConfigVersion = sharedNodeConfig.Version
|
||||
status.IsActive = true
|
||||
|
||||
hostname, _ := os.Hostname()
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
package nodes
|
||||
|
||||
type OriginServer struct {
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
package nodes
|
||||
|
||||
type Request struct {
|
||||
}
|
||||
@@ -41,6 +41,10 @@ func TestSharedRPC_Stream(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = client.Send(&pb.NodeStreamRequest{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for {
|
||||
resp, err := client.Recv()
|
||||
if err != nil {
|
||||
|
||||
90
internal/utils/byte_pool.go
Normal file
90
internal/utils/byte_pool.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// pool for get byte slice
|
||||
type BytePool struct {
|
||||
c chan []byte
|
||||
length int
|
||||
ticker *Ticker
|
||||
|
||||
lastSize int
|
||||
}
|
||||
|
||||
// 创建新对象
|
||||
func NewBytePool(maxSize, length int) *BytePool {
|
||||
if maxSize <= 0 {
|
||||
maxSize = 1024
|
||||
}
|
||||
if length <= 0 {
|
||||
length = 128
|
||||
}
|
||||
pool := &BytePool{
|
||||
c: make(chan []byte, maxSize),
|
||||
length: length,
|
||||
}
|
||||
pool.start()
|
||||
return pool
|
||||
}
|
||||
|
||||
func (this *BytePool) start() {
|
||||
// 清除Timer
|
||||
this.ticker = NewTicker(1 * time.Minute)
|
||||
go func() {
|
||||
for this.ticker.Next() {
|
||||
currentSize := len(this.c)
|
||||
if currentSize <= 32 || this.lastSize == 0 || this.lastSize != currentSize {
|
||||
this.lastSize = currentSize
|
||||
continue
|
||||
}
|
||||
|
||||
i := 0
|
||||
For:
|
||||
for {
|
||||
select {
|
||||
case _ = <-this.c:
|
||||
i++
|
||||
if i >= currentSize/2 {
|
||||
break For
|
||||
}
|
||||
default:
|
||||
break For
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 获取一个新的byte slice
|
||||
func (this *BytePool) Get() (b []byte) {
|
||||
select {
|
||||
case b = <-this.c:
|
||||
default:
|
||||
b = make([]byte, this.length)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 放回一个使用过的byte slice
|
||||
func (this *BytePool) Put(b []byte) {
|
||||
if cap(b) != this.length {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case this.c <- b:
|
||||
default:
|
||||
// 已达最大容量,则抛弃
|
||||
}
|
||||
}
|
||||
|
||||
// 当前的数量
|
||||
func (this *BytePool) Size() int {
|
||||
return len(this.c)
|
||||
}
|
||||
|
||||
// 销毁
|
||||
func (this *BytePool) Destroy() {
|
||||
this.ticker.Stop()
|
||||
}
|
||||
41
internal/utils/byte_pool_test.go
Normal file
41
internal/utils/byte_pool_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewBytePool(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
|
||||
pool := NewBytePool(5, 8)
|
||||
buf := pool.Get()
|
||||
a.IsTrue(len(buf) == 8)
|
||||
a.IsTrue(len(pool.c) == 0)
|
||||
|
||||
pool.Put(buf)
|
||||
a.IsTrue(len(pool.c) == 1)
|
||||
|
||||
pool.Get()
|
||||
a.IsTrue(len(pool.c) == 0)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
pool.Put(buf)
|
||||
}
|
||||
t.Log(len(pool.c))
|
||||
a.IsTrue(len(pool.c) == 5)
|
||||
}
|
||||
|
||||
func BenchmarkBytePool_Get(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
pool := NewBytePool(1024, 1)
|
||||
for i := 0; i < b.N; i++ {
|
||||
buf := pool.Get()
|
||||
_ = buf
|
||||
pool.Put(buf)
|
||||
}
|
||||
|
||||
b.Log(pool.Size())
|
||||
}
|
||||
53
internal/utils/http.go
Normal file
53
internal/utils/http.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTP请求客户端管理
|
||||
var timeoutClientMap = map[time.Duration]*http.Client{} // timeout => Client
|
||||
var timeoutClientLocker = sync.Mutex{}
|
||||
|
||||
// 导出响应
|
||||
func DumpResponse(resp *http.Response) (header []byte, body []byte, err error) {
|
||||
header, err = httputil.DumpResponse(resp, false)
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取一个新的Client
|
||||
func NewHTTPClient(timeout time.Duration) *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 4096,
|
||||
MaxIdleConnsPerHost: 32,
|
||||
MaxConnsPerHost: 32,
|
||||
IdleConnTimeout: 2 * time.Minute,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSHandshakeTimeout: 0,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 获取一个公用的Client
|
||||
func SharedHttpClient(timeout time.Duration) *http.Client {
|
||||
timeoutClientLocker.Lock()
|
||||
defer timeoutClientLocker.Unlock()
|
||||
|
||||
client, ok := timeoutClientMap[timeout]
|
||||
if ok {
|
||||
return client
|
||||
}
|
||||
client = NewHTTPClient(timeout)
|
||||
timeoutClientMap[timeout] = client
|
||||
return client
|
||||
}
|
||||
32
internal/utils/http_test.go
Normal file
32
internal/utils/http_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewHTTPClient(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
|
||||
client := NewHTTPClient(1 * time.Second)
|
||||
a.IsTrue(client.Timeout == 1*time.Second)
|
||||
|
||||
client2 := NewHTTPClient(1 * time.Second)
|
||||
a.IsTrue(client != client2)
|
||||
}
|
||||
|
||||
func TestSharedHTTPClient(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
|
||||
_ = SharedHttpClient(2 * time.Second)
|
||||
_ = SharedHttpClient(3 * time.Second)
|
||||
|
||||
client := SharedHttpClient(1 * time.Second)
|
||||
a.IsTrue(client.Timeout == 1*time.Second)
|
||||
|
||||
client2 := SharedHttpClient(1 * time.Second)
|
||||
a.IsTrue(client == client2)
|
||||
|
||||
t.Log(timeoutClientMap)
|
||||
}
|
||||
47
internal/utils/ticker.go
Normal file
47
internal/utils/ticker.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// 类似于time.Ticker,但能够真正地停止
|
||||
type Ticker struct {
|
||||
raw *time.Ticker
|
||||
|
||||
S chan bool
|
||||
C <-chan time.Time
|
||||
|
||||
isStopped bool
|
||||
}
|
||||
|
||||
// 创建新Ticker
|
||||
func NewTicker(duration time.Duration) *Ticker {
|
||||
raw := time.NewTicker(duration)
|
||||
return &Ticker{
|
||||
raw: raw,
|
||||
C: raw.C,
|
||||
S: make(chan bool, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// 查找下一个Tick
|
||||
func (this *Ticker) Next() bool {
|
||||
select {
|
||||
case <-this.raw.C:
|
||||
return true
|
||||
case <-this.S:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 停止
|
||||
func (this *Ticker) Stop() {
|
||||
if this.isStopped {
|
||||
return
|
||||
}
|
||||
|
||||
this.isStopped = true
|
||||
|
||||
this.raw.Stop()
|
||||
this.S <- true
|
||||
}
|
||||
52
internal/utils/ticker_test.go
Normal file
52
internal/utils/ticker_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTicker(t *testing.T) {
|
||||
ticker := NewTicker(3 * time.Second)
|
||||
go func() {
|
||||
time.Sleep(10 * time.Second)
|
||||
ticker.Stop()
|
||||
}()
|
||||
for ticker.Next() {
|
||||
logs.Println("tick")
|
||||
}
|
||||
t.Log("finished")
|
||||
}
|
||||
|
||||
func TestTicker2(t *testing.T) {
|
||||
ticker := NewTicker(1 * time.Second)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Second)
|
||||
ticker.Stop()
|
||||
}()
|
||||
for {
|
||||
logs.Println("loop")
|
||||
select {
|
||||
case <-ticker.C:
|
||||
logs.Println("tick")
|
||||
case <-ticker.S:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTickerEvery(t *testing.T) {
|
||||
i := 0
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
Every(2*time.Second, func(ticker *Ticker) {
|
||||
i++
|
||||
logs.Println("TestTickerEvery i:", i)
|
||||
if i >= 4 {
|
||||
ticker.Stop()
|
||||
wg.Done()
|
||||
}
|
||||
})
|
||||
wg.Wait()
|
||||
}
|
||||
15
internal/utils/ticker_utils.go
Normal file
15
internal/utils/ticker_utils.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package utils
|
||||
|
||||
import "time"
|
||||
|
||||
// 定时运行某个函数
|
||||
func Every(duration time.Duration, f func(ticker *Ticker)) *Ticker {
|
||||
ticker := NewTicker(duration)
|
||||
go func() {
|
||||
for ticker.Next() {
|
||||
f(ticker)
|
||||
}
|
||||
}()
|
||||
|
||||
return ticker
|
||||
}
|
||||
Reference in New Issue
Block a user