简化API

This commit is contained in:
GoEdgeLab
2022-08-04 19:36:25 +08:00
parent 7ea3c6a5a3
commit 1c221e305b

View File

@@ -7,6 +7,7 @@ import (
"github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/rpc/services" "github.com/TeaOSLab/EdgeAPI/internal/rpc/services"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils/sizes"
"github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/maps"
"io" "io"
"net" "net"
@@ -25,15 +26,15 @@ var restServicesMap = map[string]reflect.Value{
type RestServer struct{} type RestServer struct{}
func (this *RestServer) Listen(listener net.Listener) error { func (this *RestServer) Listen(listener net.Listener) error {
mux := http.NewServeMux() var mux = http.NewServeMux()
mux.HandleFunc("/", this.handle) mux.HandleFunc("/", this.handle)
server := &http.Server{} var server = &http.Server{}
server.Handler = mux server.Handler = mux
return server.Serve(listener) return server.Serve(listener)
} }
func (this *RestServer) ListenHTTPS(listener net.Listener, tlsConfig *tls.Config) error { func (this *RestServer) ListenHTTPS(listener net.Listener, tlsConfig *tls.Config) error {
mux := http.NewServeMux() var mux = http.NewServeMux()
mux.HandleFunc("/", this.handle) mux.HandleFunc("/", this.handle)
server := &http.Server{} server := &http.Server{}
server.Handler = mux server.Handler = mux
@@ -42,10 +43,10 @@ func (this *RestServer) ListenHTTPS(listener net.Listener, tlsConfig *tls.Config
} }
func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) { func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) {
path := req.URL.Path var path = req.URL.Path
// 是否显示Pretty后的JSON // 是否显示Pretty后的JSON
shouldPretty := req.Header.Get("X-Edge-Response-Pretty") == "on" var shouldPretty = req.Header.Get("X-Edge-Response-Pretty") == "on"
// 兼容老的Header // 兼容老的Header
var oldShouldPretty = req.Header.Get("Edge-Response-Pretty") var oldShouldPretty = req.Header.Get("Edge-Response-Pretty")
@@ -63,14 +64,14 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) {
return return
} }
matches := servicePathReg.FindStringSubmatch(path) var matches = servicePathReg.FindStringSubmatch(path)
if len(matches) != 3 { if len(matches) != 3 {
writer.WriteHeader(http.StatusNotFound) writer.WriteHeader(http.StatusNotFound)
return return
} }
serviceName := matches[1] var serviceName = matches[1]
methodName := matches[2] var methodName = matches[2]
serviceType, ok := restServicesMap[serviceName] serviceType, ok := restServicesMap[serviceName]
if !ok { if !ok {
@@ -85,10 +86,20 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) {
// 再次查找 // 再次查找
methodName = strings.ToUpper(string(methodName[0])) + methodName[1:] methodName = strings.ToUpper(string(methodName[0])) + methodName[1:]
method := serviceType.MethodByName(methodName) var method = serviceType.MethodByName(methodName)
if !method.IsValid() { if !method.IsValid() {
writer.WriteHeader(http.StatusNotFound) // 兼容Enabled
return if strings.Contains(methodName, "Enabled") {
methodName = strings.Replace(methodName, "Enabled", "", 1)
method = serviceType.MethodByName(methodName)
if !method.IsValid() {
writer.WriteHeader(http.StatusNotFound)
return
}
} else {
writer.WriteHeader(http.StatusNotFound)
return
}
} }
if method.Type().NumIn() != 2 || method.Type().NumOut() != 2 { if method.Type().NumIn() != 2 || method.Type().NumOut() != 2 {
writer.WriteHeader(http.StatusNotFound) writer.WriteHeader(http.StatusNotFound)
@@ -100,11 +111,11 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) {
} }
// 上下文 // 上下文
ctx := context.Background() var ctx = context.Background()
if serviceName != "APIAccessTokenService" || (methodName != "GetAPIAccessToken" && methodName != "getAPIAccessToken") { if serviceName != "APIAccessTokenService" || (methodName != "GetAPIAccessToken" && methodName != "getAPIAccessToken") {
// 校验TOKEN // 校验TOKEN
token := req.Header.Get("X-Edge-Access-Token") var token = req.Header.Get("X-Edge-Access-Token")
if len(token) == 0 { if len(token) == 0 {
token = req.Header.Get("Edge-Access-Token") token = req.Header.Get("Edge-Access-Token")
if len(token) == 0 { if len(token) == 0 {
@@ -151,8 +162,8 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) {
} }
} }
// TODO 需要防止BODY过大攻击 // TODO 可以设置最大可接收内容尺寸
body, err := io.ReadAll(req.Body) body, err := io.ReadAll(io.LimitReader(req.Body, 32*sizes.M))
if err != nil { if err != nil {
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte(err.Error())) _, _ = writer.Write([]byte(err.Error()))
@@ -160,7 +171,7 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) {
} }
// 请求数据 // 请求数据
reqValue := reflect.New(method.Type().In(1).Elem()).Interface() var reqValue = reflect.New(method.Type().In(1).Elem()).Interface()
err = json.Unmarshal(body, reqValue) err = json.Unmarshal(body, reqValue)
if err != nil { if err != nil {
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
@@ -168,8 +179,8 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) {
return return
} }
result := method.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(reqValue)}) var result = method.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(reqValue)})
resultErr := result[1].Interface() var resultErr = result[1].Interface()
if resultErr != nil { if resultErr != nil {
e, ok := resultErr.(error) e, ok := resultErr.(error)
if ok { if ok {
@@ -186,7 +197,7 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) {
}, shouldPretty) }, shouldPretty)
} }
} else { // 没有返回错误 } else { // 没有返回错误
data := maps.Map{ var data = maps.Map{
"code": 200, "code": 200,
"message": "ok", "message": "ok",
"data": result[0].Interface(), "data": result[0].Interface(),