[API节点]支持HTTP API

This commit is contained in:
刘祥超
2021-01-01 20:49:09 +08:00
parent f905bb7066
commit b4ab1f0ec8
18 changed files with 667 additions and 80 deletions

View File

@@ -0,0 +1,68 @@
package models
import (
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/rands"
"time"
)
type APIAccessTokenDAO dbs.DAO
func NewAPIAccessTokenDAO() *APIAccessTokenDAO {
return dbs.NewDAO(&APIAccessTokenDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgeAPIAccessTokens",
Model: new(APIAccessToken),
PkName: "id",
},
}).(*APIAccessTokenDAO)
}
var SharedAPIAccessTokenDAO *APIAccessTokenDAO
func init() {
dbs.OnReady(func() {
SharedAPIAccessTokenDAO = NewAPIAccessTokenDAO()
})
}
// 生成AccessToken
func (this *APIAccessTokenDAO) GenerateAccessToken(userId int64) (token string, expiresAt int64, err error) {
// 查询以前的
accessToken, err := this.Query().
Attr("userId", userId).
Find()
if err != nil {
return "", 0, err
}
token = rands.String(128) // TODO 增强安全性,将来使用 base64_encode(encrypt(salt+random)) 算法来代替
expiresAt = time.Now().Unix() + 7200
op := NewAPIAccessTokenOperator()
if accessToken != nil {
op.Id = accessToken.(*APIAccessToken).Id
}
op.UserId = userId
op.Token = token
op.CreatedAt = time.Now().Unix()
op.ExpiredAt = expiresAt
err = this.Save(op)
return
}
// 查找AccessToken
func (this *APIAccessTokenDAO) FindAccessToken(token string) (*APIAccessToken, error) {
one, err := this.Query().
Attr("token", token).
Find()
if one == nil || err != nil {
return nil, err
}
return one.(*APIAccessToken), nil
}

View File

@@ -0,0 +1,5 @@
package models
import (
_ "github.com/go-sql-driver/mysql"
)

View File

@@ -0,0 +1,22 @@
package models
// API访问令牌
type APIAccessToken struct {
Id uint64 `field:"id"` // ID
UserId uint32 `field:"userId"` // 用户ID
Token string `field:"token"` // 令牌
CreatedAt uint64 `field:"createdAt"` // 创建时间
ExpiredAt uint64 `field:"expiredAt"` // 过期时间
}
type APIAccessTokenOperator struct {
Id interface{} // ID
UserId interface{} // 用户ID
Token interface{} // 令牌
CreatedAt interface{} // 创建时间
ExpiredAt interface{} // 过期时间
}
func NewAPIAccessTokenOperator() *APIAccessTokenOperator {
return &APIAccessTokenOperator{}
}

View File

@@ -0,0 +1 @@
package models

View File

