From c26e08c1e3612e9e0b5706eb4d7cd83183ea93c4 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Wed, 29 Jun 2022 21:55:57 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E6=BA=90=E7=AB=99=E7=AB=AF?= =?UTF-8?q?=E5=8F=A3=E8=B7=9F=E9=9A=8F=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/db/models/origin_dao.go | 36 ++++++++++++++++++------ internal/db/models/origin_model.go | 2 ++ internal/rpc/services/service_origin.go | 37 +++++++++++++++---------- 3 files changed, 52 insertions(+), 23 deletions(-) diff --git a/internal/db/models/origin_dao.go b/internal/db/models/origin_dao.go index 01298b58..2359af07 100644 --- a/internal/db/models/origin_dao.go +++ b/internal/db/models/origin_dao.go @@ -101,7 +101,8 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, maxIdleConns int32, certRef *sslconfigs.SSLCertRef, domains []string, - host string) (originId int64, err error) { + host string, + followPort bool) (originId int64, err error) { var op = NewOriginOperator() op.AdminId = adminId op.UserId = userId @@ -167,6 +168,7 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, } op.Host = host + op.FollowPort = followPort op.State = OriginStateEnabled err = this.Save(tx, op) @@ -191,7 +193,8 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, maxIdleConns int32, certRef *sslconfigs.SSLCertRef, domains []string, - host string) error { + host string, + followPort bool) error { if originId <= 0 { return errors.New("invalid originId") } @@ -262,6 +265,7 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, } op.Host = host + op.FollowPort = followPort err := this.Save(tx, op) if err != nil { @@ -304,10 +308,11 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap RequestURI: origin.HttpRequestURI, RequestHost: origin.Host, Domains: origin.DecodeDomains(), + FollowPort: origin.FollowPort, } if IsNotNull(origin.Addr) { - addr := &serverconfigs.NetworkAddressConfig{} + var addr = &serverconfigs.NetworkAddressConfig{} err = json.Unmarshal(origin.Addr, addr) if err != nil { return nil, err @@ -316,7 +321,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap } if IsNotNull(origin.ConnTimeout) { - connTimeout := &shared.TimeDuration{} + var connTimeout = &shared.TimeDuration{} err = json.Unmarshal(origin.ConnTimeout, &connTimeout) if err != nil { return nil, err @@ -325,7 +330,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap } if IsNotNull(origin.ReadTimeout) { - readTimeout := &shared.TimeDuration{} + var readTimeout = &shared.TimeDuration{} err = json.Unmarshal(origin.ReadTimeout, &readTimeout) if err != nil { return nil, err @@ -334,7 +339,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap } if IsNotNull(origin.IdleTimeout) { - idleTimeout := &shared.TimeDuration{} + var idleTimeout = &shared.TimeDuration{} err = json.Unmarshal(origin.IdleTimeout, &idleTimeout) if err != nil { return nil, err @@ -363,7 +368,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap } if IsNotNull(origin.HttpResponseHeader) { - ref := &shared.HTTPHeaderPolicyRef{} + var ref = &shared.HTTPHeaderPolicyRef{} err = json.Unmarshal(origin.HttpResponseHeader, ref) if err != nil { return nil, err @@ -382,7 +387,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap } if IsNotNull(origin.HealthCheck) { - healthCheck := &serverconfigs.HealthCheckConfig{} + var healthCheck = &serverconfigs.HealthCheckConfig{} err = json.Unmarshal(origin.HealthCheck, healthCheck) if err != nil { return nil, err @@ -391,7 +396,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap } if IsNotNull(origin.Cert) { - ref := &sslconfigs.SSLCertRef{} + var ref = &sslconfigs.SSLCertRef{} err = json.Unmarshal(origin.Cert, ref) if err != nil { return nil, err @@ -417,6 +422,19 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap return config, nil } +// CheckUserOrigin 检查源站权限 +func (this *OriginDAO) CheckUserOrigin(tx *dbs.Tx, userId int64, originId int64) error { + reverseProxyId, err := SharedReverseProxyDAO.FindReverseProxyContainsOriginId(tx, originId) + if err != nil { + return err + } + if reverseProxyId == 0 { + // 这里我们不允许源站没有被使用 + return ErrNotFound + } + return SharedReverseProxyDAO.CheckUserReverseProxy(tx, userId, reverseProxyId) +} + // NotifyUpdate 通知更新 func (this *OriginDAO) NotifyUpdate(tx *dbs.Tx, originId int64) error { reverseProxyId, err := SharedReverseProxyDAO.FindReverseProxyContainsOriginId(tx, originId) diff --git a/internal/db/models/origin_model.go b/internal/db/models/origin_model.go index f0bf312c..e4f88738 100644 --- a/internal/db/models/origin_model.go +++ b/internal/db/models/origin_model.go @@ -29,6 +29,7 @@ type Origin struct { Ftp dbs.JSON `field:"ftp"` // FTP相关设置 CreatedAt uint64 `field:"createdAt"` // 创建时间 Domains dbs.JSON `field:"domains"` // 所属域名 + FollowPort bool `field:"followPort"` // 端口跟随 State uint8 `field:"state"` // 状态 } @@ -58,6 +59,7 @@ type OriginOperator struct { Ftp interface{} // FTP相关设置 CreatedAt interface{} // 创建时间 Domains interface{} // 所属域名 + FollowPort interface{} // 端口跟随 State interface{} // 状态 } diff --git a/internal/rpc/services/service_origin.go b/internal/rpc/services/service_origin.go index 5c40c8dc..7f4f0acc 100644 --- a/internal/rpc/services/service_origin.go +++ b/internal/rpc/services/service_origin.go @@ -72,7 +72,7 @@ func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOrigi } } - originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, certRef, req.Domains, req.Host) + originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, certRef, req.Domains, req.Host, req.FollowPort) if err != nil { return nil, err } @@ -87,20 +87,23 @@ func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOrigi return nil, err } + var tx = this.NullTx() + if userId > 0 { - // TODO 校验权限 + err = models.SharedOriginDAO.CheckUserOrigin(tx, userId, req.OriginId) + if err != nil { + return nil, err + } } if req.Addr == nil { return nil, errors.New("'addr' can not be nil") } - addrMap := maps.Map{ + var addrMap = maps.Map{ "protocol": req.Addr.Protocol, "portRange": req.Addr.PortRange, "host": req.Addr.Host, } - tx := this.NullTx() - // 校验参数 var connTimeout = &shared.TimeDuration{} if len(req.ConnTimeoutJSON) > 0 { @@ -139,7 +142,7 @@ func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOrigi } } - err = models.SharedOriginDAO.UpdateOrigin(tx, req.OriginId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, certRef, req.Domains, req.Host) + err = models.SharedOriginDAO.UpdateOrigin(tx, req.OriginId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, certRef, req.Domains, req.Host, req.FollowPort) if err != nil { return nil, err } @@ -154,11 +157,14 @@ func (this *OriginService) FindEnabledOrigin(ctx context.Context, req *pb.FindEn return nil, err } - if userId > 0 { - // TODO 校验权限 - } + var tx = this.NullTx() - tx := this.NullTx() + if userId > 0 { + err = models.SharedOriginDAO.CheckUserOrigin(tx, userId, req.OriginId) + if err != nil { + return nil, err + } + } origin, err := models.SharedOriginDAO.FindEnabledOrigin(tx, req.OriginId) if err != nil { @@ -196,11 +202,14 @@ func (this *OriginService) FindEnabledOriginConfig(ctx context.Context, req *pb. return nil, err } - if userId > 0 { - // TODO 校验权限 - } + var tx = this.NullTx() - tx := this.NullTx() + if userId > 0 { + err = models.SharedOriginDAO.CheckUserOrigin(tx, userId, req.OriginId) + if err != nil { + return nil, err + } + } config, err := models.SharedOriginDAO.ComposeOriginConfig(tx, req.OriginId, nil) if err != nil {