mirror of
https://github.com/TeaOSLab/EdgeAPI.git
synced 2025-11-06 18:10:25 +08:00
[API节点]支持HTTP API
This commit is contained in:
68
internal/db/models/api_access_token_dao.go
Normal file
68
internal/db/models/api_access_token_dao.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
"github.com/iwind/TeaGo/Tea"
|
||||||
|
"github.com/iwind/TeaGo/dbs"
|
||||||
|
"github.com/iwind/TeaGo/rands"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type APIAccessTokenDAO dbs.DAO
|
||||||
|
|
||||||
|
func NewAPIAccessTokenDAO() *APIAccessTokenDAO {
|
||||||
|
return dbs.NewDAO(&APIAccessTokenDAO{
|
||||||
|
DAOObject: dbs.DAOObject{
|
||||||
|
DB: Tea.Env,
|
||||||
|
Table: "edgeAPIAccessTokens",
|
||||||
|
Model: new(APIAccessToken),
|
||||||
|
PkName: "id",
|
||||||
|
},
|
||||||
|
}).(*APIAccessTokenDAO)
|
||||||
|
}
|
||||||
|
|
||||||
|
var SharedAPIAccessTokenDAO *APIAccessTokenDAO
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
dbs.OnReady(func() {
|
||||||
|
SharedAPIAccessTokenDAO = NewAPIAccessTokenDAO()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成AccessToken
|
||||||
|
func (this *APIAccessTokenDAO) GenerateAccessToken(userId int64) (token string, expiresAt int64, err error) {
|
||||||
|
// 查询以前的
|
||||||
|
accessToken, err := this.Query().
|
||||||
|
Attr("userId", userId).
|
||||||
|
Find()
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
token = rands.String(128) // TODO 增强安全性,将来使用 base64_encode(encrypt(salt+random)) 算法来代替
|
||||||
|
expiresAt = time.Now().Unix() + 7200
|
||||||
|
|
||||||
|
op := NewAPIAccessTokenOperator()
|
||||||
|
|
||||||
|
if accessToken != nil {
|
||||||
|
op.Id = accessToken.(*APIAccessToken).Id
|
||||||
|
}
|
||||||
|
|
||||||
|
op.UserId = userId
|
||||||
|
op.Token = token
|
||||||
|
op.CreatedAt = time.Now().Unix()
|
||||||
|
op.ExpiredAt = expiresAt
|
||||||
|
err = this.Save(op)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找AccessToken
|
||||||
|
func (this *APIAccessTokenDAO) FindAccessToken(token string) (*APIAccessToken, error) {
|
||||||
|
one, err := this.Query().
|
||||||
|
Attr("token", token).
|
||||||
|
Find()
|
||||||
|
if one == nil || err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return one.(*APIAccessToken), nil
|
||||||
|
}
|
||||||
5
internal/db/models/api_access_token_dao_test.go
Normal file
5
internal/db/models/api_access_token_dao_test.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
)
|
||||||
22
internal/db/models/api_access_token_model.go
Normal file
22
internal/db/models/api_access_token_model.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
// API访问令牌
|
||||||
|
type APIAccessToken struct {
|
||||||
|
Id uint64 `field:"id"` // ID
|
||||||
|
UserId uint32 `field:"userId"` // 用户ID
|
||||||
|
Token string `field:"token"` // 令牌
|
||||||
|
CreatedAt uint64 `field:"createdAt"` // 创建时间
|
||||||
|
ExpiredAt uint64 `field:"expiredAt"` // 过期时间
|
||||||
|
}
|
||||||
|
|
||||||
|
type APIAccessTokenOperator struct {
|
||||||
|
Id interface{} // ID
|
||||||
|
UserId interface{} // 用户ID
|
||||||
|
Token interface{} // 令牌
|
||||||
|
CreatedAt interface{} // 创建时间
|
||||||
|
ExpiredAt interface{} // 过期时间
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAPIAccessTokenOperator() *APIAccessTokenOperator {
|
||||||
|
return &APIAccessTokenOperator{}
|
||||||
|
}
|
||||||
1
internal/db/models/api_access_token_model_ext.go
Normal file
1
internal/db/models/api_access_token_model_ext.go
Normal file
@@ -0,0 +1 @@
|
|||||||
|
package models
|
||||||
@@ -90,7 +90,7 @@ func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建API节点
|
// 创建API节点
|
||||||
func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) {
|
func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON []byte, httpsJSON []byte, restIsOn bool, restHTTPJSON []byte, restHTTPSJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) {
|
||||||
uniqueId, err := this.genUniqueId()
|
uniqueId, err := this.genUniqueId()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -114,6 +114,13 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON
|
|||||||
if len(httpsJSON) > 0 {
|
if len(httpsJSON) > 0 {
|
||||||
op.Https = httpsJSON
|
op.Https = httpsJSON
|
||||||
}
|
}
|
||||||
|
op.RestIsOn = restIsOn
|
||||||
|
if len(restHTTPJSON) > 0 {
|
||||||
|
op.RestHTTP = restHTTPJSON
|
||||||
|
}
|
||||||
|
if len(restHTTPSJSON) > 0 {
|
||||||
|
op.RestHTTPS = restHTTPSJSON
|
||||||
|
}
|
||||||
if len(accessAddrsJSON) > 0 {
|
if len(accessAddrsJSON) > 0 {
|
||||||
op.AccessAddrs = accessAddrsJSON
|
op.AccessAddrs = accessAddrsJSON
|
||||||
}
|
}
|
||||||
@@ -128,7 +135,7 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 修改API节点
|
// 修改API节点
|
||||||
func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) error {
|
func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, restIsOn bool, restHTTPJSON []byte, restHTTPSJSON []byte, accessAddrsJSON []byte, isOn bool) error {
|
||||||
if nodeId <= 0 {
|
if nodeId <= 0 {
|
||||||
return errors.New("invalid nodeId")
|
return errors.New("invalid nodeId")
|
||||||
}
|
}
|
||||||
@@ -142,17 +149,28 @@ func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description str
|
|||||||
if len(httpJSON) > 0 {
|
if len(httpJSON) > 0 {
|
||||||
op.Http = httpJSON
|
op.Http = httpJSON
|
||||||
} else {
|
} else {
|
||||||
op.Http = "null"
|
op.Http = "{}"
|
||||||
}
|
}
|
||||||
if len(httpsJSON) > 0 {
|
if len(httpsJSON) > 0 {
|
||||||
op.Https = httpsJSON
|
op.Https = httpsJSON
|
||||||
} else {
|
} else {
|
||||||
op.Https = "null"
|
op.Https = "{}"
|
||||||
|
}
|
||||||
|
op.RestIsOn = restIsOn
|
||||||
|
if len(restHTTPJSON) > 0 {
|
||||||
|
op.RestHTTP = restHTTPJSON
|
||||||
|
} else {
|
||||||
|
op.RestHTTP = "{}"
|
||||||
|
}
|
||||||
|
if len(restHTTPSJSON) > 0 {
|
||||||
|
op.RestHTTPS = restHTTPSJSON
|
||||||
|
} else {
|
||||||
|
op.RestHTTPS = "{}"
|
||||||
}
|
}
|
||||||
if len(accessAddrsJSON) > 0 {
|
if len(accessAddrsJSON) > 0 {
|
||||||
op.AccessAddrs = accessAddrsJSON
|
op.AccessAddrs = accessAddrsJSON
|
||||||
} else {
|
} else {
|
||||||
op.AccessAddrs = "null"
|
op.AccessAddrs = "[]"
|
||||||
}
|
}
|
||||||
|
|
||||||
err := this.Save(op)
|
err := this.Save(op)
|
||||||
|
|||||||
@@ -11,6 +11,9 @@ type APINode struct {
|
|||||||
Description string `field:"description"` // 描述
|
Description string `field:"description"` // 描述
|
||||||
Http string `field:"http"` // 监听的HTTP配置
|
Http string `field:"http"` // 监听的HTTP配置
|
||||||
Https string `field:"https"` // 监听的HTTPS配置
|
Https string `field:"https"` // 监听的HTTPS配置
|
||||||
|
RestIsOn uint8 `field:"restIsOn"` // 是否开放REST
|
||||||
|
RestHTTP string `field:"restHTTP"` // REST HTTP配置
|
||||||
|
RestHTTPS string `field:"restHTTPS"` // REST HTTPS配置
|
||||||
AccessAddrs string `field:"accessAddrs"` // 外部访问地址
|
AccessAddrs string `field:"accessAddrs"` // 外部访问地址
|
||||||
Order uint32 `field:"order"` // 排序
|
Order uint32 `field:"order"` // 排序
|
||||||
State uint8 `field:"state"` // 状态
|
State uint8 `field:"state"` // 状态
|
||||||
@@ -30,6 +33,9 @@ type APINodeOperator struct {
|
|||||||
Description interface{} // 描述
|
Description interface{} // 描述
|
||||||
Http interface{} // 监听的HTTP配置
|
Http interface{} // 监听的HTTP配置
|
||||||
Https interface{} // 监听的HTTPS配置
|
Https interface{} // 监听的HTTPS配置
|
||||||
|
RestIsOn interface{} // 是否开放REST
|
||||||
|
RestHTTP interface{} // REST HTTP配置
|
||||||
|
RestHTTPS interface{} // REST HTTPS配置
|
||||||
AccessAddrs interface{} // 外部访问地址
|
AccessAddrs interface{} // 外部访问地址
|
||||||
Order interface{} // 排序
|
Order interface{} // 排序
|
||||||
State interface{} // 状态
|
State interface{} // 状态
|
||||||
|
|||||||
@@ -93,3 +93,65 @@ func (this *APINode) DecodeAccessAddrStrings() ([]string, error) {
|
|||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 解析Rest HTTP配置
|
||||||
|
func (this *APINode) DecodeRestHTTP() (*serverconfigs.HTTPProtocolConfig, error) {
|
||||||
|
if this.RestIsOn != 1 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if !IsNotNull(this.RestHTTP) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
config := &serverconfigs.HTTPProtocolConfig{}
|
||||||
|
err := json.Unmarshal([]byte(this.RestHTTP), config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = config.Init()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析HTTPS配置
|
||||||
|
func (this *APINode) DecodeRestHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) {
|
||||||
|
if this.RestIsOn != 1 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if !IsNotNull(this.RestHTTPS) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
config := &serverconfigs.HTTPSProtocolConfig{}
|
||||||
|
err := json.Unmarshal([]byte(this.RestHTTPS), config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = config.Init()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.SSLPolicyRef != nil {
|
||||||
|
policyId := config.SSLPolicyRef.SSLPolicyId
|
||||||
|
if policyId > 0 {
|
||||||
|
sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(policyId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if sslPolicy != nil {
|
||||||
|
config.SSLPolicy = sslPolicy
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = config.Init()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,11 +2,13 @@ package models
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/TeaOSLab/EdgeAPI/internal/errors"
|
||||||
"github.com/iwind/TeaGo/dbs"
|
"github.com/iwind/TeaGo/dbs"
|
||||||
"github.com/iwind/TeaGo/lists"
|
"github.com/iwind/TeaGo/lists"
|
||||||
"github.com/iwind/TeaGo/logs"
|
"github.com/iwind/TeaGo/logs"
|
||||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||||
"hash/crc32"
|
"hash/crc32"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -46,17 +48,22 @@ func randomAccessLogDAO() (dao *HTTPAccessLogDAOWrapper) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 检查表格是否存在
|
// 检查表格是否存在
|
||||||
func findAccessLogTableName(db *dbs.DB, day string) (string, bool, error) {
|
func findAccessLogTableName(db *dbs.DB, day string) (tableName string, ok bool, err error) {
|
||||||
|
if !regexp.MustCompile(`^\d{8}$`).MatchString(day) {
|
||||||
|
err = errors.New("invalid day '" + day + "', should be YYYYMMDD")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
config, err := db.Config()
|
config, err := db.Config()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", false, err
|
return "", false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tableName := "edgeHTTPAccessLogs_" + day
|
tableName = "edgeHTTPAccessLogs_" + day
|
||||||
cacheKey := tableName + "_" + fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(config.Dsn)))
|
cacheKey := tableName + "_" + fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(config.Dsn)))
|
||||||
|
|
||||||
accessLogLocker.RLock()
|
accessLogLocker.RLock()
|
||||||
_, ok := accessLogTableMapping[cacheKey]
|
_, ok = accessLogTableMapping[cacheKey]
|
||||||
accessLogLocker.RUnlock()
|
accessLogLocker.RUnlock()
|
||||||
if ok {
|
if ok {
|
||||||
return tableName, true, nil
|
return tableName, true, nil
|
||||||
|
|||||||
@@ -168,6 +168,7 @@ func (this *HTTPAccessLogDAO) listAccessLogs(lastRequestId string, size int64, d
|
|||||||
dao := daoWrapper.DAO
|
dao := daoWrapper.DAO
|
||||||
|
|
||||||
tableName, exists, err := findAccessLogTableName(dao.Instance, day)
|
tableName, exists, err := findAccessLogTableName(dao.Instance, day)
|
||||||
|
logs.Println("tableName:", tableName, exists, err) // TODO
|
||||||
if !exists {
|
if !exists {
|
||||||
// 表格不存在则跳过
|
// 表格不存在则跳过
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -109,3 +109,17 @@ func (this *UserAccessKeyDAO) UpdateAccessKeyIsOn(accessKeyId int64, isOn bool)
|
|||||||
Update()
|
Update()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据UniqueId查找AccessKey
|
||||||
|
func (this *UserAccessKeyDAO) FindAccessKeyWithUniqueId(uniqueId string) (*UserAccessKey, error) {
|
||||||
|
one, err := this.Query().
|
||||||
|
Attr("uniqueId", uniqueId).
|
||||||
|
Attr("isOn", true).
|
||||||
|
State(UserAccessKeyStateEnabled).
|
||||||
|
Find()
|
||||||
|
if one == nil || err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return one.(*UserAccessKey), nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -75,72 +75,7 @@ func (this *APINode) Start() {
|
|||||||
// 监听RPC服务
|
// 监听RPC服务
|
||||||
remotelogs.Println("API_NODE", "starting RPC server ...")
|
remotelogs.Println("API_NODE", "starting RPC server ...")
|
||||||
|
|
||||||
// HTTP
|
isListening := this.listenPorts(apiNode)
|
||||||
httpConfig, err := apiNode.DecodeHTTP()
|
|
||||||
if err != nil {
|
|
||||||
remotelogs.Error("API_NODE", "decode http config: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
isListening := false
|
|
||||||
if httpConfig != nil && httpConfig.IsOn && len(httpConfig.Listen) > 0 {
|
|
||||||
for _, listen := range httpConfig.Listen {
|
|
||||||
for _, addr := range listen.Addresses() {
|
|
||||||
listener, err := net.Listen("tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
remotelogs.Error("API_NODE", "listening '"+addr+"' failed: "+err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
err := this.listenRPC(listener, nil)
|
|
||||||
if err != nil {
|
|
||||||
remotelogs.Error("API_NODE", "listening '"+addr+"' rpc: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
isListening = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTTPS
|
|
||||||
httpsConfig, err := apiNode.DecodeHTTPS()
|
|
||||||
if err != nil {
|
|
||||||
remotelogs.Error("API_NODE", "decode https config: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if httpsConfig != nil &&
|
|
||||||
httpsConfig.IsOn &&
|
|
||||||
len(httpsConfig.Listen) > 0 &&
|
|
||||||
httpsConfig.SSLPolicy != nil &&
|
|
||||||
httpsConfig.SSLPolicy.IsOn &&
|
|
||||||
len(httpsConfig.SSLPolicy.Certs) > 0 {
|
|
||||||
certs := []tls.Certificate{}
|
|
||||||
for _, cert := range httpsConfig.SSLPolicy.Certs {
|
|
||||||
certs = append(certs, *cert.CertObject())
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, listen := range httpsConfig.Listen {
|
|
||||||
for _, addr := range listen.Addresses() {
|
|
||||||
listener, err := net.Listen("tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
remotelogs.Error("API_NODE", "listening '"+addr+"' failed: "+err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
err := this.listenRPC(listener, &tls.Config{
|
|
||||||
Certificates: certs,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
remotelogs.Error("API_NODE", "listening '"+addr+"' rpc: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
isListening = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTTP接口
|
|
||||||
|
|
||||||
if !isListening {
|
if !isListening {
|
||||||
remotelogs.Error("API_NODE", "the api node require at least one listening address")
|
remotelogs.Error("API_NODE", "the api node require at least one listening address")
|
||||||
@@ -155,10 +90,10 @@ func (this *APINode) Start() {
|
|||||||
func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error {
|
func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error {
|
||||||
var rpcServer *grpc.Server
|
var rpcServer *grpc.Server
|
||||||
if tlsConfig == nil {
|
if tlsConfig == nil {
|
||||||
remotelogs.Println("API_NODE", "listening http://"+listener.Addr().String()+" ...")
|
remotelogs.Println("API_NODE", "listening GRPC http://"+listener.Addr().String()+" ...")
|
||||||
rpcServer = grpc.NewServer()
|
rpcServer = grpc.NewServer()
|
||||||
} else {
|
} else {
|
||||||
logs.Println("[API_NODE]listening https://" + listener.Addr().String() + " ...")
|
logs.Println("[API_NODE]listening GRPC https://" + listener.Addr().String() + " ...")
|
||||||
rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||||
}
|
}
|
||||||
pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
|
pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
|
||||||
@@ -264,3 +199,142 @@ func (this *APINode) autoUpgrade() error {
|
|||||||
logs.Println("[API_NODE]upgrade database done")
|
logs.Println("[API_NODE]upgrade database done")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 启动端口
|
||||||
|
func (this *APINode) listenPorts(apiNode *models.APINode) (isListening bool) {
|
||||||
|
// HTTP
|
||||||
|
httpConfig, err := apiNode.DecodeHTTP()
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("API_NODE", "decode http config: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
isListening = false
|
||||||
|
if httpConfig != nil && httpConfig.IsOn && len(httpConfig.Listen) > 0 {
|
||||||
|
for _, listen := range httpConfig.Listen {
|
||||||
|
for _, addr := range listen.Addresses() {
|
||||||
|
listener, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("API_NODE", "listening '"+addr+"' failed: "+err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
err := this.listenRPC(listener, nil)
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("API_NODE", "listening '"+addr+"' rpc: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
isListening = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPS
|
||||||
|
httpsConfig, err := apiNode.DecodeHTTPS()
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("API_NODE", "decode https config: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if httpsConfig != nil &&
|
||||||
|
httpsConfig.IsOn &&
|
||||||
|
len(httpsConfig.Listen) > 0 &&
|
||||||
|
httpsConfig.SSLPolicy != nil &&
|
||||||
|
httpsConfig.SSLPolicy.IsOn &&
|
||||||
|
len(httpsConfig.SSLPolicy.Certs) > 0 {
|
||||||
|
certs := []tls.Certificate{}
|
||||||
|
for _, cert := range httpsConfig.SSLPolicy.Certs {
|
||||||
|
certs = append(certs, *cert.CertObject())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, listen := range httpsConfig.Listen {
|
||||||
|
for _, addr := range listen.Addresses() {
|
||||||
|
listener, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("API_NODE", "listening '"+addr+"' failed: "+err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
err := this.listenRPC(listener, &tls.Config{
|
||||||
|
Certificates: certs,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("API_NODE", "listening '"+addr+"' rpc: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
isListening = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rest HTTP
|
||||||
|
restHTTPConfig, err := apiNode.DecodeRestHTTP()
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("API_NODE", "decode REST http config: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if restHTTPConfig != nil && restHTTPConfig.IsOn && len(restHTTPConfig.Listen) > 0 {
|
||||||
|
for _, listen := range restHTTPConfig.Listen {
|
||||||
|
for _, addr := range listen.Addresses() {
|
||||||
|
listener, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("API_NODE", "listening REST 'http://"+addr+"' failed: "+err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
remotelogs.Println("API_NODE", "listening REST http://"+addr+" ...")
|
||||||
|
server := &RestServer{}
|
||||||
|
err := server.Listen(listener)
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("API_NODE", "listening REST 'http://"+addr+"' failed: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
isListening = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rest HTTPS
|
||||||
|
restHTTPSConfig, err := apiNode.DecodeRestHTTPS()
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("API_NODE", "decode REST https config: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if restHTTPSConfig != nil &&
|
||||||
|
restHTTPSConfig.IsOn &&
|
||||||
|
len(restHTTPSConfig.Listen) > 0 &&
|
||||||
|
restHTTPSConfig.SSLPolicy != nil &&
|
||||||
|
restHTTPSConfig.SSLPolicy.IsOn &&
|
||||||
|
len(restHTTPSConfig.SSLPolicy.Certs) > 0 {
|
||||||
|
for _, listen := range restHTTPSConfig.Listen {
|
||||||
|
for _, addr := range listen.Addresses() {
|
||||||
|
listener, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("API_NODE", "listening REST 'https://"+addr+"' failed: "+err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
remotelogs.Println("API_NODE", "listening REST https://"+addr+" ...")
|
||||||
|
server := &RestServer{}
|
||||||
|
|
||||||
|
certs := []tls.Certificate{}
|
||||||
|
for _, cert := range httpsConfig.SSLPolicy.Certs {
|
||||||
|
certs = append(certs, *cert.CertObject())
|
||||||
|
}
|
||||||
|
|
||||||
|
err := server.ListenHTTPS(listener, &tls.Config{
|
||||||
|
Certificates: certs,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("API_NODE", "listening REST 'https://"+addr+"' failed: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
isListening = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
189
internal/nodes/rest_server.go
Normal file
189
internal/nodes/rest_server.go
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
package nodes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
|
||||||
|
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
|
||||||
|
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
|
||||||
|
"github.com/iwind/TeaGo/maps"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var servicePathReg = regexp.MustCompile(`^/([a-zA-Z0-9]+)/([a-zA-Z0-9]+)$`)
|
||||||
|
var servicesMap = map[string]reflect.Value{
|
||||||
|
"APIAccessTokenService": reflect.ValueOf(new(services.APIAccessTokenService)),
|
||||||
|
"HTTPAccessLogService": reflect.ValueOf(new(services.HTTPAccessLogService)),
|
||||||
|
}
|
||||||
|
|
||||||
|
type RestServer struct{}
|
||||||
|
|
||||||
|
func (this *RestServer) Listen(listener net.Listener) error {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/", this.handle)
|
||||||
|
server := &http.Server{}
|
||||||
|
server.Handler = mux
|
||||||
|
return server.Serve(listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RestServer) ListenHTTPS(listener net.Listener, tlsConfig *tls.Config) error {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/", this.handle)
|
||||||
|
server := &http.Server{}
|
||||||
|
server.Handler = mux
|
||||||
|
server.TLSConfig = tlsConfig
|
||||||
|
return server.ServeTLS(listener, "", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) {
|
||||||
|
path := req.URL.Path
|
||||||
|
matches := servicePathReg.FindStringSubmatch(path)
|
||||||
|
if len(matches) != 3 {
|
||||||
|
writer.WriteHeader(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serviceName := matches[1]
|
||||||
|
methodName := matches[2]
|
||||||
|
|
||||||
|
serviceType, ok := servicesMap[serviceName]
|
||||||
|
if !ok {
|
||||||
|
writer.WriteHeader(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
method := serviceType.MethodByName(methodName)
|
||||||
|
if !method.IsValid() {
|
||||||
|
writer.WriteHeader(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if method.Type().NumIn() != 2 || method.Type().NumOut() != 2 {
|
||||||
|
writer.WriteHeader(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if method.Type().In(0).Name() != "Context" {
|
||||||
|
writer.WriteHeader(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 是否显示Pretty后的JSON
|
||||||
|
shouldPretty := req.Header.Get("Edge-Response-Pretty") == "on"
|
||||||
|
|
||||||
|
// 上下文
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
if serviceName != "APIAccessTokenService" || methodName != "GetAPIAccessToken" {
|
||||||
|
// 校验TOKEN
|
||||||
|
token := req.Header.Get("Edge-Access-Token")
|
||||||
|
if len(token) == 0 {
|
||||||
|
this.writeJSON(writer, maps.Map{
|
||||||
|
"code": 400,
|
||||||
|
"data": maps.Map{},
|
||||||
|
"message": "require 'Edge-Access-Token' header",
|
||||||
|
}, shouldPretty)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, err := models.SharedAPIAccessTokenDAO.FindAccessToken(token)
|
||||||
|
if err != nil {
|
||||||
|
this.writeJSON(writer, maps.Map{
|
||||||
|
"code": 400,
|
||||||
|
"data": maps.Map{},
|
||||||
|
"message": "server error: " + err.Error(),
|
||||||
|
}, shouldPretty)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if accessToken == nil || int64(accessToken.ExpiredAt) < time.Now().Unix() {
|
||||||
|
this.writeJSON(writer, maps.Map{
|
||||||
|
"code": 400,
|
||||||
|
"data": maps.Map{},
|
||||||
|
"message": "invalid access token",
|
||||||
|
}, shouldPretty)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if accessToken.UserId > 0 {
|
||||||
|
ctx = rpcutils.NewPlainContext("user", int64(accessToken.UserId))
|
||||||
|
} else {
|
||||||
|
// TODO 支持更多类型的角色
|
||||||
|
this.writeJSON(writer, maps.Map{
|
||||||
|
"code": 400,
|
||||||
|
"data": maps.Map{},
|
||||||
|
"message": "not supported role",
|
||||||
|
}, shouldPretty)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO 需要防止BODY过大攻击
|
||||||
|
body, err := ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
writer.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = writer.Write([]byte(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 请求数据
|
||||||
|
reqValue := reflect.New(method.Type().In(1).Elem()).Interface()
|
||||||
|
err = json.Unmarshal(body, reqValue)
|
||||||
|
if err != nil {
|
||||||
|
writer.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = writer.Write([]byte(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result := method.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(reqValue)})
|
||||||
|
resultErr := result[1].Interface()
|
||||||
|
if resultErr != nil {
|
||||||
|
e, ok := resultErr.(error)
|
||||||
|
if ok {
|
||||||
|
this.writeJSON(writer, maps.Map{
|
||||||
|
"code": 400,
|
||||||
|
"message": e.Error(),
|
||||||
|
"data": maps.Map{},
|
||||||
|
}, shouldPretty)
|
||||||
|
} else {
|
||||||
|
this.writeJSON(writer, maps.Map{
|
||||||
|
"code": 500,
|
||||||
|
"message": "server error: server should return a error object, but return a " + result[1].Type().String(),
|
||||||
|
"data": maps.Map{},
|
||||||
|
}, shouldPretty)
|
||||||
|
}
|
||||||
|
} else { // 没有返回错误
|
||||||
|
data := maps.Map{
|
||||||
|
"code": 200,
|
||||||
|
"message": "ok",
|
||||||
|
"data": result[0].Interface(),
|
||||||
|
}
|
||||||
|
var dataJSON []byte
|
||||||
|
if shouldPretty {
|
||||||
|
dataJSON = data.AsPrettyJSON()
|
||||||
|
} else {
|
||||||
|
dataJSON = data.AsJSON()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
this.writeJSON(writer, maps.Map{
|
||||||
|
"code": 500,
|
||||||
|
"message": "server error: marshal json failed: " + err.Error(),
|
||||||
|
"data": maps.Map{},
|
||||||
|
}, shouldPretty)
|
||||||
|
} else {
|
||||||
|
_, _ = writer.Write(dataJSON)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RestServer) writeJSON(writer http.ResponseWriter, v maps.Map, pretty bool) {
|
||||||
|
if pretty {
|
||||||
|
_, _ = writer.Write(v.AsPrettyJSON())
|
||||||
|
} else {
|
||||||
|
_, _ = writer.Write(v.AsJSON())
|
||||||
|
}
|
||||||
|
}
|
||||||
40
internal/rpc/services/service_api_access_token.go
Normal file
40
internal/rpc/services/service_api_access_token.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
|
||||||
|
"github.com/TeaOSLab/EdgeAPI/internal/errors"
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AccessToken相关服务
|
||||||
|
type APIAccessTokenService struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取AccessToken
|
||||||
|
func (this *APIAccessTokenService) GetAPIAccessToken(ctx context.Context, req *pb.GetAPIAccessTokenRequest) (*pb.GetAPIAccessTokenResponse, error) {
|
||||||
|
if req.Type == "user" { // 用户
|
||||||
|
accessKey, err := models.SharedUserAccessKeyDAO.FindAccessKeyWithUniqueId(req.AccessKeyId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if accessKey == nil {
|
||||||
|
return nil, errors.New("access key not found")
|
||||||
|
}
|
||||||
|
if accessKey.Secret != req.AccessKey {
|
||||||
|
return nil, errors.New("access key not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建AccessToken
|
||||||
|
token, expiresAt, err := models.SharedAPIAccessTokenDAO.GenerateAccessToken(int64(accessKey.UserId))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &pb.GetAPIAccessTokenResponse{
|
||||||
|
Token: token,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
}, nil
|
||||||
|
} else {
|
||||||
|
return nil, errors.New("unsupported type '" + req.Type + "'")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,7 +19,7 @@ func (this *APINodeService) CreateAPINode(ctx context.Context, req *pb.CreateAPI
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn)
|
nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.RestIsOn, req.RestHTTPJSON, req.RestHTTPSJSON, req.AccessAddrsJSON, req.IsOn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -34,7 +34,7 @@ func (this *APINodeService) UpdateAPINode(ctx context.Context, req *pb.UpdateAPI
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn)
|
err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.RestIsOn, req.RestHTTPJSON, req.RestHTTPSJSON, req.AccessAddrsJSON, req.IsOn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -138,6 +138,9 @@ func (this *APINodeService) ListEnabledAPINodes(ctx context.Context, req *pb.Lis
|
|||||||
Description: node.Description,
|
Description: node.Description,
|
||||||
HttpJSON: []byte(node.Http),
|
HttpJSON: []byte(node.Http),
|
||||||
HttpsJSON: []byte(node.Https),
|
HttpsJSON: []byte(node.Https),
|
||||||
|
RestIsOn: node.RestIsOn == 1,
|
||||||
|
RestHTTPJSON: []byte(node.RestHTTP),
|
||||||
|
RestHTTPSJSON: []byte(node.RestHTTPS),
|
||||||
AccessAddrsJSON: []byte(node.AccessAddrs),
|
AccessAddrsJSON: []byte(node.AccessAddrs),
|
||||||
AccessAddrs: accessAddrs,
|
AccessAddrs: accessAddrs,
|
||||||
StatusJSON: []byte(node.Status),
|
StatusJSON: []byte(node.Status),
|
||||||
@@ -178,6 +181,9 @@ func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.Find
|
|||||||
Description: node.Description,
|
Description: node.Description,
|
||||||
HttpJSON: []byte(node.Http),
|
HttpJSON: []byte(node.Http),
|
||||||
HttpsJSON: []byte(node.Https),
|
HttpsJSON: []byte(node.Https),
|
||||||
|
RestIsOn: node.RestIsOn == 1,
|
||||||
|
RestHTTPJSON: []byte(node.RestHTTP),
|
||||||
|
RestHTTPSJSON: []byte(node.RestHTTPS),
|
||||||
AccessAddrsJSON: []byte(node.AccessAddrs),
|
AccessAddrsJSON: []byte(node.AccessAddrs),
|
||||||
AccessAddrs: accessAddrs,
|
AccessAddrs: accessAddrs,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,12 +3,14 @@ package services
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
|
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
|
||||||
|
"github.com/TeaOSLab/EdgeAPI/internal/errors"
|
||||||
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
|
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
|
||||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 访问日志相关服务
|
// 访问日志相关服务
|
||||||
type HTTPAccessLogService struct {
|
type HTTPAccessLogService struct {
|
||||||
|
BaseService
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建访问日志
|
// 创建访问日志
|
||||||
@@ -34,11 +36,23 @@ func (this *HTTPAccessLogService) CreateHTTPAccessLogs(ctx context.Context, req
|
|||||||
// 列出单页访问日志
|
// 列出单页访问日志
|
||||||
func (this *HTTPAccessLogService) ListHTTPAccessLogs(ctx context.Context, req *pb.ListHTTPAccessLogsRequest) (*pb.ListHTTPAccessLogsResponse, error) {
|
func (this *HTTPAccessLogService) ListHTTPAccessLogs(ctx context.Context, req *pb.ListHTTPAccessLogsRequest) (*pb.ListHTTPAccessLogsResponse, error) {
|
||||||
// 校验请求
|
// 校验请求
|
||||||
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
|
_, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查服务ID
|
||||||
|
if userId > 0 {
|
||||||
|
if req.ServerId <= 0 {
|
||||||
|
return nil, errors.New("invalid serverId")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
accessLogs, requestId, hasMore, err := models.SharedHTTPAccessLogDAO.ListAccessLogs(req.RequestId, req.Size, req.Day, req.ServerId, req.Reverse, req.HasError, req.FirewallPolicyId, req.FirewallRuleGroupId, req.FirewallRuleSetId)
|
accessLogs, requestId, hasMore, err := models.SharedHTTPAccessLogDAO.ListAccessLogs(req.RequestId, req.Size, req.Day, req.ServerId, req.Reverse, req.HasError, req.FirewallPolicyId, req.FirewallRuleGroupId, req.FirewallRuleSetId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
37
internal/rpc/utils/plain_context.go
Normal file
37
internal/rpc/utils/plain_context.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package rpcutils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PlainContext struct {
|
||||||
|
UserType string
|
||||||
|
UserId int64
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPlainContext(userType string, userId int64) *PlainContext {
|
||||||
|
return &PlainContext{
|
||||||
|
UserType: userType,
|
||||||
|
UserId: userId,
|
||||||
|
ctx: context.Background(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *PlainContext) Deadline() (deadline time.Time, ok bool) {
|
||||||
|
return this.ctx.Deadline()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *PlainContext) Done() <-chan struct{} {
|
||||||
|
return this.ctx.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *PlainContext) Err() error {
|
||||||
|
return this.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *PlainContext) Value(key interface{}) interface{} {
|
||||||
|
return this.ctx.Value(key)
|
||||||
|
}
|
||||||
@@ -33,6 +33,29 @@ const (
|
|||||||
|
|
||||||
// 校验请求
|
// 校验请求
|
||||||
func ValidateRequest(ctx context.Context, userTypes ...UserType) (userType UserType, userId int64, err error) {
|
func ValidateRequest(ctx context.Context, userTypes ...UserType) (userType UserType, userId int64, err error) {
|
||||||
|
if ctx == nil {
|
||||||
|
err = errors.New("context should not be nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 支持直接认证
|
||||||
|
plainCtx, ok := ctx.(*PlainContext)
|
||||||
|
if ok {
|
||||||
|
userType = plainCtx.UserType
|
||||||
|
userId = plainCtx.UserId
|
||||||
|
|
||||||
|
if len(userTypes) > 0 && !lists.ContainsString(userTypes, userType) {
|
||||||
|
userType = UserTypeNone
|
||||||
|
userId = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if userId <= 0 {
|
||||||
|
err = errors.New("context: can not find user or permission denied")
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
md, ok := metadata.FromIncomingContext(ctx)
|
md, ok := metadata.FromIncomingContext(ctx)
|
||||||
if !ok {
|
if !ok {
|
||||||
return UserTypeNone, 0, errors.New("context: need 'nodeId'")
|
return UserTypeNone, 0, errors.New("context: need 'nodeId'")
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ func (this *Setup) Run() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建API节点
|
// 创建API节点
|
||||||
nodeId, err := dao.CreateAPINode("默认API节点", "这是默认创建的第一个API节点", httpJSON, httpsJSON, addrsJSON, true)
|
nodeId, err := dao.CreateAPINode("默认API节点", "这是默认创建的第一个API节点", httpJSON, httpsJSON, false, nil, nil, addrsJSON, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("create api node in database failed: " + err.Error())
|
return errors.New("create api node in database failed: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user