@@ -90,7 +90,7 @@ func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) {
} }
// 创建API节点 // 创建API节点
func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) { func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON []byte, httpsJSON []byte, restIsOn bool, restHTTPJSON []byte, restHTTPSJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) {
uniqueId, err := this.genUniqueId() uniqueId, err := this.genUniqueId()
if err != nil { if err != nil {
return 0, err return 0, err
@@ -114,6 +114,13 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON
if len(httpsJSON) > 0 { if len(httpsJSON) > 0 {
op.Https = httpsJSON op.Https = httpsJSON
} }
op.RestIsOn = restIsOn
if len(restHTTPJSON) > 0 {
op.RestHTTP = restHTTPJSON
}
if len(restHTTPSJSON) > 0 {
op.RestHTTPS = restHTTPSJSON
}
if len(accessAddrsJSON) > 0 { if len(accessAddrsJSON) > 0 {
op.AccessAddrs = accessAddrsJSON op.AccessAddrs = accessAddrsJSON
} }
@@ -128,7 +135,7 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON
} }
// 修改API节点 // 修改API节点
func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) error { func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, restIsOn bool, restHTTPJSON []byte, restHTTPSJSON []byte, accessAddrsJSON []byte, isOn bool) error {
if nodeId <= 0 { if nodeId <= 0 {
return errors.New("invalid nodeId") return errors.New("invalid nodeId")
} }
@@ -142,17 +149,28 @@ func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description str
if len(httpJSON) > 0 { if len(httpJSON) > 0 {
op.Http = httpJSON op.Http = httpJSON
} else { } else {
op.Http = "null" op.Http = "{}"
} }
if len(httpsJSON) > 0 { if len(httpsJSON) > 0 {
op.Https = httpsJSON op.Https = httpsJSON
} else { } else {
op.Https = "null" op.Https = "{}"
}
op.RestIsOn = restIsOn
if len(restHTTPJSON) > 0 {
op.RestHTTP = restHTTPJSON
} else {
op.RestHTTP = "{}"
}
if len(restHTTPSJSON) > 0 {
op.RestHTTPS = restHTTPSJSON
} else {
op.RestHTTPS = "{}"
} }
if len(accessAddrsJSON) > 0 { if len(accessAddrsJSON) > 0 {
op.AccessAddrs = accessAddrsJSON op.AccessAddrs = accessAddrsJSON
} else { } else {
op.AccessAddrs = "null" op.AccessAddrs = "[]"
} }
err := this.Save(op) err := this.Save(op)

View File

@@ -11,6 +11,9 @@ type APINode struct {
Description string `field:"description"` // 描述 Description string `field:"description"` // 描述
Http string `field:"http"` // 监听的HTTP配置 Http string `field:"http"` // 监听的HTTP配置
Https string `field:"https"` // 监听的HTTPS配置 Https string `field:"https"` // 监听的HTTPS配置
RestIsOn uint8 `field:"restIsOn"` // 是否开放REST
RestHTTP string `field:"restHTTP"` // REST HTTP配置
RestHTTPS string `field:"restHTTPS"` // REST HTTPS配置
AccessAddrs string `field:"accessAddrs"` // 外部访问地址 AccessAddrs string `field:"accessAddrs"` // 外部访问地址
Order uint32 `field:"order"` // 排序 Order uint32 `field:"order"` // 排序
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
@@ -30,6 +33,9 @@ type APINodeOperator struct {
Description interface{} // 描述 Description interface{} // 描述
Http interface{} // 监听的HTTP配置 Http interface{} // 监听的HTTP配置
Https interface{} // 监听的HTTPS配置 Https interface{} // 监听的HTTPS配置
RestIsOn interface{} // 是否开放REST
RestHTTP interface{} // REST HTTP配置
RestHTTPS interface{} // REST HTTPS配置
AccessAddrs interface{} // 外部访问地址 AccessAddrs interface{} // 外部访问地址
Order interface{} // 排序 Order interface{} // 排序
State interface{} // 状态 State interface{} // 状态

View File

@@ -93,3 +93,65 @@ func (this *APINode) DecodeAccessAddrStrings() ([]string, error) {
} }
return result, nil return result, nil
} }
// 解析Rest HTTP配置
func (this *APINode) DecodeRestHTTP() (*serverconfigs.HTTPProtocolConfig, error) {
if this.RestIsOn != 1 {
return nil, nil
}
if !IsNotNull(this.RestHTTP) {
return nil, nil
}
config := &serverconfigs.HTTPProtocolConfig{}
err := json.Unmarshal([]byte(this.RestHTTP), config)
if err != nil {
return nil, err
}
err = config.Init()
if err != nil {
return nil, err
}
return config, nil
}
// 解析HTTPS配置
func (this *APINode) DecodeRestHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) {
if this.RestIsOn != 1 {
return nil, nil
}
if !IsNotNull(this.RestHTTPS) {
return nil, nil
}
config := &serverconfigs.HTTPSProtocolConfig{}
err := json.Unmarshal([]byte(this.RestHTTPS), config)
if err != nil {
return nil, err
}
err = config.Init()
if err != nil {
return nil, err
}
if config.SSLPolicyRef != nil {
policyId := config.SSLPolicyRef.SSLPolicyId
if policyId > 0 {
sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(policyId)
if err != nil {
return nil, err
}
if sslPolicy != nil {
config.SSLPolicy = sslPolicy
}
}
}
err = config.Init()
if err != nil {
return nil, err
}
return config, nil
}

View File

