diff --git a/pkg/serverconfigs/http_web_config.go b/pkg/serverconfigs/http_web_config.go index a189dfa..22b9447 100644 --- a/pkg/serverconfigs/http_web_config.go +++ b/pkg/serverconfigs/http_web_config.go @@ -320,6 +320,14 @@ func (this *HTTPWebConfig) Init(ctx context.Context) error { } } + // referers + if this.Referers != nil { + err := this.Referers.Init() + if err != nil { + return err + } + } + return nil } diff --git a/pkg/serverconfigs/referers_config.go b/pkg/serverconfigs/referers_config.go index 66e8bc7..b93355c 100644 --- a/pkg/serverconfigs/referers_config.go +++ b/pkg/serverconfigs/referers_config.go @@ -2,7 +2,10 @@ package serverconfigs -import "github.com/TeaOSLab/EdgeCommon/pkg/configutils" +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/configutils" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" +) // NewReferersConfig 获取新防盗链配置对象 func NewReferersConfig() *ReferersConfig { @@ -20,9 +23,27 @@ type ReferersConfig struct { AllowDomains []string `yaml:"allowDomains" json:"allowDomains"` // 允许的来源域名列表 DenyDomains []string `yaml:"denyDomains" json:"denyDomains"` // 禁止的来源域名列表 CheckOrigin bool `yaml:"checkOrigin" json:"checkOrigin"` // 是否检查Origin + + OnlyURLPatterns []*shared.URLPattern `yaml:"onlyURLPatterns" json:"onlyURLPatterns"` // 仅限的URL + ExceptURLPatterns []*shared.URLPattern `yaml:"exceptURLPatterns" json:"exceptURLPatterns"` // 排除的URL } func (this *ReferersConfig) Init() error { + // url patterns + for _, pattern := range this.ExceptURLPatterns { + err := pattern.Init() + if err != nil { + return err + } + } + + for _, pattern := range this.OnlyURLPatterns { + err := pattern.Init() + if err != nil { + return err + } + } + return nil } @@ -54,3 +75,26 @@ func (this *ReferersConfig) MatchDomain(requestDomain string, refererDomain stri return false } + +func (this *ReferersConfig) MatchURL(url string) bool { + // except + if len(this.ExceptURLPatterns) > 0 { + for _, pattern := range this.ExceptURLPatterns { + if pattern.Match(url) { + return false + } + } + } + + // only + if len(this.OnlyURLPatterns) > 0 { + for _, pattern := range this.OnlyURLPatterns { + if pattern.Match(url) { + return true + } + } + return false + } + + return true +}