mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 16:00:25 +08:00 
			
		
		
		
	实现WAF
This commit is contained in:
		
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							@@ -7,6 +7,7 @@ replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
 | 
				
			|||||||
require (
 | 
					require (
 | 
				
			||||||
	github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect
 | 
						github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect
 | 
				
			||||||
	github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000
 | 
						github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000
 | 
				
			||||||
 | 
						github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f
 | 
				
			||||||
	github.com/dchest/siphash v1.2.1
 | 
						github.com/dchest/siphash v1.2.1
 | 
				
			||||||
	github.com/go-ole/go-ole v1.2.4 // indirect
 | 
						github.com/go-ole/go-ole v1.2.4 // indirect
 | 
				
			||||||
	github.com/go-yaml/yaml v2.1.0+incompatible
 | 
						github.com/go-yaml/yaml v2.1.0+incompatible
 | 
				
			||||||
@@ -14,4 +15,5 @@ require (
 | 
				
			|||||||
	github.com/shirou/gopsutil v2.20.9+incompatible
 | 
						github.com/shirou/gopsutil v2.20.9+incompatible
 | 
				
			||||||
	golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7
 | 
						golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7
 | 
				
			||||||
	google.golang.org/grpc v1.32.0
 | 
						google.golang.org/grpc v1.32.0
 | 
				
			||||||
 | 
						gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							@@ -14,6 +14,8 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
 | 
				
			|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 | 
					github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 | 
				
			||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 | 
					github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 | 
				
			||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 | 
					github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 | 
				
			||||||
 | 
					github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f h1:q/DpyjJjZs94bziQ7YkBmIlpqbVP7yw179rnzoNVX1M=
 | 
				
			||||||
 | 
					github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f/go.mod h1:QGrK8vMWWHQYQ3QU9bw9Y9OPNfxccGzfb41qjvVeXtY=
 | 
				
			||||||
github.com/dchest/siphash v1.2.1 h1:4cLinnzVJDKxTCl9B01807Yiy+W7ZzVHj/KIroQRvT4=
 | 
					github.com/dchest/siphash v1.2.1 h1:4cLinnzVJDKxTCl9B01807Yiy+W7ZzVHj/KIroQRvT4=
 | 
				
			||||||
github.com/dchest/siphash v1.2.1/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4=
 | 
					github.com/dchest/siphash v1.2.1/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4=
 | 
				
			||||||
github.com/dgryski/go-rendezvous v0.0.0-20200624174652-8d2f3be8b2d9/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
 | 
					github.com/dgryski/go-rendezvous v0.0.0-20200624174652-8d2f3be8b2d9/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										2
									
								
								internal/grids/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								internal/grids/README.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
				
			|||||||
 | 
					# Memory Grid
 | 
				
			||||||
 | 
					Cache items in memory, using partitions and LRU.
 | 
				
			||||||
							
								
								
									
										186
									
								
								internal/grids/cell.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								internal/grids/cell.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,186 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"math"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Cell struct {
 | 
				
			||||||
 | 
						LimitSize  int64
 | 
				
			||||||
 | 
						LimitCount int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						mapping    map[uint64]*Item // key => item
 | 
				
			||||||
 | 
						list       *List            // { item1, item2, ... }
 | 
				
			||||||
 | 
						totalBytes int64
 | 
				
			||||||
 | 
						locker     sync.RWMutex
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewCell() *Cell {
 | 
				
			||||||
 | 
						return &Cell{
 | 
				
			||||||
 | 
							mapping: map[uint64]*Item{},
 | 
				
			||||||
 | 
							list:    NewList(),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Cell) Write(hashKey uint64, item *Item) {
 | 
				
			||||||
 | 
						if item == nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						oldItem, ok := this.mapping[hashKey]
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							this.list.Remove(oldItem)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if this.LimitSize > 0 {
 | 
				
			||||||
 | 
								this.totalBytes -= oldItem.Size()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// limit count
 | 
				
			||||||
 | 
						if this.LimitCount > 0 && len(this.mapping) >= this.LimitCount {
 | 
				
			||||||
 | 
							this.locker.Unlock()
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// trim memory
 | 
				
			||||||
 | 
						size := item.Size()
 | 
				
			||||||
 | 
						shouldTrim := false
 | 
				
			||||||
 | 
						if this.LimitSize > 0 && this.LimitSize < this.totalBytes+size {
 | 
				
			||||||
 | 
							this.Trim()
 | 
				
			||||||
 | 
							shouldTrim = true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// compare again
 | 
				
			||||||
 | 
						if shouldTrim {
 | 
				
			||||||
 | 
							if this.LimitSize > 0 && this.LimitSize < this.totalBytes+size {
 | 
				
			||||||
 | 
								this.locker.Unlock()
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.totalBytes += size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.list.Add(item)
 | 
				
			||||||
 | 
						this.mapping[hashKey] = item
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.locker.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Cell) Increase64(key []byte, expireAt int64, hashKey uint64, delta int64) (result int64) {
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						item, ok := this.mapping[hashKey]
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							// reset to zero if expired
 | 
				
			||||||
 | 
							if item.ExpireAt < time.Now().Unix() {
 | 
				
			||||||
 | 
								item.ValueInt64 = 0
 | 
				
			||||||
 | 
								item.ExpireAt = expireAt
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							item.IncreaseInt64(delta)
 | 
				
			||||||
 | 
							result = item.ValueInt64
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							item := NewItem(key, ItemInt64)
 | 
				
			||||||
 | 
							item.ValueInt64 = delta
 | 
				
			||||||
 | 
							item.ExpireAt = expireAt
 | 
				
			||||||
 | 
							this.mapping[hashKey] = item
 | 
				
			||||||
 | 
							result = delta
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.locker.Unlock()
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Cell) Read(hashKey uint64) *Item {
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						item, ok := this.mapping[hashKey]
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							this.list.Remove(item)
 | 
				
			||||||
 | 
							this.list.Add(item)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							this.locker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if item.ExpireAt < time.Now().Unix() {
 | 
				
			||||||
 | 
								return nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return item
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.locker.Unlock()
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Cell) Stat() *CellStat {
 | 
				
			||||||
 | 
						this.locker.RLock()
 | 
				
			||||||
 | 
						defer this.locker.RUnlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &CellStat{
 | 
				
			||||||
 | 
							TotalBytes: this.totalBytes,
 | 
				
			||||||
 | 
							CountItems: len(this.mapping),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// trim NOT ACTIVE items
 | 
				
			||||||
 | 
					// should called in locker context
 | 
				
			||||||
 | 
					func (this *Cell) Trim() {
 | 
				
			||||||
 | 
						l := len(this.mapping)
 | 
				
			||||||
 | 
						if l == 0 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						inactiveSize := int(math.Ceil(float64(l) / 10)) // trim 10% items
 | 
				
			||||||
 | 
						this.list.Range(func(item *Item) (goNext bool) {
 | 
				
			||||||
 | 
							inactiveSize--
 | 
				
			||||||
 | 
							delete(this.mapping, item.HashKey())
 | 
				
			||||||
 | 
							this.list.Remove(item)
 | 
				
			||||||
 | 
							this.totalBytes -= item.Size()
 | 
				
			||||||
 | 
							return inactiveSize > 0
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Cell) Delete(hashKey uint64) {
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						item, ok := this.mapping[hashKey]
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							delete(this.mapping, hashKey)
 | 
				
			||||||
 | 
							this.list.Remove(item)
 | 
				
			||||||
 | 
							this.totalBytes -= item.Size()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.locker.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// range all items in the cell
 | 
				
			||||||
 | 
					func (this *Cell) Range(f func(item *Item)) {
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						for _, item := range this.mapping {
 | 
				
			||||||
 | 
							f(item)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.locker.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Cell) Recycle() {
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						if len(this.mapping) == 0 {
 | 
				
			||||||
 | 
							this.locker.Unlock()
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						timestamp := time.Now().Unix()
 | 
				
			||||||
 | 
						for key, item := range this.mapping {
 | 
				
			||||||
 | 
							if item.ExpireAt < timestamp {
 | 
				
			||||||
 | 
								delete(this.mapping, key)
 | 
				
			||||||
 | 
								this.list.Remove(item)
 | 
				
			||||||
 | 
								this.totalBytes -= item.Size()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.locker.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Cell) Reset() {
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						this.list.Reset()
 | 
				
			||||||
 | 
						this.mapping = map[uint64]*Item{}
 | 
				
			||||||
 | 
						this.totalBytes = 0
 | 
				
			||||||
 | 
						this.locker.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										6
									
								
								internal/grids/cell_stat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								internal/grids/cell_stat.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type CellStat struct {
 | 
				
			||||||
 | 
						TotalBytes int64
 | 
				
			||||||
 | 
						CountItems int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										214
									
								
								internal/grids/cell_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										214
									
								
								internal/grids/cell_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,214 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"runtime"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCell_List(t *testing.T) {
 | 
				
			||||||
 | 
						cell := NewCell()
 | 
				
			||||||
 | 
						cell.Write(1, &Item{
 | 
				
			||||||
 | 
							ValueInt64: 1,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						cell.Write(2, &Item{
 | 
				
			||||||
 | 
							ValueInt64: 2,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						cell.Write(3, &Item{
 | 
				
			||||||
 | 
							ValueInt64: 3,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							t.Log("====")
 | 
				
			||||||
 | 
							l := cell.list
 | 
				
			||||||
 | 
							for e := l.head; e != nil; e = e.Next {
 | 
				
			||||||
 | 
								t.Log("element:", e.ValueInt64)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cell.Write(1, &Item{
 | 
				
			||||||
 | 
							ValueInt64: 1,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						cell.Write(3, &Item{
 | 
				
			||||||
 | 
							ValueInt64: 3,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						cell.Write(3, &Item{
 | 
				
			||||||
 | 
							ValueInt64: 3,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						cell.Write(2, &Item{
 | 
				
			||||||
 | 
							ValueInt64: 2,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						cell.Delete(2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							t.Log("====")
 | 
				
			||||||
 | 
							l := cell.list
 | 
				
			||||||
 | 
							for e := l.head; e != nil; e = e.Next {
 | 
				
			||||||
 | 
								t.Log("element:", e.ValueInt64)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, m := range cell.mapping {
 | 
				
			||||||
 | 
							t.Log(m.ValueInt64)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCell_LimitSize(t *testing.T) {
 | 
				
			||||||
 | 
						cell := NewCell()
 | 
				
			||||||
 | 
						cell.LimitSize = 1024
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := int64(0); i < 100; i ++ {
 | 
				
			||||||
 | 
							key := []byte(fmt.Sprintf("%d", i))
 | 
				
			||||||
 | 
							cell.Write(HashKey(key), &Item{
 | 
				
			||||||
 | 
								Key:        key,
 | 
				
			||||||
 | 
								ValueInt64: i,
 | 
				
			||||||
 | 
								Type:       ItemInt64,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log("totalBytes:", cell.totalBytes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							t.Log("====")
 | 
				
			||||||
 | 
							l := cell.list
 | 
				
			||||||
 | 
							s := []string{}
 | 
				
			||||||
 | 
							for e := l.head; e != nil; e = e.Next {
 | 
				
			||||||
 | 
								s = append(s, fmt.Sprintf("%d", e.ValueInt64))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							t.Log("{ " + strings.Join(s, ", ") + " }")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log("mapping:", len(cell.mapping))
 | 
				
			||||||
 | 
						s := []string{}
 | 
				
			||||||
 | 
						for _, item := range cell.mapping {
 | 
				
			||||||
 | 
							s = append(s, fmt.Sprintf("%d", item.ValueInt64))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log("{ " + strings.Join(s, ", ") + " }")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCell_MemoryUsage(t *testing.T) {
 | 
				
			||||||
 | 
						//runtime.GOMAXPROCS(4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cell := NewCell()
 | 
				
			||||||
 | 
						cell.LimitSize = 1024 * 1024 * 1024 * 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						before := time.Now()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						wg := sync.WaitGroup{}
 | 
				
			||||||
 | 
						wg.Add(4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for j := 0; j < 4; j ++ {
 | 
				
			||||||
 | 
							go func(j int) {
 | 
				
			||||||
 | 
								start := j * 50 * 10000
 | 
				
			||||||
 | 
								for i := start; i < start+50*10000; i ++ {
 | 
				
			||||||
 | 
									key := []byte(strconv.Itoa(i) + "VERY_LONG_STRING")
 | 
				
			||||||
 | 
									cell.Write(HashKey(key), &Item{
 | 
				
			||||||
 | 
										Key:        key,
 | 
				
			||||||
 | 
										ValueInt64: int64(i),
 | 
				
			||||||
 | 
										Type:       ItemInt64,
 | 
				
			||||||
 | 
									})
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								wg.Done()
 | 
				
			||||||
 | 
							}(j)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						wg.Wait()
 | 
				
			||||||
 | 
						t.Log("items:", len(cell.mapping))
 | 
				
			||||||
 | 
						t.Log(time.Since(before).Seconds(), "s", "totalBytes:", cell.totalBytes/1024/1024, "M")
 | 
				
			||||||
 | 
						//time.Sleep(10 * time.Second)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkCell_Write(b *testing.B) {
 | 
				
			||||||
 | 
						runtime.GOMAXPROCS(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cell := NewCell()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i ++ {
 | 
				
			||||||
 | 
							key := []byte(strconv.Itoa(i) + "_LONG_KEY_LONG_KEY_LONG_KEY_LONG_KEY")
 | 
				
			||||||
 | 
							cell.Write(HashKey(key), &Item{
 | 
				
			||||||
 | 
								Key:        key,
 | 
				
			||||||
 | 
								ValueInt64: int64(i),
 | 
				
			||||||
 | 
								Type:       ItemInt64,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						b.Log("items:", len(cell.mapping))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCell_Read(t *testing.T) {
 | 
				
			||||||
 | 
						cell := NewCell()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cell.Write(1, &Item{
 | 
				
			||||||
 | 
							ValueInt64: 1,
 | 
				
			||||||
 | 
							ExpireAt:   time.Now().Unix() + 3600,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						cell.Write(2, &Item{
 | 
				
			||||||
 | 
							ValueInt64: 2,
 | 
				
			||||||
 | 
							ExpireAt:   time.Now().Unix() + 3600,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						cell.Write(3, &Item{
 | 
				
			||||||
 | 
							ValueInt64: 3,
 | 
				
			||||||
 | 
							ExpireAt:   time.Now().Unix() + 3600,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							s := []string{}
 | 
				
			||||||
 | 
							cell.list.Range(func(item *Item) (goNext bool) {
 | 
				
			||||||
 | 
								s = append(s, fmt.Sprintf("%d", item.ValueInt64))
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							t.Log("before:", s)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log(cell.Read(1).ValueInt64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							s := []string{}
 | 
				
			||||||
 | 
							cell.list.Range(func(item *Item) (goNext bool) {
 | 
				
			||||||
 | 
								s = append(s, fmt.Sprintf("%d", item.ValueInt64))
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							t.Log("after:", s)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log(cell.Read(2).ValueInt64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							s := []string{}
 | 
				
			||||||
 | 
							cell.list.Range(func(item *Item) (goNext bool) {
 | 
				
			||||||
 | 
								s = append(s, fmt.Sprintf("%d", item.ValueInt64))
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							t.Log("after:", s)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCell_Recycle(t *testing.T) {
 | 
				
			||||||
 | 
						cell := NewCell()
 | 
				
			||||||
 | 
						cell.Write(1, &Item{
 | 
				
			||||||
 | 
							ValueInt64: 1,
 | 
				
			||||||
 | 
							ExpireAt:   time.Now().Unix() - 1,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cell.Write(2, &Item{
 | 
				
			||||||
 | 
							ValueInt64: 2,
 | 
				
			||||||
 | 
							ExpireAt:   time.Now().Unix() + 1,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cell.Recycle()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							s := []string{}
 | 
				
			||||||
 | 
							cell.list.Range(func(item *Item) (goNext bool) {
 | 
				
			||||||
 | 
								s = append(s, fmt.Sprintf("%d", item.ValueInt64))
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							t.Log("after:", s)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log(cell.list.Len(), cell.totalBytes)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										225
									
								
								internal/grids/grid.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										225
									
								
								internal/grids/grid.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,225 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"compress/gzip"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/timers"
 | 
				
			||||||
 | 
						"math"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Memory Cache Grid
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// |           Grid                 |
 | 
				
			||||||
 | 
					// | cell1, cell2, ..., cell1024    |
 | 
				
			||||||
 | 
					// | item1, item2, ..., item1000000 |
 | 
				
			||||||
 | 
					type Grid struct {
 | 
				
			||||||
 | 
						cells      []*Cell
 | 
				
			||||||
 | 
						countCells uint64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						recycleIndex    int
 | 
				
			||||||
 | 
						recycleLooper   *timers.Looper
 | 
				
			||||||
 | 
						recycleInterval int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						gzipLevel int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						limitSize  int64
 | 
				
			||||||
 | 
						limitCount int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewGrid(countCells int, opt ...interface{}) *Grid {
 | 
				
			||||||
 | 
						grid := &Grid{
 | 
				
			||||||
 | 
							recycleIndex: -1,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, o := range opt {
 | 
				
			||||||
 | 
							switch x := o.(type) {
 | 
				
			||||||
 | 
							case *CompressOpt:
 | 
				
			||||||
 | 
								grid.gzipLevel = x.Level
 | 
				
			||||||
 | 
							case *LimitSizeOpt:
 | 
				
			||||||
 | 
								grid.limitSize = x.Size
 | 
				
			||||||
 | 
							case *LimitCountOpt:
 | 
				
			||||||
 | 
								grid.limitCount = x.Count
 | 
				
			||||||
 | 
							case *RecycleIntervalOpt:
 | 
				
			||||||
 | 
								grid.recycleInterval = x.Interval
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cells := []*Cell{}
 | 
				
			||||||
 | 
						if countCells <= 0 {
 | 
				
			||||||
 | 
							countCells = 1
 | 
				
			||||||
 | 
						} else if countCells > 100*10000 {
 | 
				
			||||||
 | 
							countCells = 100 * 10000
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for i := 0; i < countCells; i++ {
 | 
				
			||||||
 | 
							cell := NewCell()
 | 
				
			||||||
 | 
							cell.LimitSize = int64(math.Floor(float64(grid.limitSize) / float64(countCells)))
 | 
				
			||||||
 | 
							cell.LimitCount = int(math.Floor(float64(grid.limitCount)) / float64(countCells))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							cells = append(cells, cell)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						grid.cells = cells
 | 
				
			||||||
 | 
						grid.countCells = uint64(len(cells))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						grid.recycleTimer()
 | 
				
			||||||
 | 
						return grid
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// get all cells in the grid
 | 
				
			||||||
 | 
					func (this *Grid) Cells() []*Cell {
 | 
				
			||||||
 | 
						return this.cells
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Grid) WriteItem(item *Item) {
 | 
				
			||||||
 | 
						if this.countCells <= 0 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						hashKey := item.HashKey()
 | 
				
			||||||
 | 
						this.cellForHashKey(hashKey).Write(hashKey, item)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Grid) WriteInt64(key []byte, value int64, lifeSeconds int64) {
 | 
				
			||||||
 | 
						this.WriteItem(&Item{
 | 
				
			||||||
 | 
							Key:        key,
 | 
				
			||||||
 | 
							Type:       ItemInt64,
 | 
				
			||||||
 | 
							ValueInt64: value,
 | 
				
			||||||
 | 
							ExpireAt:   time.Now().Unix() + lifeSeconds,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Grid) IncreaseInt64(key []byte, delta int64, lifeSeconds int64) (result int64) {
 | 
				
			||||||
 | 
						hashKey := HashKey(key)
 | 
				
			||||||
 | 
						return this.cellForHashKey(hashKey).Increase64(key, time.Now().Unix()+lifeSeconds, hashKey, delta)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Grid) WriteString(key []byte, value string, lifeSeconds int64) {
 | 
				
			||||||
 | 
						this.WriteBytes(key, []byte(value), lifeSeconds)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Grid) WriteBytes(key []byte, value []byte, lifeSeconds int64) {
 | 
				
			||||||
 | 
						isCompressed := false
 | 
				
			||||||
 | 
						if this.gzipLevel != gzip.NoCompression {
 | 
				
			||||||
 | 
							buf := bytes.NewBuffer([]byte{})
 | 
				
			||||||
 | 
							writer, err := gzip.NewWriterLevel(buf, this.gzipLevel)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logs.Error(err)
 | 
				
			||||||
 | 
								this.WriteItem(&Item{
 | 
				
			||||||
 | 
									Key:        key,
 | 
				
			||||||
 | 
									Type:       ItemBytes,
 | 
				
			||||||
 | 
									ValueBytes: value,
 | 
				
			||||||
 | 
									ExpireAt:   time.Now().Unix() + lifeSeconds,
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							_, err = writer.Write([]byte(value))
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logs.Error(err)
 | 
				
			||||||
 | 
								this.WriteItem(&Item{
 | 
				
			||||||
 | 
									Key:        key,
 | 
				
			||||||
 | 
									Type:       ItemBytes,
 | 
				
			||||||
 | 
									ValueBytes: value,
 | 
				
			||||||
 | 
									ExpireAt:   time.Now().Unix() + lifeSeconds,
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							err = writer.Close()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logs.Error(err)
 | 
				
			||||||
 | 
								this.WriteItem(&Item{
 | 
				
			||||||
 | 
									Key:        key,
 | 
				
			||||||
 | 
									Type:       ItemBytes,
 | 
				
			||||||
 | 
									ValueBytes: value,
 | 
				
			||||||
 | 
									ExpireAt:   time.Now().Unix() + lifeSeconds,
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							value = buf.Bytes()
 | 
				
			||||||
 | 
							isCompressed = true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.WriteItem(&Item{
 | 
				
			||||||
 | 
							Key:          key,
 | 
				
			||||||
 | 
							Type:         ItemBytes,
 | 
				
			||||||
 | 
							ValueBytes:   value,
 | 
				
			||||||
 | 
							ExpireAt:     time.Now().Unix() + lifeSeconds,
 | 
				
			||||||
 | 
							IsCompressed: isCompressed,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Grid) WriteInterface(key []byte, value interface{}, lifeSeconds int64) {
 | 
				
			||||||
 | 
						this.WriteItem(&Item{
 | 
				
			||||||
 | 
							Key:            key,
 | 
				
			||||||
 | 
							Type:           ItemInterface,
 | 
				
			||||||
 | 
							ValueInterface: value,
 | 
				
			||||||
 | 
							ExpireAt:       time.Now().Unix() + lifeSeconds,
 | 
				
			||||||
 | 
							IsCompressed:   false,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Grid) Read(key []byte) *Item {
 | 
				
			||||||
 | 
						if this.countCells <= 0 {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						hashKey := HashKey(key)
 | 
				
			||||||
 | 
						return this.cellForHashKey(hashKey).Read(hashKey)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Grid) Stat() *Stat {
 | 
				
			||||||
 | 
						stat := &Stat{}
 | 
				
			||||||
 | 
						for _, cell := range this.cells {
 | 
				
			||||||
 | 
							cellStat := cell.Stat()
 | 
				
			||||||
 | 
							stat.CountItems += cellStat.CountItems
 | 
				
			||||||
 | 
							stat.TotalBytes += cellStat.TotalBytes
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return stat
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Grid) Delete(key []byte) {
 | 
				
			||||||
 | 
						if this.countCells <= 0 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						hashKey := HashKey(key)
 | 
				
			||||||
 | 
						this.cellForHashKey(hashKey).Delete(hashKey)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Grid) Reset() {
 | 
				
			||||||
 | 
						for _, cell := range this.cells {
 | 
				
			||||||
 | 
							cell.Reset()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Grid) Destroy() {
 | 
				
			||||||
 | 
						if this.recycleLooper != nil {
 | 
				
			||||||
 | 
							this.recycleLooper.Stop()
 | 
				
			||||||
 | 
							this.recycleLooper = nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.cells = nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Grid) cellForHashKey(hashKey uint64) *Cell {
 | 
				
			||||||
 | 
						if hashKey < 0 {
 | 
				
			||||||
 | 
							return this.cells[-hashKey%this.countCells]
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							return this.cells[hashKey%this.countCells]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Grid) recycleTimer() {
 | 
				
			||||||
 | 
						duration := 1 * time.Minute
 | 
				
			||||||
 | 
						if this.recycleInterval > 0 {
 | 
				
			||||||
 | 
							duration = time.Duration(this.recycleInterval) * time.Second
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.recycleLooper = timers.Loop(duration, func(looper *timers.Looper) {
 | 
				
			||||||
 | 
							if this.countCells == 0 {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							this.recycleIndex++
 | 
				
			||||||
 | 
							if this.recycleIndex > int(this.countCells-1) {
 | 
				
			||||||
 | 
								this.recycleIndex = 0
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							this.cells[this.recycleIndex].Recycle()
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										204
									
								
								internal/grids/grid_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										204
									
								
								internal/grids/grid_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,204 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"compress/gzip"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"runtime"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMemoryGrid_Write(t *testing.T) {
 | 
				
			||||||
 | 
						grid := NewGrid(5, NewRecycleIntervalOpt(2), NewLimitSizeOpt(10240))
 | 
				
			||||||
 | 
						t.Log("123456:", grid.Read([]byte("123456")))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						grid.WriteInt64([]byte("abc"), 1, 5)
 | 
				
			||||||
 | 
						t.Log(grid.Read([]byte("abc")).ValueInt64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						grid.WriteString([]byte("abc"), "123", 5)
 | 
				
			||||||
 | 
						t.Log(string(grid.Read([]byte("abc")).Bytes()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						grid.WriteBytes([]byte("abc"), []byte("123"), 5)
 | 
				
			||||||
 | 
						t.Log(grid.Read([]byte("abc")).Bytes())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						grid.Delete([]byte("abc"))
 | 
				
			||||||
 | 
						t.Log(grid.Read([]byte("abc")))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := 0; i < 100; i++ {
 | 
				
			||||||
 | 
							grid.WriteInt64([]byte(fmt.Sprintf("%d", i)), 123, 1)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log("before recycle:")
 | 
				
			||||||
 | 
						for index, cell := range grid.cells {
 | 
				
			||||||
 | 
							t.Log("cell:", index, len(cell.mapping), "items")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						time.Sleep(3 * time.Second)
 | 
				
			||||||
 | 
						t.Log("after recycle:")
 | 
				
			||||||
 | 
						for index, cell := range grid.cells {
 | 
				
			||||||
 | 
							t.Log("cell:", index, len(cell.mapping), "items")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						grid.Destroy()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMemoryGrid_Write_LimitCount(t *testing.T) {
 | 
				
			||||||
 | 
						grid := NewGrid(2, NewLimitCountOpt(10))
 | 
				
			||||||
 | 
						for i := 0; i < 100; i++ {
 | 
				
			||||||
 | 
							grid.WriteInt64([]byte(strconv.Itoa(i)), int64(i), 30)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(grid.Stat().CountItems, "items")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMemoryGrid_Compress(t *testing.T) {
 | 
				
			||||||
 | 
						grid := NewGrid(5, NewCompressOpt(1))
 | 
				
			||||||
 | 
						grid.WriteString([]byte("hello"), strings.Repeat("abcd", 10240), 30)
 | 
				
			||||||
 | 
						t.Log(len(string(grid.Read([]byte("hello")).String())))
 | 
				
			||||||
 | 
						t.Log(len(grid.Read([]byte("hello")).ValueBytes))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkMemoryGrid_Performance(b *testing.B) {
 | 
				
			||||||
 | 
						grid := NewGrid(1024)
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
 | 
							grid.WriteInt64([]byte("key:"+strconv.Itoa(i)), int64(i), 3600)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMemoryGrid_Performance(t *testing.T) {
 | 
				
			||||||
 | 
						runtime.GOMAXPROCS(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						grid := NewGrid(1024)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						now := time.Now()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s := []byte(strings.Repeat("abcd", 10*1024))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := 0; i < 100000; i++ {
 | 
				
			||||||
 | 
							grid.WriteBytes([]byte(fmt.Sprintf("key:%d_%d", i, 1)), s, 3600)
 | 
				
			||||||
 | 
							item := grid.Read([]byte(fmt.Sprintf("key:%d_%d", i, 1)))
 | 
				
			||||||
 | 
							if item != nil {
 | 
				
			||||||
 | 
								_ = item.String()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						countItems := 0
 | 
				
			||||||
 | 
						for _, cell := range grid.cells {
 | 
				
			||||||
 | 
							countItems += len(cell.mapping)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(countItems, "items")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log(time.Since(now).Seconds()*1000, "ms")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMemoryGrid_Performance_Concurrent(t *testing.T) {
 | 
				
			||||||
 | 
						//runtime.GOMAXPROCS(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						grid := NewGrid(1024)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						now := time.Now()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s := []byte(strings.Repeat("abcd", 10*1024))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						wg := sync.WaitGroup{}
 | 
				
			||||||
 | 
						wg.Add(runtime.NumCPU())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for c := 0; c < runtime.NumCPU(); c++ {
 | 
				
			||||||
 | 
							go func(c int) {
 | 
				
			||||||
 | 
								defer wg.Done()
 | 
				
			||||||
 | 
								for i := 0; i < 50000; i++ {
 | 
				
			||||||
 | 
									grid.WriteBytes([]byte(fmt.Sprintf("key:%d_%d", i, c)), s, 3600)
 | 
				
			||||||
 | 
									item := grid.Read([]byte(fmt.Sprintf("key:%d_%d", i, c)))
 | 
				
			||||||
 | 
									if item != nil {
 | 
				
			||||||
 | 
										_ = item.String()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}(c)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						wg.Wait()
 | 
				
			||||||
 | 
						countItems := 0
 | 
				
			||||||
 | 
						for _, cell := range grid.cells {
 | 
				
			||||||
 | 
							countItems += len(cell.mapping)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(countItems, "items")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log(time.Since(now).Seconds()*1000, "ms")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMemoryGrid_CompressPerformance(t *testing.T) {
 | 
				
			||||||
 | 
						runtime.GOMAXPROCS(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						grid := NewGrid(1024, NewCompressOpt(gzip.BestCompression))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						now := time.Now()
 | 
				
			||||||
 | 
						data := []byte(strings.Repeat("abcd", 1024))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := 0; i < 100000; i++ {
 | 
				
			||||||
 | 
							grid.WriteBytes([]byte(fmt.Sprintf("key:%d", i)), data, 3600)
 | 
				
			||||||
 | 
							item := grid.Read([]byte(fmt.Sprintf("key:%d", i+100)))
 | 
				
			||||||
 | 
							if item != nil {
 | 
				
			||||||
 | 
								_ = item.String()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						countItems := 0
 | 
				
			||||||
 | 
						for _, cell := range grid.cells {
 | 
				
			||||||
 | 
							countItems += len(cell.mapping)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(countItems, "items")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log(time.Since(now).Seconds()*1000, "ms")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMemoryGrid_IncreaseInt64(t *testing.T) {
 | 
				
			||||||
 | 
						grid := NewGrid(1024)
 | 
				
			||||||
 | 
						grid.WriteInt64([]byte("abc"), 123, 10)
 | 
				
			||||||
 | 
						grid.IncreaseInt64([]byte("abc"), 123, 10)
 | 
				
			||||||
 | 
						grid.IncreaseInt64([]byte("abc"), 123, 10)
 | 
				
			||||||
 | 
						item := grid.Read([]byte("abc"))
 | 
				
			||||||
 | 
						if item == nil {
 | 
				
			||||||
 | 
							t.Fatal("item == nil")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if item.ValueInt64 != 369 {
 | 
				
			||||||
 | 
							t.Fatal("not 369")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMemoryGrid_Destroy(t *testing.T) {
 | 
				
			||||||
 | 
						grid := NewGrid(1024)
 | 
				
			||||||
 | 
						grid.WriteInt64([]byte("abc"), 123, 10)
 | 
				
			||||||
 | 
						t.Log(grid.recycleLooper, grid.cells)
 | 
				
			||||||
 | 
						grid.Destroy()
 | 
				
			||||||
 | 
						t.Log(grid.recycleLooper, grid.cells)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if grid.recycleLooper != nil {
 | 
				
			||||||
 | 
							t.Fatal("looper != nil")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMemoryGrid_Recycle(t *testing.T) {
 | 
				
			||||||
 | 
						cell := NewCell()
 | 
				
			||||||
 | 
						timestamp := time.Now().Unix()
 | 
				
			||||||
 | 
						for i := 0; i < 300_0000; i++ {
 | 
				
			||||||
 | 
							cell.Write(uint64(i), &Item{
 | 
				
			||||||
 | 
								ExpireAt: timestamp - 30,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						before := time.Now()
 | 
				
			||||||
 | 
						cell.Recycle()
 | 
				
			||||||
 | 
						t.Log(time.Since(before).Seconds()*1000, "ms")
 | 
				
			||||||
 | 
						t.Log(len(cell.mapping))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						runtime.GC()
 | 
				
			||||||
 | 
						printMem(t)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func printMem(t *testing.T) {
 | 
				
			||||||
 | 
						mem := &runtime.MemStats{}
 | 
				
			||||||
 | 
						runtime.ReadMemStats(mem)
 | 
				
			||||||
 | 
						t.Log(mem.Alloc/1024/1024, "M")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										88
									
								
								internal/grids/item.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								internal/grids/item.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,88 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"compress/gzip"
 | 
				
			||||||
 | 
						"github.com/dchest/siphash"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"sync/atomic"
 | 
				
			||||||
 | 
						"unsafe"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ItemType = int8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						ItemInt64     = 1
 | 
				
			||||||
 | 
						ItemBytes     = 2
 | 
				
			||||||
 | 
						ItemInterface = 3
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func HashKey(key []byte) uint64 {
 | 
				
			||||||
 | 
						return siphash.Hash(0, 0, key)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Item struct {
 | 
				
			||||||
 | 
						Key            []byte
 | 
				
			||||||
 | 
						ExpireAt       int64
 | 
				
			||||||
 | 
						Type           ItemType
 | 
				
			||||||
 | 
						ValueInt64     int64
 | 
				
			||||||
 | 
						ValueBytes     []byte
 | 
				
			||||||
 | 
						ValueInterface interface{}
 | 
				
			||||||
 | 
						IsCompressed   bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// linked list
 | 
				
			||||||
 | 
						Prev *Item
 | 
				
			||||||
 | 
						Next *Item
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						size int64
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewItem(key []byte, dataType ItemType) *Item {
 | 
				
			||||||
 | 
						return &Item{
 | 
				
			||||||
 | 
							Key:  key,
 | 
				
			||||||
 | 
							Type: dataType,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Item) HashKey() uint64 {
 | 
				
			||||||
 | 
						return HashKey(this.Key)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Item) IncreaseInt64(delta int64) {
 | 
				
			||||||
 | 
						atomic.AddInt64(&this.ValueInt64, delta)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Item) Bytes() []byte {
 | 
				
			||||||
 | 
						if this.IsCompressed {
 | 
				
			||||||
 | 
							reader, err := gzip.NewReader(bytes.NewBuffer(this.ValueBytes))
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logs.Error(err)
 | 
				
			||||||
 | 
								return this.ValueBytes
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							buf := make([]byte, 256)
 | 
				
			||||||
 | 
							dataBuf := bytes.NewBuffer([]byte{})
 | 
				
			||||||
 | 
							for {
 | 
				
			||||||
 | 
								n, err := reader.Read(buf)
 | 
				
			||||||
 | 
								if n > 0 {
 | 
				
			||||||
 | 
									dataBuf.Write(buf[:n])
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									break
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return dataBuf.Bytes()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return this.ValueBytes
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Item) String() string {
 | 
				
			||||||
 | 
						return string(this.Bytes())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Item) Size() int64 {
 | 
				
			||||||
 | 
						if this.size == 0 {
 | 
				
			||||||
 | 
							this.size = int64(int(unsafe.Sizeof(*this)) + len(this.Key) + len(this.ValueBytes))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return this.size
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										69
									
								
								internal/grids/item_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								internal/grids/item_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,69 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"crypto/md5"
 | 
				
			||||||
 | 
						"github.com/dchest/siphash"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestItem_Size(t *testing.T) {
 | 
				
			||||||
 | 
						item := &Item{
 | 
				
			||||||
 | 
							ValueInt64: 1024,
 | 
				
			||||||
 | 
							Key:        []byte("123"),
 | 
				
			||||||
 | 
							ValueBytes: []byte("Hello, World"),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(item.Size())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkItem_Size(b *testing.B) {
 | 
				
			||||||
 | 
						item := &Item{
 | 
				
			||||||
 | 
							ValueInt64: 1024,
 | 
				
			||||||
 | 
							Key:        []byte("123"),
 | 
				
			||||||
 | 
							ValueBytes: []byte("Hello, World"),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i ++ {
 | 
				
			||||||
 | 
							_ = item.Size()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestItem_HashKey(t *testing.T) {
 | 
				
			||||||
 | 
						t.Log(HashKey([]byte("2")))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestItem_siphash(t *testing.T) {
 | 
				
			||||||
 | 
						result := siphash.Hash(0, 0, []byte("123456"))
 | 
				
			||||||
 | 
						t.Log(result)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestItem_unique(t *testing.T) {
 | 
				
			||||||
 | 
						m := map[uint64]bool{}
 | 
				
			||||||
 | 
						for i := 0; i < 1000*10000; i ++ {
 | 
				
			||||||
 | 
							s := "Hello,World,LONG KEY,LONG KEY,LONG KEY,LONG KEY" + strconv.Itoa(i)
 | 
				
			||||||
 | 
							result := siphash.Hash(0, 0, []byte(s))
 | 
				
			||||||
 | 
							_, ok := m[result]
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								t.Log("found same", i)
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								m[result] = true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log(siphash.Hash(0, 0, []byte("01")))
 | 
				
			||||||
 | 
						t.Log(siphash.Hash(0, 0, []byte("10")))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkItem_HashKeyMd5(b *testing.B) {
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i ++ {
 | 
				
			||||||
 | 
							h := md5.New()
 | 
				
			||||||
 | 
							h.Write([]byte("HELLO_KEY_" + strconv.Itoa(i)))
 | 
				
			||||||
 | 
							_ = h.Sum(nil)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkItem_siphash(b *testing.B) {
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i ++ {
 | 
				
			||||||
 | 
							_ = siphash.Hash(0, 0, []byte("HELLO_KEY_"+strconv.Itoa(i)))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										68
									
								
								internal/grids/list.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								internal/grids/list.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,68 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type List struct {
 | 
				
			||||||
 | 
						head *Item
 | 
				
			||||||
 | 
						end  *Item
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewList() *List {
 | 
				
			||||||
 | 
						return &List{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *List) Add(item *Item) {
 | 
				
			||||||
 | 
						if item == nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if this.end != nil {
 | 
				
			||||||
 | 
							this.end.Next = item
 | 
				
			||||||
 | 
							item.Prev = this.end
 | 
				
			||||||
 | 
							item.Next = nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.end = item
 | 
				
			||||||
 | 
						if this.head == nil {
 | 
				
			||||||
 | 
							this.head = item
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *List) Remove(item *Item) {
 | 
				
			||||||
 | 
						if item == nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if item.Prev != nil {
 | 
				
			||||||
 | 
							item.Prev.Next = item.Next
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if item.Next != nil {
 | 
				
			||||||
 | 
							item.Next.Prev = item.Prev
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if item == this.head {
 | 
				
			||||||
 | 
							this.head = item.Next
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if item == this.end {
 | 
				
			||||||
 | 
							this.end = item.Prev
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						item.Prev = nil
 | 
				
			||||||
 | 
						item.Next = nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *List) Len() int {
 | 
				
			||||||
 | 
						l := 0
 | 
				
			||||||
 | 
						for e := this.head; e != nil; e = e.Next {
 | 
				
			||||||
 | 
							l ++
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return l
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *List) Range(f func(item *Item) (goNext bool)) {
 | 
				
			||||||
 | 
						for e := this.head; e != nil; e = e.Next {
 | 
				
			||||||
 | 
							goNext := f(e)
 | 
				
			||||||
 | 
							if !goNext {
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *List) Reset() {
 | 
				
			||||||
 | 
						this.head = nil
 | 
				
			||||||
 | 
						this.end = nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										64
									
								
								internal/grids/list_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								internal/grids/list_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,64 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestList(t *testing.T) {
 | 
				
			||||||
 | 
						l := &List{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var e1 *Item = nil
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							e := &Item{
 | 
				
			||||||
 | 
								ValueInt64: 1,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							l.Add(e)
 | 
				
			||||||
 | 
							e1 = e
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var e2 *Item = nil
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							e := &Item{
 | 
				
			||||||
 | 
								ValueInt64: 2,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							l.Add(e)
 | 
				
			||||||
 | 
							e2 = e
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var e3 *Item = nil
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							e := &Item{
 | 
				
			||||||
 | 
								ValueInt64: 3,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							l.Add(e)
 | 
				
			||||||
 | 
							e3 = e
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var e4 *Item = nil
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							e := &Item{
 | 
				
			||||||
 | 
								ValueInt64: 4,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							l.Add(e)
 | 
				
			||||||
 | 
							e4 = e
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						l.Remove(e1)
 | 
				
			||||||
 | 
						//l.Remove(e2)
 | 
				
			||||||
 | 
						//l.Remove(e3)
 | 
				
			||||||
 | 
						l.Remove(e4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for e := l.head; e != nil; e = e.Next {
 | 
				
			||||||
 | 
							t.Log(e.ValueInt64)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log("e1, e2, e3, e4, head, end:", e1, e2, e3, e4)
 | 
				
			||||||
 | 
						if l.head != nil {
 | 
				
			||||||
 | 
							t.Log("head:", l.head.ValueInt64)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							t.Log("head: nil")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if l.end != nil {
 | 
				
			||||||
 | 
							t.Log("end:", l.end.ValueInt64)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							t.Log("end: nil")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										11
									
								
								internal/grids/opt_compress.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								internal/grids/opt_compress.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,11 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type CompressOpt struct {
 | 
				
			||||||
 | 
						Level int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewCompressOpt(level int) *CompressOpt {
 | 
				
			||||||
 | 
						return &CompressOpt{
 | 
				
			||||||
 | 
							Level: level,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										11
									
								
								internal/grids/opt_limit_count.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								internal/grids/opt_limit_count.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,11 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type LimitCountOpt struct {
 | 
				
			||||||
 | 
						Count int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewLimitCountOpt(count int) *LimitCountOpt {
 | 
				
			||||||
 | 
						return &LimitCountOpt{
 | 
				
			||||||
 | 
							Count: count,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										11
									
								
								internal/grids/opt_limit_size.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								internal/grids/opt_limit_size.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,11 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type LimitSizeOpt struct {
 | 
				
			||||||
 | 
						Size int64
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewLimitSizeOpt(size int64) *LimitSizeOpt {
 | 
				
			||||||
 | 
						return &LimitSizeOpt{
 | 
				
			||||||
 | 
							Size: size,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										11
									
								
								internal/grids/opt_recycle_interval.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								internal/grids/opt_recycle_interval.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,11 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RecycleIntervalOpt struct {
 | 
				
			||||||
 | 
						Interval int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewRecycleIntervalOpt(interval int) *RecycleIntervalOpt {
 | 
				
			||||||
 | 
						return &RecycleIntervalOpt{
 | 
				
			||||||
 | 
							Interval: interval,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										6
									
								
								internal/grids/stat.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								internal/grids/stat.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					package grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Stat struct {
 | 
				
			||||||
 | 
						TotalBytes int64
 | 
				
			||||||
 | 
						CountItems int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -96,7 +96,11 @@ func (this *HTTPRequest) Do() {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// WAF
 | 
						// WAF
 | 
				
			||||||
	// TODO 需要实现
 | 
						if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn && this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
 | 
				
			||||||
 | 
							if this.doWAFRequest() {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 访问控制
 | 
						// 访问控制
 | 
				
			||||||
	// TODO 需要实现
 | 
						// TODO 需要实现
 | 
				
			||||||
@@ -253,6 +257,12 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
 | 
				
			|||||||
		this.web.Cache = web.Cache
 | 
							this.web.Cache = web.Cache
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// waf
 | 
				
			||||||
 | 
						if web.FirewallRef != nil && (web.FirewallRef.IsPrior || isTop) {
 | 
				
			||||||
 | 
							this.web.FirewallRef = web.FirewallRef
 | 
				
			||||||
 | 
							this.web.FirewallPolicy = web.FirewallPolicy
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 重写规则
 | 
						// 重写规则
 | 
				
			||||||
	if len(web.RewriteRefs) > 0 {
 | 
						if len(web.RewriteRefs) > 0 {
 | 
				
			||||||
		for index, ref := range web.RewriteRefs {
 | 
							for index, ref := range web.RewriteRefs {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -166,12 +166,26 @@ func (this *HTTPRequest) doReverseProxy() {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// WAF对出站进行检查
 | 
						// WAF对出站进行检查
 | 
				
			||||||
	// TODO
 | 
						if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn && this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
 | 
				
			||||||
 | 
							if this.doWAFResponse(resp) {
 | 
				
			||||||
 | 
								err = resp.Body.Close()
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									logs.Error(err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TODO 清除源站错误次数
 | 
						// TODO 清除源站错误次数
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 特殊页面
 | 
						// 特殊页面
 | 
				
			||||||
	// TODO
 | 
						if len(this.web.Pages) > 0 && this.doPage(resp.StatusCode) {
 | 
				
			||||||
 | 
							err = resp.Body.Close()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logs.Error(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 设置Charset
 | 
						// 设置Charset
 | 
				
			||||||
	// TODO 这里应该可以设置文本类型的列表,以及是否强制覆盖所有文本类型的字符集
 | 
						// TODO 这里应该可以设置文本类型的列表,以及是否强制覆盖所有文本类型的字符集
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										51
									
								
								internal/nodes/http_request_waf.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								internal/nodes/http_request_waf.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,51 @@
 | 
				
			|||||||
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 调用WAF
 | 
				
			||||||
 | 
					func (this *HTTPRequest) doWAFRequest() (blocked bool) {
 | 
				
			||||||
 | 
						w := sharedWAFManager.FindWAF(this.web.FirewallPolicy.Id)
 | 
				
			||||||
 | 
						if w == nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						goNext, _, ruleSet, err := w.MatchRequest(this.RawReq, this.writer)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logs.Error(err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if ruleSet != nil {
 | 
				
			||||||
 | 
							if ruleSet.Action != waf.ActionAllow {
 | 
				
			||||||
 | 
								// TODO 记录日志
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return !goNext
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// call response waf
 | 
				
			||||||
 | 
					func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
 | 
				
			||||||
 | 
						w := sharedWAFManager.FindWAF(this.web.FirewallPolicy.Id)
 | 
				
			||||||
 | 
						if w == nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						goNext, _, ruleSet, err := w.MatchResponse(this.RawReq, resp, this.writer)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logs.Error(err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if ruleSet != nil {
 | 
				
			||||||
 | 
							if ruleSet.Action != waf.ActionAllow {
 | 
				
			||||||
 | 
								// TODO 记录日志
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return !goNext
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -104,6 +104,7 @@ func (this *Node) syncConfig(isFirstTime bool) error {
 | 
				
			|||||||
	logs.Println("[NODE]reload config ...")
 | 
						logs.Println("[NODE]reload config ...")
 | 
				
			||||||
	nodeconfigs.ResetNodeConfig(nodeConfig)
 | 
						nodeconfigs.ResetNodeConfig(nodeConfig)
 | 
				
			||||||
	caches.SharedManager.UpdatePolicies(nodeConfig.AllCachePolicies())
 | 
						caches.SharedManager.UpdatePolicies(nodeConfig.AllCachePolicies())
 | 
				
			||||||
 | 
						sharedWAFManager.UpdatePolicies(nodeConfig.AllHTTPFirewallPolicies())
 | 
				
			||||||
	sharedNodeConfig = nodeConfig
 | 
						sharedNodeConfig = nodeConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !isFirstTime {
 | 
						if !isFirstTime {
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										175
									
								
								internal/nodes/waf_manager.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										175
									
								
								internal/nodes/waf_manager.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,175 @@
 | 
				
			|||||||
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/errors"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var sharedWAFManager = NewWAFManager()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WAF管理器
 | 
				
			||||||
 | 
					type WAFManager struct {
 | 
				
			||||||
 | 
						mapping map[int64]*waf.WAF // policyId => WAF
 | 
				
			||||||
 | 
						locker  sync.RWMutex
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 获取新对象
 | 
				
			||||||
 | 
					func NewWAFManager() *WAFManager {
 | 
				
			||||||
 | 
						return &WAFManager{
 | 
				
			||||||
 | 
							mapping: map[int64]*waf.WAF{},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 更新策略
 | 
				
			||||||
 | 
					func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallPolicy) {
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						defer this.locker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m := map[int64]*waf.WAF{}
 | 
				
			||||||
 | 
						for _, p := range policies {
 | 
				
			||||||
 | 
							w, err := this.convertWAF(p)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logs.Println("[WAF]initialize policy '" + strconv.FormatInt(p.Id, 10) + "' failed: " + err.Error())
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if w == nil {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							m[p.Id] = w
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.mapping = m
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 查找WAF
 | 
				
			||||||
 | 
					func (this *WAFManager) FindWAF(policyId int64) *waf.WAF {
 | 
				
			||||||
 | 
						this.locker.RLock()
 | 
				
			||||||
 | 
						w, _ := this.mapping[policyId]
 | 
				
			||||||
 | 
						this.locker.RUnlock()
 | 
				
			||||||
 | 
						return w
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 判断是否包含int64
 | 
				
			||||||
 | 
					func (this *WAFManager) containsInt64(values []int64, value int64) bool {
 | 
				
			||||||
 | 
						for _, v := range values {
 | 
				
			||||||
 | 
							if v == value {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 将Policy转换为WAF
 | 
				
			||||||
 | 
					func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (*waf.WAF, error) {
 | 
				
			||||||
 | 
						if policy == nil {
 | 
				
			||||||
 | 
							return nil, errors.New("policy should not be nil")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						w := &waf.WAF{
 | 
				
			||||||
 | 
							Id:   strconv.FormatInt(policy.Id, 10),
 | 
				
			||||||
 | 
							IsOn: policy.IsOn,
 | 
				
			||||||
 | 
							Name: policy.Name,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// inbound
 | 
				
			||||||
 | 
						if policy.Inbound != nil && policy.Inbound.IsOn {
 | 
				
			||||||
 | 
							for _, group := range policy.Inbound.Groups {
 | 
				
			||||||
 | 
								g := &waf.RuleGroup{
 | 
				
			||||||
 | 
									Id:          strconv.FormatInt(group.Id, 10),
 | 
				
			||||||
 | 
									IsOn:        group.IsOn,
 | 
				
			||||||
 | 
									Name:        group.Name,
 | 
				
			||||||
 | 
									Description: group.Description,
 | 
				
			||||||
 | 
									Code:        group.Code,
 | 
				
			||||||
 | 
									IsInbound:   true,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// rule sets
 | 
				
			||||||
 | 
								for _, set := range group.Sets {
 | 
				
			||||||
 | 
									s := &waf.RuleSet{
 | 
				
			||||||
 | 
										Id:            strconv.FormatInt(set.Id, 10),
 | 
				
			||||||
 | 
										Code:          set.Code,
 | 
				
			||||||
 | 
										IsOn:          set.IsOn,
 | 
				
			||||||
 | 
										Name:          set.Name,
 | 
				
			||||||
 | 
										Description:   set.Description,
 | 
				
			||||||
 | 
										Connector:     set.Connector,
 | 
				
			||||||
 | 
										Action:        set.Action,
 | 
				
			||||||
 | 
										ActionOptions: set.ActionOptions,
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// rules
 | 
				
			||||||
 | 
									for _, rule := range set.Rules {
 | 
				
			||||||
 | 
										r := &waf.Rule{
 | 
				
			||||||
 | 
											Description:       rule.Description,
 | 
				
			||||||
 | 
											Param:             rule.Param,
 | 
				
			||||||
 | 
											Operator:          rule.Operator,
 | 
				
			||||||
 | 
											Value:             rule.Value,
 | 
				
			||||||
 | 
											IsCaseInsensitive: rule.IsCaseInsensitive,
 | 
				
			||||||
 | 
											CheckpointOptions: rule.CheckpointOptions,
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										s.Rules = append(s.Rules, r)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									g.RuleSets = append(g.RuleSets, s)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								w.Inbound = append(w.Inbound, g)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// outbound
 | 
				
			||||||
 | 
						if policy.Outbound != nil && policy.Outbound.IsOn {
 | 
				
			||||||
 | 
							for _, group := range policy.Outbound.Groups {
 | 
				
			||||||
 | 
								g := &waf.RuleGroup{
 | 
				
			||||||
 | 
									Id:          strconv.FormatInt(group.Id, 10),
 | 
				
			||||||
 | 
									IsOn:        group.IsOn,
 | 
				
			||||||
 | 
									Name:        group.Name,
 | 
				
			||||||
 | 
									Description: group.Description,
 | 
				
			||||||
 | 
									Code:        group.Code,
 | 
				
			||||||
 | 
									IsInbound:   true,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// rule sets
 | 
				
			||||||
 | 
								for _, set := range group.Sets {
 | 
				
			||||||
 | 
									s := &waf.RuleSet{
 | 
				
			||||||
 | 
										Id:            strconv.FormatInt(set.Id, 10),
 | 
				
			||||||
 | 
										Code:          set.Code,
 | 
				
			||||||
 | 
										IsOn:          set.IsOn,
 | 
				
			||||||
 | 
										Name:          set.Name,
 | 
				
			||||||
 | 
										Description:   set.Description,
 | 
				
			||||||
 | 
										Connector:     set.Connector,
 | 
				
			||||||
 | 
										Action:        set.Action,
 | 
				
			||||||
 | 
										ActionOptions: set.ActionOptions,
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// rules
 | 
				
			||||||
 | 
									for _, rule := range set.Rules {
 | 
				
			||||||
 | 
										r := &waf.Rule{
 | 
				
			||||||
 | 
											Description:       rule.Description,
 | 
				
			||||||
 | 
											Param:             rule.Param,
 | 
				
			||||||
 | 
											Operator:          rule.Operator,
 | 
				
			||||||
 | 
											Value:             rule.Value,
 | 
				
			||||||
 | 
											IsCaseInsensitive: rule.IsCaseInsensitive,
 | 
				
			||||||
 | 
											CheckpointOptions: rule.CheckpointOptions,
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										s.Rules = append(s.Rules, r)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									g.RuleSets = append(g.RuleSets, s)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								w.Outbound = append(w.Outbound, g)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// action
 | 
				
			||||||
 | 
						// TODO
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err := w.Init()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return w, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										44
									
								
								internal/nodes/waf_manager_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								internal/nodes/waf_manager_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
				
			|||||||
 | 
					package nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestWAFManager_convert(t *testing.T) {
 | 
				
			||||||
 | 
						p := &firewallconfigs.HTTPFirewallPolicy{
 | 
				
			||||||
 | 
							Id:   1,
 | 
				
			||||||
 | 
							IsOn: true,
 | 
				
			||||||
 | 
							Inbound: &firewallconfigs.HTTPFirewallInboundConfig{
 | 
				
			||||||
 | 
								IsOn: true,
 | 
				
			||||||
 | 
								Groups: []*firewallconfigs.HTTPFirewallRuleGroup{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										Id: 1,
 | 
				
			||||||
 | 
										Sets: []*firewallconfigs.HTTPFirewallRuleSet{
 | 
				
			||||||
 | 
											{
 | 
				
			||||||
 | 
												Id: 1,
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
											{
 | 
				
			||||||
 | 
												Id: 2,
 | 
				
			||||||
 | 
												Rules: []*firewallconfigs.HTTPFirewallRule{
 | 
				
			||||||
 | 
													{
 | 
				
			||||||
 | 
														Id: 1,
 | 
				
			||||||
 | 
													},
 | 
				
			||||||
 | 
													{
 | 
				
			||||||
 | 
														Id: 2,
 | 
				
			||||||
 | 
													},
 | 
				
			||||||
 | 
												},
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						w, err := sharedWAFManager.convertWAF(p)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						logs.PrintAsJSON(w, t)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										77
									
								
								internal/utils/get.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								internal/utils/get.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,77 @@
 | 
				
			|||||||
 | 
					package utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"regexp"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var RegexpDigitNumber = regexp.MustCompile("^\\d+$")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func Get(object interface{}, keys []string) interface{} {
 | 
				
			||||||
 | 
						if len(keys) == 0 {
 | 
				
			||||||
 | 
							return object
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if object == nil {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						firstKey := keys[0]
 | 
				
			||||||
 | 
						keys = keys[1:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						value := reflect.ValueOf(object)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !value.IsValid() {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if value.Kind() == reflect.Ptr {
 | 
				
			||||||
 | 
							value = value.Elem()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if value.Kind() == reflect.Struct {
 | 
				
			||||||
 | 
							field := value.FieldByName(firstKey)
 | 
				
			||||||
 | 
							if !field.IsValid() {
 | 
				
			||||||
 | 
								return nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if len(keys) == 0 {
 | 
				
			||||||
 | 
								return field.Interface()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return Get(field.Interface(), keys)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if value.Kind() == reflect.Map {
 | 
				
			||||||
 | 
							mapKey := reflect.ValueOf(firstKey)
 | 
				
			||||||
 | 
							mapValue := value.MapIndex(mapKey)
 | 
				
			||||||
 | 
							if !mapValue.IsValid() {
 | 
				
			||||||
 | 
								return nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if len(keys) == 0 {
 | 
				
			||||||
 | 
								return mapValue.Interface()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return Get(mapValue.Interface(), keys)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if value.Kind() == reflect.Slice {
 | 
				
			||||||
 | 
							if RegexpDigitNumber.MatchString(firstKey) {
 | 
				
			||||||
 | 
								firstKeyInt := types.Int(firstKey)
 | 
				
			||||||
 | 
								if value.Len() > firstKeyInt {
 | 
				
			||||||
 | 
									result := value.Index(firstKeyInt).Interface()
 | 
				
			||||||
 | 
									if len(keys) == 0 {
 | 
				
			||||||
 | 
										return result
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									return Get(result, keys)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										79
									
								
								internal/utils/get_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								internal/utils/get_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,79 @@
 | 
				
			|||||||
 | 
					package utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestGetStruct(t *testing.T) {
 | 
				
			||||||
 | 
						object := struct {
 | 
				
			||||||
 | 
							Name  string
 | 
				
			||||||
 | 
							Age   int
 | 
				
			||||||
 | 
							Books []string
 | 
				
			||||||
 | 
							Extend struct {
 | 
				
			||||||
 | 
								Location struct {
 | 
				
			||||||
 | 
									City string
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							Name:  "lu",
 | 
				
			||||||
 | 
							Age:   20,
 | 
				
			||||||
 | 
							Books: []string{"Golang"},
 | 
				
			||||||
 | 
							Extend: struct {
 | 
				
			||||||
 | 
								Location struct {
 | 
				
			||||||
 | 
									City string
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}{
 | 
				
			||||||
 | 
								Location: struct {
 | 
				
			||||||
 | 
									City string
 | 
				
			||||||
 | 
								}{
 | 
				
			||||||
 | 
									City: "Beijing",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if Get(object, []string{"Name"}) != "lu" {
 | 
				
			||||||
 | 
							t.Fatal("[ERROR]Name != lu")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if Get(object, []string{"Age"}) != 20 {
 | 
				
			||||||
 | 
							t.Fatal("[ERROR]Age != 20")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if Get(object, []string{"Books", "0"}) != "Golang" {
 | 
				
			||||||
 | 
							t.Fatal("[ERROR]books.0 != Golang")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log("Extend.Location:", Get(object, []string{"Extend", "Location"}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if Get(object, []string{"Extend", "Location", "City"}) != "Beijing" {
 | 
				
			||||||
 | 
							t.Fatal("[ERROR]Extend.Location.City != Beijing")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestGetMap(t *testing.T) {
 | 
				
			||||||
 | 
						object := map[string]interface{}{
 | 
				
			||||||
 | 
							"Name": "lu",
 | 
				
			||||||
 | 
							"Age":  20,
 | 
				
			||||||
 | 
							"Extend": map[string]interface{}{
 | 
				
			||||||
 | 
								"Location": map[string]interface{}{
 | 
				
			||||||
 | 
									"City": "Beijing",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if Get(object, []string{"Name"}) != "lu" {
 | 
				
			||||||
 | 
							t.Fatal("[ERROR]Name != lu")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if Get(object, []string{"Age"}) != 20 {
 | 
				
			||||||
 | 
							t.Fatal("[ERROR]Age != 20")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if Get(object, []string{"Books", "0"}) != nil {
 | 
				
			||||||
 | 
							t.Fatal("[ERROR]books.0 != nil")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Log(Get(object, []string{"Extend", "Location"}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if Get(object, []string{"Extend", "Location", "City"}) != "Beijing" {
 | 
				
			||||||
 | 
							t.Fatal("[ERROR]Extend.Location.City != Beijing")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										37
									
								
								internal/utils/string.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								internal/utils/string.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,37 @@
 | 
				
			|||||||
 | 
					package utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"unsafe"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// convert bytes to string
 | 
				
			||||||
 | 
					func UnsafeBytesToString(bs []byte) string {
 | 
				
			||||||
 | 
						return *(*string)(unsafe.Pointer(&bs))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// convert string to bytes
 | 
				
			||||||
 | 
					func UnsafeStringToBytes(s string) []byte {
 | 
				
			||||||
 | 
						return *(*[]byte)(unsafe.Pointer(&s))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// format address
 | 
				
			||||||
 | 
					func FormatAddress(addr string) string {
 | 
				
			||||||
 | 
						if strings.HasSuffix(addr, "unix:") {
 | 
				
			||||||
 | 
							return addr
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						addr = strings.Replace(addr, " ", "", -1)
 | 
				
			||||||
 | 
						addr = strings.Replace(addr, "\t", "", -1)
 | 
				
			||||||
 | 
						addr = strings.Replace(addr, ":", ":", -1)
 | 
				
			||||||
 | 
						addr = strings.TrimSpace(addr)
 | 
				
			||||||
 | 
						return addr
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// format address list
 | 
				
			||||||
 | 
					func FormatAddressList(addrList []string) []string {
 | 
				
			||||||
 | 
						result := []string{}
 | 
				
			||||||
 | 
						for _, addr := range addrList {
 | 
				
			||||||
 | 
							result = append(result, FormatAddress(addr))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return result
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										56
									
								
								internal/utils/string_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								internal/utils/string_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,56 @@
 | 
				
			|||||||
 | 
					package utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestBytesToString(t *testing.T) {
 | 
				
			||||||
 | 
						t.Log(UnsafeBytesToString([]byte("Hello,World")))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestStringToBytes(t *testing.T) {
 | 
				
			||||||
 | 
						t.Log(string(UnsafeStringToBytes("Hello,World")))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkBytesToString(b *testing.B) {
 | 
				
			||||||
 | 
						data := []byte("Hello,World")
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
 | 
							_ = UnsafeBytesToString(data)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkBytesToString2(b *testing.B) {
 | 
				
			||||||
 | 
						data := []byte("Hello,World")
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
 | 
							_ = string(data)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkStringToBytes(b *testing.B) {
 | 
				
			||||||
 | 
						s := strings.Repeat("Hello,World", 1024)
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
 | 
							_ = UnsafeStringToBytes(s)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkStringToBytes2(b *testing.B) {
 | 
				
			||||||
 | 
						s := strings.Repeat("Hello,World", 1024)
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
 | 
							_ = []byte(s)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestFormatAddress(t *testing.T) {
 | 
				
			||||||
 | 
						t.Log(FormatAddress("127.0.0.1:1234"))
 | 
				
			||||||
 | 
						t.Log(FormatAddress("127.0.0.1 : 1234"))
 | 
				
			||||||
 | 
						t.Log(FormatAddress("127.0.0.1:1234"))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestFormatAddressList(t *testing.T) {
 | 
				
			||||||
 | 
						t.Log(FormatAddressList([]string{
 | 
				
			||||||
 | 
							"127.0.0.1:1234",
 | 
				
			||||||
 | 
							"127.0.0.1 : 1234",
 | 
				
			||||||
 | 
							"127.0.0.1:1234",
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										48
									
								
								internal/waf/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								internal/waf/README.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,48 @@
 | 
				
			|||||||
 | 
					# WAF
 | 
				
			||||||
 | 
					A basic WAF for TeaWeb.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Config Constructions
 | 
				
			||||||
 | 
					~~~
 | 
				
			||||||
 | 
					WAF
 | 
				
			||||||
 | 
					  Inbound
 | 
				
			||||||
 | 
						  Rule Groups
 | 
				
			||||||
 | 
							Rule Sets
 | 
				
			||||||
 | 
							  Rules
 | 
				
			||||||
 | 
								Checkpoint Param <Operator> Value
 | 
				
			||||||
 | 
					  Outbound
 | 
				
			||||||
 | 
					  	  Rule Groups
 | 
				
			||||||
 | 
					  	    ... 				
 | 
				
			||||||
 | 
					~~~
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Apply WAF
 | 
				
			||||||
 | 
					~~~
 | 
				
			||||||
 | 
					Request  -->  WAF  -->   Backends
 | 
				
			||||||
 | 
								/
 | 
				
			||||||
 | 
					Response  <-- WAF <----		
 | 
				
			||||||
 | 
					~~~
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Coding
 | 
				
			||||||
 | 
					~~~go
 | 
				
			||||||
 | 
					waf := teawaf.NewWAF()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// add rule groups here
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					err := waf.Init()
 | 
				
			||||||
 | 
					if err != nil {
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					waf.Start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// match http request
 | 
				
			||||||
 | 
					// (req *http.Request, responseWriter http.ResponseWriter)
 | 
				
			||||||
 | 
					goNext, ruleSet, _ := waf.MatchRequest(req, responseWriter)
 | 
				
			||||||
 | 
					if ruleSet != nil {
 | 
				
			||||||
 | 
						log.Println("meet rule set:", ruleSet.Name, "action:", ruleSet.Action)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					if !goNext {
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// stop the waf
 | 
				
			||||||
 | 
					// waf.Stop()
 | 
				
			||||||
 | 
					~~~
 | 
				
			||||||
							
								
								
									
										14
									
								
								internal/waf/action_allow.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								internal/waf/action_allow.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,14 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AllowAction struct {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *AllowAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						// do nothing
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										94
									
								
								internal/waf/action_block.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								internal/waf/action_block.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,94 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/Tea"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"path/filepath"
 | 
				
			||||||
 | 
						"regexp"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// url client configure
 | 
				
			||||||
 | 
					var urlPrefixReg = regexp.MustCompile("^(?i)(http|https)://")
 | 
				
			||||||
 | 
					var httpClient = utils.SharedHttpClient(5 * time.Second)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type BlockAction struct {
 | 
				
			||||||
 | 
						StatusCode int    `yaml:"statusCode" json:"statusCode"`
 | 
				
			||||||
 | 
						Body       string `yaml:"body" json:"body"` // supports HTML
 | 
				
			||||||
 | 
						URL        string `yaml:"url" json:"url"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						if writer != nil {
 | 
				
			||||||
 | 
							// if status code eq 444, we close the connection
 | 
				
			||||||
 | 
							if this.StatusCode == 444 {
 | 
				
			||||||
 | 
								hijack, ok := writer.(http.Hijacker)
 | 
				
			||||||
 | 
								if ok {
 | 
				
			||||||
 | 
									conn, _, _ := hijack.Hijack()
 | 
				
			||||||
 | 
									if conn != nil {
 | 
				
			||||||
 | 
										_ = conn.Close()
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// output response
 | 
				
			||||||
 | 
							if this.StatusCode > 0 {
 | 
				
			||||||
 | 
								writer.WriteHeader(this.StatusCode)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								writer.WriteHeader(http.StatusForbidden)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if len(this.URL) > 0 {
 | 
				
			||||||
 | 
								if urlPrefixReg.MatchString(this.URL) {
 | 
				
			||||||
 | 
									req, err := http.NewRequest(http.MethodGet, this.URL, nil)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										logs.Error(err)
 | 
				
			||||||
 | 
										return false
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									resp, err := httpClient.Do(req)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										logs.Error(err)
 | 
				
			||||||
 | 
										return false
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									defer func() {
 | 
				
			||||||
 | 
										_ = resp.Body.Close()
 | 
				
			||||||
 | 
									}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									for k, v := range resp.Header {
 | 
				
			||||||
 | 
										for _, v1 := range v {
 | 
				
			||||||
 | 
											writer.Header().Add(k, v1)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									buf := make([]byte, 1024)
 | 
				
			||||||
 | 
									_, _ = io.CopyBuffer(writer, resp.Body, buf)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									path := this.URL
 | 
				
			||||||
 | 
									if !filepath.IsAbs(this.URL) {
 | 
				
			||||||
 | 
										path = Tea.Root + string(os.PathSeparator) + path
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									data, err := ioutil.ReadFile(path)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										logs.Error(err)
 | 
				
			||||||
 | 
										return false
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									_, _ = writer.Write(data)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if len(this.Body) > 0 {
 | 
				
			||||||
 | 
								_, _ = writer.Write([]byte(this.Body))
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								_, _ = writer.Write([]byte("The request is blocked by " + teaconst.ProductName))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										39
									
								
								internal/waf/action_captcha.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								internal/waf/action_captcha.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,39 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
 | 
						stringutil "github.com/iwind/TeaGo/utils/string"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var captchaSalt = stringutil.Rand(32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						CaptchaSeconds = 600 // 10 minutes
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type CaptchaAction struct {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CaptchaAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						// TEAWEB_CAPTCHA:
 | 
				
			||||||
 | 
						cookie, err := request.Cookie("TEAWEB_WAF_CAPTCHA")
 | 
				
			||||||
 | 
						if err == nil && cookie != nil && len(cookie.Value) > 32 {
 | 
				
			||||||
 | 
							m := cookie.Value[:32]
 | 
				
			||||||
 | 
							timestamp := cookie.Value[32:]
 | 
				
			||||||
 | 
							if stringutil.Md5(captchaSalt+timestamp) == m && time.Now().Unix() < types.Int64(timestamp) { // verify md5
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						refURL := request.URL.String()
 | 
				
			||||||
 | 
						if len(request.Referer()) > 0 {
 | 
				
			||||||
 | 
							refURL = request.Referer()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						http.Redirect(writer, request.Raw(), "/WAFCAPTCHA?url="+url.QueryEscape(refURL), http.StatusTemporaryRedirect)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										12
									
								
								internal/waf/action_definition.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								internal/waf/action_definition.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,12 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "reflect"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// action definition
 | 
				
			||||||
 | 
					type ActionDefinition struct {
 | 
				
			||||||
 | 
						Name        string
 | 
				
			||||||
 | 
						Code        ActionString
 | 
				
			||||||
 | 
						Description string
 | 
				
			||||||
 | 
						Instance    ActionInterface
 | 
				
			||||||
 | 
						Type        reflect.Type
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										34
									
								
								internal/waf/action_go_group.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								internal/waf/action_go_group.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,34 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GoGroupAction struct {
 | 
				
			||||||
 | 
						GroupId string `yaml:"groupId" json:"groupId"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						group := waf.FindRuleGroup(this.GroupId)
 | 
				
			||||||
 | 
						if group == nil || !group.IsOn {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						b, set, err := group.MatchRequest(request)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logs.Error(err)
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !b {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						actionObject := FindActionInstance(set.Action, set.ActionOptions)
 | 
				
			||||||
 | 
						if actionObject == nil {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return actionObject.Perform(waf, request, writer)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										37
									
								
								internal/waf/action_go_set.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								internal/waf/action_go_set.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,37 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type GoSetAction struct {
 | 
				
			||||||
 | 
						GroupId string `yaml:"groupId" json:"groupId"`
 | 
				
			||||||
 | 
						SetId   string `yaml:"setId" json:"setId"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						group := waf.FindRuleGroup(this.GroupId)
 | 
				
			||||||
 | 
						if group == nil || !group.IsOn {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						set := group.FindRuleSet(this.SetId)
 | 
				
			||||||
 | 
						if set == nil || !set.IsOn {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						b, err := set.MatchRequest(request)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logs.Error(err)
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !b {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						actionObject := FindActionInstance(set.Action, set.ActionOptions)
 | 
				
			||||||
 | 
						if actionObject == nil {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return actionObject.Perform(waf, request, writer)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										5
									
								
								internal/waf/action_instance.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								internal/waf/action_instance.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Action struct {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										13
									
								
								internal/waf/action_log.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								internal/waf/action_log.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,13 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type LogAction struct {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *LogAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										21
									
								
								internal/waf/action_type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/action_type.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ActionString = string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						ActionLog     = "log"     // allow and log
 | 
				
			||||||
 | 
						ActionBlock   = "block"   // block
 | 
				
			||||||
 | 
						ActionCaptcha = "captcha" // block and show captcha
 | 
				
			||||||
 | 
						ActionAllow   = "allow"   // allow
 | 
				
			||||||
 | 
						ActionGoGroup = "go_group" // go to next rule group
 | 
				
			||||||
 | 
						ActionGoSet   = "go_set"   // go to next rule set
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ActionInterface interface {
 | 
				
			||||||
 | 
						Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										82
									
								
								internal/waf/action_utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								internal/waf/action_utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,82 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var AllActions = []*ActionDefinition{
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "阻止",
 | 
				
			||||||
 | 
							Code:     ActionBlock,
 | 
				
			||||||
 | 
							Instance: new(BlockAction),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "允许通过",
 | 
				
			||||||
 | 
							Code:     ActionAllow,
 | 
				
			||||||
 | 
							Instance: new(AllowAction),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "允许并记录日志",
 | 
				
			||||||
 | 
							Code:     ActionLog,
 | 
				
			||||||
 | 
							Instance: new(LogAction),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "Captcha验证码",
 | 
				
			||||||
 | 
							Code:     ActionCaptcha,
 | 
				
			||||||
 | 
							Instance: new(CaptchaAction),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "跳到下一个规则分组",
 | 
				
			||||||
 | 
							Code:     ActionGoGroup,
 | 
				
			||||||
 | 
							Instance: new(GoGroupAction),
 | 
				
			||||||
 | 
							Type:     reflect.TypeOf(new(GoGroupAction)).Elem(),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:     "跳到下一个规则集",
 | 
				
			||||||
 | 
							Code:     ActionGoSet,
 | 
				
			||||||
 | 
							Instance: new(GoSetAction),
 | 
				
			||||||
 | 
							Type:     reflect.TypeOf(new(GoSetAction)).Elem(),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
 | 
				
			||||||
 | 
						for _, def := range AllActions {
 | 
				
			||||||
 | 
							if def.Code == action {
 | 
				
			||||||
 | 
								if def.Type != nil {
 | 
				
			||||||
 | 
									// create new instance
 | 
				
			||||||
 | 
									ptrValue := reflect.New(def.Type)
 | 
				
			||||||
 | 
									instance := ptrValue.Interface().(ActionInterface)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if len(options) > 0 {
 | 
				
			||||||
 | 
										count := def.Type.NumField()
 | 
				
			||||||
 | 
										for i := 0; i < count; i++ {
 | 
				
			||||||
 | 
											field := def.Type.Field(i)
 | 
				
			||||||
 | 
											tag, ok := field.Tag.Lookup("yaml")
 | 
				
			||||||
 | 
											if ok {
 | 
				
			||||||
 | 
												v, ok := options[tag]
 | 
				
			||||||
 | 
												if ok && reflect.TypeOf(v) == field.Type {
 | 
				
			||||||
 | 
													ptrValue.Elem().FieldByName(field.Name).Set(reflect.ValueOf(v))
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									return instance
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// return shared instance
 | 
				
			||||||
 | 
								return def.Instance
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func FindActionName(action ActionString) string {
 | 
				
			||||||
 | 
						for _, def := range AllActions {
 | 
				
			||||||
 | 
							if def.Code == action {
 | 
				
			||||||
 | 
								return def.Name
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										29
									
								
								internal/waf/action_utils_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								internal/waf/action_utils_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/assert"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
 | 
						"runtime"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestFindActionInstance(t *testing.T) {
 | 
				
			||||||
 | 
						a := assert.NewAssertion(t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil))
 | 
				
			||||||
 | 
						t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil))
 | 
				
			||||||
 | 
						t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
 | 
				
			||||||
 | 
						t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
 | 
				
			||||||
 | 
						t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
 | 
				
			||||||
 | 
						t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
 | 
				
			||||||
 | 
						t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b",}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkFindActionInstance(b *testing.B) {
 | 
				
			||||||
 | 
						runtime.GOMAXPROCS(1)
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
 | 
							FindActionInstance(ActionGoSet, nil)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										84
									
								
								internal/waf/captcha_validator.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								internal/waf/captcha_validator.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,84 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"encoding/base64"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/dchest/captcha"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						stringutil "github.com/iwind/TeaGo/utils/string"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var captchaValidator = &CaptchaValidator{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type CaptchaValidator struct {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CaptchaValidator) Run(request *requests.Request, writer http.ResponseWriter) {
 | 
				
			||||||
 | 
						if request.Method == http.MethodPost && len(request.FormValue("TEAWEB_WAF_CAPTCHA_ID")) > 0 {
 | 
				
			||||||
 | 
							this.validate(request, writer)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							this.show(request, writer)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CaptchaValidator) show(request *requests.Request, writer http.ResponseWriter) {
 | 
				
			||||||
 | 
						// show captcha
 | 
				
			||||||
 | 
						captchaId := captcha.NewLen(6)
 | 
				
			||||||
 | 
						buf := bytes.NewBuffer([]byte{})
 | 
				
			||||||
 | 
						err := captcha.WriteImage(buf, captchaId, 200, 100)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logs.Error(err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, _ = writer.Write([]byte(`<!DOCTYPE html>
 | 
				
			||||||
 | 
					<html>
 | 
				
			||||||
 | 
					<head>
 | 
				
			||||||
 | 
						<title>Verify Yourself</title>
 | 
				
			||||||
 | 
					</head>
 | 
				
			||||||
 | 
					<body>
 | 
				
			||||||
 | 
					<form method="POST">
 | 
				
			||||||
 | 
						<input type="hidden" name="TEAWEB_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
 | 
				
			||||||
 | 
						<img src="data:image/png;base64, ` + base64.StdEncoding.EncodeToString(buf.Bytes()) + `"/>` + `
 | 
				
			||||||
 | 
						<div>
 | 
				
			||||||
 | 
							<p>Input verify code above:</p>
 | 
				
			||||||
 | 
							<input type="text" name="TEAWEB_WAF_CAPTCHA_CODE" maxlength="6" size="18" autocomplete="off" z-index="1" style="font-size:16px;line-height:24px; letter-spacing: 15px; padding-left: 4px"/>
 | 
				
			||||||
 | 
						</div>
 | 
				
			||||||
 | 
						<div>
 | 
				
			||||||
 | 
							<button type="submit" onclick="window.location = '/webhook'" style="line-height:24px;margin-top:10px">Verify Yourself</button>
 | 
				
			||||||
 | 
						</div>
 | 
				
			||||||
 | 
					</form>
 | 
				
			||||||
 | 
					</body>
 | 
				
			||||||
 | 
					</html>`))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CaptchaValidator) validate(request *requests.Request, writer http.ResponseWriter) (allow bool) {
 | 
				
			||||||
 | 
						captchaId := request.FormValue("TEAWEB_WAF_CAPTCHA_ID")
 | 
				
			||||||
 | 
						if len(captchaId) > 0 {
 | 
				
			||||||
 | 
							captchaCode := request.FormValue("TEAWEB_WAF_CAPTCHA_CODE")
 | 
				
			||||||
 | 
							if captcha.VerifyString(captchaId, captchaCode) {
 | 
				
			||||||
 | 
								// set cookie
 | 
				
			||||||
 | 
								timestamp := fmt.Sprintf("%d", time.Now().Unix()+CaptchaSeconds)
 | 
				
			||||||
 | 
								m := stringutil.Md5(captchaSalt + timestamp)
 | 
				
			||||||
 | 
								http.SetCookie(writer, &http.Cookie{
 | 
				
			||||||
 | 
									Name:   "TEAWEB_WAF_CAPTCHA",
 | 
				
			||||||
 | 
									Value:  m + timestamp,
 | 
				
			||||||
 | 
									MaxAge: CaptchaSeconds,
 | 
				
			||||||
 | 
									Path:   "/", // all of dirs
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								rawURL := request.URL.Query().Get("url")
 | 
				
			||||||
 | 
								http.Redirect(writer, request.Raw(), rawURL, http.StatusSeeOther)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								http.Redirect(writer, request.Raw(), request.URL.String(), http.StatusSeeOther)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										246
									
								
								internal/waf/checkpoints/cc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										246
									
								
								internal/waf/checkpoints/cc.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,246 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/grids"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"regexp"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ${cc.arg}
 | 
				
			||||||
 | 
					// TODO implement more traffic rules
 | 
				
			||||||
 | 
					type CCCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						grid *grids.Grid
 | 
				
			||||||
 | 
						once sync.Once
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CCCheckpoint) Init() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CCCheckpoint) Start() {
 | 
				
			||||||
 | 
						if this.grid != nil {
 | 
				
			||||||
 | 
							this.grid.Destroy()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.grid = grids.NewGrid(32, grids.NewLimitCountOpt(1000_0000))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if this.grid == nil {
 | 
				
			||||||
 | 
							this.once.Do(func() {
 | 
				
			||||||
 | 
								this.Start()
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							if this.grid == nil {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						periodString, ok := options["period"]
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						period := types.Int64(periodString)
 | 
				
			||||||
 | 
						if period < 1 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						v, _ := options["userType"]
 | 
				
			||||||
 | 
						userType := types.String(v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						v, _ = options["userField"]
 | 
				
			||||||
 | 
						userField := types.String(v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						v, _ = options["userIndex"]
 | 
				
			||||||
 | 
						userIndex := types.Int(v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if param == "requests" { // requests
 | 
				
			||||||
 | 
							var key = ""
 | 
				
			||||||
 | 
							switch userType {
 | 
				
			||||||
 | 
							case "ip":
 | 
				
			||||||
 | 
								key = this.ip(req)
 | 
				
			||||||
 | 
							case "cookie":
 | 
				
			||||||
 | 
								if len(userField) == 0 {
 | 
				
			||||||
 | 
									key = this.ip(req)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									cookie, _ := req.Cookie(userField)
 | 
				
			||||||
 | 
									if cookie != nil {
 | 
				
			||||||
 | 
										v := cookie.Value
 | 
				
			||||||
 | 
										if userIndex > 0 && len(v) > userIndex {
 | 
				
			||||||
 | 
											v = v[userIndex:]
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										key = "USER@" + userType + "@" + userField + "@" + v
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case "get":
 | 
				
			||||||
 | 
								if len(userField) == 0 {
 | 
				
			||||||
 | 
									key = this.ip(req)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									v := req.URL.Query().Get(userField)
 | 
				
			||||||
 | 
									if userIndex > 0 && len(v) > userIndex {
 | 
				
			||||||
 | 
										v = v[userIndex:]
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									key = "USER@" + userType + "@" + userField + "@" + v
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case "post":
 | 
				
			||||||
 | 
								if len(userField) == 0 {
 | 
				
			||||||
 | 
									key = this.ip(req)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									v := req.PostFormValue(userField)
 | 
				
			||||||
 | 
									if userIndex > 0 && len(v) > userIndex {
 | 
				
			||||||
 | 
										v = v[userIndex:]
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									key = "USER@" + userType + "@" + userField + "@" + v
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case "header":
 | 
				
			||||||
 | 
								if len(userField) == 0 {
 | 
				
			||||||
 | 
									key = this.ip(req)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									v := req.Header.Get(userField)
 | 
				
			||||||
 | 
									if userIndex > 0 && len(v) > userIndex {
 | 
				
			||||||
 | 
										v = v[userIndex:]
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									key = "USER@" + userType + "@" + userField + "@" + v
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								key = this.ip(req)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if len(key) == 0 {
 | 
				
			||||||
 | 
								key = this.ip(req)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							value = this.grid.IncreaseInt64([]byte(key), 1, period)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CCCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CCCheckpoint) ParamOptions() *ParamOptions {
 | 
				
			||||||
 | 
						option := NewParamOptions()
 | 
				
			||||||
 | 
						option.AddParam("请求数", "requests")
 | 
				
			||||||
 | 
						return option
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CCCheckpoint) Options() []OptionInterface {
 | 
				
			||||||
 | 
						options := []OptionInterface{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// period
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							option := NewFieldOption("统计周期", "period")
 | 
				
			||||||
 | 
							option.Value = "60"
 | 
				
			||||||
 | 
							option.RightLabel = "秒"
 | 
				
			||||||
 | 
							option.Size = 8
 | 
				
			||||||
 | 
							option.MaxLength = 8
 | 
				
			||||||
 | 
							option.Validate = func(value string) (ok bool, message string) {
 | 
				
			||||||
 | 
								if regexp.MustCompile("^\\d+$").MatchString(value) {
 | 
				
			||||||
 | 
									ok = true
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								message = "周期需要是一个整数数字"
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							options = append(options, option)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// type
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							option := NewOptionsOption("用户识别读取来源", "userType")
 | 
				
			||||||
 | 
							option.Size = 10
 | 
				
			||||||
 | 
							option.SetOptions([]maps.Map{
 | 
				
			||||||
 | 
								{
 | 
				
			||||||
 | 
									"name":  "IP",
 | 
				
			||||||
 | 
									"value": "ip",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								{
 | 
				
			||||||
 | 
									"name":  "Cookie",
 | 
				
			||||||
 | 
									"value": "cookie",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								{
 | 
				
			||||||
 | 
									"name":  "URL参数",
 | 
				
			||||||
 | 
									"value": "get",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								{
 | 
				
			||||||
 | 
									"name":  "POST参数",
 | 
				
			||||||
 | 
									"value": "post",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								{
 | 
				
			||||||
 | 
									"name":  "HTTP Header",
 | 
				
			||||||
 | 
									"value": "header",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							options = append(options, option)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// user field
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							option := NewFieldOption("用户识别字段", "userField")
 | 
				
			||||||
 | 
							option.Comment = "识别用户的唯一性字段,在用户读取来源不是IP时使用"
 | 
				
			||||||
 | 
							options = append(options, option)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// user value index
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							option := NewFieldOption("字段读取位置", "userIndex")
 | 
				
			||||||
 | 
							option.Size = 5
 | 
				
			||||||
 | 
							option.MaxLength = 5
 | 
				
			||||||
 | 
							option.Comment = "读取用户识别字段的位置,从0开始,比如user12345的数字ID 12345的位置就是5,在用户读取来源不是IP时使用"
 | 
				
			||||||
 | 
							options = append(options, option)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return options
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CCCheckpoint) Stop() {
 | 
				
			||||||
 | 
						if this.grid != nil {
 | 
				
			||||||
 | 
							this.grid.Destroy()
 | 
				
			||||||
 | 
							this.grid = nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *CCCheckpoint) ip(req *requests.Request) string {
 | 
				
			||||||
 | 
						// X-Forwarded-For
 | 
				
			||||||
 | 
						forwardedFor := req.Header.Get("X-Forwarded-For")
 | 
				
			||||||
 | 
						if len(forwardedFor) > 0 {
 | 
				
			||||||
 | 
							commaIndex := strings.Index(forwardedFor, ",")
 | 
				
			||||||
 | 
							if commaIndex > 0 {
 | 
				
			||||||
 | 
								return forwardedFor[:commaIndex]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return forwardedFor
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Real-IP
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							realIP, ok := req.Header["X-Real-IP"]
 | 
				
			||||||
 | 
							if ok && len(realIP) > 0 {
 | 
				
			||||||
 | 
								return realIP[0]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Real-Ip
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							realIP, ok := req.Header["X-Real-Ip"]
 | 
				
			||||||
 | 
							if ok && len(realIP) > 0 {
 | 
				
			||||||
 | 
								return realIP[0]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Remote-Addr
 | 
				
			||||||
 | 
						host, _, err := net.SplitHostPort(req.RemoteAddr)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							return host
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return req.RemoteAddr
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										42
									
								
								internal/waf/checkpoints/cc_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								internal/waf/checkpoints/cc_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,42 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCCCheckpoint_RequestValue(t *testing.T) {
 | 
				
			||||||
 | 
						raw, err := http.NewRequest(http.MethodGet, "http://teaos.cn/", nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req := requests.NewRequest(raw)
 | 
				
			||||||
 | 
						req.RemoteAddr = "127.0.0.1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(CCCheckpoint)
 | 
				
			||||||
 | 
						checkpoint.Init()
 | 
				
			||||||
 | 
						checkpoint.Start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						options := map[string]string{
 | 
				
			||||||
 | 
							"period": "5",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req.RemoteAddr = "127.0.0.2"
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req.RemoteAddr = "127.0.0.1"
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req.RemoteAddr = "127.0.0.2"
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req.RemoteAddr = "127.0.0.2"
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req.RemoteAddr = "127.0.0.2"
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "requests", options))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										28
									
								
								internal/waf/checkpoints/checkpoint.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								internal/waf/checkpoints/checkpoint.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,28 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Checkpoint struct {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Checkpoint) Init() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Checkpoint) IsRequest() bool {
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Checkpoint) ParamOptions() *ParamOptions {
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Checkpoint) Options() []OptionInterface {
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Checkpoint) Start() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Checkpoint) Stop() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										10
									
								
								internal/waf/checkpoints/checkpoint_definition.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								internal/waf/checkpoints/checkpoint_definition.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// check point definition
 | 
				
			||||||
 | 
					type CheckpointDefinition struct {
 | 
				
			||||||
 | 
						Name        string
 | 
				
			||||||
 | 
						Description string
 | 
				
			||||||
 | 
						Prefix      string
 | 
				
			||||||
 | 
						HasParams   bool // has sub params
 | 
				
			||||||
 | 
						Instance    CheckpointInterface
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										32
									
								
								internal/waf/checkpoints/checkpoint_interface.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								internal/waf/checkpoints/checkpoint_interface.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,32 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Check Point
 | 
				
			||||||
 | 
					type CheckpointInterface interface {
 | 
				
			||||||
 | 
						// initialize
 | 
				
			||||||
 | 
						Init()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// is request?
 | 
				
			||||||
 | 
						IsRequest() bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// get request value
 | 
				
			||||||
 | 
						RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// get response value
 | 
				
			||||||
 | 
						ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// param option list
 | 
				
			||||||
 | 
						ParamOptions() *ParamOptions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// options
 | 
				
			||||||
 | 
						Options() []OptionInterface
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// start
 | 
				
			||||||
 | 
						Start()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// stop
 | 
				
			||||||
 | 
						Stop()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										5
									
								
								internal/waf/checkpoints/option.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								internal/waf/checkpoints/option.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type OptionInterface interface {
 | 
				
			||||||
 | 
						Type() string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										26
									
								
								internal/waf/checkpoints/option_field.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								internal/waf/checkpoints/option_field.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,26 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// attach option
 | 
				
			||||||
 | 
					type FieldOption struct {
 | 
				
			||||||
 | 
						Name        string
 | 
				
			||||||
 | 
						Code        string
 | 
				
			||||||
 | 
						Value       string // default value
 | 
				
			||||||
 | 
						IsRequired  bool
 | 
				
			||||||
 | 
						Size        int
 | 
				
			||||||
 | 
						Comment     string
 | 
				
			||||||
 | 
						Placeholder string
 | 
				
			||||||
 | 
						RightLabel  string
 | 
				
			||||||
 | 
						MaxLength   int
 | 
				
			||||||
 | 
						Validate    func(value string) (ok bool, message string)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewFieldOption(name string, code string) *FieldOption {
 | 
				
			||||||
 | 
						return &FieldOption{
 | 
				
			||||||
 | 
							Name: name,
 | 
				
			||||||
 | 
							Code: code,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *FieldOption) Type() string {
 | 
				
			||||||
 | 
						return "field"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										30
									
								
								internal/waf/checkpoints/option_options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								internal/waf/checkpoints/option_options.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,30 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "github.com/iwind/TeaGo/maps"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type OptionsOption struct {
 | 
				
			||||||
 | 
						Name       string
 | 
				
			||||||
 | 
						Code       string
 | 
				
			||||||
 | 
						Value      string // default value
 | 
				
			||||||
 | 
						IsRequired bool
 | 
				
			||||||
 | 
						Size       int
 | 
				
			||||||
 | 
						Comment    string
 | 
				
			||||||
 | 
						RightLabel string
 | 
				
			||||||
 | 
						Validate   func(value string) (ok bool, message string)
 | 
				
			||||||
 | 
						Options    []maps.Map
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewOptionsOption(name string, code string) *OptionsOption {
 | 
				
			||||||
 | 
						return &OptionsOption{
 | 
				
			||||||
 | 
							Name: name,
 | 
				
			||||||
 | 
							Code: code,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *OptionsOption) Type() string {
 | 
				
			||||||
 | 
						return "options"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *OptionsOption) SetOptions(options []maps.Map) {
 | 
				
			||||||
 | 
						this.Options = options
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										21
									
								
								internal/waf/checkpoints/param_option.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/checkpoints/param_option.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type KeyValue struct {
 | 
				
			||||||
 | 
						Name  string `json:"name"`
 | 
				
			||||||
 | 
						Value string `json:"value"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ParamOptions struct {
 | 
				
			||||||
 | 
						Options []*KeyValue `json:"options"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewParamOptions() *ParamOptions {
 | 
				
			||||||
 | 
						return &ParamOptions{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *ParamOptions) AddParam(name string, value string) {
 | 
				
			||||||
 | 
						this.Options = append(this.Options, &KeyValue{
 | 
				
			||||||
 | 
							Name:  name,
 | 
				
			||||||
 | 
							Value: value,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										46
									
								
								internal/waf/checkpoints/request_all.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								internal/waf/checkpoints/request_all.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,46 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ${requestAll}
 | 
				
			||||||
 | 
					type RequestAllCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						valueBytes := []byte{}
 | 
				
			||||||
 | 
						if len(req.RequestURI) > 0 {
 | 
				
			||||||
 | 
							valueBytes = append(valueBytes, req.RequestURI...)
 | 
				
			||||||
 | 
						} else if req.URL != nil {
 | 
				
			||||||
 | 
							valueBytes = append(valueBytes, req.URL.RequestURI()...)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if req.Body != nil {
 | 
				
			||||||
 | 
							valueBytes = append(valueBytes, ' ')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if len(req.BodyData) == 0 {
 | 
				
			||||||
 | 
								data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return "", err, nil
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								req.BodyData = data
 | 
				
			||||||
 | 
								req.RestoreBody(data)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							valueBytes = append(valueBytes, req.BodyData...)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						value = valueBytes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestAllCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = ""
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										70
									
								
								internal/waf/checkpoints/request_all_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								internal/waf/checkpoints/request_all_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,70 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"runtime"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRequestAllCheckpoint_RequestValue(t *testing.T) {
 | 
				
			||||||
 | 
						req, err := http.NewRequest(http.MethodPost, "http://teaos.cn/hello/world", bytes.NewBuffer([]byte("123456")))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(RequestAllCheckpoint)
 | 
				
			||||||
 | 
						v, sysErr, userErr := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
 | 
				
			||||||
 | 
						if sysErr != nil {
 | 
				
			||||||
 | 
							t.Fatal(sysErr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if userErr != nil {
 | 
				
			||||||
 | 
							t.Fatal(userErr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(v)
 | 
				
			||||||
 | 
						t.Log(types.String(v))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						body, err := ioutil.ReadAll(req.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(string(body))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRequestAllCheckpoint_RequestValue_Max(t *testing.T) {
 | 
				
			||||||
 | 
						req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(strings.Repeat("123456", 10240000))))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(RequestBodyCheckpoint)
 | 
				
			||||||
 | 
						value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log("value bytes:", len(types.String(value)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						body, err := ioutil.ReadAll(req.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log("raw bytes:", len(body))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func BenchmarkRequestAllCheckpoint_RequestValue(b *testing.B) {
 | 
				
			||||||
 | 
						runtime.GOMAXPROCS(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req, err := http.NewRequest(http.MethodPost, "http://teaos.cn/hello/world", bytes.NewBuffer(bytes.Repeat([]byte("HELLO"), 1024)))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							b.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(RequestAllCheckpoint)
 | 
				
			||||||
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
 | 
							_, _, _ = checkpoint.RequestValue(requests.NewRequest(req), "", nil)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										20
									
								
								internal/waf/checkpoints/request_arg.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								internal/waf/checkpoints/request_arg.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestArgCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestArgCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						return req.URL.Query().Get(param), nil, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										21
									
								
								internal/waf/checkpoints/request_arg_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/checkpoints/request_arg_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestArgParam_RequestValue(t *testing.T) {
 | 
				
			||||||
 | 
						rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/?name=lu", nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req := requests.NewRequest(rawReq)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(RequestArgCheckpoint)
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "name", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.ResponseValue(req, nil, "name", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "name2", nil))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										21
									
								
								internal/waf/checkpoints/request_args.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/checkpoints/request_args.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestArgsCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestArgsCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = req.URL.RawQuery
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestArgsCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										36
									
								
								internal/waf/checkpoints/request_body.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								internal/waf/checkpoints/request_body.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,36 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ${requestBody}
 | 
				
			||||||
 | 
					type RequestBodyCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestBodyCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if req.Body == nil {
 | 
				
			||||||
 | 
							value = ""
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(req.BodyData) == 0 {
 | 
				
			||||||
 | 
							data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return "", err, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							req.BodyData = data
 | 
				
			||||||
 | 
							req.RestoreBody(data)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return req.BodyData, nil, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										47
									
								
								internal/waf/checkpoints/request_body_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								internal/waf/checkpoints/request_body_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,47 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRequestBodyCheckpoint_RequestValue(t *testing.T) {
 | 
				
			||||||
 | 
						req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456")))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(RequestBodyCheckpoint)
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(requests.NewRequest(req), "", nil))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						body, err := ioutil.ReadAll(req.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(string(body))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
 | 
				
			||||||
 | 
						req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(strings.Repeat("123456", 10240000))))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(RequestBodyCheckpoint)
 | 
				
			||||||
 | 
						value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log("value bytes:", len(types.String(value)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						body, err := ioutil.ReadAll(req.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log("raw bytes:", len(body))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										21
									
								
								internal/waf/checkpoints/request_content_type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/checkpoints/request_content_type.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestContentTypeCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestContentTypeCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = req.Header.Get("Content-Type")
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestContentTypeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										27
									
								
								internal/waf/checkpoints/request_cookie.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								internal/waf/checkpoints/request_cookie.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestCookieCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						cookie, err := req.Cookie(param)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							value = ""
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						value = cookie.Value
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestCookieCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										27
									
								
								internal/waf/checkpoints/request_cookies.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								internal/waf/checkpoints/request_cookies.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestCookiesCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestCookiesCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						var cookies = []string{}
 | 
				
			||||||
 | 
						for _, cookie := range req.Cookies() {
 | 
				
			||||||
 | 
							cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						value = strings.Join(cookies, "&")
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestCookiesCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										39
									
								
								internal/waf/checkpoints/request_form_arg.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								internal/waf/checkpoints/request_form_arg.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,39 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ${requestForm.arg}
 | 
				
			||||||
 | 
					type RequestFormArgCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestFormArgCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if req.Body == nil {
 | 
				
			||||||
 | 
							value = ""
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(req.BodyData) == 0 {
 | 
				
			||||||
 | 
							data, err := req.ReadBody(32 * 1024 * 1024) // read 32m bytes
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return "", err, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							req.BodyData = data
 | 
				
			||||||
 | 
							req.RestoreBody(data)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO improve performance
 | 
				
			||||||
 | 
						values, _ := url.ParseQuery(string(req.BodyData))
 | 
				
			||||||
 | 
						return values.Get(param), nil, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestFormArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										32
									
								
								internal/waf/checkpoints/request_form_arg_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								internal/waf/checkpoints/request_form_arg_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,32 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
 | 
				
			||||||
 | 
						rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("name=lu&age=20&encoded="+url.QueryEscape("<strong>ENCODED STRING</strong>"))))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req := requests.NewRequest(rawReq)
 | 
				
			||||||
 | 
						req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(RequestFormArgCheckpoint)
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "name", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "age", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "Hello", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "encoded", nil))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						body, err := ioutil.ReadAll(req.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(string(body))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										27
									
								
								internal/waf/checkpoints/request_header.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								internal/waf/checkpoints/request_header.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestHeaderCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						v, found := req.Header[param]
 | 
				
			||||||
 | 
						if !found {
 | 
				
			||||||
 | 
							value = ""
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						value = strings.Join(v, ";")
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										30
									
								
								internal/waf/checkpoints/request_headers.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								internal/waf/checkpoints/request_headers.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,30 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"sort"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestHeadersCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						var headers = []string{}
 | 
				
			||||||
 | 
						for k, v := range req.Header {
 | 
				
			||||||
 | 
							for _, subV := range v {
 | 
				
			||||||
 | 
								headers = append(headers, k+": "+subV)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						sort.Strings(headers)
 | 
				
			||||||
 | 
						value = strings.Join(headers, "\n")
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestHeadersCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										21
									
								
								internal/waf/checkpoints/request_host.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/checkpoints/request_host.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestHostCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestHostCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = req.Host
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestHostCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										20
									
								
								internal/waf/checkpoints/request_host_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								internal/waf/checkpoints/request_host_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRequestHostCheckpoint_RequestValue(t *testing.T) {
 | 
				
			||||||
 | 
						rawReq, err := http.NewRequest(http.MethodGet, "https://teaos.cn/?name=lu", nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req := requests.NewRequest(rawReq)
 | 
				
			||||||
 | 
						req.Header.Set("Host", "cloud.teaos.cn")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(RequestHostCheckpoint)
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "", nil))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										44
									
								
								internal/waf/checkpoints/request_json_arg.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								internal/waf/checkpoints/request_json_arg.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ${requestJSON.arg}
 | 
				
			||||||
 | 
					type RequestJSONArgCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if len(req.BodyData) == 0 {
 | 
				
			||||||
 | 
							data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return "", err, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							req.BodyData = data
 | 
				
			||||||
 | 
							defer req.RestoreBody(data)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO improve performance
 | 
				
			||||||
 | 
						var m interface{} = nil
 | 
				
			||||||
 | 
						err := json.Unmarshal(req.BodyData, &m)
 | 
				
			||||||
 | 
						if err != nil || m == nil {
 | 
				
			||||||
 | 
							return "", nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						value = utils.Get(m, strings.Split(param, "."))
 | 
				
			||||||
 | 
						if value != nil {
 | 
				
			||||||
 | 
							return value, nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return "", nil, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestJSONArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										99
									
								
								internal/waf/checkpoints/request_json_arg_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								internal/waf/checkpoints/request_json_arg_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,99 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) {
 | 
				
			||||||
 | 
						rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(`
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
						"name": "lu",
 | 
				
			||||||
 | 
						"age": 20,
 | 
				
			||||||
 | 
						"books": [ "PHP", "Golang", "Python" ]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					`)))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req := requests.NewRequest(rawReq)
 | 
				
			||||||
 | 
						//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(RequestJSONArgCheckpoint)
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "name", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "age", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "Hello", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "books", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "books.1", nil))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						body, err := ioutil.ReadAll(req.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(string(body))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) {
 | 
				
			||||||
 | 
						rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(`
 | 
				
			||||||
 | 
					[{
 | 
				
			||||||
 | 
						"name": "lu",
 | 
				
			||||||
 | 
						"age": 20,
 | 
				
			||||||
 | 
						"books": [ "PHP", "Golang", "Python" ]
 | 
				
			||||||
 | 
					}]
 | 
				
			||||||
 | 
					`)))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req := requests.NewRequest(rawReq)
 | 
				
			||||||
 | 
						//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(RequestJSONArgCheckpoint)
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "0.name", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "0.age", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "0.Hello", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "0.books", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						body, err := ioutil.ReadAll(req.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(string(body))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) {
 | 
				
			||||||
 | 
						rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(`
 | 
				
			||||||
 | 
					[{
 | 
				
			||||||
 | 
						"name": "lu",
 | 
				
			||||||
 | 
						"age": 20,
 | 
				
			||||||
 | 
						"books": [ "PHP", "Golang", "Python" ]
 | 
				
			||||||
 | 
					}]
 | 
				
			||||||
 | 
					`)))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req := requests.NewRequest(rawReq)
 | 
				
			||||||
 | 
						//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(RequestJSONArgCheckpoint)
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "0.name", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "0.age", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "0.Hello", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "0.books", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						body, err := ioutil.ReadAll(req.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(string(body))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										21
									
								
								internal/waf/checkpoints/request_length.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/checkpoints/request_length.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestLengthCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestLengthCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = req.ContentLength
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										21
									
								
								internal/waf/checkpoints/request_method.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/checkpoints/request_method.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestMethodCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestMethodCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = req.Method
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestMethodCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										20
									
								
								internal/waf/checkpoints/request_path.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								internal/waf/checkpoints/request_path.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestPathCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestPathCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						return req.URL.Path, nil, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestPathCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										18
									
								
								internal/waf/checkpoints/request_path_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								internal/waf/checkpoints/request_path_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,18 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRequestPathCheckpoint_RequestValue(t *testing.T) {
 | 
				
			||||||
 | 
						rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/index?name=lu", nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req := requests.NewRequest(rawReq)
 | 
				
			||||||
 | 
						checkpoint := new(RequestPathCheckpoint)
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "", nil))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										21
									
								
								internal/waf/checkpoints/request_proto.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/checkpoints/request_proto.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestProtoCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestProtoCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = req.Proto
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestProtoCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										27
									
								
								internal/waf/checkpoints/request_raw_remote_addr.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								internal/waf/checkpoints/request_raw_remote_addr.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestRawRemoteAddrCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						host, _, err := net.SplitHostPort(req.RemoteAddr)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							value = host
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							value = req.RemoteAddr
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										21
									
								
								internal/waf/checkpoints/request_referer.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/checkpoints/request_referer.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestRefererCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestRefererCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = req.Referer()
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestRefererCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										59
									
								
								internal/waf/checkpoints/request_remote_addr.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								internal/waf/checkpoints/request_remote_addr.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,59 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestRemoteAddrCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						// X-Forwarded-For
 | 
				
			||||||
 | 
						forwardedFor := req.Header.Get("X-Forwarded-For")
 | 
				
			||||||
 | 
						if len(forwardedFor) > 0 {
 | 
				
			||||||
 | 
							commaIndex := strings.Index(forwardedFor, ",")
 | 
				
			||||||
 | 
							if commaIndex > 0 {
 | 
				
			||||||
 | 
								value = forwardedFor[:commaIndex]
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							value = forwardedFor
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Real-IP
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							realIP, ok := req.Header["X-Real-IP"]
 | 
				
			||||||
 | 
							if ok && len(realIP) > 0 {
 | 
				
			||||||
 | 
								value = realIP[0]
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Real-Ip
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							realIP, ok := req.Header["X-Real-Ip"]
 | 
				
			||||||
 | 
							if ok && len(realIP) > 0 {
 | 
				
			||||||
 | 
								value = realIP[0]
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Remote-Addr
 | 
				
			||||||
 | 
						host, _, err := net.SplitHostPort(req.RemoteAddr)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							value = host
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							value = req.RemoteAddr
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										28
									
								
								internal/waf/checkpoints/request_remote_port.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								internal/waf/checkpoints/request_remote_port.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,28 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestRemotePortCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						_, port, err := net.SplitHostPort(req.RemoteAddr)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							value = types.Int(port)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							value = 0
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestRemotePortCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										26
									
								
								internal/waf/checkpoints/request_remote_user.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								internal/waf/checkpoints/request_remote_user.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,26 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestRemoteUserCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						username, _, ok := req.BasicAuth()
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							value = ""
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						value = username
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestRemoteUserCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										21
									
								
								internal/waf/checkpoints/request_scheme.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/checkpoints/request_scheme.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestSchemeCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestSchemeCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = req.URL.Scheme
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestSchemeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										18
									
								
								internal/waf/checkpoints/request_scheme_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								internal/waf/checkpoints/request_scheme_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,18 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRequestSchemeCheckpoint_RequestValue(t *testing.T) {
 | 
				
			||||||
 | 
						rawReq, err := http.NewRequest(http.MethodGet, "https://teaos.cn/?name=lu", nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req := requests.NewRequest(rawReq)
 | 
				
			||||||
 | 
						checkpoint := new(RequestSchemeCheckpoint)
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "", nil))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										130
									
								
								internal/waf/checkpoints/request_upload.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								internal/waf/checkpoints/request_upload.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,130 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/lists"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"path/filepath"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ${requestUpload.arg}
 | 
				
			||||||
 | 
					type RequestUploadCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = ""
 | 
				
			||||||
 | 
						if param == "minSize" || param == "maxSize" {
 | 
				
			||||||
 | 
							value = 0
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if req.Method != http.MethodPost {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if req.Body == nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if req.MultipartForm == nil {
 | 
				
			||||||
 | 
							if len(req.BodyData) == 0 {
 | 
				
			||||||
 | 
								data, err := req.ReadBody(32 * 1024 * 1024)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									sysErr = err
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								req.BodyData = data
 | 
				
			||||||
 | 
								defer req.RestoreBody(data)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							oldBody := req.Body
 | 
				
			||||||
 | 
							req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyData))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							err := req.ParseMultipartForm(32 * 1024 * 1024)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// 还原
 | 
				
			||||||
 | 
							req.Body = oldBody
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								userErr = err
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if req.MultipartForm == nil {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if param == "field" { // field
 | 
				
			||||||
 | 
							fields := []string{}
 | 
				
			||||||
 | 
							for field := range req.MultipartForm.File {
 | 
				
			||||||
 | 
								fields = append(fields, field)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							value = strings.Join(fields, ",")
 | 
				
			||||||
 | 
						} else if param == "minSize" { // minSize
 | 
				
			||||||
 | 
							minSize := int64(0)
 | 
				
			||||||
 | 
							for _, files := range req.MultipartForm.File {
 | 
				
			||||||
 | 
								for _, file := range files {
 | 
				
			||||||
 | 
									if minSize == 0 || minSize > file.Size {
 | 
				
			||||||
 | 
										minSize = file.Size
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							value = minSize
 | 
				
			||||||
 | 
						} else if param == "maxSize" { // maxSize
 | 
				
			||||||
 | 
							maxSize := int64(0)
 | 
				
			||||||
 | 
							for _, files := range req.MultipartForm.File {
 | 
				
			||||||
 | 
								for _, file := range files {
 | 
				
			||||||
 | 
									if maxSize < file.Size {
 | 
				
			||||||
 | 
										maxSize = file.Size
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							value = maxSize
 | 
				
			||||||
 | 
						} else if param == "name" { // name
 | 
				
			||||||
 | 
							names := []string{}
 | 
				
			||||||
 | 
							for _, files := range req.MultipartForm.File {
 | 
				
			||||||
 | 
								for _, file := range files {
 | 
				
			||||||
 | 
									if !lists.ContainsString(names, file.Filename) {
 | 
				
			||||||
 | 
										names = append(names, file.Filename)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							value = strings.Join(names, ",")
 | 
				
			||||||
 | 
						} else if param == "ext" { // ext
 | 
				
			||||||
 | 
							extensions := []string{}
 | 
				
			||||||
 | 
							for _, files := range req.MultipartForm.File {
 | 
				
			||||||
 | 
								for _, file := range files {
 | 
				
			||||||
 | 
									if len(file.Filename) > 0 {
 | 
				
			||||||
 | 
										exit := strings.ToLower(filepath.Ext(file.Filename))
 | 
				
			||||||
 | 
										if !lists.ContainsString(extensions, exit) {
 | 
				
			||||||
 | 
											extensions = append(extensions, exit)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							value = strings.Join(extensions, ",")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestUploadCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestUploadCheckpoint) ParamOptions() *ParamOptions {
 | 
				
			||||||
 | 
						option := NewParamOptions()
 | 
				
			||||||
 | 
						option.AddParam("最小文件尺寸", "minSize")
 | 
				
			||||||
 | 
						option.AddParam("最大文件尺寸", "maxSize")
 | 
				
			||||||
 | 
						option.AddParam("扩展名(如.txt)", "ext")
 | 
				
			||||||
 | 
						option.AddParam("原始文件名", "name")
 | 
				
			||||||
 | 
						option.AddParam("表单字段名", "field")
 | 
				
			||||||
 | 
						return option
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										81
									
								
								internal/waf/checkpoints/request_upload_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								internal/waf/checkpoints/request_upload_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,81 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"mime/multipart"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRequestUploadCheckpoint_RequestValue(t *testing.T) {
 | 
				
			||||||
 | 
						body := bytes.NewBuffer([]byte{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						writer := multipart.NewWriter(body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							part, err := writer.CreateFormField("name")
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								part.Write([]byte("lu"))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							part, err := writer.CreateFormField("age")
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								part.Write([]byte("20"))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							part, err := writer.CreateFormFile("myFile", "hello.txt")
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								part.Write([]byte("Hello, World!"))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							part, err := writer.CreateFormFile("myFile2", "hello.PHP")
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								part.Write([]byte("Hello, World, PHP!"))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							part, err := writer.CreateFormFile("myFile3", "hello.asp")
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								part.Write([]byte("Hello, World, ASP Pages!"))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							part, err := writer.CreateFormFile("myFile4", "hello.asp")
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								part.Write([]byte("Hello, World, ASP Pages!"))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						writer.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn/", body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req := requests.NewRequest(rawReq)
 | 
				
			||||||
 | 
						req.Header.Add("Content-Type", writer.FormDataContentType())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(RequestUploadCheckpoint)
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "field", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "minSize", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "maxSize", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "name", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.RequestValue(req, "ext", nil))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						data, err := ioutil.ReadAll(req.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log(string(data))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										25
									
								
								internal/waf/checkpoints/request_uri.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								internal/waf/checkpoints/request_uri.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestURICheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestURICheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if len(req.RequestURI) > 0 {
 | 
				
			||||||
 | 
							value = req.RequestURI
 | 
				
			||||||
 | 
						} else if req.URL != nil {
 | 
				
			||||||
 | 
							value = req.URL.RequestURI()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestURICheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										21
									
								
								internal/waf/checkpoints/request_user_agent.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/checkpoints/request_user_agent.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type RequestUserAgentCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestUserAgentCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = req.UserAgent()
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RequestUserAgentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										41
									
								
								internal/waf/checkpoints/response_body.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								internal/waf/checkpoints/response_body.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,41 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ${responseBody}
 | 
				
			||||||
 | 
					type ResponseBodyCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *ResponseBodyCheckpoint) IsRequest() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *ResponseBodyCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = ""
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *ResponseBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = ""
 | 
				
			||||||
 | 
						if resp != nil && resp.Body != nil {
 | 
				
			||||||
 | 
							if len(resp.BodyData) > 0 {
 | 
				
			||||||
 | 
								value = string(resp.BodyData)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							body, err := ioutil.ReadAll(resp.Body)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								sysErr = err
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							resp.BodyData = body
 | 
				
			||||||
 | 
							_ = resp.Body.Close()
 | 
				
			||||||
 | 
							value = body
 | 
				
			||||||
 | 
							resp.Body = ioutil.NopCloser(bytes.NewBuffer(body))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										29
									
								
								internal/waf/checkpoints/response_body_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								internal/waf/checkpoints/response_body_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestResponseBodyCheckpoint_ResponseValue(t *testing.T) {
 | 
				
			||||||
 | 
						resp := requests.NewResponse(new(http.Response))
 | 
				
			||||||
 | 
						resp.StatusCode = 200
 | 
				
			||||||
 | 
						resp.Header = http.Header{}
 | 
				
			||||||
 | 
						resp.Header.Set("Hello", "World")
 | 
				
			||||||
 | 
						resp.Body = ioutil.NopCloser(bytes.NewBuffer([]byte("Hello, World")))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(ResponseBodyCheckpoint)
 | 
				
			||||||
 | 
						t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
 | 
				
			||||||
 | 
						t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						data, err := ioutil.ReadAll(resp.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						t.Log("after read:", string(data))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										27
									
								
								internal/waf/checkpoints/response_bytes_sent.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								internal/waf/checkpoints/response_bytes_sent.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ${bytesSent}
 | 
				
			||||||
 | 
					type ResponseBytesSentCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *ResponseBytesSentCheckpoint) IsRequest() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *ResponseBytesSentCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = 0
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *ResponseBytesSentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = 0
 | 
				
			||||||
 | 
						if resp != nil {
 | 
				
			||||||
 | 
							value = resp.ContentLength
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										28
									
								
								internal/waf/checkpoints/response_header.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								internal/waf/checkpoints/response_header.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,28 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ${responseHeader.arg}
 | 
				
			||||||
 | 
					type ResponseHeaderCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *ResponseHeaderCheckpoint) IsRequest() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *ResponseHeaderCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = ""
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *ResponseHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if resp != nil && resp.Header != nil {
 | 
				
			||||||
 | 
							value = resp.Header.Get(param)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							value = ""
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										17
									
								
								internal/waf/checkpoints/response_header_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								internal/waf/checkpoints/response_header_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestResponseHeaderCheckpoint_ResponseValue(t *testing.T) {
 | 
				
			||||||
 | 
						resp := requests.NewResponse(new(http.Response))
 | 
				
			||||||
 | 
						resp.StatusCode = 200
 | 
				
			||||||
 | 
						resp.Header = http.Header{}
 | 
				
			||||||
 | 
						resp.Header.Set("Hello", "World")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(ResponseHeaderCheckpoint)
 | 
				
			||||||
 | 
						t.Log(checkpoint.ResponseValue(nil, resp, "Hello", nil))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										26
									
								
								internal/waf/checkpoints/response_status.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								internal/waf/checkpoints/response_status.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,26 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ${bytesSent}
 | 
				
			||||||
 | 
					type ResponseStatusCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *ResponseStatusCheckpoint) IsRequest() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *ResponseStatusCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						value = 0
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *ResponseStatusCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if resp != nil {
 | 
				
			||||||
 | 
							value = resp.StatusCode
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										15
									
								
								internal/waf/checkpoints/response_status_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								internal/waf/checkpoints/response_status_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,15 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestResponseStatusCheckpoint_ResponseValue(t *testing.T) {
 | 
				
			||||||
 | 
						resp := requests.NewResponse(new(http.Response))
 | 
				
			||||||
 | 
						resp.StatusCode = 200
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpoint := new(ResponseStatusCheckpoint)
 | 
				
			||||||
 | 
						t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										21
									
								
								internal/waf/checkpoints/sample_request.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								internal/waf/checkpoints/sample_request.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// just a sample checkpoint, copy and change it for your new checkpoint
 | 
				
			||||||
 | 
					type SampleRequestCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *SampleRequestCheckpoint) RequestValue(req *requests.Request, param string, options map[string]string) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *SampleRequestCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]string) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						if this.IsRequest() {
 | 
				
			||||||
 | 
							return this.RequestValue(req, param, options)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										22
									
								
								internal/waf/checkpoints/sample_response.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								internal/waf/checkpoints/sample_response.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// just a sample checkpoint, copy and change it for your new checkpoint
 | 
				
			||||||
 | 
					type SampleResponseCheckpoint struct {
 | 
				
			||||||
 | 
						Checkpoint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *SampleResponseCheckpoint) IsRequest() bool {
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *SampleResponseCheckpoint) RequestValue(req *requests.Request, param string, options map[string]string) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *SampleResponseCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]string) (value interface{}, sysErr error, userErr error) {
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										235
									
								
								internal/waf/checkpoints/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										235
									
								
								internal/waf/checkpoints/utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,235 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// all check points list
 | 
				
			||||||
 | 
					var AllCheckpoints = []*CheckpointDefinition{
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "客户端地址(IP)",
 | 
				
			||||||
 | 
							Prefix:      "remoteAddr",
 | 
				
			||||||
 | 
							Description: "试图通过分析X-Forwarded-For等Header获取的客户端地址,比如192.168.1.100",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestRemoteAddrCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "客户端源地址(IP)",
 | 
				
			||||||
 | 
							Prefix:      "rawRemoteAddr",
 | 
				
			||||||
 | 
							Description: "直接连接的客户端地址,比如192.168.1.100",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestRawRemoteAddrCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "客户端端口",
 | 
				
			||||||
 | 
							Prefix:      "remotePort",
 | 
				
			||||||
 | 
							Description: "直接连接的客户端地址端口",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestRemotePortCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "客户端用户名",
 | 
				
			||||||
 | 
							Prefix:      "remoteUser",
 | 
				
			||||||
 | 
							Description: "通过BasicAuth登录的客户端用户名",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestRemoteUserCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "请求URI",
 | 
				
			||||||
 | 
							Prefix:      "requestURI",
 | 
				
			||||||
 | 
							Description: "包含URL参数的请求URI,比如/hello/world?lang=go",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestURICheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "请求路径",
 | 
				
			||||||
 | 
							Prefix:      "requestPath",
 | 
				
			||||||
 | 
							Description: "不包含URL参数的请求路径,比如/hello/world",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestPathCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "请求内容长度",
 | 
				
			||||||
 | 
							Prefix:      "requestLength",
 | 
				
			||||||
 | 
							Description: "请求Header中的Content-Length",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestLengthCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "请求体内容",
 | 
				
			||||||
 | 
							Prefix:      "requestBody",
 | 
				
			||||||
 | 
							Description: "通常在POST或者PUT等操作时会附带请求体,最大限制32M",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestBodyCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "请求URI和请求体组合",
 | 
				
			||||||
 | 
							Prefix:      "requestAll",
 | 
				
			||||||
 | 
							Description: "${requestURI}和${requestBody}组合",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestAllCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "请求表单参数",
 | 
				
			||||||
 | 
							Prefix:      "requestForm",
 | 
				
			||||||
 | 
							Description: "获取POST或者其他方法发送的表单参数,最大请求体限制32M",
 | 
				
			||||||
 | 
							HasParams:   true,
 | 
				
			||||||
 | 
							Instance:    new(RequestFormArgCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "上传文件",
 | 
				
			||||||
 | 
							Prefix:      "requestUpload",
 | 
				
			||||||
 | 
							Description: "获取POST上传的文件信息,最大请求体限制32M",
 | 
				
			||||||
 | 
							HasParams:   true,
 | 
				
			||||||
 | 
							Instance:    new(RequestUploadCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "请求JSON参数",
 | 
				
			||||||
 | 
							Prefix:      "requestJSON",
 | 
				
			||||||
 | 
							Description: "获取POST或者其他方法发送的JSON,最大请求体限制32M,使用点(.)符号表示多级数据",
 | 
				
			||||||
 | 
							HasParams:   true,
 | 
				
			||||||
 | 
							Instance:    new(RequestJSONArgCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "请求方法",
 | 
				
			||||||
 | 
							Prefix:      "requestMethod",
 | 
				
			||||||
 | 
							Description: "比如GET、POST",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestMethodCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "请求协议",
 | 
				
			||||||
 | 
							Prefix:      "scheme",
 | 
				
			||||||
 | 
							Description: "比如http或https",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestSchemeCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "HTTP协议版本",
 | 
				
			||||||
 | 
							Prefix:      "proto",
 | 
				
			||||||
 | 
							Description: "比如HTTP/1.1",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestProtoCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "主机名",
 | 
				
			||||||
 | 
							Prefix:      "host",
 | 
				
			||||||
 | 
							Description: "比如teaos.cn",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestHostCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "请求来源URL",
 | 
				
			||||||
 | 
							Prefix:      "referer",
 | 
				
			||||||
 | 
							Description: "请求Header中的Referer值",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestRefererCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "客户端信息",
 | 
				
			||||||
 | 
							Prefix:      "userAgent",
 | 
				
			||||||
 | 
							Description: "比如Mozilla/5.0 AppleWebKit/537.36 (KHTML, like Gecko) Chrome/73.0.3683.103",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestUserAgentCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "内容类型",
 | 
				
			||||||
 | 
							Prefix:      "contentType",
 | 
				
			||||||
 | 
							Description: "请求Header的Content-Type",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestContentTypeCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "所有cookie组合字符串",
 | 
				
			||||||
 | 
							Prefix:      "cookies",
 | 
				
			||||||
 | 
							Description: "比如sid=IxZVPFhE&city=beijing&uid=18237",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestCookiesCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "单个cookie值",
 | 
				
			||||||
 | 
							Prefix:      "cookie",
 | 
				
			||||||
 | 
							Description: "单个cookie值",
 | 
				
			||||||
 | 
							HasParams:   true,
 | 
				
			||||||
 | 
							Instance:    new(RequestCookieCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "所有URL参数组合",
 | 
				
			||||||
 | 
							Prefix:      "args",
 | 
				
			||||||
 | 
							Description: "比如name=lu&age=20",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestArgsCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "单个URL参数值",
 | 
				
			||||||
 | 
							Prefix:      "arg",
 | 
				
			||||||
 | 
							Description: "单个URL参数值",
 | 
				
			||||||
 | 
							HasParams:   true,
 | 
				
			||||||
 | 
							Instance:    new(RequestArgCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "所有Header信息",
 | 
				
			||||||
 | 
							Prefix:      "headers",
 | 
				
			||||||
 | 
							Description: "使用\n隔开的Header信息字符串",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(RequestHeadersCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "单个Header值",
 | 
				
			||||||
 | 
							Prefix:      "header",
 | 
				
			||||||
 | 
							Description: "单个Header值",
 | 
				
			||||||
 | 
							HasParams:   true,
 | 
				
			||||||
 | 
							Instance:    new(RequestHeaderCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "CC统计",
 | 
				
			||||||
 | 
							Prefix:      "cc",
 | 
				
			||||||
 | 
							Description: "统计某段时间段内的请求信息",
 | 
				
			||||||
 | 
							HasParams:   true,
 | 
				
			||||||
 | 
							Instance:    new(CCCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "响应状态码",
 | 
				
			||||||
 | 
							Prefix:      "status",
 | 
				
			||||||
 | 
							Description: "响应状态码,比如200、404、500",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(ResponseStatusCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "响应Header",
 | 
				
			||||||
 | 
							Prefix:      "responseHeader",
 | 
				
			||||||
 | 
							Description: "响应Header值",
 | 
				
			||||||
 | 
							HasParams:   true,
 | 
				
			||||||
 | 
							Instance:    new(ResponseHeaderCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "响应内容",
 | 
				
			||||||
 | 
							Prefix:      "responseBody",
 | 
				
			||||||
 | 
							Description: "响应内容字符串",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(ResponseBodyCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Name:        "响应内容长度",
 | 
				
			||||||
 | 
							Prefix:      "bytesSent",
 | 
				
			||||||
 | 
							Description: "响应内容长度,通过响应的Header Content-Length获取",
 | 
				
			||||||
 | 
							HasParams:   false,
 | 
				
			||||||
 | 
							Instance:    new(ResponseBytesSentCheckpoint),
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// find a check point
 | 
				
			||||||
 | 
					func FindCheckpoint(prefix string) CheckpointInterface {
 | 
				
			||||||
 | 
						for _, def := range AllCheckpoints {
 | 
				
			||||||
 | 
							if def.Prefix == prefix {
 | 
				
			||||||
 | 
								return def.Instance
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// find a check point definition
 | 
				
			||||||
 | 
					func FindCheckpointDefinition(prefix string) *CheckpointDefinition {
 | 
				
			||||||
 | 
						for _, def := range AllCheckpoints {
 | 
				
			||||||
 | 
							if def.Prefix == prefix {
 | 
				
			||||||
 | 
								return def
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										31
									
								
								internal/waf/checkpoints/utils_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								internal/waf/checkpoints/utils_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,31 @@
 | 
				
			|||||||
 | 
					package checkpoints
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestFindCheckpointDefinition_Markdown(t *testing.T) {
 | 
				
			||||||
 | 
						result := []string{}
 | 
				
			||||||
 | 
						for _, def := range AllCheckpoints {
 | 
				
			||||||
 | 
							row := "## " + def.Name + "\n* 前缀:`${" + def.Prefix + "}`\n* 描述:" + def.Description
 | 
				
			||||||
 | 
							if def.HasParams {
 | 
				
			||||||
 | 
								row += "\n* 是否有子参数:YES"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								paramOptions := def.Instance.ParamOptions()
 | 
				
			||||||
 | 
								if paramOptions != nil && len(paramOptions.Options) > 0 {
 | 
				
			||||||
 | 
									row += "\n* 可选子参数"
 | 
				
			||||||
 | 
									for _, option := range paramOptions.Options {
 | 
				
			||||||
 | 
										row += "\n   * `" + option.Name + "`:值为 `" + option.Value + "`"
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								row += "\n* 是否有子参数:NO"
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							row += "\n"
 | 
				
			||||||
 | 
							result = append(result, row)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fmt.Print(strings.Join(result, "\n") + "\n")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										154
									
								
								internal/waf/ip_table.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								internal/waf/ip_table.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,154 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/lists"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
 | 
						stringutil "github.com/iwind/TeaGo/utils/string"
 | 
				
			||||||
 | 
						"regexp"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type IPAction = string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var RegexpDigitNumber = regexp.MustCompile("^\\d+$")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						IPActionAccept IPAction = "accept"
 | 
				
			||||||
 | 
						IPActionReject IPAction = "reject"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ip table
 | 
				
			||||||
 | 
					type IPTable struct {
 | 
				
			||||||
 | 
						Id       string   `yaml:"id" json:"id"`
 | 
				
			||||||
 | 
						On       bool     `yaml:"on" json:"on"`
 | 
				
			||||||
 | 
						IP       string   `yaml:"ip" json:"ip"`             // single ip, cidr, ip range, TODO support *
 | 
				
			||||||
 | 
						Port     string   `yaml:"port" json:"port"`         // single port, range, *
 | 
				
			||||||
 | 
						Action   IPAction `yaml:"action" json:"action"`     // accept, reject
 | 
				
			||||||
 | 
						TimeFrom int64    `yaml:"timeFrom" json:"timeFrom"` // from timestamp
 | 
				
			||||||
 | 
						TimeTo   int64    `yaml:"timeTo" json:"timeTo"`     // zero means forever
 | 
				
			||||||
 | 
						Remark   string   `yaml:"remark" json:"remark"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// port
 | 
				
			||||||
 | 
						minPort int
 | 
				
			||||||
 | 
						maxPort int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						minPortWildcard bool
 | 
				
			||||||
 | 
						maxPortWildcard bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ports []int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// ip
 | 
				
			||||||
 | 
						ipRange *shared.IPRangeConfig
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewIPTable() *IPTable {
 | 
				
			||||||
 | 
						return &IPTable{
 | 
				
			||||||
 | 
							On: true,
 | 
				
			||||||
 | 
							Id: stringutil.Rand(16),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *IPTable) Init() error {
 | 
				
			||||||
 | 
						// parse port
 | 
				
			||||||
 | 
						if RegexpDigitNumber.MatchString(this.Port) {
 | 
				
			||||||
 | 
							this.minPort = types.Int(this.Port)
 | 
				
			||||||
 | 
							this.maxPort = types.Int(this.Port)
 | 
				
			||||||
 | 
						} else if regexp.MustCompile(`[:-]`).MatchString(this.Port) {
 | 
				
			||||||
 | 
							pieces := regexp.MustCompile(`[:-]`).Split(this.Port, 2)
 | 
				
			||||||
 | 
							if pieces[0] == "*" {
 | 
				
			||||||
 | 
								this.minPortWildcard = true
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								this.minPort = types.Int(pieces[0])
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if pieces[1] == "*" {
 | 
				
			||||||
 | 
								this.maxPortWildcard = true
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								this.maxPort = types.Int(pieces[1])
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else if strings.Contains(this.Port, ",") {
 | 
				
			||||||
 | 
							pieces := strings.Split(this.Port, ",")
 | 
				
			||||||
 | 
							for _, piece := range pieces {
 | 
				
			||||||
 | 
								piece = strings.TrimSpace(piece)
 | 
				
			||||||
 | 
								if len(piece) > 0 {
 | 
				
			||||||
 | 
									this.ports = append(this.ports, types.Int(piece))
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else if this.Port == "*" {
 | 
				
			||||||
 | 
							this.minPortWildcard = true
 | 
				
			||||||
 | 
							this.maxPortWildcard = true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// parse ip
 | 
				
			||||||
 | 
						if len(this.IP) > 0 {
 | 
				
			||||||
 | 
							ipRange, err := shared.ParseIPRange(this.IP)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							this.ipRange = ipRange
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// check ip
 | 
				
			||||||
 | 
					func (this *IPTable) Match(ip string, port int) (isMatched bool) {
 | 
				
			||||||
 | 
						if !this.On {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						now := time.Now().Unix()
 | 
				
			||||||
 | 
						if this.TimeFrom > 0 && now < this.TimeFrom {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if this.TimeTo > 0 && now > this.TimeTo {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !this.matchPort(port) {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !this.matchIP(ip) {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *IPTable) matchPort(port int) bool {
 | 
				
			||||||
 | 
						if port == 0 {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if this.minPortWildcard {
 | 
				
			||||||
 | 
							if this.maxPortWildcard {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if this.maxPort >= port {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if this.maxPortWildcard {
 | 
				
			||||||
 | 
							if this.minPortWildcard {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if this.minPort <= port {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if (this.minPort > 0 || this.maxPort > 0) && this.minPort <= port && this.maxPort >= port {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(this.ports) > 0 {
 | 
				
			||||||
 | 
							return lists.ContainsInt(this.ports, port)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *IPTable) matchIP(ip string) bool {
 | 
				
			||||||
 | 
						if this.ipRange == nil {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return this.ipRange.Contains(ip)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										142
									
								
								internal/waf/ip_table_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								internal/waf/ip_table_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,142 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/assert"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestIPTable_MatchIP(t *testing.T) {
 | 
				
			||||||
 | 
						a := assert.NewAssertion(t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							table := NewIPTable()
 | 
				
			||||||
 | 
							err := table.Init()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							a.IsFalse(table.Match("192.168.1.100", 8080))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							table := NewIPTable()
 | 
				
			||||||
 | 
							table.IP = "*"
 | 
				
			||||||
 | 
							table.Port = "8080"
 | 
				
			||||||
 | 
							err := table.Init()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							t.Log("port:", table.minPort, table.maxPort)
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
				
			||||||
 | 
							a.IsFalse(table.Match("192.168.1.100", 8081))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							table := NewIPTable()
 | 
				
			||||||
 | 
							table.IP = "*"
 | 
				
			||||||
 | 
							table.Port = "8080-8082"
 | 
				
			||||||
 | 
							err := table.Init()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							t.Log("port:", table.minPort, table.maxPort)
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8081))
 | 
				
			||||||
 | 
							a.IsFalse(table.Match("192.168.1.100", 8083))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							table := NewIPTable()
 | 
				
			||||||
 | 
							table.IP = "*"
 | 
				
			||||||
 | 
							table.Port = "*-8082"
 | 
				
			||||||
 | 
							err := table.Init()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							t.Log("port:", table.minPort, table.maxPort)
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8079))
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8081))
 | 
				
			||||||
 | 
							a.IsFalse(table.Match("192.168.1.100", 8083))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							table := NewIPTable()
 | 
				
			||||||
 | 
							table.IP = "*"
 | 
				
			||||||
 | 
							table.Port = "8080-*"
 | 
				
			||||||
 | 
							err := table.Init()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							t.Log("port:", table.minPort, table.maxPort)
 | 
				
			||||||
 | 
							a.IsFalse(table.Match("192.168.1.100", 8079))
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8081))
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8083))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							table := NewIPTable()
 | 
				
			||||||
 | 
							table.IP = "*"
 | 
				
			||||||
 | 
							table.Port = "*"
 | 
				
			||||||
 | 
							err := table.Init()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							t.Log("port:", table.minPort, table.maxPort)
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8079))
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8081))
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8083))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							table := NewIPTable()
 | 
				
			||||||
 | 
							table.IP = "192.168.1.100"
 | 
				
			||||||
 | 
							table.Port = "*"
 | 
				
			||||||
 | 
							err := table.Init()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							t.Log("port:", table.minPort, table.maxPort)
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							table := NewIPTable()
 | 
				
			||||||
 | 
							table.IP = "192.168.1.99-192.168.1.101"
 | 
				
			||||||
 | 
							table.Port = "*"
 | 
				
			||||||
 | 
							err := table.Init()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							t.Log("port:", table.minPort, table.maxPort)
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							table := NewIPTable()
 | 
				
			||||||
 | 
							table.IP = "192.168.1.99/24"
 | 
				
			||||||
 | 
							table.Port = "*"
 | 
				
			||||||
 | 
							err := table.Init()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							t.Log("ip:", table.ipRange)
 | 
				
			||||||
 | 
							a.IsTrue(table.Match("192.168.1.100", 8080))
 | 
				
			||||||
 | 
							a.IsFalse(table.Match("192.168.2.100", 8080))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							table := NewIPTable()
 | 
				
			||||||
 | 
							table.IP = "192.168.1.99/24"
 | 
				
			||||||
 | 
							table.TimeTo = time.Now().Unix() - 10
 | 
				
			||||||
 | 
							table.Port = "*"
 | 
				
			||||||
 | 
							err := table.Init()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								t.Fatal(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							a.IsFalse(table.Match("192.168.1.100", 8080))
 | 
				
			||||||
 | 
							a.IsFalse(table.Match("192.168.2.100", 8080))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										35
									
								
								internal/waf/requests/request.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								internal/waf/requests/request.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,35 @@
 | 
				
			|||||||
 | 
					package requests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Request struct {
 | 
				
			||||||
 | 
						*http.Request
 | 
				
			||||||
 | 
						BodyData []byte
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewRequest(raw *http.Request) *Request {
 | 
				
			||||||
 | 
						return &Request{
 | 
				
			||||||
 | 
							Request: raw,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Request) Raw() *http.Request {
 | 
				
			||||||
 | 
						return this.Request
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Request) ReadBody(max int64) (data []byte, err error) {
 | 
				
			||||||
 | 
						data, err = ioutil.ReadAll(io.LimitReader(this.Request.Body, max))
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Request) RestoreBody(data []byte) {
 | 
				
			||||||
 | 
						rawReader := bytes.NewBuffer(data)
 | 
				
			||||||
 | 
						buf := make([]byte, 1024)
 | 
				
			||||||
 | 
						io.CopyBuffer(rawReader, this.Request.Body, buf)
 | 
				
			||||||
 | 
						this.Request.Body = ioutil.NopCloser(rawReader)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										15
									
								
								internal/waf/requests/response.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								internal/waf/requests/response.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,15 @@
 | 
				
			|||||||
 | 
					package requests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "net/http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Response struct {
 | 
				
			||||||
 | 
						*http.Response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						BodyData []byte
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewResponse(resp *http.Response) *Response {
 | 
				
			||||||
 | 
						return &Response{
 | 
				
			||||||
 | 
							Response: resp,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										584
									
								
								internal/waf/rule.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										584
									
								
								internal/waf/rule.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,584 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"encoding/binary"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/lists"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/utils/string"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"regexp"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var singleParamRegexp = regexp.MustCompile("^\\${[\\w.-]+}$")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// rule
 | 
				
			||||||
 | 
					type Rule struct {
 | 
				
			||||||
 | 
						Description       string                 `yaml:"description" json:"description"`
 | 
				
			||||||
 | 
						Param             string                 `yaml:"param" json:"param"`       // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName}
 | 
				
			||||||
 | 
						Operator          RuleOperator           `yaml:"operator" json:"operator"` // such as contains, gt,  ...
 | 
				
			||||||
 | 
						Value             string                 `yaml:"value" json:"value"`       // compared value
 | 
				
			||||||
 | 
						IsCaseInsensitive bool                   `yaml:"isCaseInsensitive" json:"isCaseInsensitive"`
 | 
				
			||||||
 | 
						CheckpointOptions map[string]interface{} `yaml:"checkpointOptions" json:"checkpointOptions"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						checkpointFinder func(prefix string) checkpoints.CheckpointInterface
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						singleParam      string                          // real param after prefix
 | 
				
			||||||
 | 
						singleCheckpoint checkpoints.CheckpointInterface // if is single check point
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						multipleCheckpoints map[string]checkpoints.CheckpointInterface
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						isIP    bool
 | 
				
			||||||
 | 
						ipValue net.IP
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						floatValue float64
 | 
				
			||||||
 | 
						reg        *regexp.Regexp
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewRule() *Rule {
 | 
				
			||||||
 | 
						return &Rule{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Rule) Init() error {
 | 
				
			||||||
 | 
						// operator
 | 
				
			||||||
 | 
						switch this.Operator {
 | 
				
			||||||
 | 
						case RuleOperatorGt:
 | 
				
			||||||
 | 
							this.floatValue = types.Float64(this.Value)
 | 
				
			||||||
 | 
						case RuleOperatorGte:
 | 
				
			||||||
 | 
							this.floatValue = types.Float64(this.Value)
 | 
				
			||||||
 | 
						case RuleOperatorLt:
 | 
				
			||||||
 | 
							this.floatValue = types.Float64(this.Value)
 | 
				
			||||||
 | 
						case RuleOperatorLte:
 | 
				
			||||||
 | 
							this.floatValue = types.Float64(this.Value)
 | 
				
			||||||
 | 
						case RuleOperatorEq:
 | 
				
			||||||
 | 
							this.floatValue = types.Float64(this.Value)
 | 
				
			||||||
 | 
						case RuleOperatorNeq:
 | 
				
			||||||
 | 
							this.floatValue = types.Float64(this.Value)
 | 
				
			||||||
 | 
						case RuleOperatorMatch:
 | 
				
			||||||
 | 
							v := this.Value
 | 
				
			||||||
 | 
							if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") {
 | 
				
			||||||
 | 
								v = "(?i)" + v
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							v = this.unescape(v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							reg, err := regexp.Compile(v)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							this.reg = reg
 | 
				
			||||||
 | 
						case RuleOperatorNotMatch:
 | 
				
			||||||
 | 
							v := this.Value
 | 
				
			||||||
 | 
							if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") {
 | 
				
			||||||
 | 
								v = "(?i)" + v
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							v = this.unescape(v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							reg, err := regexp.Compile(v)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							this.reg = reg
 | 
				
			||||||
 | 
						case RuleOperatorEqIP, RuleOperatorGtIP, RuleOperatorGteIP, RuleOperatorLtIP, RuleOperatorLteIP:
 | 
				
			||||||
 | 
							this.ipValue = net.ParseIP(this.Value)
 | 
				
			||||||
 | 
							this.isIP = this.ipValue != nil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if !this.isIP {
 | 
				
			||||||
 | 
								return errors.New("value should be a valid ip")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case RuleOperatorIPRange, RuleOperatorNotIPRange:
 | 
				
			||||||
 | 
							if strings.Contains(this.Value, ",") {
 | 
				
			||||||
 | 
								ipList := strings.SplitN(this.Value, ",", 2)
 | 
				
			||||||
 | 
								ipString1 := strings.TrimSpace(ipList[0])
 | 
				
			||||||
 | 
								ipString2 := strings.TrimSpace(ipList[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if len(ipString1) > 0 {
 | 
				
			||||||
 | 
									ip1 := net.ParseIP(ipString1)
 | 
				
			||||||
 | 
									if ip1 == nil {
 | 
				
			||||||
 | 
										return errors.New("start ip is invalid")
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if len(ipString2) > 0 {
 | 
				
			||||||
 | 
									ip2 := net.ParseIP(ipString2)
 | 
				
			||||||
 | 
									if ip2 == nil {
 | 
				
			||||||
 | 
										return errors.New("end ip is invalid")
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else if strings.Contains(this.Value, "/") {
 | 
				
			||||||
 | 
								_, _, err := net.ParseCIDR(this.Value)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return errors.New("invalid ip range")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if singleParamRegexp.MatchString(this.Param) {
 | 
				
			||||||
 | 
							param := this.Param[2 : len(this.Param)-1]
 | 
				
			||||||
 | 
							pieces := strings.SplitN(param, ".", 2)
 | 
				
			||||||
 | 
							prefix := pieces[0]
 | 
				
			||||||
 | 
							if len(pieces) == 1 {
 | 
				
			||||||
 | 
								this.singleParam = ""
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								this.singleParam = pieces[1]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if this.checkpointFinder != nil {
 | 
				
			||||||
 | 
								checkpoint := this.checkpointFinder(prefix)
 | 
				
			||||||
 | 
								if checkpoint == nil {
 | 
				
			||||||
 | 
									return errors.New("no check point '" + prefix + "' found")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								this.singleCheckpoint = checkpoint
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								checkpoint := checkpoints.FindCheckpoint(prefix)
 | 
				
			||||||
 | 
								if checkpoint == nil {
 | 
				
			||||||
 | 
									return errors.New("no check point '" + prefix + "' found")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								checkpoint.Init()
 | 
				
			||||||
 | 
								this.singleCheckpoint = checkpoint
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.multipleCheckpoints = map[string]checkpoints.CheckpointInterface{}
 | 
				
			||||||
 | 
						var err error = nil
 | 
				
			||||||
 | 
						configutils.ParseVariables(this.Param, func(varName string) (value string) {
 | 
				
			||||||
 | 
							pieces := strings.SplitN(varName, ".", 2)
 | 
				
			||||||
 | 
							prefix := pieces[0]
 | 
				
			||||||
 | 
							if this.checkpointFinder != nil {
 | 
				
			||||||
 | 
								checkpoint := this.checkpointFinder(prefix)
 | 
				
			||||||
 | 
								if checkpoint == nil {
 | 
				
			||||||
 | 
									err = errors.New("no check point '" + prefix + "' found")
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									this.multipleCheckpoints[prefix] = checkpoint
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								checkpoint := checkpoints.FindCheckpoint(prefix)
 | 
				
			||||||
 | 
								if checkpoint == nil {
 | 
				
			||||||
 | 
									err = errors.New("no check point '" + prefix + "' found")
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									checkpoint.Init()
 | 
				
			||||||
 | 
									this.multipleCheckpoints[prefix] = checkpoint
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return ""
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Rule) MatchRequest(req *requests.Request) (b bool, err error) {
 | 
				
			||||||
 | 
						if this.singleCheckpoint != nil {
 | 
				
			||||||
 | 
							value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return false, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return this.Test(value), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						value := configutils.ParseVariables(this.Param, func(varName string) (value string) {
 | 
				
			||||||
 | 
							pieces := strings.SplitN(varName, ".", 2)
 | 
				
			||||||
 | 
							prefix := pieces[0]
 | 
				
			||||||
 | 
							point, ok := this.multipleCheckpoints[prefix]
 | 
				
			||||||
 | 
							if !ok {
 | 
				
			||||||
 | 
								return ""
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if len(pieces) == 1 {
 | 
				
			||||||
 | 
								value1, err1, _ := point.RequestValue(req, "", this.CheckpointOptions)
 | 
				
			||||||
 | 
								if err1 != nil {
 | 
				
			||||||
 | 
									err = err1
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return types.String(value1)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							value1, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions)
 | 
				
			||||||
 | 
							if err1 != nil {
 | 
				
			||||||
 | 
								err = err1
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return types.String(value1)
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return this.Test(value), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Rule) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, err error) {
 | 
				
			||||||
 | 
						if this.singleCheckpoint != nil {
 | 
				
			||||||
 | 
							// if is request param
 | 
				
			||||||
 | 
							if this.singleCheckpoint.IsRequest() {
 | 
				
			||||||
 | 
								value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return false, err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return this.Test(value), nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// response param
 | 
				
			||||||
 | 
							value, err, _ := this.singleCheckpoint.ResponseValue(req, resp, this.singleParam, this.CheckpointOptions)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return false, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return this.Test(value), nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						value := configutils.ParseVariables(this.Param, func(varName string) (value string) {
 | 
				
			||||||
 | 
							pieces := strings.SplitN(varName, ".", 2)
 | 
				
			||||||
 | 
							prefix := pieces[0]
 | 
				
			||||||
 | 
							point, ok := this.multipleCheckpoints[prefix]
 | 
				
			||||||
 | 
							if !ok {
 | 
				
			||||||
 | 
								return ""
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if len(pieces) == 1 {
 | 
				
			||||||
 | 
								if point.IsRequest() {
 | 
				
			||||||
 | 
									value1, err1, _ := point.RequestValue(req, "", this.CheckpointOptions)
 | 
				
			||||||
 | 
									if err1 != nil {
 | 
				
			||||||
 | 
										err = err1
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									return types.String(value1)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									value1, err1, _ := point.ResponseValue(req, resp, "", this.CheckpointOptions)
 | 
				
			||||||
 | 
									if err1 != nil {
 | 
				
			||||||
 | 
										err = err1
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									return types.String(value1)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if point.IsRequest() {
 | 
				
			||||||
 | 
								value1, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions)
 | 
				
			||||||
 | 
								if err1 != nil {
 | 
				
			||||||
 | 
									err = err1
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return types.String(value1)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								value1, err1, _ := point.ResponseValue(req, resp, pieces[1], this.CheckpointOptions)
 | 
				
			||||||
 | 
								if err1 != nil {
 | 
				
			||||||
 | 
									err = err1
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return types.String(value1)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return this.Test(value), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Rule) Test(value interface{}) bool {
 | 
				
			||||||
 | 
						// operator
 | 
				
			||||||
 | 
						switch this.Operator {
 | 
				
			||||||
 | 
						case RuleOperatorGt:
 | 
				
			||||||
 | 
							return types.Float64(value) > this.floatValue
 | 
				
			||||||
 | 
						case RuleOperatorGte:
 | 
				
			||||||
 | 
							return types.Float64(value) >= this.floatValue
 | 
				
			||||||
 | 
						case RuleOperatorLt:
 | 
				
			||||||
 | 
							return types.Float64(value) < this.floatValue
 | 
				
			||||||
 | 
						case RuleOperatorLte:
 | 
				
			||||||
 | 
							return types.Float64(value) <= this.floatValue
 | 
				
			||||||
 | 
						case RuleOperatorEq:
 | 
				
			||||||
 | 
							return types.Float64(value) == this.floatValue
 | 
				
			||||||
 | 
						case RuleOperatorNeq:
 | 
				
			||||||
 | 
							return types.Float64(value) != this.floatValue
 | 
				
			||||||
 | 
						case RuleOperatorEqString:
 | 
				
			||||||
 | 
							if this.IsCaseInsensitive {
 | 
				
			||||||
 | 
								return strings.ToLower(types.String(value)) == strings.ToLower(this.Value)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return types.String(value) == this.Value
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case RuleOperatorNeqString:
 | 
				
			||||||
 | 
							if this.IsCaseInsensitive {
 | 
				
			||||||
 | 
								return strings.ToLower(types.String(value)) != strings.ToLower(this.Value)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return types.String(value) != this.Value
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case RuleOperatorMatch:
 | 
				
			||||||
 | 
							if value == nil {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// strings
 | 
				
			||||||
 | 
							stringList, ok := value.([]string)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								for _, s := range stringList {
 | 
				
			||||||
 | 
									if utils.MatchStringCache(this.reg, s) {
 | 
				
			||||||
 | 
										return true
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// bytes
 | 
				
			||||||
 | 
							byteSlice, ok := value.([]byte)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								return utils.MatchBytesCache(this.reg, byteSlice)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// string
 | 
				
			||||||
 | 
							return utils.MatchStringCache(this.reg, types.String(value))
 | 
				
			||||||
 | 
						case RuleOperatorNotMatch:
 | 
				
			||||||
 | 
							if value == nil {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							stringList, ok := value.([]string)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								for _, s := range stringList {
 | 
				
			||||||
 | 
									if utils.MatchStringCache(this.reg, s) {
 | 
				
			||||||
 | 
										return false
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// bytes
 | 
				
			||||||
 | 
							byteSlice, ok := value.([]byte)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								return !utils.MatchBytesCache(this.reg, byteSlice)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return !utils.MatchStringCache(this.reg, types.String(value))
 | 
				
			||||||
 | 
						case RuleOperatorContains:
 | 
				
			||||||
 | 
							if types.IsSlice(value) {
 | 
				
			||||||
 | 
								ok := false
 | 
				
			||||||
 | 
								lists.Each(value, func(k int, v interface{}) {
 | 
				
			||||||
 | 
									if types.String(v) == this.Value {
 | 
				
			||||||
 | 
										ok = true
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
								return ok
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if types.IsMap(value) {
 | 
				
			||||||
 | 
								lowerValue := ""
 | 
				
			||||||
 | 
								if this.IsCaseInsensitive {
 | 
				
			||||||
 | 
									lowerValue = strings.ToLower(this.Value)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								for _, v := range maps.NewMap(value) {
 | 
				
			||||||
 | 
									if this.IsCaseInsensitive {
 | 
				
			||||||
 | 
										if strings.ToLower(types.String(v)) == lowerValue {
 | 
				
			||||||
 | 
											return true
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										if types.String(v) == this.Value {
 | 
				
			||||||
 | 
											return true
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if this.IsCaseInsensitive {
 | 
				
			||||||
 | 
								return strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return strings.Contains(types.String(value), this.Value)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case RuleOperatorNotContains:
 | 
				
			||||||
 | 
							if this.IsCaseInsensitive {
 | 
				
			||||||
 | 
								return !strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return !strings.Contains(types.String(value), this.Value)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case RuleOperatorPrefix:
 | 
				
			||||||
 | 
							if this.IsCaseInsensitive {
 | 
				
			||||||
 | 
								return strings.HasPrefix(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return strings.HasPrefix(types.String(value), this.Value)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case RuleOperatorSuffix:
 | 
				
			||||||
 | 
							if this.IsCaseInsensitive {
 | 
				
			||||||
 | 
								return strings.HasSuffix(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return strings.HasSuffix(types.String(value), this.Value)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case RuleOperatorHasKey:
 | 
				
			||||||
 | 
							if types.IsSlice(value) {
 | 
				
			||||||
 | 
								index := types.Int(this.Value)
 | 
				
			||||||
 | 
								if index < 0 {
 | 
				
			||||||
 | 
									return false
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return reflect.ValueOf(value).Len() > index
 | 
				
			||||||
 | 
							} else if types.IsMap(value) {
 | 
				
			||||||
 | 
								m := maps.NewMap(value)
 | 
				
			||||||
 | 
								if this.IsCaseInsensitive {
 | 
				
			||||||
 | 
									lowerValue := strings.ToLower(this.Value)
 | 
				
			||||||
 | 
									for k := range m {
 | 
				
			||||||
 | 
										if strings.ToLower(k) == lowerValue {
 | 
				
			||||||
 | 
											return true
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									return m.Has(this.Value)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						case RuleOperatorVersionGt:
 | 
				
			||||||
 | 
							return stringutil.VersionCompare(this.Value, types.String(value)) > 0
 | 
				
			||||||
 | 
						case RuleOperatorVersionLt:
 | 
				
			||||||
 | 
							return stringutil.VersionCompare(this.Value, types.String(value)) < 0
 | 
				
			||||||
 | 
						case RuleOperatorVersionRange:
 | 
				
			||||||
 | 
							if strings.Contains(this.Value, ",") {
 | 
				
			||||||
 | 
								versions := strings.SplitN(this.Value, ",", 2)
 | 
				
			||||||
 | 
								version1 := strings.TrimSpace(versions[0])
 | 
				
			||||||
 | 
								version2 := strings.TrimSpace(versions[1])
 | 
				
			||||||
 | 
								if len(version1) > 0 && stringutil.VersionCompare(types.String(value), version1) < 0 {
 | 
				
			||||||
 | 
									return false
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if len(version2) > 0 && stringutil.VersionCompare(types.String(value), version2) > 0 {
 | 
				
			||||||
 | 
									return false
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return stringutil.VersionCompare(types.String(value), this.Value) >= 0
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case RuleOperatorEqIP:
 | 
				
			||||||
 | 
							ip := net.ParseIP(types.String(value))
 | 
				
			||||||
 | 
							if ip == nil {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return this.isIP && bytes.Compare(this.ipValue, ip) == 0
 | 
				
			||||||
 | 
						case RuleOperatorGtIP:
 | 
				
			||||||
 | 
							ip := net.ParseIP(types.String(value))
 | 
				
			||||||
 | 
							if ip == nil {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return this.isIP && bytes.Compare(ip, this.ipValue) > 0
 | 
				
			||||||
 | 
						case RuleOperatorGteIP:
 | 
				
			||||||
 | 
							ip := net.ParseIP(types.String(value))
 | 
				
			||||||
 | 
							if ip == nil {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return this.isIP && bytes.Compare(ip, this.ipValue) >= 0
 | 
				
			||||||
 | 
						case RuleOperatorLtIP:
 | 
				
			||||||
 | 
							ip := net.ParseIP(types.String(value))
 | 
				
			||||||
 | 
							if ip == nil {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return this.isIP && bytes.Compare(ip, this.ipValue) < 0
 | 
				
			||||||
 | 
						case RuleOperatorLteIP:
 | 
				
			||||||
 | 
							ip := net.ParseIP(types.String(value))
 | 
				
			||||||
 | 
							if ip == nil {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return this.isIP && bytes.Compare(ip, this.ipValue) <= 0
 | 
				
			||||||
 | 
						case RuleOperatorIPRange:
 | 
				
			||||||
 | 
							return this.containsIP(value)
 | 
				
			||||||
 | 
						case RuleOperatorNotIPRange:
 | 
				
			||||||
 | 
							return !this.containsIP(value)
 | 
				
			||||||
 | 
						case RuleOperatorIPMod:
 | 
				
			||||||
 | 
							pieces := strings.SplitN(this.Value, ",", 2)
 | 
				
			||||||
 | 
							if len(pieces) == 1 {
 | 
				
			||||||
 | 
								rem := types.Int64(pieces[0])
 | 
				
			||||||
 | 
								return this.ipToInt64(net.ParseIP(types.String(value)))%10 == rem
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							div := types.Int64(pieces[0])
 | 
				
			||||||
 | 
							if div == 0 {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							rem := types.Int64(pieces[1])
 | 
				
			||||||
 | 
							return this.ipToInt64(net.ParseIP(types.String(value)))%div == rem
 | 
				
			||||||
 | 
						case RuleOperatorIPMod10:
 | 
				
			||||||
 | 
							return this.ipToInt64(net.ParseIP(types.String(value)))%10 == types.Int64(this.Value)
 | 
				
			||||||
 | 
						case RuleOperatorIPMod100:
 | 
				
			||||||
 | 
							return this.ipToInt64(net.ParseIP(types.String(value)))%100 == types.Int64(this.Value)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Rule) IsSingleCheckpoint() bool {
 | 
				
			||||||
 | 
						return this.singleCheckpoint != nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Rule) SetCheckpointFinder(finder func(prefix string) checkpoints.CheckpointInterface) {
 | 
				
			||||||
 | 
						this.checkpointFinder = finder
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Rule) unescape(v string) string {
 | 
				
			||||||
 | 
						//replace urlencoded characters
 | 
				
			||||||
 | 
						v = strings.Replace(v, `\s`, `(\s|%09|%0A|\+)`, -1)
 | 
				
			||||||
 | 
						v = strings.Replace(v, `\(`, `(\(|%28)`, -1)
 | 
				
			||||||
 | 
						v = strings.Replace(v, `=`, `(=|%3D)`, -1)
 | 
				
			||||||
 | 
						v = strings.Replace(v, `<`, `(<|%3C)`, -1)
 | 
				
			||||||
 | 
						v = strings.Replace(v, `\*`, `(\*|%2A)`, -1)
 | 
				
			||||||
 | 
						v = strings.Replace(v, `\\`, `(\\|%2F)`, -1)
 | 
				
			||||||
 | 
						v = strings.Replace(v, `!`, `(!|%21)`, -1)
 | 
				
			||||||
 | 
						v = strings.Replace(v, `/`, `(/|%2F)`, -1)
 | 
				
			||||||
 | 
						v = strings.Replace(v, `;`, `(;|%3B)`, -1)
 | 
				
			||||||
 | 
						v = strings.Replace(v, `\+`, `(\+|%20)`, -1)
 | 
				
			||||||
 | 
						return v
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Rule) containsIP(value interface{}) bool {
 | 
				
			||||||
 | 
						ip := net.ParseIP(types.String(value))
 | 
				
			||||||
 | 
						if ip == nil {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 检查IP范围格式
 | 
				
			||||||
 | 
						if strings.Contains(this.Value, ",") {
 | 
				
			||||||
 | 
							ipList := strings.SplitN(this.Value, ",", 2)
 | 
				
			||||||
 | 
							ipString1 := strings.TrimSpace(ipList[0])
 | 
				
			||||||
 | 
							ipString2 := strings.TrimSpace(ipList[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if len(ipString1) > 0 {
 | 
				
			||||||
 | 
								ip1 := net.ParseIP(ipString1)
 | 
				
			||||||
 | 
								if ip1 == nil {
 | 
				
			||||||
 | 
									return false
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if bytes.Compare(ip, ip1) < 0 {
 | 
				
			||||||
 | 
									return false
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if len(ipString2) > 0 {
 | 
				
			||||||
 | 
								ip2 := net.ParseIP(ipString2)
 | 
				
			||||||
 | 
								if ip2 == nil {
 | 
				
			||||||
 | 
									return false
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if bytes.Compare(ip, ip2) > 0 {
 | 
				
			||||||
 | 
									return false
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						} else if strings.Contains(this.Value, "/") {
 | 
				
			||||||
 | 
							_, ipNet, err := net.ParseCIDR(this.Value)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return ipNet.Contains(ip)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Rule) ipToInt64(ip net.IP) int64 {
 | 
				
			||||||
 | 
						if len(ip) == 0 {
 | 
				
			||||||
 | 
							return 0
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if len(ip) == 16 {
 | 
				
			||||||
 | 
							return int64(binary.BigEndian.Uint32(ip[12:16]))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return int64(binary.BigEndian.Uint32(ip))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										147
									
								
								internal/waf/rule_group.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										147
									
								
								internal/waf/rule_group.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,147 @@
 | 
				
			|||||||
 | 
					package waf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// rule group
 | 
				
			||||||
 | 
					type RuleGroup struct {
 | 
				
			||||||
 | 
						Id          string     `yaml:"id" json:"id"`
 | 
				
			||||||
 | 
						IsOn        bool       `yaml:"isOn" json:"isOn"`
 | 
				
			||||||
 | 
						Name        string     `yaml:"name" json:"name"` // such as SQL Injection
 | 
				
			||||||
 | 
						Description string     `yaml:"description" json:"description"`
 | 
				
			||||||
 | 
						Code        string     `yaml:"code" json:"code"` // identify the group
 | 
				
			||||||
 | 
						RuleSets    []*RuleSet `yaml:"ruleSets" json:"ruleSets"`
 | 
				
			||||||
 | 
						IsInbound   bool       `yaml:"isInbound" json:"isInbound"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						hasRuleSets bool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewRuleGroup() *RuleGroup {
 | 
				
			||||||
 | 
						return &RuleGroup{
 | 
				
			||||||
 | 
							IsOn: true,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RuleGroup) Init() error {
 | 
				
			||||||
 | 
						this.hasRuleSets = len(this.RuleSets) > 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if this.hasRuleSets {
 | 
				
			||||||
 | 
							for _, set := range this.RuleSets {
 | 
				
			||||||
 | 
								err := set.Init()
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RuleGroup) AddRuleSet(ruleSet *RuleSet) {
 | 
				
			||||||
 | 
						this.RuleSets = append(this.RuleSets, ruleSet)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RuleGroup) FindRuleSet(id string) *RuleSet {
 | 
				
			||||||
 | 
						if len(id) == 0 {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, ruleSet := range this.RuleSets {
 | 
				
			||||||
 | 
							if ruleSet.Id == id {
 | 
				
			||||||
 | 
								return ruleSet
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RuleGroup) FindRuleSetWithCode(code string) *RuleSet {
 | 
				
			||||||
 | 
						if len(code) == 0 {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, ruleSet := range this.RuleSets {
 | 
				
			||||||
 | 
							if ruleSet.Code == code {
 | 
				
			||||||
 | 
								return ruleSet
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RuleGroup) RemoveRuleSet(id string) {
 | 
				
			||||||
 | 
						if len(id) == 0 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						result := []*RuleSet{}
 | 
				
			||||||
 | 
						for _, ruleSet := range this.RuleSets {
 | 
				
			||||||
 | 
							if ruleSet.Id == id {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							result = append(result, ruleSet)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.RuleSets = result
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RuleGroup) MatchRequest(req *requests.Request) (b bool, set *RuleSet, err error) {
 | 
				
			||||||
 | 
						if !this.hasRuleSets {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, set := range this.RuleSets {
 | 
				
			||||||
 | 
							if !set.IsOn {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							b, err = set.MatchRequest(req)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return false, nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if b {
 | 
				
			||||||
 | 
								return true, set, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RuleGroup) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, set *RuleSet, err error) {
 | 
				
			||||||
 | 
						if !this.hasRuleSets {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, set := range this.RuleSets {
 | 
				
			||||||
 | 
							if !set.IsOn {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							b, err = set.MatchResponse(req, resp)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return false, nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if b {
 | 
				
			||||||
 | 
								return true, set, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *RuleGroup) MoveRuleSet(fromIndex int, toIndex int) {
 | 
				
			||||||
 | 
						if fromIndex < 0 || fromIndex >= len(this.RuleSets) {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if toIndex < 0 || toIndex >= len(this.RuleSets) {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if fromIndex == toIndex {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						location := this.RuleSets[fromIndex]
 | 
				
			||||||
 | 
						result := []*RuleSet{}
 | 
				
			||||||
 | 
						for i := 0; i < len(this.RuleSets); i++ {
 | 
				
			||||||
 | 
							if i == fromIndex {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if fromIndex > toIndex && i == toIndex {
 | 
				
			||||||
 | 
								result = append(result, location)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							result = append(result, this.RuleSets[i])
 | 
				
			||||||
 | 
							if fromIndex < toIndex && i == toIndex {
 | 
				
			||||||
 | 
								result = append(result, location)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.RuleSets = result
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user