Files
EdgeAPI/internal/nodes/rest_server.go

274 lines
7.2 KiB
Go
Raw Normal View History

2021-01-01 20:49:09 +08:00
package nodes
import (
"context"
"crypto/tls"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
2022-08-04 19:36:25 +08:00
"github.com/TeaOSLab/EdgeAPI/internal/utils/sizes"
2021-01-01 20:49:09 +08:00
"github.com/iwind/TeaGo/maps"
2022-08-04 11:41:42 +08:00
"io"
2021-01-01 20:49:09 +08:00
"net"
"net/http"
"reflect"
"regexp"
"strings"
2021-01-01 20:49:09 +08:00
"time"
)
var servicePathReg = regexp.MustCompile(`^/([a-zA-Z0-9]+)/([a-zA-Z0-9]+)$`)
var restServicesMap = map[string]reflect.Value{
2021-01-01 20:49:09 +08:00
"APIAccessTokenService": reflect.ValueOf(new(services.APIAccessTokenService)),
}
type RestServer struct{}
func (this *RestServer) Listen(listener net.Listener) error {
2022-08-04 19:36:25 +08:00
var mux = http.NewServeMux()
2021-01-01 20:49:09 +08:00
mux.HandleFunc("/", this.handle)
2022-08-04 19:36:25 +08:00
var server = &http.Server{}
2021-01-01 20:49:09 +08:00
server.Handler = mux
return server.Serve(listener)
}
func (this *RestServer) ListenHTTPS(listener net.Listener, tlsConfig *tls.Config) error {
2022-08-04 19:36:25 +08:00
var mux = http.NewServeMux()
2021-01-01 20:49:09 +08:00
mux.HandleFunc("/", this.handle)
server := &http.Server{}
server.Handler = mux
server.TLSConfig = tlsConfig
return server.ServeTLS(listener, "", "")
}
func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) {
2022-08-04 19:36:25 +08:00
var path = req.URL.Path
2021-01-03 21:37:47 +08:00
// 是否显示Pretty后的JSON
2022-08-04 19:36:25 +08:00
var shouldPretty = req.Header.Get("X-Edge-Response-Pretty") == "on"
2021-10-19 19:49:06 +08:00
// 兼容老的Header
var oldShouldPretty = req.Header.Get("Edge-Response-Pretty")
if len(oldShouldPretty) > 0 {
shouldPretty = oldShouldPretty == "on"
}
2021-01-03 21:37:47 +08:00
// 欢迎页
if path == "/" {
this.writeJSON(writer, maps.Map{
"code": 200,
"message": "Welcome to API",
"data": maps.Map{},
}, shouldPretty)
return
}
2022-08-04 19:36:25 +08:00
var matches = servicePathReg.FindStringSubmatch(path)
2021-01-01 20:49:09 +08:00
if len(matches) != 3 {
writer.WriteHeader(http.StatusNotFound)
this.writeJSON(writer, maps.Map{
"code": "404",
"message": "invalid api path '" + path + "'",
"data": maps.Map{},
}, shouldPretty)
2021-01-01 20:49:09 +08:00
return
}
2022-08-04 19:36:25 +08:00
var serviceName = matches[1]
var methodName = matches[2]
2021-01-01 20:49:09 +08:00
serviceType, ok := restServicesMap[serviceName]
2021-01-01 20:49:09 +08:00
if !ok {
writer.WriteHeader(http.StatusNotFound)
this.writeJSON(writer, maps.Map{
"code": "404",
"message": "service '" + serviceName + "' not found",
"data": maps.Map{},
}, shouldPretty)
2021-01-01 20:49:09 +08:00
return
}
if len(methodName) == 0 {
writer.WriteHeader(http.StatusNotFound)
this.writeJSON(writer, maps.Map{
"code": "404",
"message": "method '" + methodName + "' not found",
"data": maps.Map{},
}, shouldPretty)
return
}
// 再次查找
methodName = strings.ToUpper(string(methodName[0])) + methodName[1:]
2022-08-04 19:36:25 +08:00
var method = serviceType.MethodByName(methodName)
2021-01-01 20:49:09 +08:00
if !method.IsValid() {
2022-08-04 19:36:25 +08:00
// 兼容Enabled
if strings.Contains(methodName, "Enabled") {
methodName = strings.Replace(methodName, "Enabled", "", 1)
method = serviceType.MethodByName(methodName)
if !method.IsValid() {
writer.WriteHeader(http.StatusNotFound)
this.writeJSON(writer, maps.Map{
"code": "404",
"message": "method '" + methodName + "' not found",
"data": maps.Map{},
}, shouldPretty)
2022-08-04 19:36:25 +08:00
return
}
} else {
writer.WriteHeader(http.StatusNotFound)
this.writeJSON(writer, maps.Map{
"code": "404",
"message": "method '" + methodName + "' not found",
"data": maps.Map{},
}, shouldPretty)
2022-08-04 19:36:25 +08:00
return
}
2021-01-01 20:49:09 +08:00
}
if method.Type().NumIn() != 2 || method.Type().NumOut() != 2 {
writer.WriteHeader(http.StatusNotFound)
this.writeJSON(writer, maps.Map{
"code": "404",
"message": "method '" + methodName + "' not found",
"data": maps.Map{},
}, shouldPretty)
2021-01-01 20:49:09 +08:00
return
}
if method.Type().In(0).Name() != "Context" {
writer.WriteHeader(http.StatusNotFound)
this.writeJSON(writer, maps.Map{
"code": "404",
"message": "method '" + methodName + "' not found (or invalid context)",
"data": maps.Map{},
}, shouldPretty)
2021-01-01 20:49:09 +08:00
return
}
// 上下文
2022-08-04 19:36:25 +08:00
var ctx = context.Background()
2021-01-01 20:49:09 +08:00
if serviceName != "APIAccessTokenService" || (methodName != "GetAPIAccessToken" && methodName != "getAPIAccessToken") {
2021-01-01 20:49:09 +08:00
// 校验TOKEN
2022-08-04 19:36:25 +08:00
var token = req.Header.Get("X-Edge-Access-Token")
2021-01-01 20:49:09 +08:00
if len(token) == 0 {
2021-10-19 19:49:06 +08:00
token = req.Header.Get("Edge-Access-Token")
if len(token) == 0 {
this.writeJSON(writer, maps.Map{
"code": 400,
"data": maps.Map{},
"message": "require 'X-Edge-Access-Token' header",
}, shouldPretty)
return
}
2021-01-01 20:49:09 +08:00
}
accessToken, err := models.SharedAPIAccessTokenDAO.FindAccessToken(nil, token)
2021-01-01 20:49:09 +08:00
if err != nil {
this.writeJSON(writer, maps.Map{
"code": 400,
"data": maps.Map{},
"message": "server error: " + err.Error(),
}, shouldPretty)
return
}
if accessToken == nil || int64(accessToken.ExpiredAt) < time.Now().Unix() {
this.writeJSON(writer, maps.Map{
"code": 400,
"data": maps.Map{},
"message": "invalid access token",
}, shouldPretty)
return
}
if accessToken.UserId > 0 {
ctx = rpcutils.NewPlainContext("user", int64(accessToken.UserId))
} else if accessToken.AdminId > 0 {
ctx = rpcutils.NewPlainContext("admin", int64(accessToken.AdminId))
2021-01-01 20:49:09 +08:00
} else {
// TODO 支持更多类型的角色
this.writeJSON(writer, maps.Map{
"code": 400,
"data": maps.Map{},
"message": "not supported role",
}, shouldPretty)
return
}
}
2022-08-04 19:36:25 +08:00
// TODO 可以设置最大可接收内容尺寸
body, err := io.ReadAll(io.LimitReader(req.Body, 32*sizes.M))
2021-01-01 20:49:09 +08:00
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte(err.Error()))
return
}
// 如果为空,表示传的数据为空
if len(body) == 0 {
body = []byte("{}")
}
2021-01-01 20:49:09 +08:00
// 请求数据
2022-08-04 19:36:25 +08:00
var reqValue = reflect.New(method.Type().In(1).Elem()).Interface()
2021-01-01 20:49:09 +08:00
err = json.Unmarshal(body, reqValue)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("Decode request failed: " + err.Error() + ". Request body should be a valid JSON data"))
2021-01-01 20:49:09 +08:00
return
}
2022-08-04 19:36:25 +08:00
var result = method.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(reqValue)})
var resultErr = result[1].Interface()
2021-01-01 20:49:09 +08:00
if resultErr != nil {
e, ok := resultErr.(error)
if ok {
this.writeJSON(writer, maps.Map{
"code": 400,
"message": e.Error(),
"data": maps.Map{},
}, shouldPretty)
} else {
this.writeJSON(writer, maps.Map{
"code": 500,
"message": "server error: server should return a error object, but return a " + result[1].Type().String(),
"data": maps.Map{},
}, shouldPretty)
}
} else { // 没有返回错误
2022-08-04 19:36:25 +08:00
var data = maps.Map{
2021-01-01 20:49:09 +08:00
"code": 200,
"message": "ok",
"data": result[0].Interface(),
}
var dataJSON []byte
if shouldPretty {
dataJSON = data.AsPrettyJSON()
} else {
dataJSON = data.AsJSON()
}
if err != nil {
this.writeJSON(writer, maps.Map{
"code": 500,
"message": "server error: marshal json failed: " + err.Error(),
"data": maps.Map{},
}, shouldPretty)
} else {
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
2021-01-01 20:49:09 +08:00
_, _ = writer.Write(dataJSON)
}
}
}
func (this *RestServer) writeJSON(writer http.ResponseWriter, v maps.Map, pretty bool) {
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
2021-01-03 21:37:47 +08:00
2021-01-01 20:49:09 +08:00
if pretty {
_, _ = writer.Write(v.AsPrettyJSON())
} else {
_, _ = writer.Write(v.AsJSON())
}
}