实现websocket基本功能

This commit is contained in:
刘祥超
2020-09-26 19:54:15 +08:00
parent 8b11e7172a
commit 08f0788349
14 changed files with 375 additions and 72 deletions

View File

@@ -69,6 +69,7 @@ func (this *APINode) listenRPC() error {
pb.RegisterHTTPCachePolicyServiceServer(rpcServer, &services.HTTPCachePolicyService{})
pb.RegisterHTTPFirewallPolicyServiceServer(rpcServer, &services.HTTPFirewallPolicyService{})
pb.RegisterHTTPLocationServiceServer(rpcServer, &services.HTTPLocationService{})
pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{})
err = rpcServer.Serve(listener)
if err != nil {
return errors.New("[API]start rpc failed: " + err.Error())

View File

@@ -260,6 +260,25 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon
config.RedirectToHttps = redirectToHTTPSConfig
}
// Websocket
if IsNotNull(web.Websocket) {
ref := &serverconfigs.HTTPWebsocketRef{}
err = json.Unmarshal([]byte(web.Websocket), ref)
if err != nil {
return nil, err
}
config.WebsocketRef = ref
if ref.WebsocketId > 0 {
websocketConfig, err := SharedHTTPWebsocketDAO.ComposeWebsocketConfig(ref.WebsocketId)
if err != nil {
return nil, err
}
if websocketConfig != nil {
config.Websocket = websocketConfig
}
}
}
return config, nil
}
@@ -267,7 +286,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon
func (this *HTTPWebDAO) CreateWeb(rootJSON []byte) (int64, error) {
op := NewHTTPWebOperator()
op.State = HTTPWebStateEnabled
op.Root = rootJSON
op.Root = JSONBytes(rootJSON)
_, err := this.Save(op)
if err != nil {
return 0, err
@@ -282,7 +301,7 @@ func (this *HTTPWebDAO) UpdateWeb(webId int64, rootJSON []byte) error {
}
op := NewHTTPWebOperator()
op.Id = webId
op.Root = rootJSON
op.Root = JSONBytes(rootJSON)
_, err := this.Save(op)
return err
}
@@ -294,7 +313,7 @@ func (this *HTTPWebDAO) UpdateWebGzip(webId int64, gzipJSON []byte) error {
}
op := NewHTTPWebOperator()
op.Id = webId
op.Gzip = gzipJSON
op.Gzip = JSONBytes(gzipJSON)
_, err := this.Save(op)
return err
}
@@ -306,7 +325,7 @@ func (this *HTTPWebDAO) UpdateWebCharset(webId int64, charsetJSON []byte) error
}
op := NewHTTPWebOperator()
op.Id = webId
op.Charset = charsetJSON
op.Charset = JSONBytes(charsetJSON)
_, err := this.Save(op)
return err
}
@@ -430,3 +449,15 @@ func (this *HTTPWebDAO) UpdateWebRedirectToHTTPS(webId int64, redirectToHTTPSJSO
_, err := this.Save(op)
return err
}
// 修改Websocket设置
func (this *HTTPWebDAO) UpdateWebsocket(webId int64, websocketJSON []byte) error {
if webId <= 0 {
return errors.New("invalid webId")
}
op := NewHTTPWebOperator()
op.Id = webId
op.Websocket = JSONBytes(websocketJSON)
_, err := this.Save(op)
return err
}

View File