@@ -2,11 +2,13 @@ package models
import ( import (
"fmt" "fmt"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/logs"
timeutil "github.com/iwind/TeaGo/utils/time" timeutil "github.com/iwind/TeaGo/utils/time"
"hash/crc32" "hash/crc32"
"regexp"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -46,17 +48,22 @@ func randomAccessLogDAO() (dao *HTTPAccessLogDAOWrapper) {
} }
// 检查表格是否存在 // 检查表格是否存在
func findAccessLogTableName(db *dbs.DB, day string) (string, bool, error) { func findAccessLogTableName(db *dbs.DB, day string) (tableName string, ok bool, err error) {
if !regexp.MustCompile(`^\d{8}$`).MatchString(day) {
err = errors.New("invalid day '" + day + "', should be YYYYMMDD")
return
}
config, err := db.Config() config, err := db.Config()
if err != nil { if err != nil {
return "", false, err return "", false, err
} }
tableName := "edgeHTTPAccessLogs_" + day tableName = "edgeHTTPAccessLogs_" + day
cacheKey := tableName + "_" + fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(config.Dsn))) cacheKey := tableName + "_" + fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(config.Dsn)))
accessLogLocker.RLock() accessLogLocker.RLock()
_, ok := accessLogTableMapping[cacheKey] _, ok = accessLogTableMapping[cacheKey]
accessLogLocker.RUnlock() accessLogLocker.RUnlock()
if ok { if ok {
return tableName, true, nil return tableName, true, nil

View File

@@ -168,6 +168,7 @@ func (this *HTTPAccessLogDAO) listAccessLogs(lastRequestId string, size int64, d
dao := daoWrapper.DAO dao := daoWrapper.DAO
tableName, exists, err := findAccessLogTableName(dao.Instance, day) tableName, exists, err := findAccessLogTableName(dao.Instance, day)
logs.Println("tableName:", tableName, exists, err) // TODO
if !exists { if !exists {
// 表格不存在则跳过 // 表格不存在则跳过
return return

View File

@@ -109,3 +109,17 @@ func (this *UserAccessKeyDAO) UpdateAccessKeyIsOn(accessKeyId int64, isOn bool)
Update() Update()
return err return err
} }
// 根据UniqueId查找AccessKey
func (this *UserAccessKeyDAO) FindAccessKeyWithUniqueId(uniqueId string) (*UserAccessKey, error) {
one, err := this.Query().
Attr("uniqueId", uniqueId).
Attr("isOn", true).
State(UserAccessKeyStateEnabled).
Find()
if one == nil || err != nil {
return nil, err
}
return one.(*UserAccessKey), nil
}

View File

@@ -75,72 +75,7 @@ func (this *APINode) Start() {
// 监听RPC服务 // 监听RPC服务
remotelogs.Println("API_NODE", "starting RPC server ...") remotelogs.Println("API_NODE", "starting RPC server ...")
// HTTP isListening := this.listenPorts(apiNode)
httpConfig, err := apiNode.DecodeHTTP()
if err != nil {
remotelogs.Error("API_NODE", "decode http config: "+err.Error())
return
}
isListening := false
if httpConfig != nil && httpConfig.IsOn && len(httpConfig.Listen) > 0 {
for _, listen := range httpConfig.Listen {
for _, addr := range listen.Addresses() {
listener, err := net.Listen("tcp", addr)
if err != nil {
remotelogs.Error("API_NODE", "listening '"+addr+"' failed: "+err.Error())
continue
}
go func() {
err := this.listenRPC(listener, nil)
if err != nil {
remotelogs.Error("API_NODE", "listening '"+addr+"' rpc: "+err.Error())
return
}
}()
isListening = true
}
}
}
// HTTPS
httpsConfig, err := apiNode.DecodeHTTPS()
if err != nil {
remotelogs.Error("API_NODE", "decode https config: "+err.Error())
return
}
if httpsConfig != nil &&
httpsConfig.IsOn &&
len(httpsConfig.Listen) > 0 &&
httpsConfig.SSLPolicy != nil &&
httpsConfig.SSLPolicy.IsOn &&
len(httpsConfig.SSLPolicy.Certs) > 0 {
certs := []tls.Certificate{}
for _, cert := range httpsConfig.SSLPolicy.Certs {
certs = append(certs, *cert.CertObject())
}
for _, listen := range httpsConfig.Listen {
for _, addr := range listen.Addresses() {
listener, err := net.Listen("tcp", addr)
if err != nil {
remotelogs.Error("API_NODE", "listening '"+addr+"' failed: "+err.Error())
continue
}
go func() {
err := this.listenRPC(listener, &tls.Config{
Certificates: certs,
})
if err != nil {
remotelogs.Error("API_NODE", "listening '"+addr+"' rpc: "+err.Error())
return
}
}()
isListening = true
}
}
}
// HTTP接口
if !isListening { if !isListening {
remotelogs.Error("API_NODE", "the api node require at least one listening address") remotelogs.Error("API_NODE", "the api node require at least one listening address")
@@ -155,10 +90,10 @@ func (this *APINode) Start() {
func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error { func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error {
var rpcServer *grpc.Server var rpcServer *grpc.Server
if tlsConfig == nil { if tlsConfig == nil {
remotelogs.Println("API_NODE", "listening http://"+listener.Addr().String()+" ...") remotelogs.Println("API_NODE", "listening GRPC http://"+listener.Addr().String()+" ...")
rpcServer = grpc.NewServer() rpcServer = grpc.NewServer()
} else { } else {
logs.Println("[API_NODE]listening https://" + listener.Addr().String() + " ...") logs.Println("[API_NODE]listening GRPC https://" + listener.Addr().String() + " ...")
rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
} }
pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{}) pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{})
@@ -264,3 +199,142 @@ func (this *APINode) autoUpgrade() error {
logs.Println("[API_NODE]upgrade database done") logs.Println("[API_NODE]upgrade database done")
return nil return nil
} }
// 启动端口
func (this *APINode) listenPorts(apiNode *models.APINode) (isListening bool) {
// HTTP
httpConfig, err := apiNode.DecodeHTTP()
if err != nil {
remotelogs.Error("API_NODE", "decode http config: "+err.Error())
return
}
isListening = false
if httpConfig != nil && httpConfig.IsOn && len(httpConfig.Listen) > 0 {
for _, listen := range httpConfig.Listen {
for _, addr := range listen.Addresses() {
listener, err := net.Listen("tcp", addr)
if err != nil {
remotelogs.Error("API_NODE", "listening '"+addr+"' failed: "+err.Error())
continue
}
go func() {
err := this.listenRPC(listener, nil)
if err != nil {
remotelogs.Error("API_NODE", "listening '"+addr+"' rpc: "+err.Error())
return
}
}()
isListening = true
}
}
}
// HTTPS
httpsConfig, err := apiNode.DecodeHTTPS()
if err != nil {
remotelogs.Error("API_NODE", "decode https config: "+err.Error())
return
}
if httpsConfig != nil &&
httpsConfig.IsOn &&
len(httpsConfig.Listen) > 0 &&
httpsConfig.SSLPolicy != nil &&
httpsConfig.SSLPolicy.IsOn &&
len(httpsConfig.SSLPolicy.Certs) > 0 {
certs := []tls.Certificate{}
for _, cert := range httpsConfig.SSLPolicy.Certs {
certs = append(certs, *cert.CertObject())
}
for _, listen := range httpsConfig.Listen {
for _, addr := range listen.Addresses() {
listener, err := net.Listen("tcp", addr)
if err != nil {
remotelogs.Error("API_NODE", "listening '"+addr+"' failed: "+err.Error())
continue
}
go func() {
err := this.listenRPC(listener, &tls.Config{
Certificates: certs,
})
if err != nil {
remotelogs.Error("API_NODE", "listening '"+addr+"' rpc: "+err.Error())
return
}
}()
isListening = true
}
}
}
// Rest HTTP
restHTTPConfig, err := apiNode.DecodeRestHTTP()
if err != nil {
remotelogs.Error("API_NODE", "decode REST http config: "+err.Error())
return
}
if restHTTPConfig != nil && restHTTPConfig.IsOn && len(restHTTPConfig.Listen) > 0 {
for _, listen := range restHTTPConfig.Listen {
for _, addr := range listen.Addresses() {
listener, err := net.Listen("tcp", addr)
if err != nil {
remotelogs.Error("API_NODE", "listening REST 'http://"+addr+"' failed: "+err.Error())
continue
}
go func() {
remotelogs.Println("API_NODE", "listening REST http://"+addr+" ...")
server := &RestServer{}
err := server.Listen(listener)
if err != nil {
remotelogs.Error("API_NODE", "listening REST 'http://"+addr+"' failed: "+err.Error())
return
}
}()
isListening = true
}
}
}
// Rest HTTPS
restHTTPSConfig, err := apiNode.DecodeRestHTTPS()
if err != nil {
remotelogs.Error("API_NODE", "decode REST https config: "+err.Error())
return
}
if restHTTPSConfig != nil &&
restHTTPSConfig.IsOn &&
len(restHTTPSConfig.Listen) > 0 &&
restHTTPSConfig.SSLPolicy != nil &&
restHTTPSConfig.SSLPolicy.IsOn &&
len(restHTTPSConfig.SSLPolicy.Certs) > 0 {
for _, listen := range restHTTPSConfig.Listen {
for _, addr := range listen.Addresses() {
listener, err := net.Listen("tcp", addr)
if err != nil {
remotelogs.Error("API_NODE", "listening REST 'https://"+addr+"' failed: "+err.Error())
continue
}
go func() {
remotelogs.Println("API_NODE", "listening REST https://"+addr+" ...")
server := &RestServer{}
certs := []tls.Certificate{}
for _, cert := range httpsConfig.SSLPolicy.Certs {
certs = append(certs, *cert.CertObject())
}
err := server.ListenHTTPS(listener, &tls.Config{
Certificates: certs,
})
if err != nil {
remotelogs.Error("API_NODE", "listening REST 'https://"+addr+"' failed: "+err.Error())
return
}
}()
isListening = true
}
}
}
return
}

View File

@@ -0,0 +1,189 @@
package nodes
import (
"context"
"crypto/tls"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/iwind/TeaGo/maps"
"io/ioutil"
"net"
"net/http"
"reflect"
"regexp"
"time"
)
var servicePathReg = regexp.MustCompile(`^/([a-zA-Z0-9]+)/([a-zA-Z0-9]+)$`)
var servicesMap = map[string]reflect.Value{
"APIAccessTokenService": reflect.ValueOf(new(services.APIAccessTokenService)),
"HTTPAccessLogService": reflect.ValueOf(new(services.HTTPAccessLogService)),
}
type RestServer struct{}
func (this *RestServer) Listen(listener net.Listener) error {
mux := http.NewServeMux()
mux.HandleFunc("/", this.handle)
server := &http.Server{}
server.Handler = mux
return server.Serve(listener)
}
func (this *RestServer) ListenHTTPS(listener net.Listener, tlsConfig *tls.Config) error {
mux := http.NewServeMux()
mux.HandleFunc("/", this.handle)
server := &http.Server{}
server.Handler = mux
server.TLSConfig = tlsConfig
return server.ServeTLS(listener, "", "")
}
func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) {
path := req.URL.Path
matches := servicePathReg.FindStringSubmatch(path)
if len(matches) != 3 {
writer.WriteHeader(http.StatusNotFound)
return
}
serviceName := matches[1]
methodName := matches[2]
serviceType, ok := servicesMap[serviceName]
if !ok {
writer.WriteHeader(http.StatusNotFound)
return
}
method := serviceType.MethodByName(methodName)
if !method.IsValid() {
writer.WriteHeader(http.StatusNotFound)
return
}
if method.Type().NumIn() != 2 || method.Type().NumOut() != 2 {
writer.WriteHeader(http.StatusNotFound)
return
}
if method.Type().In(0).Name() != "Context" {
writer.WriteHeader(http.StatusNotFound)
return
}
// 是否显示Pretty后的JSON
shouldPretty := req.Header.Get("Edge-Response-Pretty") == "on"
// 上下文
ctx := context.Background()
if serviceName != "APIAccessTokenService" || methodName != "GetAPIAccessToken" {
// 校验TOKEN
token := req.Header.Get("Edge-Access-Token")
if len(token) == 0 {
this.writeJSON(writer, maps.Map{
"code": 400,
"data": maps.Map{},
"message": "require 'Edge-Access-Token' header",
}, shouldPretty)
return
}
accessToken, err := models.SharedAPIAccessTokenDAO.FindAccessToken(token)
if err != nil {
this.writeJSON(writer, maps.Map{
"code": 400,
"data": maps.Map{},
"message": "server error: " + err.Error(),
}, shouldPretty)
return
}
if accessToken == nil || int64(accessToken.ExpiredAt) < time.Now().Unix() {
this.writeJSON(writer, maps.Map{
"code": 400,
"data": maps.Map{},
"message": "invalid access token",
}, shouldPretty)
return
}
if accessToken.UserId > 0 {
ctx = rpcutils.NewPlainContext("user", int64(accessToken.UserId))
} else {
// TODO 支持更多类型的角色
this.writeJSON(writer, maps.Map{
"code": 400,
"data": maps.Map{},
"message": "not supported role",
}, shouldPretty)
return
}
}
// TODO 需要防止BODY过大攻击
body, err := ioutil.ReadAll(req.Body)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte(err.Error()))
return
}
// 请求数据
reqValue := reflect.New(method.Type().In(1).Elem()).Interface()
err = json.Unmarshal(body, reqValue)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte(err.Error()))
return
}
result := method.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(reqValue)})
resultErr := result[1].Interface()
if resultErr != nil {
e, ok := resultErr.(error)
if ok {
this.writeJSON(writer, maps.Map{
"code": 400,
"message": e.Error(),
"data": maps.Map{},
}, shouldPretty)
} else {
this.writeJSON(writer, maps.Map{
"code": 500,
"message": "server error: server should return a error object, but return a " + result[1].Type().String(),
"data": maps.Map{},
}, shouldPretty)
}
} else { // 没有返回错误
data := maps.Map{
"code": 200,
"message": "ok",
"data": result[0].Interface(),
}
var dataJSON []byte
if shouldPretty {
dataJSON = data.AsPrettyJSON()
} else {
dataJSON = data.AsJSON()
}
if err != nil {
this.writeJSON(writer, maps.Map{
"code": 500,
"message": "server error: marshal json failed: " + err.Error(),
"data": maps.Map{},
}, shouldPretty)
} else {
_, _ = writer.Write(dataJSON)
}
}
}
func (this *RestServer) writeJSON(writer http.ResponseWriter, v maps.Map, pretty bool) {
if pretty {
_, _ = writer.Write(v.AsPrettyJSON())
} else {
_, _ = writer.Write(v.AsJSON())
}
}

