diff --git a/internal/apis/api_node.go b/internal/apis/api_node.go index 3378f6a8..acae94dd 100644 --- a/internal/apis/api_node.go +++ b/internal/apis/api_node.go @@ -73,6 +73,7 @@ func (this *APINode) listenRPC() error { pb.RegisterHTTPRewriteRuleServiceServer(rpcServer, &services.HTTPRewriteRuleService{}) pb.RegisterSSLCertServiceServer(rpcServer, &services.SSLCertService{}) pb.RegisterSSLPolicyServiceServer(rpcServer, &services.SSLPolicyService{}) + pb.RegisterSysSettingServiceServer(rpcServer, &services.SysSettingService{}) err = rpcServer.Serve(listener) if err != nil { return errors.New("[API]start rpc failed: " + err.Error()) diff --git a/internal/db/models/http_cache_policy_dao.go b/internal/db/models/http_cache_policy_dao.go index f65e6516..ac0cb463 100644 --- a/internal/db/models/http_cache_policy_dao.go +++ b/internal/db/models/http_cache_policy_dao.go @@ -1,9 +1,14 @@ package models import ( + "encoding/json" + "github.com/TeaOSLab/EdgeAPI/internal/errors" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/types" ) const ( @@ -87,3 +92,144 @@ func (this *HTTPCachePolicyDAO) FindAllEnabledCachePolicies() (result []*HTTPCac FindAll() return } + +// 创建缓存策略 +func (this *HTTPCachePolicyDAO) CreateCachePolicy(isOn bool, name string, description string, capacityJSON []byte, maxKeys int64, maxSizeJSON []byte, storageType string, storageOptionsJSON []byte) (int64, error) { + op := NewHTTPCachePolicyOperator() + op.State = HTTPCachePolicyStateEnabled + op.IsOn = isOn + op.Name = name + op.Description = description + if len(capacityJSON) > 0 { + op.Capacity = capacityJSON + } + op.MaxKeys = maxKeys + if len(maxSizeJSON) > 0 { + op.MaxSize = maxSizeJSON + } + op.Type = storageType + if len(storageOptionsJSON) > 0 { + op.Options = storageOptionsJSON + } + _, err := this.Save(op) + if err != nil { + return 0, err + } + return types.Int64(op.Id), nil +} + +// 修改缓存策略 +func (this *HTTPCachePolicyDAO) UpdateCachePolicy(policyId int64, isOn bool, name string, description string, capacityJSON []byte, maxKeys int64, maxSizeJSON []byte, storageType string, storageOptionsJSON []byte) error { + if policyId <= 0 { + return errors.New("invalid policyId") + } + + op := NewHTTPCachePolicyOperator() + op.Id = policyId + op.IsOn = isOn + op.Name = name + op.Description = description + if len(capacityJSON) > 0 { + op.Capacity = capacityJSON + } + op.MaxKeys = maxKeys + if len(maxSizeJSON) > 0 { + op.MaxSize = maxSizeJSON + } + op.Type = storageType + if len(storageOptionsJSON) > 0 { + op.Options = storageOptionsJSON + } + _, err := this.Save(op) + return errors.Wrap(err) +} + +// 组合配置 +func (this *HTTPCachePolicyDAO) ComposeCachePolicy(policyId int64) (*serverconfigs.HTTPCachePolicy, error) { + policy, err := this.FindEnabledHTTPCachePolicy(policyId) + if err != nil { + return nil, err + } + if policy == nil { + return nil, nil + } + config := &serverconfigs.HTTPCachePolicy{} + config.Id = int64(policy.Id) + config.IsOn = policy.IsOn == 1 + config.Name = policy.Name + config.Description = policy.Description + + // capacity + if IsNotNull(policy.Capacity) { + capacityConfig := &shared.SizeCapacity{} + err = json.Unmarshal([]byte(policy.Capacity), capacityConfig) + if err != nil { + return nil, err + } + config.Capacity = capacityConfig + } + + config.MaxKeys = types.Int64(policy.MaxKeys) + + // max size + if IsNotNull(policy.MaxSize) { + maxSizeConfig := &shared.SizeCapacity{} + err = json.Unmarshal([]byte(policy.MaxSize), maxSizeConfig) + if err != nil { + return nil, err + } + config.MaxSize = maxSizeConfig + } + + config.Type = policy.Type + + // options + if IsNotNull(policy.Options) { + m := map[string]interface{}{} + err = json.Unmarshal([]byte(policy.Options), &m) + if err != nil { + return nil, errors.Wrap(err) + } + config.Options = m + } + + return config, nil +} + +// 计算可用缓存策略数量 +func (this *HTTPCachePolicyDAO) CountAllEnabledHTTPCachePolicies() (int64, error) { + return this.Query(). + State(HTTPCachePolicyStateEnabled). + Count() +} + +// 列出单页的缓存策略 +func (this *HTTPCachePolicyDAO) ListEnabledHTTPCachePolicies(offset int64, size int64) ([]*serverconfigs.HTTPCachePolicy, error) { + ones, err := this.Query(). + State(HTTPCachePolicyStateEnabled). + ResultPk(). + Offset(offset). + Limit(size). + DescPk(). + FindAll() + if err != nil { + return nil, errors.Wrap(err) + } + cachePolicyIds := []int64{} + for _, one := range ones { + cachePolicyIds = append(cachePolicyIds, int64(one.(*HTTPCachePolicy).Id)) + } + if len(cachePolicyIds) == 0 { + return nil, nil + } + + cachePolicies := []*serverconfigs.HTTPCachePolicy{} + for _, policyId := range cachePolicyIds { + cachePolicyConfig, err := this.ComposeCachePolicy(policyId) + if err != nil { + return nil, errors.Wrap(err) + } + cachePolicies = append(cachePolicies, cachePolicyConfig) + } + return cachePolicies, nil +} diff --git a/internal/db/models/http_cache_policy_model.go b/internal/db/models/http_cache_policy_model.go index 0c1817b0..27f21811 100644 --- a/internal/db/models/http_cache_policy_model.go +++ b/internal/db/models/http_cache_policy_model.go @@ -2,43 +2,37 @@ package models // HTTP缓存策略 type HTTPCachePolicy struct { - Id uint32 `field:"id"` // ID - AdminId uint32 `field:"adminId"` // 管理员ID - UserId uint32 `field:"userId"` // 用户ID - TemplateId uint32 `field:"templateId"` // 模版ID - IsOn uint8 `field:"isOn"` // 是否启用 - Name string `field:"name"` // 名称 - Key string `field:"key"` // 缓存Key规则 - Capacity string `field:"capacity"` // 容量数据 - Life string `field:"life"` // 有效期 - Status string `field:"status"` // HTTP状态码列表 - MaxSize string `field:"maxSize"` // 最大尺寸 - SkipCacheControlValues string `field:"skipCacheControlValues"` // 忽略的cache-control - SkipSetCookie uint8 `field:"skipSetCookie"` // 是否忽略Set-Cookie Header - EnableRequestCachePragma uint8 `field:"enableRequestCachePragma"` // 是否支持客户端的Pragma: no-cache - Conds string `field:"conds"` // 请求条件 - CreatedAt uint64 `field:"createdAt"` // 创建时间 - State uint8 `field:"state"` // 状态 + Id uint32 `field:"id"` // ID + AdminId uint32 `field:"adminId"` // 管理员ID + UserId uint32 `field:"userId"` // 用户ID + TemplateId uint32 `field:"templateId"` // 模版ID + IsOn uint8 `field:"isOn"` // 是否启用 + Name string `field:"name"` // 名称 + Capacity string `field:"capacity"` // 容量数据 + MaxKeys uint64 `field:"maxKeys"` // 最多Key值 + MaxSize string `field:"maxSize"` // 最大缓存内容尺寸 + Type string `field:"type"` // 存储类型 + Options string `field:"options"` // 存储选项 + CreatedAt uint64 `field:"createdAt"` // 创建时间 + State uint8 `field:"state"` // 状态 + Description string `field:"description"` // 描述 } type HTTPCachePolicyOperator struct { - Id interface{} // ID - AdminId interface{} // 管理员ID - UserId interface{} // 用户ID - TemplateId interface{} // 模版ID - IsOn interface{} // 是否启用 - Name interface{} // 名称 - Key interface{} // 缓存Key规则 - Capacity interface{} // 容量数据 - Life interface{} // 有效期 - Status interface{} // HTTP状态码列表 - MaxSize interface{} // 最大尺寸 - SkipCacheControlValues interface{} // 忽略的cache-control - SkipSetCookie interface{} // 是否忽略Set-Cookie Header - EnableRequestCachePragma interface{} // 是否支持客户端的Pragma: no-cache - Conds interface{} // 请求条件 - CreatedAt interface{} // 创建时间 - State interface{} // 状态 + Id interface{} // ID + AdminId interface{} // 管理员ID + UserId interface{} // 用户ID + TemplateId interface{} // 模版ID + IsOn interface{} // 是否启用 + Name interface{} // 名称 + Capacity interface{} // 容量数据 + MaxKeys interface{} // 最多Key值 + MaxSize interface{} // 最大缓存内容尺寸 + Type interface{} // 存储类型 + Options interface{} // 存储选项 + CreatedAt interface{} // 创建时间 + State interface{} // 状态 + Description interface{} // 描述 } func NewHTTPCachePolicyOperator() *HTTPCachePolicyOperator { diff --git a/internal/db/models/http_location_dao.go b/internal/db/models/http_location_dao.go index 1f81a116..77edca5a 100644 --- a/internal/db/models/http_location_dao.go +++ b/internal/db/models/http_location_dao.go @@ -233,3 +233,14 @@ func (this *HTTPLocationDAO) ConvertLocationRefs(refs []*serverconfigs.HTTPLocat return } + +// 根据WebId查找LocationId +func (this *HTTPLocationDAO) FindEnabledLocationIdWithWebId(webId int64) (locationId int64, err error) { + if webId <= 0 { + return + } + return this.Query(). + Attr("webId", webId). + ResultPk(). + FindInt64Col(0) +} diff --git a/internal/db/models/http_web_dao.go b/internal/db/models/http_web_dao.go index f5f52c95..91dbac63 100644 --- a/internal/db/models/http_web_dao.go +++ b/internal/db/models/http_web_dao.go @@ -9,6 +9,7 @@ import ( "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/types" + "strconv" ) const ( @@ -214,12 +215,23 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon // 缓存配置 if IsNotNull(web.Cache) { - cacheRef := &serverconfigs.HTTPCacheRef{} - err = json.Unmarshal([]byte(web.Cache), cacheRef) + cacheRefs := []*serverconfigs.HTTPCacheRef{} + err = json.Unmarshal([]byte(web.Cache), &cacheRefs) if err != nil { return nil, err } - config.CacheRef = cacheRef + for _, cacheRef := range cacheRefs { + if cacheRef.CachePolicyId > 0 { + cachePolicy, err := SharedHTTPCachePolicyDAO.ComposeCachePolicy(cacheRef.CachePolicyId) + if err != nil { + return nil, err + } + if cachePolicy != nil { + config.CacheRefs = append(config.CacheRefs, cacheRef) + config.CachePolicies = append(config.CachePolicies, cachePolicy) + } + } + } } // 防火墙配置 @@ -492,3 +504,67 @@ func (this *HTTPWebDAO) UpdateWebRewriteRules(webId int64, rewriteRulesJSON []by _, err := this.Save(op) return err } + +// 根据缓存策略ID查找所有的WebId +func (this *HTTPWebDAO) FindAllWebIdsWithCachePolicyId(cachePolicyId int64) ([]int64, error) { + ones, err := this.Query(). + State(HTTPWebStateEnabled). + ResultPk(). + Where(`JSON_CONTAINS(cache, '{"cachePolicyId": ` + strconv.FormatInt(cachePolicyId, 10) + ` }')`). + Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用 + FindAll() + if err != nil { + return nil, err + } + result := []int64{} + for _, one := range ones { + webId := int64(one.(*HTTPWeb).Id) + + // 判断是否为Location + for { + locationId, err := SharedHTTPLocationDAO.FindEnabledLocationIdWithWebId(webId) + if err != nil { + return nil, err + } + + // 如果非Location + if locationId == 0 { + if !this.containsInt64(result, webId) { + result = append(result, webId) + } + break + } + + // 查找包含此Location的Web + // TODO 需要支持嵌套的Location查询 + webId, err = this.FindEnabledWebIdWithLocationId(locationId) + if err != nil { + return nil, err + } + if webId == 0 { + break + } + } + } + return result, nil +} + +// 查找包含某个Location的Web +func (this *HTTPWebDAO) FindEnabledWebIdWithLocationId(locationId int64) (webId int64, err error) { + return this.Query(). + State(HTTPWebStateEnabled). + ResultPk(). + Where(`JSON_CONTAINS(locations, '{"locationId": ` + strconv.FormatInt(locationId, 10) + ` }')`). + Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用 + FindInt64Col(0) +} + +// 判断slice是否包含某个int64值 +func (this *HTTPWebDAO) containsInt64(values []int64, value int64) bool { + for _, v := range values { + if v == value { + return true + } + } + return false +} diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index b879c576..3c48caa6 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -663,7 +663,7 @@ func (this *ServerDAO) FindServerWebId(serverId int64) (int64, error) { } // 计算使用SSL策略的所有服务数量 -func (this *ServerDAO) CountServersWithSSLPolicyIds(sslPolicyIds []int64) (count int64, err error) { +func (this *ServerDAO) CountAllEnabledServersWithSSLPolicyIds(sslPolicyIds []int64) (count int64, err error) { if len(sslPolicyIds) == 0 { return } @@ -678,8 +678,8 @@ func (this *ServerDAO) CountServersWithSSLPolicyIds(sslPolicyIds []int64) (count Count() } -// 查找使用SSL策略的所有服务 -func (this *ServerDAO) FindAllServersWithSSLPolicyIds(sslPolicyIds []int64) (result []*Server, err error) { +// 查找使用某个SSL策略的所有服务 +func (this *ServerDAO) FindAllEnabledServersWithSSLPolicyIds(sslPolicyIds []int64) (result []*Server, err error) { if len(sslPolicyIds) == 0 { return } @@ -698,6 +698,31 @@ func (this *ServerDAO) FindAllServersWithSSLPolicyIds(sslPolicyIds []int64) (res return } +// 计算使用某个缓存策略的所有服务数量 +func (this *ServerDAO) CountEnabledServersWithWebIds(webIds []int64) (count int64, err error) { + if len(webIds) == 0 { + return + } + return this.Query(). + State(ServerStateEnabled). + Attr("webId", webIds). + Count() +} + +// 查找使用某个缓存策略的所有服务 +func (this *ServerDAO) FindAllEnabledServersWithWebIds(webIds []int64) (result []*Server, err error) { + if len(webIds) == 0 { + return + } + _, err = this.Query(). + State(ServerStateEnabled). + Attr("webId", webIds). + AscPk(). + Slice(&result). + FindAll() + return +} + // 创建事件 func (this *ServerDAO) createEvent() error { return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) diff --git a/internal/db/models/ssl_cert_dao.go b/internal/db/models/ssl_cert_dao.go index b6c24263..f08e407c 100644 --- a/internal/db/models/ssl_cert_dao.go +++ b/internal/db/models/ssl_cert_dao.go @@ -128,8 +128,15 @@ func (this *SSLCertDAO) UpdateCert(certId int64, isOn bool, name string, descrip op.Description = description op.ServerName = serverName op.IsCA = isCA - op.CertData = certData - op.KeyData = keyData + + // cert和key均为有重新上传才会修改 + if len(certData) > 0 { + op.CertData = certData + } + if len(keyData) > 0 { + op.KeyData = keyData + } + op.TimeBeginAt = timeBeginAt op.TimeEndAt = timeEndAt diff --git a/internal/db/models/sys_setting_dao.go b/internal/db/models/sys_setting_dao.go index 33c2ac9b..533bd6af 100644 --- a/internal/db/models/sys_setting_dao.go +++ b/internal/db/models/sys_setting_dao.go @@ -29,16 +29,16 @@ func NewSysSettingDAO() *SysSettingDAO { var SharedSysSettingDAO = NewSysSettingDAO() // 设置配置 -func (this *SysSettingDAO) UpdateSetting(code string, valueJSON []byte, args ...interface{}) error { - if len(args) > 0 { - code = fmt.Sprintf(code, args...) +func (this *SysSettingDAO) UpdateSetting(codeFormat string, valueJSON []byte, codeFormatArgs ...interface{}) error { + if len(codeFormatArgs) > 0 { + codeFormat = fmt.Sprintf(codeFormat, codeFormatArgs...) } countRetries := 3 var lastErr error for i := 0; i < countRetries; i++ { settingId, err := this.Query(). - Attr("code", code). + Attr("code", codeFormat). ResultPk(). FindInt64Col(0) if err != nil { @@ -48,7 +48,7 @@ func (this *SysSettingDAO) UpdateSetting(code string, valueJSON []byte, args ... if settingId == 0 { // 新建 op := NewSysSettingOperator() - op.Code = code + op.Code = codeFormat op.Value = valueJSON _, err = this.Save(op) if err != nil { diff --git a/internal/errors/error.go b/internal/errors/error.go new file mode 100644 index 00000000..74faa7db --- /dev/null +++ b/internal/errors/error.go @@ -0,0 +1,60 @@ +package errors + +import ( + "errors" + "path/filepath" + "runtime" + "strconv" +) + +type errorObj struct { + err error + file string + line int + funcName string +} + +func (this *errorObj) Error() string { + s := this.err.Error() + "\n " + this.file + if len(this.funcName) > 0 { + s += ":" + this.funcName + "()" + } + s += ":" + strconv.Itoa(this.line) + return s +} + +// 新错误 +func New(errText string) error { + ptr, file, line, ok := runtime.Caller(1) + funcName := "" + if ok { + frame, _ := runtime.CallersFrames([]uintptr{ptr}).Next() + funcName = filepath.Base(frame.Function) + } + return &errorObj{ + err: errors.New(errText), + file: file, + line: line, + funcName: funcName, + } +} + +// 包装已有错误 +func Wrap(err error) error { + if err == nil { + return nil + } + + ptr, file, line, ok := runtime.Caller(1) + funcName := "" + if ok { + frame, _ := runtime.CallersFrames([]uintptr{ptr}).Next() + funcName = filepath.Base(frame.Function) + } + return &errorObj{ + err: err, + file: file, + line: line, + funcName: funcName, + } +} diff --git a/internal/errors/error_test.go b/internal/errors/error_test.go new file mode 100644 index 00000000..b288d27c --- /dev/null +++ b/internal/errors/error_test.go @@ -0,0 +1,22 @@ +package errors + +import ( + "errors" + "testing" +) + +func TestNew(t *testing.T) { + t.Log(New("hello")) + t.Log(Wrap(errors.New("hello"))) + t.Log(testError1()) + t.Log(Wrap(testError1())) + t.Log(Wrap(testError2())) +} + +func testError1() error { + return New("test error1") +} + +func testError2() error { + return Wrap(testError1()) +} diff --git a/internal/rpc/services/service_http_cache_policy.go b/internal/rpc/services/service_http_cache_policy.go index 605933c9..cd13d04a 100644 --- a/internal/rpc/services/service_http_cache_policy.go +++ b/internal/rpc/services/service_http_cache_policy.go @@ -2,6 +2,7 @@ package services import ( "context" + "encoding/json" "github.com/TeaOSLab/EdgeAPI/internal/db/models" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" @@ -32,3 +33,100 @@ func (this *HTTPCachePolicyService) FindAllEnabledHTTPCachePolicies(ctx context. } return &pb.FindAllEnabledHTTPCachePoliciesResponse{CachePolicies: result}, nil } + +// 创建缓存策略 +func (this *HTTPCachePolicyService) CreateHTTPCachePolicy(ctx context.Context, req *pb.CreateHTTPCachePolicyRequest) (*pb.CreateHTTPCachePolicyResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + policyId, err := models.SharedHTTPCachePolicyDAO.CreateCachePolicy(req.IsOn, req.Name, req.Description, req.CapacityJSON, req.MaxKeys, req.MaxSizeJSON, req.Type, req.OptionsJSON) + if err != nil { + return nil, err + } + return &pb.CreateHTTPCachePolicyResponse{CachePolicyId: policyId}, nil +} + +// 修改缓存策略 +func (this *HTTPCachePolicyService) UpdateHTTPCachePolicy(ctx context.Context, req *pb.UpdateHTTPCachePolicyRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedHTTPCachePolicyDAO.UpdateCachePolicy(req.CachePolicyId, req.IsOn, req.Name, req.Description, req.CapacityJSON, req.MaxKeys, req.MaxSizeJSON, req.Type, req.OptionsJSON) + if err != nil { + return nil, err + } + + return rpcutils.RPCUpdateSuccess() +} + +// 删除缓存策略 +func (this *HTTPCachePolicyService) DeleteHTTPCachePolicy(ctx context.Context, req *pb.DeleteHTTPCachePolicyRequest) (*pb.RPCDeleteSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedHTTPCachePolicyDAO.DisableHTTPCachePolicy(req.CachePolicyId) + if err != nil { + return nil, err + } + + return rpcutils.RPCDeleteSuccess() +} + +// 计算缓存策略数量 +func (this *HTTPCachePolicyService) CountAllEnabledHTTPCachePolicies(ctx context.Context, req *pb.CountAllEnabledHTTPCachePoliciesRequest) (*pb.CountAllEnabledHTTPCachePoliciesResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + count, err := models.SharedHTTPCachePolicyDAO.CountAllEnabledHTTPCachePolicies() + if err != nil { + return nil, err + } + return &pb.CountAllEnabledHTTPCachePoliciesResponse{Count: count}, nil +} + +// 列出单页的缓存策略 +func (this *HTTPCachePolicyService) ListEnabledHTTPCachePolicies(ctx context.Context, req *pb.ListEnabledHTTPCachePoliciesRequest) (*pb.ListEnabledHTTPCachePoliciesResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + cachePolicies, err := models.SharedHTTPCachePolicyDAO.ListEnabledHTTPCachePolicies(req.Offset, req.Size) + if err != nil { + return nil, err + } + cachePoliciesJSON, err := json.Marshal(cachePolicies) + if err != nil { + return nil, err + } + return &pb.ListEnabledHTTPCachePoliciesResponse{CachePoliciesJSON: cachePoliciesJSON}, nil +} + +// 查找单个缓存策略配置 +func (this *HTTPCachePolicyService) FindEnabledHTTPCachePolicyConfig(ctx context.Context, req *pb.FindEnabledHTTPCachePolicyConfigRequest) (*pb.FindEnabledHTTPCachePolicyConfigResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + cachePolicy, err := models.SharedHTTPCachePolicyDAO.ComposeCachePolicy(req.CachePolicyId) + if err != nil { + return nil, err + } + cachePolicyJSON, err := json.Marshal(cachePolicy) + return &pb.FindEnabledHTTPCachePolicyConfigResponse{CachePolicyJSON: cachePolicyJSON}, nil +} diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index 7faca1bf..fab1ebcd 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -585,7 +585,7 @@ func (this *ServerService) FindAndInitServerWebConfig(ctx context.Context, req * } // 计算使用某个SSL证书的服务数量 -func (this *ServerService) CountServersWithSSLCertId(ctx context.Context, req *pb.CountServersWithSSLCertIdRequest) (*pb.CountServersWithSSLCertIdResponse, error) { +func (this *ServerService) CountAllEnabledServersWithSSLCertId(ctx context.Context, req *pb.CountAllEnabledServersWithSSLCertIdRequest) (*pb.CountAllEnabledServersWithSSLCertIdResponse, error) { // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { @@ -598,19 +598,19 @@ func (this *ServerService) CountServersWithSSLCertId(ctx context.Context, req *p } if len(policyIds) == 0 { - return &pb.CountServersWithSSLCertIdResponse{Count: 0}, nil + return &pb.CountAllEnabledServersWithSSLCertIdResponse{Count: 0}, nil } - count, err := models.SharedServerDAO.CountServersWithSSLPolicyIds(policyIds) + count, err := models.SharedServerDAO.CountAllEnabledServersWithSSLPolicyIds(policyIds) if err != nil { return nil, err } - return &pb.CountServersWithSSLCertIdResponse{Count: count}, nil + return &pb.CountAllEnabledServersWithSSLCertIdResponse{Count: count}, nil } // 查找使用某个SSL证书的所有服务 -func (this *ServerService) FindAllServersWithSSLCertId(ctx context.Context, req *pb.FindAllServersWithSSLCertIdRequest) (*pb.FindAllServersWithSSLCertIdResponse, error) { +func (this *ServerService) FindAllEnabledServersWithSSLCertId(ctx context.Context, req *pb.FindAllEnabledServersWithSSLCertIdRequest) (*pb.FindAllEnabledServersWithSSLCertIdResponse, error) { // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { @@ -622,10 +622,10 @@ func (this *ServerService) FindAllServersWithSSLCertId(ctx context.Context, req return nil, err } if len(policyIds) == 0 { - return &pb.FindAllServersWithSSLCertIdResponse{Servers: nil}, nil + return &pb.FindAllEnabledServersWithSSLCertIdResponse{Servers: nil}, nil } - servers, err := models.SharedServerDAO.FindAllServersWithSSLPolicyIds(policyIds) + servers, err := models.SharedServerDAO.FindAllEnabledServersWithSSLPolicyIds(policyIds) if err != nil { return nil, err } @@ -638,5 +638,58 @@ func (this *ServerService) FindAllServersWithSSLCertId(ctx context.Context, req Type: server.Type, }) } - return &pb.FindAllServersWithSSLCertIdResponse{Servers: result}, nil + return &pb.FindAllEnabledServersWithSSLCertIdResponse{Servers: result}, nil +} + +// 计算使用某个缓存策略的服务数量 +func (this *ServerService) CountAllEnabledServersWithCachePolicyId(ctx context.Context, req *pb.CountAllEnabledServersWithCachePolicyIdRequest) (*pb.CountAllEnabledServersWithCachePolicyIdResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + webIds, err := models.SharedHTTPWebDAO.FindAllWebIdsWithCachePolicyId(req.CachePolicyId) + if err != nil { + return nil, err + } + if len(webIds) == 0 { + return &pb.CountAllEnabledServersWithCachePolicyIdResponse{Count: 0}, nil + } + + countServers, err := models.SharedServerDAO.CountEnabledServersWithWebIds(webIds) + if err != nil { + return nil, err + } + return &pb.CountAllEnabledServersWithCachePolicyIdResponse{Count: countServers}, nil +} + +// 查找使用某个缓存策略的所有服务 +func (this *ServerService) FindAllEnabledServersWithCachePolicyId(ctx context.Context, req *pb.FindAllEnabledServersWithCachePolicyIdRequest) (*pb.FindAllEnabledServersWithCachePolicyIdResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + webIds, err := models.SharedHTTPWebDAO.FindAllWebIdsWithCachePolicyId(req.CachePolicyId) + if err != nil { + return nil, err + } + + if len(webIds) == 0 { + return &pb.FindAllEnabledServersWithCachePolicyIdResponse{Servers: nil}, nil + } + + servers, err := models.SharedServerDAO.FindAllEnabledServersWithWebIds(webIds) + result := []*pb.Server{} + for _, server := range servers { + result = append(result, &pb.Server{ + Id: int64(server.Id), + Name: server.Name, + IsOn: server.IsOn == 1, + Type: server.Type, + }) + } + return &pb.FindAllEnabledServersWithCachePolicyIdResponse{Servers: result}, nil } diff --git a/internal/rpc/services/service_sys_setting.go b/internal/rpc/services/service_sys_setting.go new file mode 100644 index 00000000..9b5354e4 --- /dev/null +++ b/internal/rpc/services/service_sys_setting.go @@ -0,0 +1,43 @@ +package services + +import ( + "context" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" +) + +type SysSettingService struct { +} + +// 更改配置 +func (this *SysSettingService) UpdateSysSetting(ctx context.Context, req *pb.UpdateSysSettingRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedSysSettingDAO.UpdateSetting(req.Code, req.ValueJSON) + if err != nil { + return nil, err + } + + return rpcutils.RPCUpdateSuccess() +} + +// 读取配置 +func (this *SysSettingService) ReadSysSetting(ctx context.Context, req *pb.ReadSysSettingRequest) (*pb.ReadSysSettingResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + valueJSON, err := models.SharedSysSettingDAO.ReadSetting(req.Code) + if err != nil { + return nil, err + } + + return &pb.ReadSysSettingResponse{ValueJSON: valueJSON}, nil +} diff --git a/internal/utils/errors.go b/internal/utils/errors.go index a54dddab..52992767 100644 --- a/internal/utils/errors.go +++ b/internal/utils/errors.go @@ -1,7 +1,10 @@ package utils -import "github.com/iwind/TeaGo/logs" +import ( + "github.com/iwind/TeaGo/logs" +) +// 打印错误 func PrintError(err error) { // TODO 记录调用的文件名、行数 logs.Println("[ERROR]" + err.Error())