@@ -9,7 +9,7 @@ type HTTPWeb struct {
UserId uint32 `field:"userId"` // 用户ID
State uint8 `field:"state"` // 状态
CreatedAt uint64 `field:"createdAt"` // 创建时间
Root string `field:"root"` // 资源根目录
Root string `field:"root"` // 根目录
Charset string `field:"charset"` // 字符集
Shutdown string `field:"shutdown"` // 临时关闭页面配置
Pages string `field:"pages"` // 特殊页面
@@ -24,6 +24,7 @@ type HTTPWeb struct {
Cache string `field:"cache"` // 缓存配置
Firewall string `field:"firewall"` // 防火墙设置
Locations string `field:"locations"` // 路径规则配置
Websocket string `field:"websocket"` // Websocket设置
}
type HTTPWebOperator struct {
@@ -34,7 +35,7 @@ type HTTPWebOperator struct {
UserId interface{} // 用户ID
State interface{} // 状态
CreatedAt interface{} // 创建时间
Root interface{} // 资源根目录
Root interface{} // 根目录
Charset interface{} // 字符集
Shutdown interface{} // 临时关闭页面配置
Pages interface{} // 特殊页面
@@ -49,6 +50,7 @@ type HTTPWebOperator struct {
Cache interface{} // 缓存配置
Firewall interface{} // 防火墙设置
Locations interface{} // 路径规则配置
Websocket interface{} // Websocket设置
}
func NewHTTPWebOperator() *HTTPWebOperator {

View File

@@ -0,0 +1,146 @@
package models
import (
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
)
const (
HTTPWebsocketStateEnabled = 1 // 已启用
HTTPWebsocketStateDisabled = 0 // 已禁用
)
type HTTPWebsocketDAO dbs.DAO
func NewHTTPWebsocketDAO() *HTTPWebsocketDAO {
return dbs.NewDAO(&HTTPWebsocketDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgeHTTPWebsockets",
Model: new(HTTPWebsocket),
PkName: "id",
},
}).(*HTTPWebsocketDAO)
}
var SharedHTTPWebsocketDAO = NewHTTPWebsocketDAO()
// 启用条目
func (this *HTTPWebsocketDAO) EnableHTTPWebsocket(id int64) error {
_, err := this.Query().
Pk(id).
Set("state", HTTPWebsocketStateEnabled).
Update()
return err
}
// 禁用条目
func (this *HTTPWebsocketDAO) DisableHTTPWebsocket(id int64) error {
_, err := this.Query().
Pk(id).
Set("state", HTTPWebsocketStateDisabled).
Update()
return err
}
// 查找启用中的条目
func (this *HTTPWebsocketDAO) FindEnabledHTTPWebsocket(id int64) (*HTTPWebsocket, error) {
result, err := this.Query().
Pk(id).
Attr("state", HTTPWebsocketStateEnabled).
Find()
if result == nil {
return nil, err
}
return result.(*HTTPWebsocket), err
}
// 组合配置
func (this *HTTPWebsocketDAO) ComposeWebsocketConfig(websocketId int64) (*serverconfigs.HTTPWebsocketConfig, error) {
websocket, err := this.FindEnabledHTTPWebsocket(websocketId)
if err != nil {
return nil, err
}
if websocket == nil {
return nil, nil
}
config := &serverconfigs.HTTPWebsocketConfig{}
config.Id = int64(websocket.Id)
config.IsOn = websocket.IsOn == 1
config.AllowAllOrigins = websocket.AllowAllOrigins == 1
if IsNotNull(websocket.AllowedOrigins) {
origins := []string{}
err = json.Unmarshal([]byte(websocket.AllowedOrigins), &origins)
if err != nil {
return nil, err
}
config.AllowedOrigins = origins
}
if IsNotNull(websocket.HandshakeTimeout) {
duration := &shared.TimeDuration{}
err = json.Unmarshal([]byte(websocket.HandshakeTimeout), duration)
if err != nil {
return nil, err
}
config.HandshakeTimeout = duration
}
config.RequestSameOrigin = websocket.RequestSameOrigin == 1
config.RequestOrigin = websocket.RequestOrigin
return config, nil
}
// 创建Websocket配置
func (this *HTTPWebsocketDAO) CreateWebsocket(handshakeTimeoutJSON []byte, allowAllOrigins bool, allowedOrigins []string, requestSameOrigin bool, requestOrigin string) (websocketId int64, err error) {
op := NewHTTPWebsocketOperator()
op.IsOn = true
op.State = HTTPWebsocketStateEnabled
if len(handshakeTimeoutJSON) > 0 {
op.HandshakeTimeout = handshakeTimeoutJSON
}
op.AllowAllOrigins = allowAllOrigins
if len(allowedOrigins) > 0 {
originsJSON, err := json.Marshal(allowedOrigins)
if err != nil {
return 0, err
}
op.AllowedOrigins = originsJSON
}
op.RequestSameOrigin = requestSameOrigin
op.RequestOrigin = requestOrigin
_, err = this.Save(op)
return types.Int64(op.Id), err
}
// 修改Websocket配置
func (this *HTTPWebsocketDAO) UpdateWebsocket(websocketId int64, handshakeTimeoutJSON []byte, allowAllOrigins bool, allowedOrigins []string, requestSameOrigin bool, requestOrigin string) error {
if websocketId <= 0 {
return errors.New("invalid websocketId")
}
op := NewHTTPWebsocketOperator()
op.Id = websocketId
if len(handshakeTimeoutJSON) > 0 {
op.HandshakeTimeout = handshakeTimeoutJSON
}
op.AllowAllOrigins = allowAllOrigins
if len(allowedOrigins) > 0 {
originsJSON, err := json.Marshal(allowedOrigins)
if err != nil {
return err
}
op.AllowedOrigins = originsJSON
}
op.RequestSameOrigin = requestSameOrigin
op.RequestOrigin = requestOrigin
_, err := this.Save(op)
return err
}

View File

@@ -0,0 +1,5 @@
package models
import (
_ "github.com/go-sql-driver/mysql"
)

View File

@@ -0,0 +1,34 @@
package models
// Websocket设置
type HTTPWebsocket struct {
Id uint32 `field:"id"` // ID
AdminId uint32 `field:"adminId"` // 管理员ID
UserId uint32 `field:"userId"` // 用户ID
CreatedAt uint64 `field:"createdAt"` // 创建时间
State uint8 `field:"state"` // 状态
IsOn uint8 `field:"isOn"` // 是否启用
HandshakeTimeout string `field:"handshakeTimeout"` // 握手超时时间
AllowAllOrigins uint8 `field:"allowAllOrigins"` // 是否支持所有源
AllowedOrigins string `field:"allowedOrigins"` // 支持的源域名列表
RequestSameOrigin uint8 `field:"requestSameOrigin"` // 是否请求一样的Origin
RequestOrigin string `field:"requestOrigin"` // 请求Origin
}
type HTTPWebsocketOperator struct {
Id interface{} // ID
AdminId interface{} // 管理员ID
UserId interface{} // 用户ID
CreatedAt interface{} // 创建时间
State interface{} // 状态
IsOn interface{} // 是否启用
HandshakeTimeout interface{} // 握手超时时间
AllowAllOrigins interface{} // 是否支持所有源
AllowedOrigins interface{} // 支持的源域名列表
RequestSameOrigin interface{} // 是否请求一样的Origin
RequestOrigin interface{} // 请求Origin
}
func NewHTTPWebsocketOperator() *HTTPWebsocketOperator {
return &HTTPWebsocketOperator{}
}

View File

@@ -0,0 +1 @@
package models

View File

@@ -175,23 +175,42 @@ func (this *OriginDAO) ComposeOriginConfig(originId int64) (*serverconfigs.Origi
config.IdleTimeout = idleTimeout
}
if origin.RequestHeaderPolicyId > 0 {
policyConfig, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(int64(origin.RequestHeaderPolicyId))
// headers
if IsNotNull(origin.HttpRequestHeader) {
ref := &shared.HTTPHeaderPolicyRef{}
err = json.Unmarshal([]byte(origin.HttpRequestHeader), ref)
if err != nil {
return nil, err
}
if policyConfig != nil {
config.RequestHeaderPolicy = policyConfig
config.RequestHeaderPolicyRef = ref
if ref.HeaderPolicyId > 0 {
headerPolicy, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(ref.HeaderPolicyId)
if err != nil {
return nil, err
}
if headerPolicy != nil {
config.RequestHeaderPolicy = headerPolicy
}
}
}
if origin.ResponseHeaderPolicyId > 0 {
policyConfig, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(int64(origin.ResponseHeaderPolicyId))
if IsNotNull(origin.HttpResponseHeader) {
ref := &shared.HTTPHeaderPolicyRef{}
err = json.Unmarshal([]byte(origin.HttpResponseHeader), ref)
if err != nil {
return nil, err
}
if policyConfig != nil {
config.ResponseHeaderPolicy = policyConfig
config.ResponseHeaderPolicyRef = ref
if ref.HeaderPolicyId > 0 {
headerPolicy, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(ref.HeaderPolicyId)
if err != nil {
return nil, err
}
if headerPolicy != nil {
config.ResponseHeaderPolicy = headerPolicy
}
}
}

View File

@@ -19,8 +19,8 @@ type Origin struct {
MaxConns uint32 `field:"maxConns"` // 最大并发连接数
MaxIdleConns uint32 `field:"maxIdleConns"` // 最多空闲连接数
HttpRequestURI string `field:"httpRequestURI"` // 转发后的请求URI
RequestHeaderPolicyId uint32 `field:"requestHeaderPolicyId"` // 请求Header
ResponseHeaderPolicyId uint32 `field:"responseHeaderPolicyId"` // 响应Header
HttpRequestHeader string `field:"httpRequestHeader"` // 请求Header配置
HttpResponseHeader string `field:"httpResponseHeader"` // 响应Header配置
Host string `field:"host"` // 自定义主机名
HealthCheck string `field:"healthCheck"` // 健康检查设置
Cert string `field:"cert"` // 证书设置
@@ -47,8 +47,8 @@ type OriginOperator struct {
MaxConns interface{} // 最大并发连接数
MaxIdleConns interface{} // 最多空闲连接数
HttpRequestURI interface{} // 转发后的请求URI
RequestHeaderPolicyId interface{} // 请求Header
ResponseHeaderPolicyId interface{} // 响应Header
HttpRequestHeader interface{} // 请求Header配置
HttpResponseHeader interface{} // 响应Header配置
Host interface{} // 自定义主机名
HealthCheck interface{} // 健康检查设置
Cert interface{} // 证书设置

View File

@@ -142,7 +142,7 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(reverseProxyId int64) (*s
// 创建反向代理
func (this *ReverseProxyDAO) CreateReverseProxy(schedulingJSON []byte, primaryOriginsJSON []byte, backupOriginsJSON []byte) (int64, error) {
op := NewReverseProxyOperator()
op.IsOn = false
op.IsOn = true
op.State = ReverseProxyStateEnabled
if len(schedulingJSON) > 0 {
op.Scheduling = string(schedulingJSON)

View File

@@ -134,32 +134,32 @@ func (this *HTTPLocationService) FindAndInitHTTPLocationWebConfig(ctx context.Co
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
return nil, rpcutils.Wrap("ValidateRequest()", err)
}
webId, err := models.SharedHTTPLocationDAO.FindLocationWebId(req.LocationId)
if err != nil {
return nil, err
return nil, rpcutils.Wrap("FindLocationWebId()", err)
}
if webId <= 0 {
webId, err = models.SharedHTTPWebDAO.CreateWeb(nil)
if err != nil {
return nil, err
return nil, rpcutils.Wrap("CreateWeb()", err)
}
err = models.SharedHTTPLocationDAO.UpdateLocationWeb(req.LocationId, webId)
if err != nil {
return nil, err
return nil, rpcutils.Wrap("UpdateLocationWeb()", err)
}
}
config, err := models.SharedHTTPWebDAO.ComposeWebConfig(webId)
if err != nil {
return nil, err
return nil, rpcutils.Wrap("ComposeWebConfig()", err)
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
return nil, rpcutils.Wrap("json.Marshal()", err)
}
return &pb.FindAndInitHTTPLocationWebConfigResponse{
WebJSON: configJSON,

View File

@@ -257,7 +257,7 @@ func (this *HTTPWebService) UpdateHTTPWebLocations(ctx context.Context, req *pb.
return rpcutils.RPCUpdateSuccess()
}
// 跳转到HTTPS
// 更改跳转到HTTPS设置
func (this *HTTPWebService) UpdateHTTPWebRedirectToHTTPS(ctx context.Context, req *pb.UpdateHTTPWebRedirectToHTTPSRequest) (*pb.RPCUpdateSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
@@ -271,3 +271,18 @@ func (this *HTTPWebService) UpdateHTTPWebRedirectToHTTPS(ctx context.Context, re
}
return rpcutils.RPCUpdateSuccess()
}
// 更改Websocket设置
func (this *HTTPWebService) UpdateHTTPWebWebsocket(ctx context.Context, req *pb.UpdateHTTPWebWebsocketRequest) (*pb.RPCUpdateSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
err = models.SharedHTTPWebDAO.UpdateWebsocket(req.WebId, req.WebsocketJSON)
if err != nil {
return nil, err
}
return rpcutils.RPCUpdateSuccess()
}

View File

@@ -0,0 +1,41 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
type HTTPWebsocketService struct {
}
// 创建Websocket配置
func (this *HTTPWebsocketService) CreateHTTPWebsocket(ctx context.Context, req *pb.CreateHTTPWebsocketRequest) (*pb.CreateHTTPWebsocketResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
websocketId, err := models.SharedHTTPWebsocketDAO.CreateWebsocket(req.HandshakeTimeoutJSON, req.AllowAllOrigins, req.AllowedOrigins, req.RequestSameOrigin, req.RequestOrigin)
if err != nil {
return nil, err
}
return &pb.CreateHTTPWebsocketResponse{WebsocketId: websocketId}, nil
}
// 修改Websocket配置
func (this *HTTPWebsocketService) UpdateHTTPWebsocket(ctx context.Context, req *pb.UpdateHTTPWebsocketRequest) (*pb.RPCUpdateSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
err = models.SharedHTTPWebsocketDAO.UpdateWebsocket(req.WebsocketId, req.HandshakeTimeoutJSON, req.AllowAllOrigins, req.AllowedOrigins, req.RequestSameOrigin, req.RequestOrigin)
if err != nil {
return nil, err
}
return rpcutils.RPCUpdateSuccess()
}

View File

@@ -127,3 +127,11 @@ func RPCUpdateSuccess() (*pb.RPCUpdateSuccess, error) {
func RPCDeleteSuccess() (*pb.RPCDeleteSuccess, error) {
return &pb.RPCDeleteSuccess{}, nil
}
// 包装错误
func Wrap(description string, err error) error {
if err == nil {
return errors.New(description)
}
return errors.New(description + ": " + err.Error())
}