优化鉴权

This commit is contained in:
刘祥超
2022-08-30 11:24:07 +08:00
parent b54c44bfb9
commit 5c885050fa
11 changed files with 3526 additions and 3384 deletions

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
*_plus.go
*_plus_test.go

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,99 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package serverconfigs
import (
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/iwind/TeaGo/lists"
"net/http"
"path/filepath"
"regexp"
"strings"
)
var httpAuthTimestampRegexp = regexp.MustCompile(`^\d{10}$`)
type HTTPAuthBaseMethod struct {
Exts []string `json:"exts"`
Domains []string `json:"domains"`
}
func (this *HTTPAuthBaseMethod) SetExts(exts []string) {
this.Exts = exts
}
func (this *HTTPAuthBaseMethod) SetDomains(domains []string) {
this.Domains = domains
}
func (this *HTTPAuthBaseMethod) removeQueryArgs(query string, args []string) string {
var pieces = strings.Split(query, "&")
var result = []string{}
Loop:
for _, piece := range pieces {
for _, arg := range args {
if strings.HasPrefix(piece, arg+"=") {
continue Loop
}
}
result = append(result, piece)
}
return strings.Join(result, "&")
}
func (this *HTTPAuthBaseMethod) matchTimestamp(timestamp string) bool {
return httpAuthTimestampRegexp.MatchString(timestamp)
}
func (this *HTTPAuthBaseMethod) MatchRequest(req *http.Request) bool {
if len(this.Exts) > 0 {
var ext = filepath.Ext(req.URL.Path)
if len(ext) == 0 {
return false
}
// ext中包含点符号
ext = strings.ToLower(ext)
if !lists.ContainsString(this.Exts, ext) {
return false
}
}
if len(this.Domains) > 0 {
var domain = req.Host
if len(domain) == 0 {
return false
}
if !configutils.MatchDomains(this.Domains, domain) {
return false
}
}
return true
}
// cleanPath 清理Path中的多余的字符
func (this *HTTPAuthBaseMethod) cleanPath(path string) string {
var l = len(path)
if l == 0 {
return "/"
}
var result = []byte{'/'}
var isSlash = true
for i := 0; i < l; i++ {
if path[i] == '?' {
result = append(result, path[i:]...)
break
}
if path[i] == '\\' || path[i] == '/' {
if !isSlash {
isSlash = true
result = append(result, '/')
}
} else {
isSlash = false
result = append(result, path[i])
}
}
return string(result)
}

View File

