diff --git a/internal/db/models/http_location_dao.go b/internal/db/models/http_location_dao.go index 69057640..3c64868f 100644 --- a/internal/db/models/http_location_dao.go +++ b/internal/db/models/http_location_dao.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" @@ -36,12 +37,12 @@ func init() { }) } -// 初始化 +// Init 初始化 func (this *HTTPLocationDAO) Init() { _ = this.DAOObject.Init() } -// 启用条目 +// EnableHTTPLocation 启用条目 func (this *HTTPLocationDAO) EnableHTTPLocation(tx *dbs.Tx, id int64) error { _, err := this.Query(tx). Pk(id). @@ -50,7 +51,7 @@ func (this *HTTPLocationDAO) EnableHTTPLocation(tx *dbs.Tx, id int64) error { return err } -// 禁用条目 +// DisableHTTPLocation 禁用条目 func (this *HTTPLocationDAO) DisableHTTPLocation(tx *dbs.Tx, locationId int64) error { _, err := this.Query(tx). Pk(locationId). @@ -62,7 +63,7 @@ func (this *HTTPLocationDAO) DisableHTTPLocation(tx *dbs.Tx, locationId int64) e return this.NotifyUpdate(tx, locationId) } -// 查找启用中的条目 +// FindEnabledHTTPLocation 查找启用中的条目 func (this *HTTPLocationDAO) FindEnabledHTTPLocation(tx *dbs.Tx, id int64) (*HTTPLocation, error) { result, err := this.Query(tx). Pk(id). @@ -74,7 +75,7 @@ func (this *HTTPLocationDAO) FindEnabledHTTPLocation(tx *dbs.Tx, id int64) (*HTT return result.(*HTTPLocation), err } -// 根据主键查找名称 +// FindHTTPLocationName 根据主键查找名称 func (this *HTTPLocationDAO) FindHTTPLocationName(tx *dbs.Tx, id int64) (string, error) { return this.Query(tx). Pk(id). @@ -82,8 +83,8 @@ func (this *HTTPLocationDAO) FindHTTPLocationName(tx *dbs.Tx, id int64) (string, FindStringCol("") } -// 创建路径规则 -func (this *HTTPLocationDAO) CreateLocation(tx *dbs.Tx, parentId int64, name string, pattern string, description string, isBreak bool) (int64, error) { +// CreateLocation 创建路径规则 +func (this *HTTPLocationDAO) CreateLocation(tx *dbs.Tx, parentId int64, name string, pattern string, description string, isBreak bool, condsJSON []byte) (int64, error) { op := NewHTTPLocationOperator() op.IsOn = true op.State = HTTPLocationStateEnabled @@ -92,6 +93,11 @@ func (this *HTTPLocationDAO) CreateLocation(tx *dbs.Tx, parentId int64, name str op.Pattern = pattern op.Description = description op.IsBreak = isBreak + + if len(condsJSON) > 0 { + op.Conds = condsJSON + } + err := this.Save(tx, op) if err != nil { return 0, err @@ -99,8 +105,8 @@ func (this *HTTPLocationDAO) CreateLocation(tx *dbs.Tx, parentId int64, name str return types.Int64(op.Id), nil } -// 修改路径规则 -func (this *HTTPLocationDAO) UpdateLocation(tx *dbs.Tx, locationId int64, name string, pattern string, description string, isOn bool, isBreak bool) error { +// UpdateLocation 修改路径规则 +func (this *HTTPLocationDAO) UpdateLocation(tx *dbs.Tx, locationId int64, name string, pattern string, description string, isOn bool, isBreak bool, condsJSON []byte) error { if locationId <= 0 { return errors.New("invalid locationId") } @@ -111,6 +117,11 @@ func (this *HTTPLocationDAO) UpdateLocation(tx *dbs.Tx, locationId int64, name s op.Description = description op.IsOn = isOn op.IsBreak = isBreak + + if len(condsJSON) > 0 { + op.Conds = condsJSON + } + err := this.Save(tx, op) if err != nil { return err @@ -118,7 +129,7 @@ func (this *HTTPLocationDAO) UpdateLocation(tx *dbs.Tx, locationId int64, name s return this.NotifyUpdate(tx, locationId) } -// 组合配置 +// ComposeLocationConfig 组合配置 func (this *HTTPLocationDAO) ComposeLocationConfig(tx *dbs.Tx, locationId int64) (*serverconfigs.HTTPLocationConfig, error) { location, err := this.FindEnabledHTTPLocation(tx, locationId) if err != nil { @@ -163,10 +174,20 @@ func (this *HTTPLocationDAO) ComposeLocationConfig(tx *dbs.Tx, locationId int64) } } + // conds + if len(location.Conds) > 0 { + conds := &shared.HTTPRequestCondsConfig{} + err = json.Unmarshal([]byte(location.Conds), conds) + if err != nil { + return nil, err + } + config.Conds = conds + } + return config, nil } -// 查找反向代理设置 +// FindLocationReverseProxy 查找反向代理设置 func (this *HTTPLocationDAO) FindLocationReverseProxy(tx *dbs.Tx, locationId int64) (*serverconfigs.ReverseProxyRef, error) { refString, err := this.Query(tx). Pk(locationId). @@ -186,7 +207,7 @@ func (this *HTTPLocationDAO) FindLocationReverseProxy(tx *dbs.Tx, locationId int return nil, nil } -// 更改反向代理设置 +// UpdateLocationReverseProxy 更改反向代理设置 func (this *HTTPLocationDAO) UpdateLocationReverseProxy(tx *dbs.Tx, locationId int64, reverseProxyJSON []byte) error { if locationId <= 0 { return errors.New("invalid locationId") @@ -201,7 +222,7 @@ func (this *HTTPLocationDAO) UpdateLocationReverseProxy(tx *dbs.Tx, locationId i return this.NotifyUpdate(tx, locationId) } -// 查找WebId +// FindLocationWebId 查找WebId func (this *HTTPLocationDAO) FindLocationWebId(tx *dbs.Tx, locationId int64) (int64, error) { webId, err := this.Query(tx). Pk(locationId). @@ -210,7 +231,7 @@ func (this *HTTPLocationDAO) FindLocationWebId(tx *dbs.Tx, locationId int64) (in return int64(webId), err } -// 更改Web设置 +// UpdateLocationWeb 更改Web设置 func (this *HTTPLocationDAO) UpdateLocationWeb(tx *dbs.Tx, locationId int64, webId int64) error { if locationId <= 0 { return errors.New("invalid locationId") @@ -225,7 +246,7 @@ func (this *HTTPLocationDAO) UpdateLocationWeb(tx *dbs.Tx, locationId int64, web return this.NotifyUpdate(tx, locationId) } -// 转换引用为配置 +// ConvertLocationRefs 转换引用为配置 func (this *HTTPLocationDAO) ConvertLocationRefs(tx *dbs.Tx, refs []*serverconfigs.HTTPLocationRef) (locations []*serverconfigs.HTTPLocationConfig, err error) { for _, ref := range refs { config, err := this.ComposeLocationConfig(tx, ref.LocationId) @@ -243,7 +264,7 @@ func (this *HTTPLocationDAO) ConvertLocationRefs(tx *dbs.Tx, refs []*serverconfi return } -// 根据WebId查找LocationId +// FindEnabledLocationIdWithWebId 根据WebId查找LocationId func (this *HTTPLocationDAO) FindEnabledLocationIdWithWebId(tx *dbs.Tx, webId int64) (locationId int64, err error) { if webId <= 0 { return @@ -254,7 +275,7 @@ func (this *HTTPLocationDAO) FindEnabledLocationIdWithWebId(tx *dbs.Tx, webId in FindInt64Col(0) } -// 通知更新 +// NotifyUpdate 通知更新 func (this *HTTPLocationDAO) NotifyUpdate(tx *dbs.Tx, locationId int64) error { webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithLocationId(tx, locationId) if err != nil { diff --git a/internal/db/models/http_rewrite_rule_dao.go b/internal/db/models/http_rewrite_rule_dao.go index ced47f47..229e02d5 100644 --- a/internal/db/models/http_rewrite_rule_dao.go +++ b/internal/db/models/http_rewrite_rule_dao.go @@ -1,8 +1,10 @@ package models import ( + "encoding/json" "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" @@ -35,12 +37,12 @@ func init() { }) } -// 初始化 +// Init 初始化 func (this *HTTPRewriteRuleDAO) Init() { _ = this.DAOObject.Init() } -// 启用条目 +// EnableHTTPRewriteRule 启用条目 func (this *HTTPRewriteRuleDAO) EnableHTTPRewriteRule(tx *dbs.Tx, id int64) error { _, err := this.Query(tx). Pk(id). @@ -49,7 +51,7 @@ func (this *HTTPRewriteRuleDAO) EnableHTTPRewriteRule(tx *dbs.Tx, id int64) erro return err } -// 禁用条目 +// DisableHTTPRewriteRule 禁用条目 func (this *HTTPRewriteRuleDAO) DisableHTTPRewriteRule(tx *dbs.Tx, rewriteRuleId int64) error { _, err := this.Query(tx). Pk(rewriteRuleId). @@ -61,7 +63,7 @@ func (this *HTTPRewriteRuleDAO) DisableHTTPRewriteRule(tx *dbs.Tx, rewriteRuleId return this.NotifyUpdate(tx, rewriteRuleId) } -// 查找启用中的条目 +// FindEnabledHTTPRewriteRule 查找启用中的条目 func (this *HTTPRewriteRuleDAO) FindEnabledHTTPRewriteRule(tx *dbs.Tx, id int64) (*HTTPRewriteRule, error) { result, err := this.Query(tx). Pk(id). @@ -73,7 +75,7 @@ func (this *HTTPRewriteRuleDAO) FindEnabledHTTPRewriteRule(tx *dbs.Tx, id int64) return result.(*HTTPRewriteRule), err } -// 构造配置 +// ComposeRewriteRule 构造配置 func (this *HTTPRewriteRuleDAO) ComposeRewriteRule(tx *dbs.Tx, rewriteRuleId int64) (*serverconfigs.HTTPRewriteRule, error) { rule, err := this.FindEnabledHTTPRewriteRule(tx, rewriteRuleId) if err != nil { @@ -93,11 +95,21 @@ func (this *HTTPRewriteRuleDAO) ComposeRewriteRule(tx *dbs.Tx, rewriteRuleId int config.ProxyHost = rule.ProxyHost config.IsBreak = rule.IsBreak == 1 config.WithQuery = rule.WithQuery == 1 + + // conds + if len(rule.Conds) > 0 { + conds := &shared.HTTPRequestCondsConfig{} + err = json.Unmarshal([]byte(rule.Conds), conds) + if err != nil { + return nil, err + } + config.Conds = conds + } return config, nil } -// 创建规则 -func (this *HTTPRewriteRuleDAO) CreateRewriteRule(tx *dbs.Tx, pattern string, replace string, mode string, redirectStatus int, isBreak bool, proxyHost string, withQuery bool, isOn bool) (int64, error) { +// CreateRewriteRule 创建规则 +func (this *HTTPRewriteRuleDAO) CreateRewriteRule(tx *dbs.Tx, pattern string, replace string, mode string, redirectStatus int, isBreak bool, proxyHost string, withQuery bool, isOn bool, condsJSON []byte) (int64, error) { op := NewHTTPRewriteRuleOperator() op.State = HTTPRewriteRuleStateEnabled op.IsOn = isOn @@ -109,12 +121,17 @@ func (this *HTTPRewriteRuleDAO) CreateRewriteRule(tx *dbs.Tx, pattern string, re op.IsBreak = isBreak op.WithQuery = withQuery op.ProxyHost = proxyHost + + if len(condsJSON) > 0 { + op.Conds = condsJSON + } + err := this.Save(tx, op) return types.Int64(op.Id), err } -// 修改规则 -func (this *HTTPRewriteRuleDAO) UpdateRewriteRule(tx *dbs.Tx, rewriteRuleId int64, pattern string, replace string, mode string, redirectStatus int, isBreak bool, proxyHost string, withQuery bool, isOn bool) error { +// UpdateRewriteRule 修改规则 +func (this *HTTPRewriteRuleDAO) UpdateRewriteRule(tx *dbs.Tx, rewriteRuleId int64, pattern string, replace string, mode string, redirectStatus int, isBreak bool, proxyHost string, withQuery bool, isOn bool, condsJSON []byte) error { if rewriteRuleId <= 0 { return errors.New("invalid rewriteRuleId") } @@ -128,6 +145,11 @@ func (this *HTTPRewriteRuleDAO) UpdateRewriteRule(tx *dbs.Tx, rewriteRuleId int6 op.IsBreak = isBreak op.WithQuery = withQuery op.ProxyHost = proxyHost + + if len(condsJSON) > 0 { + op.Conds = condsJSON + } + err := this.Save(tx, op) if err != nil { return err @@ -135,7 +157,7 @@ func (this *HTTPRewriteRuleDAO) UpdateRewriteRule(tx *dbs.Tx, rewriteRuleId int6 return this.NotifyUpdate(tx, rewriteRuleId) } -// 通知更新 +// NotifyUpdate 通知更新 func (this *HTTPRewriteRuleDAO) NotifyUpdate(tx *dbs.Tx, rewriteRuleId int64) error { webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithRewriteRuleId(tx, rewriteRuleId) if err != nil { diff --git a/internal/rpc/services/service_http_location.go b/internal/rpc/services/service_http_location.go index 8d9a8999..a1056f28 100644 --- a/internal/rpc/services/service_http_location.go +++ b/internal/rpc/services/service_http_location.go @@ -9,11 +9,12 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" ) +// HTTPLocationService 路径规则相关服务 type HTTPLocationService struct { BaseService } -// 创建路径规则 +// CreateHTTPLocation 创建路径规则 func (this *HTTPLocationService) CreateHTTPLocation(ctx context.Context, req *pb.CreateHTTPLocationRequest) (*pb.CreateHTTPLocationResponse, error) { // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) @@ -23,7 +24,7 @@ func (this *HTTPLocationService) CreateHTTPLocation(ctx context.Context, req *pb tx := this.NullTx() - locationId, err := models.SharedHTTPLocationDAO.CreateLocation(tx, req.ParentId, req.Name, req.Pattern, req.Description, req.IsBreak) + locationId, err := models.SharedHTTPLocationDAO.CreateLocation(tx, req.ParentId, req.Name, req.Pattern, req.Description, req.IsBreak, req.CondsJSON) if err != nil { return nil, err } @@ -31,7 +32,7 @@ func (this *HTTPLocationService) CreateHTTPLocation(ctx context.Context, req *pb return &pb.CreateHTTPLocationResponse{LocationId: locationId}, nil } -// 修改路径规则 +// UpdateHTTPLocation 修改路径规则 func (this *HTTPLocationService) UpdateHTTPLocation(ctx context.Context, req *pb.UpdateHTTPLocationRequest) (*pb.RPCSuccess, error) { // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) @@ -41,7 +42,7 @@ func (this *HTTPLocationService) UpdateHTTPLocation(ctx context.Context, req *pb tx := this.NullTx() - err = models.SharedHTTPLocationDAO.UpdateLocation(tx, req.LocationId, req.Name, req.Pattern, req.Description, req.IsOn, req.IsBreak) + err = models.SharedHTTPLocationDAO.UpdateLocation(tx, req.LocationId, req.Name, req.Pattern, req.Description, req.IsOn, req.IsBreak, req.CondsJSON) if err != nil { return nil, err } @@ -49,7 +50,7 @@ func (this *HTTPLocationService) UpdateHTTPLocation(ctx context.Context, req *pb return this.Success() } -// 查找路径规则配置 +// FindEnabledHTTPLocationConfig 查找路径规则配置 func (this *HTTPLocationService) FindEnabledHTTPLocationConfig(ctx context.Context, req *pb.FindEnabledHTTPLocationConfigRequest) (*pb.FindEnabledHTTPLocationConfigResponse, error) { // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) @@ -70,7 +71,7 @@ func (this *HTTPLocationService) FindEnabledHTTPLocationConfig(ctx context.Conte return &pb.FindEnabledHTTPLocationConfigResponse{LocationJSON: configJSON}, nil } -// 删除路径规则 +// DeleteHTTPLocation 删除路径规则 func (this *HTTPLocationService) DeleteHTTPLocation(ctx context.Context, req *pb.DeleteHTTPLocationRequest) (*pb.RPCSuccess, error) { // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) @@ -87,7 +88,7 @@ func (this *HTTPLocationService) DeleteHTTPLocation(ctx context.Context, req *pb return this.Success() } -// 查找反向代理设置 +// FindAndInitHTTPLocationReverseProxyConfig 查找反向代理设置 func (this *HTTPLocationService) FindAndInitHTTPLocationReverseProxyConfig(ctx context.Context, req *pb.FindAndInitHTTPLocationReverseProxyConfigRequest) (*pb.FindAndInitHTTPLocationReverseProxyConfigResponse, error) { // 校验请求 adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) @@ -140,7 +141,7 @@ func (this *HTTPLocationService) FindAndInitHTTPLocationReverseProxyConfig(ctx c }, nil } -// 初始化Web设置 +// FindAndInitHTTPLocationWebConfig 初始化Web设置 func (this *HTTPLocationService) FindAndInitHTTPLocationWebConfig(ctx context.Context, req *pb.FindAndInitHTTPLocationWebConfigRequest) (*pb.FindAndInitHTTPLocationWebConfigResponse, error) { // 校验请求 adminId, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) @@ -179,7 +180,7 @@ func (this *HTTPLocationService) FindAndInitHTTPLocationWebConfig(ctx context.Co }, nil } -// 修改反向代理设置 +// UpdateHTTPLocationReverseProxy 修改反向代理设置 func (this *HTTPLocationService) UpdateHTTPLocationReverseProxy(ctx context.Context, req *pb.UpdateHTTPLocationReverseProxyRequest) (*pb.RPCSuccess, error) { // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) diff --git a/internal/rpc/services/service_http_rewrite_rule.go b/internal/rpc/services/service_http_rewrite_rule.go index 1d76ea28..e3e657dd 100644 --- a/internal/rpc/services/service_http_rewrite_rule.go +++ b/internal/rpc/services/service_http_rewrite_rule.go @@ -8,11 +8,12 @@ import ( "github.com/iwind/TeaGo/types" ) +// HTTPRewriteRuleService 重写规则相关服务 type HTTPRewriteRuleService struct { BaseService } -// 创建重写规则 +// CreateHTTPRewriteRule 创建重写规则 func (this *HTTPRewriteRuleService) CreateHTTPRewriteRule(ctx context.Context, req *pb.CreateHTTPRewriteRuleRequest) (*pb.CreateHTTPRewriteRuleResponse, error) { // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) @@ -22,7 +23,7 @@ func (this *HTTPRewriteRuleService) CreateHTTPRewriteRule(ctx context.Context, r tx := this.NullTx() - rewriteRuleId, err := models.SharedHTTPRewriteRuleDAO.CreateRewriteRule(tx, req.Pattern, req.Replace, req.Mode, types.Int(req.RedirectStatus), req.IsBreak, req.ProxyHost, req.WithQuery, req.IsOn) + rewriteRuleId, err := models.SharedHTTPRewriteRuleDAO.CreateRewriteRule(tx, req.Pattern, req.Replace, req.Mode, types.Int(req.RedirectStatus), req.IsBreak, req.ProxyHost, req.WithQuery, req.IsOn, req.CondsJSON) if err != nil { return nil, err } @@ -30,7 +31,7 @@ func (this *HTTPRewriteRuleService) CreateHTTPRewriteRule(ctx context.Context, r return &pb.CreateHTTPRewriteRuleResponse{RewriteRuleId: rewriteRuleId}, nil } -// 修改重写规则 +// UpdateHTTPRewriteRule 修改重写规则 func (this *HTTPRewriteRuleService) UpdateHTTPRewriteRule(ctx context.Context, req *pb.UpdateHTTPRewriteRuleRequest) (*pb.RPCSuccess, error) { // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) @@ -40,7 +41,7 @@ func (this *HTTPRewriteRuleService) UpdateHTTPRewriteRule(ctx context.Context, r tx := this.NullTx() - err = models.SharedHTTPRewriteRuleDAO.UpdateRewriteRule(tx, req.RewriteRuleId, req.Pattern, req.Replace, req.Mode, types.Int(req.RedirectStatus), req.IsBreak, req.ProxyHost, req.WithQuery, req.IsOn) + err = models.SharedHTTPRewriteRuleDAO.UpdateRewriteRule(tx, req.RewriteRuleId, req.Pattern, req.Replace, req.Mode, types.Int(req.RedirectStatus), req.IsBreak, req.ProxyHost, req.WithQuery, req.IsOn, req.CondsJSON) if err != nil { return nil, err }