From 1c221e305b9c4cb139df5bd5ad5a04e3098afd8a Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Thu, 4 Aug 2022 19:36:25 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AE=80=E5=8C=96API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/nodes/rest_server.go | 49 +++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/internal/nodes/rest_server.go b/internal/nodes/rest_server.go index b620f5f6..c2450611 100644 --- a/internal/nodes/rest_server.go +++ b/internal/nodes/rest_server.go @@ -7,6 +7,7 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/rpc/services" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeAPI/internal/utils/sizes" "github.com/iwind/TeaGo/maps" "io" "net" @@ -25,15 +26,15 @@ var restServicesMap = map[string]reflect.Value{ type RestServer struct{} func (this *RestServer) Listen(listener net.Listener) error { - mux := http.NewServeMux() + var mux = http.NewServeMux() mux.HandleFunc("/", this.handle) - server := &http.Server{} + var server = &http.Server{} server.Handler = mux return server.Serve(listener) } func (this *RestServer) ListenHTTPS(listener net.Listener, tlsConfig *tls.Config) error { - mux := http.NewServeMux() + var mux = http.NewServeMux() mux.HandleFunc("/", this.handle) server := &http.Server{} server.Handler = mux @@ -42,10 +43,10 @@ func (this *RestServer) ListenHTTPS(listener net.Listener, tlsConfig *tls.Config } func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) { - path := req.URL.Path + var path = req.URL.Path // 是否显示Pretty后的JSON - shouldPretty := req.Header.Get("X-Edge-Response-Pretty") == "on" + var shouldPretty = req.Header.Get("X-Edge-Response-Pretty") == "on" // 兼容老的Header var oldShouldPretty = req.Header.Get("Edge-Response-Pretty") @@ -63,14 +64,14 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) { return } - matches := servicePathReg.FindStringSubmatch(path) + var matches = servicePathReg.FindStringSubmatch(path) if len(matches) != 3 { writer.WriteHeader(http.StatusNotFound) return } - serviceName := matches[1] - methodName := matches[2] + var serviceName = matches[1] + var methodName = matches[2] serviceType, ok := restServicesMap[serviceName] if !ok { @@ -85,10 +86,20 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) { // 再次查找 methodName = strings.ToUpper(string(methodName[0])) + methodName[1:] - method := serviceType.MethodByName(methodName) + var method = serviceType.MethodByName(methodName) if !method.IsValid() { - writer.WriteHeader(http.StatusNotFound) - return + // 兼容Enabled + if strings.Contains(methodName, "Enabled") { + methodName = strings.Replace(methodName, "Enabled", "", 1) + method = serviceType.MethodByName(methodName) + if !method.IsValid() { + writer.WriteHeader(http.StatusNotFound) + return + } + } else { + writer.WriteHeader(http.StatusNotFound) + return + } } if method.Type().NumIn() != 2 || method.Type().NumOut() != 2 { writer.WriteHeader(http.StatusNotFound) @@ -100,11 +111,11 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) { } // 上下文 - ctx := context.Background() + var ctx = context.Background() if serviceName != "APIAccessTokenService" || (methodName != "GetAPIAccessToken" && methodName != "getAPIAccessToken") { // 校验TOKEN - token := req.Header.Get("X-Edge-Access-Token") + var token = req.Header.Get("X-Edge-Access-Token") if len(token) == 0 { token = req.Header.Get("Edge-Access-Token") if len(token) == 0 { @@ -151,8 +162,8 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) { } } - // TODO 需要防止BODY过大攻击 - body, err := io.ReadAll(req.Body) + // TODO 可以设置最大可接收内容尺寸 + body, err := io.ReadAll(io.LimitReader(req.Body, 32*sizes.M)) if err != nil { writer.WriteHeader(http.StatusBadRequest) _, _ = writer.Write([]byte(err.Error())) @@ -160,7 +171,7 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) { } // 请求数据 - reqValue := reflect.New(method.Type().In(1).Elem()).Interface() + var reqValue = reflect.New(method.Type().In(1).Elem()).Interface() err = json.Unmarshal(body, reqValue) if err != nil { writer.WriteHeader(http.StatusBadRequest) @@ -168,8 +179,8 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) { return } - result := method.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(reqValue)}) - resultErr := result[1].Interface() + var result = method.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(reqValue)}) + var resultErr = result[1].Interface() if resultErr != nil { e, ok := resultErr.(error) if ok { @@ -186,7 +197,7 @@ func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) { }, shouldPretty) } } else { // 没有返回错误 - data := maps.Map{ + var data = maps.Map{ "code": 200, "message": "ok", "data": result[0].Interface(),