@@ -19,6 +19,8 @@ func (this *HTTPAuthBasicMethodUser) Validate(password string) (bool, error) {
// HTTPAuthBasicMethod BasicAuth方法定义
type HTTPAuthBasicMethod struct {
HTTPAuthBaseMethod
Users []*HTTPAuthBasicMethodUser `json:"users"`
Realm string `json:"realm"`
Charset string `json:"charset"`
@@ -30,7 +32,7 @@ func NewHTTPAuthBasicMethod() *HTTPAuthBasicMethod {
return &HTTPAuthBasicMethod{}
}
func (this *HTTPAuthBasicMethod) Init(params map[string]interface{}) error {
func (this *HTTPAuthBasicMethod) Init(params map[string]any) error {
this.userMap = map[string]*HTTPAuthBasicMethodUser{}
paramsJSON, err := json.Marshal(params)
@@ -49,14 +51,15 @@ func (this *HTTPAuthBasicMethod) Init(params map[string]interface{}) error {
return nil
}
func (this *HTTPAuthBasicMethod) Filter(req *http.Request, doSubReq func(subReq *http.Request) (status int, err error), formatter func(string) string) (bool, error) {
func (this *HTTPAuthBasicMethod) Filter(req *http.Request, doSubReq func(subReq *http.Request) (status int, err error), formatter func(string) string) (ok bool, newURI string, uriChanged bool, err error) {
username, password, ok := req.BasicAuth()
if !ok {
return false, nil
return false, "", false, nil
}
user, ok := this.userMap[username]
if !ok {
return false, nil
return false, "", false, nil
}
return user.Validate(password)
ok, err = user.Validate(password)
return ok, "", false, err
}

View File

@@ -11,7 +11,7 @@ import (
)
func TestHTTPAuthBasicMethodUser_Validate(t *testing.T) {
a := assert.NewAssertion(t)
var a = assert.NewAssertion(t)
{
user := &HTTPAuthBasicMethodUser{
@@ -37,8 +37,7 @@ func TestHTTPAuthBasicMethodUser_Validate(t *testing.T) {
}
func TestHTTPAuthBasicMethod_Filter(t *testing.T) {
method := &HTTPAuthBasicMethod{}
var method = &HTTPAuthBasicMethod{}
err := method.Init(map[string]interface{}{
"users": []maps.Map{
{

View File

@@ -7,8 +7,17 @@ import "net/http"
// HTTPAuthMethodInterface HTTP认证接口定义
type HTTPAuthMethodInterface interface {
// Init 初始化
Init(params map[string]interface{}) error
Init(params map[string]any) error
// MatchRequest 是否匹配请求
MatchRequest(req *http.Request) bool
// Filter 过滤
Filter(req *http.Request, subReqFunc func(subReq *http.Request) (status int, err error), formatter func(string) string) (bool, error)
Filter(req *http.Request, subReqFunc func(subReq *http.Request) (status int, err error), formatter func(string) string) (ok bool, newURI string, uriChanged bool, err error)
// SetExts 设置扩展名
SetExts(exts []string)
// SetDomains 设置域名
SetDomains(domains []string)
}

View File

@@ -21,6 +21,8 @@ var httpAuthSubRequestHTTPClient = &http.Client{
// HTTPAuthSubRequestMethod 使用URL认证
type HTTPAuthSubRequestMethod struct {
HTTPAuthBaseMethod
URL string `json:"url"`
Method string `json:"method"`
@@ -34,7 +36,7 @@ func NewHTTPAuthSubRequestMethod() *HTTPAuthSubRequestMethod {
}
// Init 初始化
func (this *HTTPAuthSubRequestMethod) Init(params map[string]interface{}) error {
func (this *HTTPAuthSubRequestMethod) Init(params map[string]any) error {
paramsJSON, err := json.Marshal(params)
if err != nil {
return err
@@ -58,7 +60,7 @@ func (this *HTTPAuthSubRequestMethod) Init(params map[string]interface{}) error
}
// Filter 过滤
func (this *HTTPAuthSubRequestMethod) Filter(req *http.Request, doSubReq func(subReq *http.Request) (status int, err error), formatter func(string) string) (bool, error) {
func (this *HTTPAuthSubRequestMethod) Filter(req *http.Request, doSubReq func(subReq *http.Request) (status int, err error), formatter func(string) string) (ok bool, newURI string, uriChanged bool, err error) {
var method = this.Method
if len(method) == 0 {
method = req.Method
@@ -78,7 +80,7 @@ func (this *HTTPAuthSubRequestMethod) Filter(req *http.Request, doSubReq func(su
}
newReq, err := http.NewRequest(method, url, nil)
if err != nil {
return false, err
return false, "", false, err
}
for k, v := range req.Header {
if k != "Connection" {
@@ -89,20 +91,20 @@ func (this *HTTPAuthSubRequestMethod) Filter(req *http.Request, doSubReq func(su
if !this.isFullURL {
status, err := doSubReq(newReq)
if err != nil {
return false, err
return false, "", false, err
}
return status >= 200 && status < 300, nil
return status >= 200 && status < 300, "", false, nil
}
// TODO 需要将Header和StatusCode、ResponseBody输出到客户端
newReq.Header.Set("Referer", scheme+"://"+host+req.URL.RequestURI())
resp, err := httpAuthSubRequestHTTPClient.Do(newReq)
if err != nil {
return false, err
return false, "", false, err
}
defer func() {
_ = resp.Body.Close()
}()
return resp.StatusCode >= 200 && resp.StatusCode < 300, nil
return resp.StatusCode >= 200 && resp.StatusCode < 300, "", false, nil
}

View File

@@ -24,7 +24,7 @@ func TestHTTPAuthRequestMethod_Filter(t *testing.T) {
}
req.Header.Set("Hello", "World")
req.Header.Set("User-Agent", "GoEdge/1.0")
b, err := method.Filter(req, func(subReq *http.Request) (status int, err error) {
b, uri, uriChanged, err := method.Filter(req, func(subReq *http.Request) (status int, err error) {
return
}, func(s string) string {
return s
@@ -32,7 +32,7 @@ func TestHTTPAuthRequestMethod_Filter(t *testing.T) {
if err != nil {
t.Fatal(err)
}
t.Log("result:", b)
t.Log("result:", b, uri, uriChanged)
}
func TestHTTPAuthRequestMethod_Filter_Path(t *testing.T) {
@@ -50,7 +50,7 @@ func TestHTTPAuthRequestMethod_Filter_Path(t *testing.T) {
}
req.Header.Set("Hello", "World")
req.Header.Set("User-Agent", "GoEdge/1.0")
b, err := method.Filter(req, func(subReq *http.Request) (status int, err error) {
b, uri, uriChanged, err := method.Filter(req, func(subReq *http.Request) (status int, err error) {
status = rands.Int(200, 400)
t.Log("execute sub request:", subReq.URL, status)
return
@@ -60,5 +60,5 @@ func TestHTTPAuthRequestMethod_Filter_Path(t *testing.T) {
if err != nil {
t.Fatal(err)
}
t.Log("result:", b)
t.Log("result:", b, uri, uriChanged)
}

View File

@@ -1,4 +1,5 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !plus
package serverconfigs

View File

@@ -3,7 +3,6 @@
package serverconfigs
import (
"errors"
"net/http"
)
@@ -18,31 +17,19 @@ type HTTPAuthPolicy struct {
method HTTPAuthMethodInterface
}
// Init 初始化
func (this *HTTPAuthPolicy) Init() error {
switch this.Type {
case HTTPAuthTypeBasicAuth:
this.method = NewHTTPAuthBasicMethod()
case HTTPAuthTypeSubRequest:
this.method = NewHTTPAuthSubRequestMethod()
}
// MatchRequest 检查是否匹配请求
func (this *HTTPAuthPolicy) MatchRequest(req *http.Request) bool {
if this.method == nil {
return errors.New("unknown auth method '" + this.Type + "'")
return false
}
err := this.method.Init(this.Params)
if err != nil {
return err
}
return nil
return this.method.MatchRequest(req)
}
// Filter 过滤
func (this *HTTPAuthPolicy) Filter(req *http.Request, subReqFunc func(subReq *http.Request) (status int, err error), formatter func(string) string) (bool, error) {
func (this *HTTPAuthPolicy) Filter(req *http.Request, subReqFunc func(subReq *http.Request) (status int, err error), formatter func(string) string) (ok bool, newURI string, uriChanged bool, err error) {
if this.method == nil {
// 如果设置正确的方法,我们直接允许请求
return true, nil
return true, "", false, nil
}
return this.method.Filter(req, subReqFunc, formatter)
}

View File

@@ -0,0 +1,28 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !plus
package serverconfigs
import (
"errors"
)
// Init 初始化
func (this *HTTPAuthPolicy) Init() error {
switch this.Type {
case HTTPAuthTypeBasicAuth:
this.method = NewHTTPAuthBasicMethod()
case HTTPAuthTypeSubRequest:
this.method = NewHTTPAuthSubRequestMethod()
}
if this.method == nil {
return errors.New("unknown auth method '" + this.Type + "'")
}
err := this.method.Init(this.Params)
if err != nil {
return err
}
return nil
}