diff --git a/internal/db/models/api_access_token_dao.go b/internal/db/models/api_access_token_dao.go new file mode 100644 index 00000000..61e211b4 --- /dev/null +++ b/internal/db/models/api_access_token_dao.go @@ -0,0 +1,68 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/rands" + "time" +) + +type APIAccessTokenDAO dbs.DAO + +func NewAPIAccessTokenDAO() *APIAccessTokenDAO { + return dbs.NewDAO(&APIAccessTokenDAO{ + DAOObject: dbs.DAOObject{ + DB: Tea.Env, + Table: "edgeAPIAccessTokens", + Model: new(APIAccessToken), + PkName: "id", + }, + }).(*APIAccessTokenDAO) +} + +var SharedAPIAccessTokenDAO *APIAccessTokenDAO + +func init() { + dbs.OnReady(func() { + SharedAPIAccessTokenDAO = NewAPIAccessTokenDAO() + }) +} + +// 生成AccessToken +func (this *APIAccessTokenDAO) GenerateAccessToken(userId int64) (token string, expiresAt int64, err error) { + // 查询以前的 + accessToken, err := this.Query(). + Attr("userId", userId). + Find() + if err != nil { + return "", 0, err + } + + token = rands.String(128) // TODO 增强安全性,将来使用 base64_encode(encrypt(salt+random)) 算法来代替 + expiresAt = time.Now().Unix() + 7200 + + op := NewAPIAccessTokenOperator() + + if accessToken != nil { + op.Id = accessToken.(*APIAccessToken).Id + } + + op.UserId = userId + op.Token = token + op.CreatedAt = time.Now().Unix() + op.ExpiredAt = expiresAt + err = this.Save(op) + return +} + +// 查找AccessToken +func (this *APIAccessTokenDAO) FindAccessToken(token string) (*APIAccessToken, error) { + one, err := this.Query(). + Attr("token", token). + Find() + if one == nil || err != nil { + return nil, err + } + return one.(*APIAccessToken), nil +} diff --git a/internal/db/models/api_access_token_dao_test.go b/internal/db/models/api_access_token_dao_test.go new file mode 100644 index 00000000..97c24b56 --- /dev/null +++ b/internal/db/models/api_access_token_dao_test.go @@ -0,0 +1,5 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" +) diff --git a/internal/db/models/api_access_token_model.go b/internal/db/models/api_access_token_model.go new file mode 100644 index 00000000..e2e38d95 --- /dev/null +++ b/internal/db/models/api_access_token_model.go @@ -0,0 +1,22 @@ +package models + +// API访问令牌 +type APIAccessToken struct { + Id uint64 `field:"id"` // ID + UserId uint32 `field:"userId"` // 用户ID + Token string `field:"token"` // 令牌 + CreatedAt uint64 `field:"createdAt"` // 创建时间 + ExpiredAt uint64 `field:"expiredAt"` // 过期时间 +} + +type APIAccessTokenOperator struct { + Id interface{} // ID + UserId interface{} // 用户ID + Token interface{} // 令牌 + CreatedAt interface{} // 创建时间 + ExpiredAt interface{} // 过期时间 +} + +func NewAPIAccessTokenOperator() *APIAccessTokenOperator { + return &APIAccessTokenOperator{} +} diff --git a/internal/db/models/api_access_token_model_ext.go b/internal/db/models/api_access_token_model_ext.go new file mode 100644 index 00000000..2640e7f9 --- /dev/null +++ b/internal/db/models/api_access_token_model_ext.go @@ -0,0 +1 @@ +package models diff --git a/internal/db/models/api_node_dao.go b/internal/db/models/api_node_dao.go index ff497cbb..685e5a68 100644 --- a/internal/db/models/api_node_dao.go +++ b/internal/db/models/api_node_dao.go @@ -90,7 +90,7 @@ func (this *APINodeDAO) FindAPINodeName(id int64) (string, error) { } // 创建API节点 -func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) { +func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON []byte, httpsJSON []byte, restIsOn bool, restHTTPJSON []byte, restHTTPSJSON []byte, accessAddrsJSON []byte, isOn bool) (nodeId int64, err error) { uniqueId, err := this.genUniqueId() if err != nil { return 0, err @@ -114,6 +114,13 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON if len(httpsJSON) > 0 { op.Https = httpsJSON } + op.RestIsOn = restIsOn + if len(restHTTPJSON) > 0 { + op.RestHTTP = restHTTPJSON + } + if len(restHTTPSJSON) > 0 { + op.RestHTTPS = restHTTPSJSON + } if len(accessAddrsJSON) > 0 { op.AccessAddrs = accessAddrsJSON } @@ -128,7 +135,7 @@ func (this *APINodeDAO) CreateAPINode(name string, description string, httpJSON } // 修改API节点 -func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, accessAddrsJSON []byte, isOn bool) error { +func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description string, httpJSON []byte, httpsJSON []byte, restIsOn bool, restHTTPJSON []byte, restHTTPSJSON []byte, accessAddrsJSON []byte, isOn bool) error { if nodeId <= 0 { return errors.New("invalid nodeId") } @@ -142,17 +149,28 @@ func (this *APINodeDAO) UpdateAPINode(nodeId int64, name string, description str if len(httpJSON) > 0 { op.Http = httpJSON } else { - op.Http = "null" + op.Http = "{}" } if len(httpsJSON) > 0 { op.Https = httpsJSON } else { - op.Https = "null" + op.Https = "{}" + } + op.RestIsOn = restIsOn + if len(restHTTPJSON) > 0 { + op.RestHTTP = restHTTPJSON + } else { + op.RestHTTP = "{}" + } + if len(restHTTPSJSON) > 0 { + op.RestHTTPS = restHTTPSJSON + } else { + op.RestHTTPS = "{}" } if len(accessAddrsJSON) > 0 { op.AccessAddrs = accessAddrsJSON } else { - op.AccessAddrs = "null" + op.AccessAddrs = "[]" } err := this.Save(op) diff --git a/internal/db/models/api_node_model.go b/internal/db/models/api_node_model.go index bf5a12ce..6d6c929c 100644 --- a/internal/db/models/api_node_model.go +++ b/internal/db/models/api_node_model.go @@ -11,6 +11,9 @@ type APINode struct { Description string `field:"description"` // 描述 Http string `field:"http"` // 监听的HTTP配置 Https string `field:"https"` // 监听的HTTPS配置 + RestIsOn uint8 `field:"restIsOn"` // 是否开放REST + RestHTTP string `field:"restHTTP"` // REST HTTP配置 + RestHTTPS string `field:"restHTTPS"` // REST HTTPS配置 AccessAddrs string `field:"accessAddrs"` // 外部访问地址 Order uint32 `field:"order"` // 排序 State uint8 `field:"state"` // 状态 @@ -30,6 +33,9 @@ type APINodeOperator struct { Description interface{} // 描述 Http interface{} // 监听的HTTP配置 Https interface{} // 监听的HTTPS配置 + RestIsOn interface{} // 是否开放REST + RestHTTP interface{} // REST HTTP配置 + RestHTTPS interface{} // REST HTTPS配置 AccessAddrs interface{} // 外部访问地址 Order interface{} // 排序 State interface{} // 状态 diff --git a/internal/db/models/api_node_model_ext.go b/internal/db/models/api_node_model_ext.go index a8b60566..1162d035 100644 --- a/internal/db/models/api_node_model_ext.go +++ b/internal/db/models/api_node_model_ext.go @@ -93,3 +93,65 @@ func (this *APINode) DecodeAccessAddrStrings() ([]string, error) { } return result, nil } + +// 解析Rest HTTP配置 +func (this *APINode) DecodeRestHTTP() (*serverconfigs.HTTPProtocolConfig, error) { + if this.RestIsOn != 1 { + return nil, nil + } + if !IsNotNull(this.RestHTTP) { + return nil, nil + } + config := &serverconfigs.HTTPProtocolConfig{} + err := json.Unmarshal([]byte(this.RestHTTP), config) + if err != nil { + return nil, err + } + + err = config.Init() + if err != nil { + return nil, err + } + + return config, nil +} + +// 解析HTTPS配置 +func (this *APINode) DecodeRestHTTPS() (*serverconfigs.HTTPSProtocolConfig, error) { + if this.RestIsOn != 1 { + return nil, nil + } + if !IsNotNull(this.RestHTTPS) { + return nil, nil + } + config := &serverconfigs.HTTPSProtocolConfig{} + err := json.Unmarshal([]byte(this.RestHTTPS), config) + if err != nil { + return nil, err + } + + err = config.Init() + if err != nil { + return nil, err + } + + if config.SSLPolicyRef != nil { + policyId := config.SSLPolicyRef.SSLPolicyId + if policyId > 0 { + sslPolicy, err := SharedSSLPolicyDAO.ComposePolicyConfig(policyId) + if err != nil { + return nil, err + } + if sslPolicy != nil { + config.SSLPolicy = sslPolicy + } + } + } + + err = config.Init() + if err != nil { + return nil, err + } + + return config, nil +} diff --git a/internal/db/models/db_node_initializer.go b/internal/db/models/db_node_initializer.go index cd6844ca..a4c259b5 100644 --- a/internal/db/models/db_node_initializer.go +++ b/internal/db/models/db_node_initializer.go @@ -2,11 +2,13 @@ package models import ( "fmt" + "github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/logs" timeutil "github.com/iwind/TeaGo/utils/time" "hash/crc32" + "regexp" "strconv" "strings" "sync" @@ -46,17 +48,22 @@ func randomAccessLogDAO() (dao *HTTPAccessLogDAOWrapper) { } // 检查表格是否存在 -func findAccessLogTableName(db *dbs.DB, day string) (string, bool, error) { +func findAccessLogTableName(db *dbs.DB, day string) (tableName string, ok bool, err error) { + if !regexp.MustCompile(`^\d{8}$`).MatchString(day) { + err = errors.New("invalid day '" + day + "', should be YYYYMMDD") + return + } + config, err := db.Config() if err != nil { return "", false, err } - tableName := "edgeHTTPAccessLogs_" + day + tableName = "edgeHTTPAccessLogs_" + day cacheKey := tableName + "_" + fmt.Sprintf("%d", crc32.ChecksumIEEE([]byte(config.Dsn))) accessLogLocker.RLock() - _, ok := accessLogTableMapping[cacheKey] + _, ok = accessLogTableMapping[cacheKey] accessLogLocker.RUnlock() if ok { return tableName, true, nil diff --git a/internal/db/models/http_access_log_dao.go b/internal/db/models/http_access_log_dao.go index be05616e..6003d56f 100644 --- a/internal/db/models/http_access_log_dao.go +++ b/internal/db/models/http_access_log_dao.go @@ -168,6 +168,7 @@ func (this *HTTPAccessLogDAO) listAccessLogs(lastRequestId string, size int64, d dao := daoWrapper.DAO tableName, exists, err := findAccessLogTableName(dao.Instance, day) + logs.Println("tableName:", tableName, exists, err) // TODO if !exists { // 表格不存在则跳过 return diff --git a/internal/db/models/user_access_key_dao.go b/internal/db/models/user_access_key_dao.go index d840dbf5..7547c669 100644 --- a/internal/db/models/user_access_key_dao.go +++ b/internal/db/models/user_access_key_dao.go @@ -109,3 +109,17 @@ func (this *UserAccessKeyDAO) UpdateAccessKeyIsOn(accessKeyId int64, isOn bool) Update() return err } + +// 根据UniqueId查找AccessKey +func (this *UserAccessKeyDAO) FindAccessKeyWithUniqueId(uniqueId string) (*UserAccessKey, error) { + one, err := this.Query(). + Attr("uniqueId", uniqueId). + Attr("isOn", true). + State(UserAccessKeyStateEnabled). + Find() + if one == nil || err != nil { + return nil, err + } + + return one.(*UserAccessKey), nil +} diff --git a/internal/nodes/api_node.go b/internal/nodes/api_node.go index b1cba71d..33ed3891 100644 --- a/internal/nodes/api_node.go +++ b/internal/nodes/api_node.go @@ -75,72 +75,7 @@ func (this *APINode) Start() { // 监听RPC服务 remotelogs.Println("API_NODE", "starting RPC server ...") - // HTTP - httpConfig, err := apiNode.DecodeHTTP() - if err != nil { - remotelogs.Error("API_NODE", "decode http config: "+err.Error()) - return - } - isListening := false - if httpConfig != nil && httpConfig.IsOn && len(httpConfig.Listen) > 0 { - for _, listen := range httpConfig.Listen { - for _, addr := range listen.Addresses() { - listener, err := net.Listen("tcp", addr) - if err != nil { - remotelogs.Error("API_NODE", "listening '"+addr+"' failed: "+err.Error()) - continue - } - go func() { - err := this.listenRPC(listener, nil) - if err != nil { - remotelogs.Error("API_NODE", "listening '"+addr+"' rpc: "+err.Error()) - return - } - }() - isListening = true - } - } - } - - // HTTPS - httpsConfig, err := apiNode.DecodeHTTPS() - if err != nil { - remotelogs.Error("API_NODE", "decode https config: "+err.Error()) - return - } - if httpsConfig != nil && - httpsConfig.IsOn && - len(httpsConfig.Listen) > 0 && - httpsConfig.SSLPolicy != nil && - httpsConfig.SSLPolicy.IsOn && - len(httpsConfig.SSLPolicy.Certs) > 0 { - certs := []tls.Certificate{} - for _, cert := range httpsConfig.SSLPolicy.Certs { - certs = append(certs, *cert.CertObject()) - } - - for _, listen := range httpsConfig.Listen { - for _, addr := range listen.Addresses() { - listener, err := net.Listen("tcp", addr) - if err != nil { - remotelogs.Error("API_NODE", "listening '"+addr+"' failed: "+err.Error()) - continue - } - go func() { - err := this.listenRPC(listener, &tls.Config{ - Certificates: certs, - }) - if err != nil { - remotelogs.Error("API_NODE", "listening '"+addr+"' rpc: "+err.Error()) - return - } - }() - isListening = true - } - } - } - - // HTTP接口 + isListening := this.listenPorts(apiNode) if !isListening { remotelogs.Error("API_NODE", "the api node require at least one listening address") @@ -155,10 +90,10 @@ func (this *APINode) Start() { func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) error { var rpcServer *grpc.Server if tlsConfig == nil { - remotelogs.Println("API_NODE", "listening http://"+listener.Addr().String()+" ...") + remotelogs.Println("API_NODE", "listening GRPC http://"+listener.Addr().String()+" ...") rpcServer = grpc.NewServer() } else { - logs.Println("[API_NODE]listening https://" + listener.Addr().String() + " ...") + logs.Println("[API_NODE]listening GRPC https://" + listener.Addr().String() + " ...") rpcServer = grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) } pb.RegisterAdminServiceServer(rpcServer, &services.AdminService{}) @@ -264,3 +199,142 @@ func (this *APINode) autoUpgrade() error { logs.Println("[API_NODE]upgrade database done") return nil } + +// 启动端口 +func (this *APINode) listenPorts(apiNode *models.APINode) (isListening bool) { + // HTTP + httpConfig, err := apiNode.DecodeHTTP() + if err != nil { + remotelogs.Error("API_NODE", "decode http config: "+err.Error()) + return + } + isListening = false + if httpConfig != nil && httpConfig.IsOn && len(httpConfig.Listen) > 0 { + for _, listen := range httpConfig.Listen { + for _, addr := range listen.Addresses() { + listener, err := net.Listen("tcp", addr) + if err != nil { + remotelogs.Error("API_NODE", "listening '"+addr+"' failed: "+err.Error()) + continue + } + go func() { + err := this.listenRPC(listener, nil) + if err != nil { + remotelogs.Error("API_NODE", "listening '"+addr+"' rpc: "+err.Error()) + return + } + }() + isListening = true + } + } + } + + // HTTPS + httpsConfig, err := apiNode.DecodeHTTPS() + if err != nil { + remotelogs.Error("API_NODE", "decode https config: "+err.Error()) + return + } + if httpsConfig != nil && + httpsConfig.IsOn && + len(httpsConfig.Listen) > 0 && + httpsConfig.SSLPolicy != nil && + httpsConfig.SSLPolicy.IsOn && + len(httpsConfig.SSLPolicy.Certs) > 0 { + certs := []tls.Certificate{} + for _, cert := range httpsConfig.SSLPolicy.Certs { + certs = append(certs, *cert.CertObject()) + } + + for _, listen := range httpsConfig.Listen { + for _, addr := range listen.Addresses() { + listener, err := net.Listen("tcp", addr) + if err != nil { + remotelogs.Error("API_NODE", "listening '"+addr+"' failed: "+err.Error()) + continue + } + go func() { + err := this.listenRPC(listener, &tls.Config{ + Certificates: certs, + }) + if err != nil { + remotelogs.Error("API_NODE", "listening '"+addr+"' rpc: "+err.Error()) + return + } + }() + isListening = true + } + } + } + + // Rest HTTP + restHTTPConfig, err := apiNode.DecodeRestHTTP() + if err != nil { + remotelogs.Error("API_NODE", "decode REST http config: "+err.Error()) + return + } + if restHTTPConfig != nil && restHTTPConfig.IsOn && len(restHTTPConfig.Listen) > 0 { + for _, listen := range restHTTPConfig.Listen { + for _, addr := range listen.Addresses() { + listener, err := net.Listen("tcp", addr) + if err != nil { + remotelogs.Error("API_NODE", "listening REST 'http://"+addr+"' failed: "+err.Error()) + continue + } + go func() { + remotelogs.Println("API_NODE", "listening REST http://"+addr+" ...") + server := &RestServer{} + err := server.Listen(listener) + if err != nil { + remotelogs.Error("API_NODE", "listening REST 'http://"+addr+"' failed: "+err.Error()) + return + } + }() + isListening = true + } + } + } + + // Rest HTTPS + restHTTPSConfig, err := apiNode.DecodeRestHTTPS() + if err != nil { + remotelogs.Error("API_NODE", "decode REST https config: "+err.Error()) + return + } + if restHTTPSConfig != nil && + restHTTPSConfig.IsOn && + len(restHTTPSConfig.Listen) > 0 && + restHTTPSConfig.SSLPolicy != nil && + restHTTPSConfig.SSLPolicy.IsOn && + len(restHTTPSConfig.SSLPolicy.Certs) > 0 { + for _, listen := range restHTTPSConfig.Listen { + for _, addr := range listen.Addresses() { + listener, err := net.Listen("tcp", addr) + if err != nil { + remotelogs.Error("API_NODE", "listening REST 'https://"+addr+"' failed: "+err.Error()) + continue + } + go func() { + remotelogs.Println("API_NODE", "listening REST https://"+addr+" ...") + server := &RestServer{} + + certs := []tls.Certificate{} + for _, cert := range httpsConfig.SSLPolicy.Certs { + certs = append(certs, *cert.CertObject()) + } + + err := server.ListenHTTPS(listener, &tls.Config{ + Certificates: certs, + }) + if err != nil { + remotelogs.Error("API_NODE", "listening REST 'https://"+addr+"' failed: "+err.Error()) + return + } + }() + isListening = true + } + } + } + + return +} diff --git a/internal/nodes/rest_server.go b/internal/nodes/rest_server.go new file mode 100644 index 00000000..520972ed --- /dev/null +++ b/internal/nodes/rest_server.go @@ -0,0 +1,189 @@ +package nodes + +import ( + "context" + "crypto/tls" + "encoding/json" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeAPI/internal/rpc/services" + rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/iwind/TeaGo/maps" + "io/ioutil" + "net" + "net/http" + "reflect" + "regexp" + "time" +) + +var servicePathReg = regexp.MustCompile(`^/([a-zA-Z0-9]+)/([a-zA-Z0-9]+)$`) +var servicesMap = map[string]reflect.Value{ + "APIAccessTokenService": reflect.ValueOf(new(services.APIAccessTokenService)), + "HTTPAccessLogService": reflect.ValueOf(new(services.HTTPAccessLogService)), +} + +type RestServer struct{} + +func (this *RestServer) Listen(listener net.Listener) error { + mux := http.NewServeMux() + mux.HandleFunc("/", this.handle) + server := &http.Server{} + server.Handler = mux + return server.Serve(listener) +} + +func (this *RestServer) ListenHTTPS(listener net.Listener, tlsConfig *tls.Config) error { + mux := http.NewServeMux() + mux.HandleFunc("/", this.handle) + server := &http.Server{} + server.Handler = mux + server.TLSConfig = tlsConfig + return server.ServeTLS(listener, "", "") +} + +func (this *RestServer) handle(writer http.ResponseWriter, req *http.Request) { + path := req.URL.Path + matches := servicePathReg.FindStringSubmatch(path) + if len(matches) != 3 { + writer.WriteHeader(http.StatusNotFound) + return + } + + serviceName := matches[1] + methodName := matches[2] + + serviceType, ok := servicesMap[serviceName] + if !ok { + writer.WriteHeader(http.StatusNotFound) + return + } + + method := serviceType.MethodByName(methodName) + if !method.IsValid() { + writer.WriteHeader(http.StatusNotFound) + return + } + if method.Type().NumIn() != 2 || method.Type().NumOut() != 2 { + writer.WriteHeader(http.StatusNotFound) + return + } + if method.Type().In(0).Name() != "Context" { + writer.WriteHeader(http.StatusNotFound) + return + } + + // 是否显示Pretty后的JSON + shouldPretty := req.Header.Get("Edge-Response-Pretty") == "on" + + // 上下文 + ctx := context.Background() + + if serviceName != "APIAccessTokenService" || methodName != "GetAPIAccessToken" { + // 校验TOKEN + token := req.Header.Get("Edge-Access-Token") + if len(token) == 0 { + this.writeJSON(writer, maps.Map{ + "code": 400, + "data": maps.Map{}, + "message": "require 'Edge-Access-Token' header", + }, shouldPretty) + return + } + + accessToken, err := models.SharedAPIAccessTokenDAO.FindAccessToken(token) + if err != nil { + this.writeJSON(writer, maps.Map{ + "code": 400, + "data": maps.Map{}, + "message": "server error: " + err.Error(), + }, shouldPretty) + return + } + + if accessToken == nil || int64(accessToken.ExpiredAt) < time.Now().Unix() { + this.writeJSON(writer, maps.Map{ + "code": 400, + "data": maps.Map{}, + "message": "invalid access token", + }, shouldPretty) + return + } + + if accessToken.UserId > 0 { + ctx = rpcutils.NewPlainContext("user", int64(accessToken.UserId)) + } else { + // TODO 支持更多类型的角色 + this.writeJSON(writer, maps.Map{ + "code": 400, + "data": maps.Map{}, + "message": "not supported role", + }, shouldPretty) + return + } + } + + // TODO 需要防止BODY过大攻击 + body, err := ioutil.ReadAll(req.Body) + if err != nil { + writer.WriteHeader(http.StatusBadRequest) + _, _ = writer.Write([]byte(err.Error())) + return + } + + // 请求数据 + reqValue := reflect.New(method.Type().In(1).Elem()).Interface() + err = json.Unmarshal(body, reqValue) + if err != nil { + writer.WriteHeader(http.StatusBadRequest) + _, _ = writer.Write([]byte(err.Error())) + return + } + + result := method.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(reqValue)}) + resultErr := result[1].Interface() + if resultErr != nil { + e, ok := resultErr.(error) + if ok { + this.writeJSON(writer, maps.Map{ + "code": 400, + "message": e.Error(), + "data": maps.Map{}, + }, shouldPretty) + } else { + this.writeJSON(writer, maps.Map{ + "code": 500, + "message": "server error: server should return a error object, but return a " + result[1].Type().String(), + "data": maps.Map{}, + }, shouldPretty) + } + } else { // 没有返回错误 + data := maps.Map{ + "code": 200, + "message": "ok", + "data": result[0].Interface(), + } + var dataJSON []byte + if shouldPretty { + dataJSON = data.AsPrettyJSON() + } else { + dataJSON = data.AsJSON() + } + if err != nil { + this.writeJSON(writer, maps.Map{ + "code": 500, + "message": "server error: marshal json failed: " + err.Error(), + "data": maps.Map{}, + }, shouldPretty) + } else { + _, _ = writer.Write(dataJSON) + } + } +} + +func (this *RestServer) writeJSON(writer http.ResponseWriter, v maps.Map, pretty bool) { + if pretty { + _, _ = writer.Write(v.AsPrettyJSON()) + } else { + _, _ = writer.Write(v.AsJSON()) + } +} diff --git a/internal/rpc/services/service_api_access_token.go b/internal/rpc/services/service_api_access_token.go new file mode 100644 index 00000000..8019b815 --- /dev/null +++ b/internal/rpc/services/service_api_access_token.go @@ -0,0 +1,40 @@ +package services + +import ( + "context" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeAPI/internal/errors" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" +) + +// AccessToken相关服务 +type APIAccessTokenService struct { +} + +// 获取AccessToken +func (this *APIAccessTokenService) GetAPIAccessToken(ctx context.Context, req *pb.GetAPIAccessTokenRequest) (*pb.GetAPIAccessTokenResponse, error) { + if req.Type == "user" { // 用户 + accessKey, err := models.SharedUserAccessKeyDAO.FindAccessKeyWithUniqueId(req.AccessKeyId) + if err != nil { + return nil, err + } + if accessKey == nil { + return nil, errors.New("access key not found") + } + if accessKey.Secret != req.AccessKey { + return nil, errors.New("access key not found") + } + + // 创建AccessToken + token, expiresAt, err := models.SharedAPIAccessTokenDAO.GenerateAccessToken(int64(accessKey.UserId)) + if err != nil { + return nil, err + } + return &pb.GetAPIAccessTokenResponse{ + Token: token, + ExpiresAt: expiresAt, + }, nil + } else { + return nil, errors.New("unsupported type '" + req.Type + "'") + } +} diff --git a/internal/rpc/services/service_api_node.go b/internal/rpc/services/service_api_node.go index 9732a0af..f21459ea 100644 --- a/internal/rpc/services/service_api_node.go +++ b/internal/rpc/services/service_api_node.go @@ -19,7 +19,7 @@ func (this *APINodeService) CreateAPINode(ctx context.Context, req *pb.CreateAPI return nil, err } - nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn) + nodeId, err := models.SharedAPINodeDAO.CreateAPINode(req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.RestIsOn, req.RestHTTPJSON, req.RestHTTPSJSON, req.AccessAddrsJSON, req.IsOn) if err != nil { return nil, err } @@ -34,7 +34,7 @@ func (this *APINodeService) UpdateAPINode(ctx context.Context, req *pb.UpdateAPI return nil, err } - err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.AccessAddrsJSON, req.IsOn) + err = models.SharedAPINodeDAO.UpdateAPINode(req.NodeId, req.Name, req.Description, req.HttpJSON, req.HttpsJSON, req.RestIsOn, req.RestHTTPJSON, req.RestHTTPSJSON, req.AccessAddrsJSON, req.IsOn) if err != nil { return nil, err } @@ -138,6 +138,9 @@ func (this *APINodeService) ListEnabledAPINodes(ctx context.Context, req *pb.Lis Description: node.Description, HttpJSON: []byte(node.Http), HttpsJSON: []byte(node.Https), + RestIsOn: node.RestIsOn == 1, + RestHTTPJSON: []byte(node.RestHTTP), + RestHTTPSJSON: []byte(node.RestHTTPS), AccessAddrsJSON: []byte(node.AccessAddrs), AccessAddrs: accessAddrs, StatusJSON: []byte(node.Status), @@ -178,6 +181,9 @@ func (this *APINodeService) FindEnabledAPINode(ctx context.Context, req *pb.Find Description: node.Description, HttpJSON: []byte(node.Http), HttpsJSON: []byte(node.Https), + RestIsOn: node.RestIsOn == 1, + RestHTTPJSON: []byte(node.RestHTTP), + RestHTTPSJSON: []byte(node.RestHTTPS), AccessAddrsJSON: []byte(node.AccessAddrs), AccessAddrs: accessAddrs, } diff --git a/internal/rpc/services/service_http_access_log.go b/internal/rpc/services/service_http_access_log.go index 45295260..f466dc05 100644 --- a/internal/rpc/services/service_http_access_log.go +++ b/internal/rpc/services/service_http_access_log.go @@ -3,12 +3,14 @@ package services import ( "context" "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeAPI/internal/errors" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" ) // 访问日志相关服务 type HTTPAccessLogService struct { + BaseService } // 创建访问日志 @@ -34,11 +36,23 @@ func (this *HTTPAccessLogService) CreateHTTPAccessLogs(ctx context.Context, req // 列出单页访问日志 func (this *HTTPAccessLogService) ListHTTPAccessLogs(ctx context.Context, req *pb.ListHTTPAccessLogsRequest) (*pb.ListHTTPAccessLogsResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) if err != nil { return nil, err } + // 检查服务ID + if userId > 0 { + if req.ServerId <= 0 { + return nil, errors.New("invalid serverId") + } + + err = models.SharedServerDAO.CheckUserServer(req.ServerId, userId) + if err != nil { + return nil, err + } + } + accessLogs, requestId, hasMore, err := models.SharedHTTPAccessLogDAO.ListAccessLogs(req.RequestId, req.Size, req.Day, req.ServerId, req.Reverse, req.HasError, req.FirewallPolicyId, req.FirewallRuleGroupId, req.FirewallRuleSetId) if err != nil { return nil, err diff --git a/internal/rpc/utils/plain_context.go b/internal/rpc/utils/plain_context.go new file mode 100644 index 00000000..716d4146 --- /dev/null +++ b/internal/rpc/utils/plain_context.go @@ -0,0 +1,37 @@ +package rpcutils + +import ( + "context" + "time" +) + +type PlainContext struct { + UserType string + UserId int64 + + ctx context.Context +} + +func NewPlainContext(userType string, userId int64) *PlainContext { + return &PlainContext{ + UserType: userType, + UserId: userId, + ctx: context.Background(), + } +} + +func (this *PlainContext) Deadline() (deadline time.Time, ok bool) { + return this.ctx.Deadline() +} + +func (this *PlainContext) Done() <-chan struct{} { + return this.ctx.Done() +} + +func (this *PlainContext) Err() error { + return this.ctx.Err() +} + +func (this *PlainContext) Value(key interface{}) interface{} { + return this.ctx.Value(key) +} diff --git a/internal/rpc/utils/utils.go b/internal/rpc/utils/utils.go index 2a895dc7..9536bbe3 100644 --- a/internal/rpc/utils/utils.go +++ b/internal/rpc/utils/utils.go @@ -33,6 +33,29 @@ const ( // 校验请求 func ValidateRequest(ctx context.Context, userTypes ...UserType) (userType UserType, userId int64, err error) { + if ctx == nil { + err = errors.New("context should not be nil") + return + } + + // 支持直接认证 + plainCtx, ok := ctx.(*PlainContext) + if ok { + userType = plainCtx.UserType + userId = plainCtx.UserId + + if len(userTypes) > 0 && !lists.ContainsString(userTypes, userType) { + userType = UserTypeNone + userId = 0 + } + + if userId <= 0 { + err = errors.New("context: can not find user or permission denied") + } + + return + } + md, ok := metadata.FromIncomingContext(ctx) if !ok { return UserTypeNone, 0, errors.New("context: need 'nodeId'") diff --git a/internal/setup/setup.go b/internal/setup/setup.go index 496b3115..ad5c65c7 100644 --- a/internal/setup/setup.go +++ b/internal/setup/setup.go @@ -160,7 +160,7 @@ func (this *Setup) Run() error { } // 创建API节点 - nodeId, err := dao.CreateAPINode("默认API节点", "这是默认创建的第一个API节点", httpJSON, httpsJSON, addrsJSON, true) + nodeId, err := dao.CreateAPINode("默认API节点", "这是默认创建的第一个API节点", httpJSON, httpsJSON, false, nil, nil, addrsJSON, true) if err != nil { return errors.New("create api node in database failed: " + err.Error()) }