diff --git a/internal/apis/api_node.go b/internal/apis/api_node.go index 6f758c72..05149b6e 100644 --- a/internal/apis/api_node.go +++ b/internal/apis/api_node.go @@ -69,6 +69,7 @@ func (this *APINode) listenRPC() error { pb.RegisterHTTPCachePolicyServiceServer(rpcServer, &services.HTTPCachePolicyService{}) pb.RegisterHTTPFirewallPolicyServiceServer(rpcServer, &services.HTTPFirewallPolicyService{}) pb.RegisterHTTPLocationServiceServer(rpcServer, &services.HTTPLocationService{}) + pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{}) err = rpcServer.Serve(listener) if err != nil { return errors.New("[API]start rpc failed: " + err.Error()) diff --git a/internal/db/models/http_web_dao.go b/internal/db/models/http_web_dao.go index 420718fe..d4edb61a 100644 --- a/internal/db/models/http_web_dao.go +++ b/internal/db/models/http_web_dao.go @@ -260,6 +260,25 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon config.RedirectToHttps = redirectToHTTPSConfig } + // Websocket + if IsNotNull(web.Websocket) { + ref := &serverconfigs.HTTPWebsocketRef{} + err = json.Unmarshal([]byte(web.Websocket), ref) + if err != nil { + return nil, err + } + config.WebsocketRef = ref + if ref.WebsocketId > 0 { + websocketConfig, err := SharedHTTPWebsocketDAO.ComposeWebsocketConfig(ref.WebsocketId) + if err != nil { + return nil, err + } + if websocketConfig != nil { + config.Websocket = websocketConfig + } + } + } + return config, nil } @@ -267,7 +286,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon func (this *HTTPWebDAO) CreateWeb(rootJSON []byte) (int64, error) { op := NewHTTPWebOperator() op.State = HTTPWebStateEnabled - op.Root = rootJSON + op.Root = JSONBytes(rootJSON) _, err := this.Save(op) if err != nil { return 0, err @@ -282,7 +301,7 @@ func (this *HTTPWebDAO) UpdateWeb(webId int64, rootJSON []byte) error { } op := NewHTTPWebOperator() op.Id = webId - op.Root = rootJSON + op.Root = JSONBytes(rootJSON) _, err := this.Save(op) return err } @@ -294,7 +313,7 @@ func (this *HTTPWebDAO) UpdateWebGzip(webId int64, gzipJSON []byte) error { } op := NewHTTPWebOperator() op.Id = webId - op.Gzip = gzipJSON + op.Gzip = JSONBytes(gzipJSON) _, err := this.Save(op) return err } @@ -306,7 +325,7 @@ func (this *HTTPWebDAO) UpdateWebCharset(webId int64, charsetJSON []byte) error } op := NewHTTPWebOperator() op.Id = webId - op.Charset = charsetJSON + op.Charset = JSONBytes(charsetJSON) _, err := this.Save(op) return err } @@ -430,3 +449,15 @@ func (this *HTTPWebDAO) UpdateWebRedirectToHTTPS(webId int64, redirectToHTTPSJSO _, err := this.Save(op) return err } + +// 修改Websocket设置 +func (this *HTTPWebDAO) UpdateWebsocket(webId int64, websocketJSON []byte) error { + if webId <= 0 { + return errors.New("invalid webId") + } + op := NewHTTPWebOperator() + op.Id = webId + op.Websocket = JSONBytes(websocketJSON) + _, err := this.Save(op) + return err +} diff --git a/internal/db/models/http_web_model.go b/internal/db/models/http_web_model.go index 97654f77..d7bca3f7 100644 --- a/internal/db/models/http_web_model.go +++ b/internal/db/models/http_web_model.go @@ -9,7 +9,7 @@ type HTTPWeb struct { UserId uint32 `field:"userId"` // 用户ID State uint8 `field:"state"` // 状态 CreatedAt uint64 `field:"createdAt"` // 创建时间 - Root string `field:"root"` // 资源根目录 + Root string `field:"root"` // 根目录 Charset string `field:"charset"` // 字符集 Shutdown string `field:"shutdown"` // 临时关闭页面配置 Pages string `field:"pages"` // 特殊页面 @@ -24,6 +24,7 @@ type HTTPWeb struct { Cache string `field:"cache"` // 缓存配置 Firewall string `field:"firewall"` // 防火墙设置 Locations string `field:"locations"` // 路径规则配置 + Websocket string `field:"websocket"` // Websocket设置 } type HTTPWebOperator struct { @@ -34,7 +35,7 @@ type HTTPWebOperator struct { UserId interface{} // 用户ID State interface{} // 状态 CreatedAt interface{} // 创建时间 - Root interface{} // 资源根目录 + Root interface{} // 根目录 Charset interface{} // 字符集 Shutdown interface{} // 临时关闭页面配置 Pages interface{} // 特殊页面 @@ -49,6 +50,7 @@ type HTTPWebOperator struct { Cache interface{} // 缓存配置 Firewall interface{} // 防火墙设置 Locations interface{} // 路径规则配置 + Websocket interface{} // Websocket设置 } func NewHTTPWebOperator() *HTTPWebOperator { diff --git a/internal/db/models/http_websocket_dao.go b/internal/db/models/http_websocket_dao.go new file mode 100644 index 00000000..2b988d10 --- /dev/null +++ b/internal/db/models/http_websocket_dao.go @@ -0,0 +1,146 @@ +package models + +import ( + "encoding/json" + "errors" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" + _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/types" +) + +const ( + HTTPWebsocketStateEnabled = 1 // 已启用 + HTTPWebsocketStateDisabled = 0 // 已禁用 +) + +type HTTPWebsocketDAO dbs.DAO + +func NewHTTPWebsocketDAO() *HTTPWebsocketDAO { + return dbs.NewDAO(&HTTPWebsocketDAO{ + DAOObject: dbs.DAOObject{ + DB: Tea.Env, + Table: "edgeHTTPWebsockets", + Model: new(HTTPWebsocket), + PkName: "id", + }, + }).(*HTTPWebsocketDAO) +} + +var SharedHTTPWebsocketDAO = NewHTTPWebsocketDAO() + +// 启用条目 +func (this *HTTPWebsocketDAO) EnableHTTPWebsocket(id int64) error { + _, err := this.Query(). + Pk(id). + Set("state", HTTPWebsocketStateEnabled). + Update() + return err +} + +// 禁用条目 +func (this *HTTPWebsocketDAO) DisableHTTPWebsocket(id int64) error { + _, err := this.Query(). + Pk(id). + Set("state", HTTPWebsocketStateDisabled). + Update() + return err +} + +// 查找启用中的条目 +func (this *HTTPWebsocketDAO) FindEnabledHTTPWebsocket(id int64) (*HTTPWebsocket, error) { + result, err := this.Query(). + Pk(id). + Attr("state", HTTPWebsocketStateEnabled). + Find() + if result == nil { + return nil, err + } + return result.(*HTTPWebsocket), err +} + +// 组合配置 +func (this *HTTPWebsocketDAO) ComposeWebsocketConfig(websocketId int64) (*serverconfigs.HTTPWebsocketConfig, error) { + websocket, err := this.FindEnabledHTTPWebsocket(websocketId) + if err != nil { + return nil, err + } + if websocket == nil { + return nil, nil + } + config := &serverconfigs.HTTPWebsocketConfig{} + config.Id = int64(websocket.Id) + config.IsOn = websocket.IsOn == 1 + config.AllowAllOrigins = websocket.AllowAllOrigins == 1 + + if IsNotNull(websocket.AllowedOrigins) { + origins := []string{} + err = json.Unmarshal([]byte(websocket.AllowedOrigins), &origins) + if err != nil { + return nil, err + } + config.AllowedOrigins = origins + } + + if IsNotNull(websocket.HandshakeTimeout) { + duration := &shared.TimeDuration{} + err = json.Unmarshal([]byte(websocket.HandshakeTimeout), duration) + if err != nil { + return nil, err + } + config.HandshakeTimeout = duration + } + + config.RequestSameOrigin = websocket.RequestSameOrigin == 1 + config.RequestOrigin = websocket.RequestOrigin + + return config, nil +} + +// 创建Websocket配置 +func (this *HTTPWebsocketDAO) CreateWebsocket(handshakeTimeoutJSON []byte, allowAllOrigins bool, allowedOrigins []string, requestSameOrigin bool, requestOrigin string) (websocketId int64, err error) { + op := NewHTTPWebsocketOperator() + op.IsOn = true + op.State = HTTPWebsocketStateEnabled + if len(handshakeTimeoutJSON) > 0 { + op.HandshakeTimeout = handshakeTimeoutJSON + } + op.AllowAllOrigins = allowAllOrigins + if len(allowedOrigins) > 0 { + originsJSON, err := json.Marshal(allowedOrigins) + if err != nil { + return 0, err + } + op.AllowedOrigins = originsJSON + } + op.RequestSameOrigin = requestSameOrigin + op.RequestOrigin = requestOrigin + _, err = this.Save(op) + return types.Int64(op.Id), err +} + +// 修改Websocket配置 +func (this *HTTPWebsocketDAO) UpdateWebsocket(websocketId int64, handshakeTimeoutJSON []byte, allowAllOrigins bool, allowedOrigins []string, requestSameOrigin bool, requestOrigin string) error { + if websocketId <= 0 { + return errors.New("invalid websocketId") + } + op := NewHTTPWebsocketOperator() + op.Id = websocketId + if len(handshakeTimeoutJSON) > 0 { + op.HandshakeTimeout = handshakeTimeoutJSON + } + op.AllowAllOrigins = allowAllOrigins + if len(allowedOrigins) > 0 { + originsJSON, err := json.Marshal(allowedOrigins) + if err != nil { + return err + } + op.AllowedOrigins = originsJSON + } + op.RequestSameOrigin = requestSameOrigin + op.RequestOrigin = requestOrigin + _, err := this.Save(op) + return err +} diff --git a/internal/db/models/http_websocket_dao_test.go b/internal/db/models/http_websocket_dao_test.go new file mode 100644 index 00000000..97c24b56 --- /dev/null +++ b/internal/db/models/http_websocket_dao_test.go @@ -0,0 +1,5 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" +) diff --git a/internal/db/models/http_websocket_model.go b/internal/db/models/http_websocket_model.go new file mode 100644 index 00000000..83f812cd --- /dev/null +++ b/internal/db/models/http_websocket_model.go @@ -0,0 +1,34 @@ +package models + +// Websocket设置 +type HTTPWebsocket struct { + Id uint32 `field:"id"` // ID + AdminId uint32 `field:"adminId"` // 管理员ID + UserId uint32 `field:"userId"` // 用户ID + CreatedAt uint64 `field:"createdAt"` // 创建时间 + State uint8 `field:"state"` // 状态 + IsOn uint8 `field:"isOn"` // 是否启用 + HandshakeTimeout string `field:"handshakeTimeout"` // 握手超时时间 + AllowAllOrigins uint8 `field:"allowAllOrigins"` // 是否支持所有源 + AllowedOrigins string `field:"allowedOrigins"` // 支持的源域名列表 + RequestSameOrigin uint8 `field:"requestSameOrigin"` // 是否请求一样的Origin + RequestOrigin string `field:"requestOrigin"` // 请求Origin +} + +type HTTPWebsocketOperator struct { + Id interface{} // ID + AdminId interface{} // 管理员ID + UserId interface{} // 用户ID + CreatedAt interface{} // 创建时间 + State interface{} // 状态 + IsOn interface{} // 是否启用 + HandshakeTimeout interface{} // 握手超时时间 + AllowAllOrigins interface{} // 是否支持所有源 + AllowedOrigins interface{} // 支持的源域名列表 + RequestSameOrigin interface{} // 是否请求一样的Origin + RequestOrigin interface{} // 请求Origin +} + +func NewHTTPWebsocketOperator() *HTTPWebsocketOperator { + return &HTTPWebsocketOperator{} +} diff --git a/internal/db/models/http_websocket_model_ext.go b/internal/db/models/http_websocket_model_ext.go new file mode 100644 index 00000000..2640e7f9 --- /dev/null +++ b/internal/db/models/http_websocket_model_ext.go @@ -0,0 +1 @@ +package models diff --git a/internal/db/models/origin_dao.go b/internal/db/models/origin_dao.go index 7f02e860..0ef275eb 100644 --- a/internal/db/models/origin_dao.go +++ b/internal/db/models/origin_dao.go @@ -175,23 +175,42 @@ func (this *OriginDAO) ComposeOriginConfig(originId int64) (*serverconfigs.Origi config.IdleTimeout = idleTimeout } - if origin.RequestHeaderPolicyId > 0 { - policyConfig, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(int64(origin.RequestHeaderPolicyId)) + // headers + if IsNotNull(origin.HttpRequestHeader) { + ref := &shared.HTTPHeaderPolicyRef{} + err = json.Unmarshal([]byte(origin.HttpRequestHeader), ref) if err != nil { return nil, err } - if policyConfig != nil { - config.RequestHeaderPolicy = policyConfig + config.RequestHeaderPolicyRef = ref + + if ref.HeaderPolicyId > 0 { + headerPolicy, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(ref.HeaderPolicyId) + if err != nil { + return nil, err + } + if headerPolicy != nil { + config.RequestHeaderPolicy = headerPolicy + } } } - if origin.ResponseHeaderPolicyId > 0 { - policyConfig, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(int64(origin.ResponseHeaderPolicyId)) + if IsNotNull(origin.HttpResponseHeader) { + ref := &shared.HTTPHeaderPolicyRef{} + err = json.Unmarshal([]byte(origin.HttpResponseHeader), ref) if err != nil { return nil, err } - if policyConfig != nil { - config.ResponseHeaderPolicy = policyConfig + config.ResponseHeaderPolicyRef = ref + + if ref.HeaderPolicyId > 0 { + headerPolicy, err := SharedHTTPHeaderPolicyDAO.ComposeHeaderPolicyConfig(ref.HeaderPolicyId) + if err != nil { + return nil, err + } + if headerPolicy != nil { + config.ResponseHeaderPolicy = headerPolicy + } } } diff --git a/internal/db/models/origin_model.go b/internal/db/models/origin_model.go index cb5a71f4..f79fb8a3 100644 --- a/internal/db/models/origin_model.go +++ b/internal/db/models/origin_model.go @@ -2,59 +2,59 @@ package models // 源站 type Origin struct { - Id uint32 `field:"id"` // ID - AdminId uint32 `field:"adminId"` // 管理员ID - UserId uint32 `field:"userId"` // 用户ID - IsOn uint8 `field:"isOn"` // 是否启用 - Name string `field:"name"` // 名称 - Version uint32 `field:"version"` // 版本 - Addr string `field:"addr"` // 地址 - Description string `field:"description"` // 描述 - Code string `field:"code"` // 代号 - Weight uint32 `field:"weight"` // 权重 - ConnTimeout string `field:"connTimeout"` // 连接超时 - ReadTimeout string `field:"readTimeout"` // 读超时 - IdleTimeout string `field:"idleTimeout"` // 空闲连接超时 - MaxFails uint32 `field:"maxFails"` // 最多失败次数 - MaxConns uint32 `field:"maxConns"` // 最大并发连接数 - MaxIdleConns uint32 `field:"maxIdleConns"` // 最多空闲连接数 - HttpRequestURI string `field:"httpRequestURI"` // 转发后的请求URI - RequestHeaderPolicyId uint32 `field:"requestHeaderPolicyId"` // 请求Header - ResponseHeaderPolicyId uint32 `field:"responseHeaderPolicyId"` // 响应Header - Host string `field:"host"` // 自定义主机名 - HealthCheck string `field:"healthCheck"` // 健康检查设置 - Cert string `field:"cert"` // 证书设置 - Ftp string `field:"ftp"` // FTP相关设置 - CreatedAt uint64 `field:"createdAt"` // 创建时间 - State uint8 `field:"state"` // 状态 + Id uint32 `field:"id"` // ID + AdminId uint32 `field:"adminId"` // 管理员ID + UserId uint32 `field:"userId"` // 用户ID + IsOn uint8 `field:"isOn"` // 是否启用 + Name string `field:"name"` // 名称 + Version uint32 `field:"version"` // 版本 + Addr string `field:"addr"` // 地址 + Description string `field:"description"` // 描述 + Code string `field:"code"` // 代号 + Weight uint32 `field:"weight"` // 权重 + ConnTimeout string `field:"connTimeout"` // 连接超时 + ReadTimeout string `field:"readTimeout"` // 读超时 + IdleTimeout string `field:"idleTimeout"` // 空闲连接超时 + MaxFails uint32 `field:"maxFails"` // 最多失败次数 + MaxConns uint32 `field:"maxConns"` // 最大并发连接数 + MaxIdleConns uint32 `field:"maxIdleConns"` // 最多空闲连接数 + HttpRequestURI string `field:"httpRequestURI"` // 转发后的请求URI + HttpRequestHeader string `field:"httpRequestHeader"` // 请求Header配置 + HttpResponseHeader string `field:"httpResponseHeader"` // 响应Header配置 + Host string `field:"host"` // 自定义主机名 + HealthCheck string `field:"healthCheck"` // 健康检查设置 + Cert string `field:"cert"` // 证书设置 + Ftp string `field:"ftp"` // FTP相关设置 + CreatedAt uint64 `field:"createdAt"` // 创建时间 + State uint8 `field:"state"` // 状态 } type OriginOperator struct { - Id interface{} // ID - AdminId interface{} // 管理员ID - UserId interface{} // 用户ID - IsOn interface{} // 是否启用 - Name interface{} // 名称 - Version interface{} // 版本 - Addr interface{} // 地址 - Description interface{} // 描述 - Code interface{} // 代号 - Weight interface{} // 权重 - ConnTimeout interface{} // 连接超时 - ReadTimeout interface{} // 读超时 - IdleTimeout interface{} // 空闲连接超时 - MaxFails interface{} // 最多失败次数 - MaxConns interface{} // 最大并发连接数 - MaxIdleConns interface{} // 最多空闲连接数 - HttpRequestURI interface{} // 转发后的请求URI - RequestHeaderPolicyId interface{} // 请求Header - ResponseHeaderPolicyId interface{} // 响应Header - Host interface{} // 自定义主机名 - HealthCheck interface{} // 健康检查设置 - Cert interface{} // 证书设置 - Ftp interface{} // FTP相关设置 - CreatedAt interface{} // 创建时间 - State interface{} // 状态 + Id interface{} // ID + AdminId interface{} // 管理员ID + UserId interface{} // 用户ID + IsOn interface{} // 是否启用 + Name interface{} // 名称 + Version interface{} // 版本 + Addr interface{} // 地址 + Description interface{} // 描述 + Code interface{} // 代号 + Weight interface{} // 权重 + ConnTimeout interface{} // 连接超时 + ReadTimeout interface{} // 读超时 + IdleTimeout interface{} // 空闲连接超时 + MaxFails interface{} // 最多失败次数 + MaxConns interface{} // 最大并发连接数 + MaxIdleConns interface{} // 最多空闲连接数 + HttpRequestURI interface{} // 转发后的请求URI + HttpRequestHeader interface{} // 请求Header配置 + HttpResponseHeader interface{} // 响应Header配置 + Host interface{} // 自定义主机名 + HealthCheck interface{} // 健康检查设置 + Cert interface{} // 证书设置 + Ftp interface{} // FTP相关设置 + CreatedAt interface{} // 创建时间 + State interface{} // 状态 } func NewOriginOperator() *OriginOperator { diff --git a/internal/db/models/reverse_proxy_dao.go b/internal/db/models/reverse_proxy_dao.go index b9d9f5ec..6715feda 100644 --- a/internal/db/models/reverse_proxy_dao.go +++ b/internal/db/models/reverse_proxy_dao.go @@ -142,7 +142,7 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(reverseProxyId int64) (*s // 创建反向代理 func (this *ReverseProxyDAO) CreateReverseProxy(schedulingJSON []byte, primaryOriginsJSON []byte, backupOriginsJSON []byte) (int64, error) { op := NewReverseProxyOperator() - op.IsOn = false + op.IsOn = true op.State = ReverseProxyStateEnabled if len(schedulingJSON) > 0 { op.Scheduling = string(schedulingJSON) diff --git a/internal/rpc/services/service_http_location.go b/internal/rpc/services/service_http_location.go index 98072993..bc7e05e9 100644 --- a/internal/rpc/services/service_http_location.go +++ b/internal/rpc/services/service_http_location.go @@ -134,32 +134,32 @@ func (this *HTTPLocationService) FindAndInitHTTPLocationWebConfig(ctx context.Co // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { - return nil, err + return nil, rpcutils.Wrap("ValidateRequest()", err) } webId, err := models.SharedHTTPLocationDAO.FindLocationWebId(req.LocationId) if err != nil { - return nil, err + return nil, rpcutils.Wrap("FindLocationWebId()", err) } if webId <= 0 { webId, err = models.SharedHTTPWebDAO.CreateWeb(nil) if err != nil { - return nil, err + return nil, rpcutils.Wrap("CreateWeb()", err) } err = models.SharedHTTPLocationDAO.UpdateLocationWeb(req.LocationId, webId) if err != nil { - return nil, err + return nil, rpcutils.Wrap("UpdateLocationWeb()", err) } } config, err := models.SharedHTTPWebDAO.ComposeWebConfig(webId) if err != nil { - return nil, err + return nil, rpcutils.Wrap("ComposeWebConfig()", err) } configJSON, err := json.Marshal(config) if err != nil { - return nil, err + return nil, rpcutils.Wrap("json.Marshal()", err) } return &pb.FindAndInitHTTPLocationWebConfigResponse{ WebJSON: configJSON, diff --git a/internal/rpc/services/service_http_web.go b/internal/rpc/services/service_http_web.go index 7bab0afd..3fc4c6f2 100644 --- a/internal/rpc/services/service_http_web.go +++ b/internal/rpc/services/service_http_web.go @@ -257,7 +257,7 @@ func (this *HTTPWebService) UpdateHTTPWebLocations(ctx context.Context, req *pb. return rpcutils.RPCUpdateSuccess() } -// 跳转到HTTPS +// 更改跳转到HTTPS设置 func (this *HTTPWebService) UpdateHTTPWebRedirectToHTTPS(ctx context.Context, req *pb.UpdateHTTPWebRedirectToHTTPSRequest) (*pb.RPCUpdateSuccess, error) { // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) @@ -271,3 +271,18 @@ func (this *HTTPWebService) UpdateHTTPWebRedirectToHTTPS(ctx context.Context, re } return rpcutils.RPCUpdateSuccess() } + +// 更改Websocket设置 +func (this *HTTPWebService) UpdateHTTPWebWebsocket(ctx context.Context, req *pb.UpdateHTTPWebWebsocketRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedHTTPWebDAO.UpdateWebsocket(req.WebId, req.WebsocketJSON) + if err != nil { + return nil, err + } + return rpcutils.RPCUpdateSuccess() +} diff --git a/internal/rpc/services/service_http_websocket.go b/internal/rpc/services/service_http_websocket.go new file mode 100644 index 00000000..deab13ad --- /dev/null +++ b/internal/rpc/services/service_http_websocket.go @@ -0,0 +1,41 @@ +package services + +import ( + "context" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" +) + +type HTTPWebsocketService struct { +} + +// 创建Websocket配置 +func (this *HTTPWebsocketService) CreateHTTPWebsocket(ctx context.Context, req *pb.CreateHTTPWebsocketRequest) (*pb.CreateHTTPWebsocketResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + websocketId, err := models.SharedHTTPWebsocketDAO.CreateWebsocket(req.HandshakeTimeoutJSON, req.AllowAllOrigins, req.AllowedOrigins, req.RequestSameOrigin, req.RequestOrigin) + if err != nil { + return nil, err + } + return &pb.CreateHTTPWebsocketResponse{WebsocketId: websocketId}, nil +} + +// 修改Websocket配置 +func (this *HTTPWebsocketService) UpdateHTTPWebsocket(ctx context.Context, req *pb.UpdateHTTPWebsocketRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedHTTPWebsocketDAO.UpdateWebsocket(req.WebsocketId, req.HandshakeTimeoutJSON, req.AllowAllOrigins, req.AllowedOrigins, req.RequestSameOrigin, req.RequestOrigin) + if err != nil { + return nil, err + } + return rpcutils.RPCUpdateSuccess() +} diff --git a/internal/rpc/utils/utils.go b/internal/rpc/utils/utils.go index 098c066e..c1cf103f 100644 --- a/internal/rpc/utils/utils.go +++ b/internal/rpc/utils/utils.go @@ -127,3 +127,11 @@ func RPCUpdateSuccess() (*pb.RPCUpdateSuccess, error) { func RPCDeleteSuccess() (*pb.RPCDeleteSuccess, error) { return &pb.RPCDeleteSuccess{}, nil } + +// 包装错误 +func Wrap(description string, err error) error { + if err == nil { + return errors.New(description) + } + return errors.New(description + ": " + err.Error()) +}