From 7a3783a09fbdf16ce255068dfebca5ccadd840c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=A5=A5=E8=B6=85?= Date: Tue, 26 Jan 2021 20:29:41 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=AF=E4=BB=A5=E9=85=8D=E7=BD=AE=E6=98=AF?= =?UTF-8?q?=E5=90=A6=E5=9C=A8=E5=8F=8D=E5=90=91=E4=BB=A3=E7=90=86=E4=B8=AD?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0X-Real-IP=E5=92=8CX-Forwarded-*?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/db/models/reverse_proxy_dao.go | 25 ++++++++- internal/db/models/reverse_proxy_model.go | 2 + .../rpc/services/service_reverse_proxy.go | 56 ++++++++++++++++--- 3 files changed, 74 insertions(+), 9 deletions(-) diff --git a/internal/db/models/reverse_proxy_dao.go b/internal/db/models/reverse_proxy_dao.go index 8e7ef6d1..d65430d1 100644 --- a/internal/db/models/reverse_proxy_dao.go +++ b/internal/db/models/reverse_proxy_dao.go @@ -139,6 +139,16 @@ func (this *ReverseProxyDAO) ComposeReverseProxyConfig(tx *dbs.Tx, reverseProxyI } } + // add headers + if IsNotNull(reverseProxy.AddHeaders) { + addHeaders := []string{} + err = json.Unmarshal([]byte(reverseProxy.AddHeaders), &addHeaders) + if err != nil { + return nil, err + } + config.AddHeaders = addHeaders + } + return config, nil } @@ -149,6 +159,7 @@ func (this *ReverseProxyDAO) CreateReverseProxy(tx *dbs.Tx, adminId int64, userI op.State = ReverseProxyStateEnabled op.AdminId = adminId op.UserId = userId + op.AddHeaders = []string{"X-Real-IP"} if len(schedulingJSON) > 0 { op.Scheduling = string(schedulingJSON) @@ -225,7 +236,7 @@ func (this *ReverseProxyDAO) UpdateReverseProxyBackupOrigins(tx *dbs.Tx, reverse } // 修改是否启用 -func (this *ReverseProxyDAO) UpdateReverseProxy(tx *dbs.Tx, reverseProxyId int64, requestHostType int8, requestHost string, requestURI string, stripPrefix string, autoFlush bool) error { +func (this *ReverseProxyDAO) UpdateReverseProxy(tx *dbs.Tx, reverseProxyId int64, requestHostType int8, requestHost string, requestURI string, stripPrefix string, autoFlush bool, addHeaders []string) error { if reverseProxyId <= 0 { return errors.New("invalid reverseProxyId") } @@ -242,7 +253,17 @@ func (this *ReverseProxyDAO) UpdateReverseProxy(tx *dbs.Tx, reverseProxyId int64 op.RequestURI = requestURI op.StripPrefix = stripPrefix op.AutoFlush = autoFlush - err := this.Save(tx, op) + + if len(addHeaders) == 0 { + addHeaders = []string{} + } + addHeadersJSON, err := json.Marshal(addHeaders) + if err != nil { + return err + } + op.AddHeaders = addHeadersJSON + + err = this.Save(tx, op) if err != nil { return err } diff --git a/internal/db/models/reverse_proxy_model.go b/internal/db/models/reverse_proxy_model.go index 4c665438..533e44ba 100644 --- a/internal/db/models/reverse_proxy_model.go +++ b/internal/db/models/reverse_proxy_model.go @@ -15,6 +15,7 @@ type ReverseProxy struct { RequestHost string `field:"requestHost"` // 请求Host RequestURI string `field:"requestURI"` // 请求URI AutoFlush uint8 `field:"autoFlush"` // 是否自动刷新缓冲区 + AddHeaders string `field:"addHeaders"` // 自动添加的Header列表 State uint8 `field:"state"` // 状态 CreatedAt uint64 `field:"createdAt"` // 创建时间 } @@ -33,6 +34,7 @@ type ReverseProxyOperator struct { RequestHost interface{} // 请求Host RequestURI interface{} // 请求URI AutoFlush interface{} // 是否自动刷新缓冲区 + AddHeaders interface{} // 自动添加的Header列表 State interface{} // 状态 CreatedAt interface{} // 创建时间 } diff --git a/internal/rpc/services/service_reverse_proxy.go b/internal/rpc/services/service_reverse_proxy.go index 786ff020..18e026a2 100644 --- a/internal/rpc/services/service_reverse_proxy.go +++ b/internal/rpc/services/service_reverse_proxy.go @@ -43,7 +43,14 @@ func (this *ReverseProxyService) FindEnabledReverseProxy(ctx context.Context, re } if userId > 0 { - // TODO 检查权限 + serverId, err := models.SharedServerDAO.FindEnabledServerIdWithReverseProxyId(nil, req.ReverseProxyId) + if err != nil { + return nil, err + } + err = models.SharedServerDAO.CheckUserServer(nil, userId, serverId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -74,7 +81,14 @@ func (this *ReverseProxyService) FindEnabledReverseProxyConfig(ctx context.Conte } if userId > 0 { - // TODO 检查权限 + serverId, err := models.SharedServerDAO.FindEnabledServerIdWithReverseProxyId(nil, req.ReverseProxyId) + if err != nil { + return nil, err + } + err = models.SharedServerDAO.CheckUserServer(nil, userId, serverId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -101,7 +115,14 @@ func (this *ReverseProxyService) UpdateReverseProxyScheduling(ctx context.Contex } if userId > 0 { - // TODO 检查权限 + serverId, err := models.SharedServerDAO.FindEnabledServerIdWithReverseProxyId(nil, req.ReverseProxyId) + if err != nil { + return nil, err + } + err = models.SharedServerDAO.CheckUserServer(nil, userId, serverId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -123,7 +144,14 @@ func (this *ReverseProxyService) UpdateReverseProxyPrimaryOrigins(ctx context.Co } if userId > 0 { - // TODO 检查权限 + serverId, err := models.SharedServerDAO.FindEnabledServerIdWithReverseProxyId(nil, req.ReverseProxyId) + if err != nil { + return nil, err + } + err = models.SharedServerDAO.CheckUserServer(nil, userId, serverId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -145,7 +173,14 @@ func (this *ReverseProxyService) UpdateReverseProxyBackupOrigins(ctx context.Con } if userId > 0 { - // TODO 检查权限 + serverId, err := models.SharedServerDAO.FindEnabledServerIdWithReverseProxyId(nil, req.ReverseProxyId) + if err != nil { + return nil, err + } + err = models.SharedServerDAO.CheckUserServer(nil, userId, serverId) + if err != nil { + return nil, err + } } tx := this.NullTx() @@ -167,12 +202,19 @@ func (this *ReverseProxyService) UpdateReverseProxy(ctx context.Context, req *pb } if userId > 0 { - // TODO 检查权限 + serverId, err := models.SharedServerDAO.FindEnabledServerIdWithReverseProxyId(nil, req.ReverseProxyId) + if err != nil { + return nil, err + } + err = models.SharedServerDAO.CheckUserServer(nil, userId, serverId) + if err != nil { + return nil, err + } } tx := this.NullTx() - err = models.SharedReverseProxyDAO.UpdateReverseProxy(tx, req.ReverseProxyId, types.Int8(req.RequestHostType), req.RequestHost, req.RequestURI, req.StripPrefix, req.AutoFlush) + err = models.SharedReverseProxyDAO.UpdateReverseProxy(tx, req.ReverseProxyId, types.Int8(req.RequestHostType), req.RequestHost, req.RequestURI, req.StripPrefix, req.AutoFlush, req.AddHeaders) if err != nil { return nil, err }