View File

@@ -0,0 +1,40 @@
package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
)
// AccessToken相关服务
type APIAccessTokenService struct {
}
// 获取AccessToken
func (this *APIAccessTokenService) GetAPIAccessToken(ctx context.Context, req *pb.GetAPIAccessTokenRequest) (*pb.GetAPIAccessTokenResponse, error) {
if req.Type == "user" { // 用户
accessKey, err := models.SharedUserAccessKeyDAO.FindAccessKeyWithUniqueId(req.AccessKeyId)
if err != nil {
return nil, err
}
if accessKey == nil {
return nil, errors.New("access key not found")
}
if accessKey.Secret != req.AccessKey {
return nil, errors.New("access key not found")
}
// 创建AccessToken
token, expiresAt, err := models.SharedAPIAccessTokenDAO.GenerateAccessToken(int64(accessKey.UserId))
if err != nil {
return nil, err
}
return &pb.GetAPIAccessTokenResponse{
Token: token,
ExpiresAt: expiresAt,
}, nil
} else {
return nil, errors.New("unsupported type '" + req.Type + "'")
}
}

View File

@@ -19,7 +19,7 @@ func (this *APINodeService) CreateAPINode(ctx context.Context, req *pb.CreateAPI
return nil, err return nil, err
} }
nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn) nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.RestIsOn, req.RestHTTPJSON, req.RestHTTPSJSON, req.AccessAddrsJSON, req.IsOn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -34,7 +34,7 @@ func (this *APINodeService) UpdateAPINode(ctx context.Context, req *pb.UpdateAPI
return nil, err return nil, err
} }
err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn) err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.RestIsOn, req.RestHTTPJSON, req.RestHTTPSJSON, req.AccessAddrsJSON, req.IsOn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -138,6 +138,9 @@ func (this *APINodeService) ListEnabledAPINodes(ctx context.Context, req *pb.Lis
Description: node.Description, Description: node.Description,
HttpJSON: []byte(node.Http), HttpJSON: []byte(node.Http),
HttpsJSON: []byte(node.Https), HttpsJSON: []byte(node.Https),
RestIsOn: node.RestIsOn == 1,
RestHTTPJSON: []byte(node.RestHTTP),
RestHTTPSJSON: []byte(node.RestHTTPS),
AccessAddrsJSON: []byte(node.AccessAddrs), AccessAddrsJSON: []byte(node.AccessAddrs),
AccessAddrs: accessAddrs, AccessAddrs: accessAddrs,
StatusJSON: []byte(node.Status), StatusJSON: []byte(node.Status),
@@ -178,6 +181,9 @@ func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.Find
Description: node.Description, Description: node.Description,
HttpJSON: []byte(node.Http), HttpJSON: []byte(node.Http),
HttpsJSON: []byte(node.Https), HttpsJSON: []byte(node.Https),
RestIsOn: node.RestIsOn == 1,
RestHTTPJSON: []byte(node.RestHTTP),
RestHTTPSJSON: []byte(node.RestHTTPS),
AccessAddrsJSON: []byte(node.AccessAddrs), AccessAddrsJSON: []byte(node.AccessAddrs),
AccessAddrs: accessAddrs, AccessAddrs: accessAddrs,
} }

