实现源站端口跟随功能

This commit is contained in:
GoEdgeLab
2022-06-29 21:55:57 +08:00
parent 12c28f3fba
commit c26e08c1e3
3 changed files with 52 additions and 23 deletions

View File

@@ -101,7 +101,8 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx,
maxIdleConns int32,
certRef *sslconfigs.SSLCertRef,
domains []string,
host string) (originId int64, err error) {
host string,
followPort bool) (originId int64, err error) {
var op = NewOriginOperator()
op.AdminId = adminId
op.UserId = userId
@@ -167,6 +168,7 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx,
}
op.Host = host
op.FollowPort = followPort
op.State = OriginStateEnabled
err = this.Save(tx, op)
@@ -191,7 +193,8 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx,
maxIdleConns int32,
certRef *sslconfigs.SSLCertRef,
domains []string,
host string) error {
host string,
followPort bool) error {
if originId <= 0 {
return errors.New("invalid originId")
}
@@ -262,6 +265,7 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx,
}
op.Host = host
op.FollowPort = followPort
err := this.Save(tx, op)
if err != nil {
@@ -304,10 +308,11 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap
RequestURI: origin.HttpRequestURI,
RequestHost: origin.Host,
Domains: origin.DecodeDomains(),
FollowPort: origin.FollowPort,
}
if IsNotNull(origin.Addr) {
addr := &serverconfigs.NetworkAddressConfig{}
var addr = &serverconfigs.NetworkAddressConfig{}
err = json.Unmarshal(origin.Addr, addr)
if err != nil {
return nil, err
@@ -316,7 +321,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap
}
if IsNotNull(origin.ConnTimeout) {
connTimeout := &shared.TimeDuration{}
var connTimeout = &shared.TimeDuration{}
err = json.Unmarshal(origin.ConnTimeout, &connTimeout)
if err != nil {
return nil, err
@@ -325,7 +330,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap
}
if IsNotNull(origin.ReadTimeout) {
readTimeout := &shared.TimeDuration{}
var readTimeout = &shared.TimeDuration{}
err = json.Unmarshal(origin.ReadTimeout, &readTimeout)
if err != nil {
return nil, err
@@ -334,7 +339,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap
}
if IsNotNull(origin.IdleTimeout) {
idleTimeout := &shared.TimeDuration{}
var idleTimeout = &shared.TimeDuration{}
err = json.Unmarshal(origin.IdleTimeout, &idleTimeout)
if err != nil {
return nil, err
@@ -363,7 +368,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap
}
if IsNotNull(origin.HttpResponseHeader) {
ref := &shared.HTTPHeaderPolicyRef{}
var ref = &shared.HTTPHeaderPolicyRef{}
err = json.Unmarshal(origin.HttpResponseHeader, ref)
if err != nil {
return nil, err
@@ -382,7 +387,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap
}
if IsNotNull(origin.HealthCheck) {
healthCheck := &serverconfigs.HealthCheckConfig{}
var healthCheck = &serverconfigs.HealthCheckConfig{}
err = json.Unmarshal(origin.HealthCheck, healthCheck)
if err != nil {
return nil, err
@@ -391,7 +396,7 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap
}
if IsNotNull(origin.Cert) {
ref := &sslconfigs.SSLCertRef{}
var ref = &sslconfigs.SSLCertRef{}
err = json.Unmarshal(origin.Cert, ref)
if err != nil {
return nil, err
@@ -417,6 +422,19 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64, cacheMap
return config, nil
}
// CheckUserOrigin 检查源站权限
func (this *OriginDAO) CheckUserOrigin(tx *dbs.Tx, userId int64, originId int64) error {
reverseProxyId, err := SharedReverseProxyDAO.FindReverseProxyContainsOriginId(tx, originId)
if err != nil {
return err
}
if reverseProxyId == 0 {
// 这里我们不允许源站没有被使用
return ErrNotFound
}
return SharedReverseProxyDAO.CheckUserReverseProxy(tx, userId, reverseProxyId)
}
// NotifyUpdate 通知更新
func (this *OriginDAO) NotifyUpdate(tx *dbs.Tx, originId int64) error {
reverseProxyId, err := SharedReverseProxyDAO.FindReverseProxyContainsOriginId(tx, originId)