View File

@@ -3,12 +3,14 @@ package services
import ( import (
"context" "context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
) )
// 访问日志相关服务 // 访问日志相关服务
type HTTPAccessLogService struct { type HTTPAccessLogService struct {
BaseService
} }
// 创建访问日志 // 创建访问日志
@@ -34,11 +36,23 @@ func (this *HTTPAccessLogService) CreateHTTPAccessLogs(ctx context.Context, req
// 列出单页访问日志 // 列出单页访问日志
func (this *HTTPAccessLogService) ListHTTPAccessLogs(ctx context.Context, req *pb.ListHTTPAccessLogsRequest) (*pb.ListHTTPAccessLogsResponse, error) { func (this *HTTPAccessLogService) ListHTTPAccessLogs(ctx context.Context, req *pb.ListHTTPAccessLogsRequest) (*pb.ListHTTPAccessLogsResponse, error) {
// 校验请求 // 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 检查服务ID
if userId > 0 {
if req.ServerId <= 0 {
return nil, errors.New("invalid serverId")
}
err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId)
if err != nil {
return nil, err
}
}
accessLogs, requestId, hasMore, err := models.SharedHTTPAccessLogDAO.ListAccessLogs(req.RequestId, req.Size, req.Day, req.ServerId, req.Reverse, req.HasError, req.FirewallPolicyId, req.FirewallRuleGroupId, req.FirewallRuleSetId) accessLogs, requestId, hasMore, err := models.SharedHTTPAccessLogDAO.ListAccessLogs(req.RequestId, req.Size, req.Day, req.ServerId, req.Reverse, req.HasError, req.FirewallPolicyId, req.FirewallRuleGroupId, req.FirewallRuleSetId)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -0,0 +1,37 @@
package rpcutils
import (
"context"
"time"
)
type PlainContext struct {
UserType string
UserId int64
ctx context.Context
}
func NewPlainContext(userType string, userId int64) *PlainContext {
return &PlainContext{
UserType: userType,
UserId: userId,
ctx: context.Background(),
}
}
func (this *PlainContext) Deadline() (deadline time.Time, ok bool) {
return this.ctx.Deadline()
}
func (this *PlainContext) Done() <-chan struct{} {
return this.ctx.Done()
}
func (this *PlainContext) Err() error {
return this.ctx.Err()
}
func (this *PlainContext) Value(key interface{}) interface{} {
return this.ctx.Value(key)
}

View File

@@ -33,6 +33,29 @@ const (
// 校验请求 // 校验请求
func ValidateRequest(ctx context.Context, userTypes ...UserType) (userType UserType, userId int64, err error) { func ValidateRequest(ctx context.Context, userTypes ...UserType) (userType UserType, userId int64, err error) {
if ctx == nil {
err = errors.New("context should not be nil")
return
}
// 支持直接认证
plainCtx, ok := ctx.(*PlainContext)
if ok {
userType = plainCtx.UserType
userId = plainCtx.UserId
if len(userTypes) > 0 && !lists.ContainsString(userTypes, userType) {
userType = UserTypeNone
userId = 0
}
if userId <= 0 {
err = errors.New("context: can not find user or permission denied")
}
return
}
md, ok := metadata.FromIncomingContext(ctx) md, ok := metadata.FromIncomingContext(ctx)
if !ok { if !ok {
return UserTypeNone, 0, errors.New("context: need 'nodeId'") return UserTypeNone, 0, errors.New("context: need 'nodeId'")

View File

@@ -160,7 +160,7 @@ func (this *Setup) Run() error {
} }
// 创建API节点 // 创建API节点
nodeId, err := dao.CreateAPINode("默认API节点", "这是默认创建的第一个API节点", httpJSON, httpsJSON, addrsJSON, true) nodeId, err := dao.CreateAPINode("默认API节点", "这是默认创建的第一个API节点", httpJSON, httpsJSON, false, nil, nil, addrsJSON, true)
if err != nil { if err != nil {
return errors.New("create api node in database failed: " + err.Error()) return errors.New("create api node in database failed: " + err.Error())
} }