mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-03 23:20:25 +08:00
实现WAF
This commit is contained in:
2
go.mod
2
go.mod
@@ -7,6 +7,7 @@ replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
|
|||||||
require (
|
require (
|
||||||
github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect
|
github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect
|
||||||
github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000
|
github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000
|
||||||
|
github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f
|
||||||
github.com/dchest/siphash v1.2.1
|
github.com/dchest/siphash v1.2.1
|
||||||
github.com/go-ole/go-ole v1.2.4 // indirect
|
github.com/go-ole/go-ole v1.2.4 // indirect
|
||||||
github.com/go-yaml/yaml v2.1.0+incompatible
|
github.com/go-yaml/yaml v2.1.0+incompatible
|
||||||
@@ -14,4 +15,5 @@ require (
|
|||||||
github.com/shirou/gopsutil v2.20.9+incompatible
|
github.com/shirou/gopsutil v2.20.9+incompatible
|
||||||
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7
|
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7
|
||||||
google.golang.org/grpc v1.32.0
|
google.golang.org/grpc v1.32.0
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776
|
||||||
)
|
)
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -14,6 +14,8 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f h1:q/DpyjJjZs94bziQ7YkBmIlpqbVP7yw179rnzoNVX1M=
|
||||||
|
github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f/go.mod h1:QGrK8vMWWHQYQ3QU9bw9Y9OPNfxccGzfb41qjvVeXtY=
|
||||||
github.com/dchest/siphash v1.2.1 h1:4cLinnzVJDKxTCl9B01807Yiy+W7ZzVHj/KIroQRvT4=
|
github.com/dchest/siphash v1.2.1 h1:4cLinnzVJDKxTCl9B01807Yiy+W7ZzVHj/KIroQRvT4=
|
||||||
github.com/dchest/siphash v1.2.1/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4=
|
github.com/dchest/siphash v1.2.1/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200624174652-8d2f3be8b2d9/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200624174652-8d2f3be8b2d9/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
|
|||||||
2
internal/grids/README.md
Normal file
2
internal/grids/README.md
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# Memory Grid
|
||||||
|
Cache items in memory, using partitions and LRU.
|
||||||
186
internal/grids/cell.go
Normal file
186
internal/grids/cell.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Cell struct {
|
||||||
|
LimitSize int64
|
||||||
|
LimitCount int
|
||||||
|
|
||||||
|
mapping map[uint64]*Item // key => item
|
||||||
|
list *List // { item1, item2, ... }
|
||||||
|
totalBytes int64
|
||||||
|
locker sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCell() *Cell {
|
||||||
|
return &Cell{
|
||||||
|
mapping: map[uint64]*Item{},
|
||||||
|
list: NewList(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Cell) Write(hashKey uint64, item *Item) {
|
||||||
|
if item == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
this.locker.Lock()
|
||||||
|
|
||||||
|
oldItem, ok := this.mapping[hashKey]
|
||||||
|
if ok {
|
||||||
|
this.list.Remove(oldItem)
|
||||||
|
|
||||||
|
if this.LimitSize > 0 {
|
||||||
|
this.totalBytes -= oldItem.Size()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// limit count
|
||||||
|
if this.LimitCount > 0 && len(this.mapping) >= this.LimitCount {
|
||||||
|
this.locker.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// trim memory
|
||||||
|
size := item.Size()
|
||||||
|
shouldTrim := false
|
||||||
|
if this.LimitSize > 0 && this.LimitSize < this.totalBytes+size {
|
||||||
|
this.Trim()
|
||||||
|
shouldTrim = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// compare again
|
||||||
|
if shouldTrim {
|
||||||
|
if this.LimitSize > 0 && this.LimitSize < this.totalBytes+size {
|
||||||
|
this.locker.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this.totalBytes += size
|
||||||
|
|
||||||
|
this.list.Add(item)
|
||||||
|
this.mapping[hashKey] = item
|
||||||
|
|
||||||
|
this.locker.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Cell) Increase64(key []byte, expireAt int64, hashKey uint64, delta int64) (result int64) {
|
||||||
|
this.locker.Lock()
|
||||||
|
item, ok := this.mapping[hashKey]
|
||||||
|
if ok {
|
||||||
|
// reset to zero if expired
|
||||||
|
if item.ExpireAt < time.Now().Unix() {
|
||||||
|
item.ValueInt64 = 0
|
||||||
|
item.ExpireAt = expireAt
|
||||||
|
}
|
||||||
|
item.IncreaseInt64(delta)
|
||||||
|
result = item.ValueInt64
|
||||||
|
} else {
|
||||||
|
item := NewItem(key, ItemInt64)
|
||||||
|
item.ValueInt64 = delta
|
||||||
|
item.ExpireAt = expireAt
|
||||||
|
this.mapping[hashKey] = item
|
||||||
|
result = delta
|
||||||
|
}
|
||||||
|
this.locker.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Cell) Read(hashKey uint64) *Item {
|
||||||
|
this.locker.Lock()
|
||||||
|
|
||||||
|
item, ok := this.mapping[hashKey]
|
||||||
|
if ok {
|
||||||
|
this.list.Remove(item)
|
||||||
|
this.list.Add(item)
|
||||||
|
|
||||||
|
this.locker.Unlock()
|
||||||
|
|
||||||
|
if item.ExpireAt < time.Now().Unix() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return item
|
||||||
|
}
|
||||||
|
|
||||||
|
this.locker.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Cell) Stat() *CellStat {
|
||||||
|
this.locker.RLock()
|
||||||
|
defer this.locker.RUnlock()
|
||||||
|
|
||||||
|
return &CellStat{
|
||||||
|
TotalBytes: this.totalBytes,
|
||||||
|
CountItems: len(this.mapping),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// trim NOT ACTIVE items
|
||||||
|
// should called in locker context
|
||||||
|
func (this *Cell) Trim() {
|
||||||
|
l := len(this.mapping)
|
||||||
|
if l == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inactiveSize := int(math.Ceil(float64(l) / 10)) // trim 10% items
|
||||||
|
this.list.Range(func(item *Item) (goNext bool) {
|
||||||
|
inactiveSize--
|
||||||
|
delete(this.mapping, item.HashKey())
|
||||||
|
this.list.Remove(item)
|
||||||
|
this.totalBytes -= item.Size()
|
||||||
|
return inactiveSize > 0
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Cell) Delete(hashKey uint64) {
|
||||||
|
this.locker.Lock()
|
||||||
|
item, ok := this.mapping[hashKey]
|
||||||
|
if ok {
|
||||||
|
delete(this.mapping, hashKey)
|
||||||
|
this.list.Remove(item)
|
||||||
|
this.totalBytes -= item.Size()
|
||||||
|
}
|
||||||
|
this.locker.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// range all items in the cell
|
||||||
|
func (this *Cell) Range(f func(item *Item)) {
|
||||||
|
this.locker.Lock()
|
||||||
|
for _, item := range this.mapping {
|
||||||
|
f(item)
|
||||||
|
}
|
||||||
|
this.locker.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Cell) Recycle() {
|
||||||
|
this.locker.Lock()
|
||||||
|
if len(this.mapping) == 0 {
|
||||||
|
this.locker.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp := time.Now().Unix()
|
||||||
|
for key, item := range this.mapping {
|
||||||
|
if item.ExpireAt < timestamp {
|
||||||
|
delete(this.mapping, key)
|
||||||
|
this.list.Remove(item)
|
||||||
|
this.totalBytes -= item.Size()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this.locker.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Cell) Reset() {
|
||||||
|
this.locker.Lock()
|
||||||
|
this.list.Reset()
|
||||||
|
this.mapping = map[uint64]*Item{}
|
||||||
|
this.totalBytes = 0
|
||||||
|
this.locker.Unlock()
|
||||||
|
}
|
||||||
6
internal/grids/cell_stat.go
Normal file
6
internal/grids/cell_stat.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
type CellStat struct {
|
||||||
|
TotalBytes int64
|
||||||
|
CountItems int
|
||||||
|
}
|
||||||
214
internal/grids/cell_test.go
Normal file
214
internal/grids/cell_test.go
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCell_List(t *testing.T) {
|
||||||
|
cell := NewCell()
|
||||||
|
cell.Write(1, &Item{
|
||||||
|
ValueInt64: 1,
|
||||||
|
})
|
||||||
|
cell.Write(2, &Item{
|
||||||
|
ValueInt64: 2,
|
||||||
|
})
|
||||||
|
cell.Write(3, &Item{
|
||||||
|
ValueInt64: 3,
|
||||||
|
})
|
||||||
|
|
||||||
|
{
|
||||||
|
t.Log("====")
|
||||||
|
l := cell.list
|
||||||
|
for e := l.head; e != nil; e = e.Next {
|
||||||
|
t.Log("element:", e.ValueInt64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cell.Write(1, &Item{
|
||||||
|
ValueInt64: 1,
|
||||||
|
})
|
||||||
|
cell.Write(3, &Item{
|
||||||
|
ValueInt64: 3,
|
||||||
|
})
|
||||||
|
cell.Write(3, &Item{
|
||||||
|
ValueInt64: 3,
|
||||||
|
})
|
||||||
|
cell.Write(2, &Item{
|
||||||
|
ValueInt64: 2,
|
||||||
|
})
|
||||||
|
cell.Delete(2)
|
||||||
|
|
||||||
|
{
|
||||||
|
t.Log("====")
|
||||||
|
l := cell.list
|
||||||
|
for e := l.head; e != nil; e = e.Next {
|
||||||
|
t.Log("element:", e.ValueInt64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range cell.mapping {
|
||||||
|
t.Log(m.ValueInt64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCell_LimitSize(t *testing.T) {
|
||||||
|
cell := NewCell()
|
||||||
|
cell.LimitSize = 1024
|
||||||
|
|
||||||
|
for i := int64(0); i < 100; i ++ {
|
||||||
|
key := []byte(fmt.Sprintf("%d", i))
|
||||||
|
cell.Write(HashKey(key), &Item{
|
||||||
|
Key: key,
|
||||||
|
ValueInt64: i,
|
||||||
|
Type: ItemInt64,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("totalBytes:", cell.totalBytes)
|
||||||
|
|
||||||
|
{
|
||||||
|
t.Log("====")
|
||||||
|
l := cell.list
|
||||||
|
s := []string{}
|
||||||
|
for e := l.head; e != nil; e = e.Next {
|
||||||
|
s = append(s, fmt.Sprintf("%d", e.ValueInt64))
|
||||||
|
}
|
||||||
|
t.Log("{ " + strings.Join(s, ", ") + " }")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("mapping:", len(cell.mapping))
|
||||||
|
s := []string{}
|
||||||
|
for _, item := range cell.mapping {
|
||||||
|
s = append(s, fmt.Sprintf("%d", item.ValueInt64))
|
||||||
|
}
|
||||||
|
t.Log("{ " + strings.Join(s, ", ") + " }")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCell_MemoryUsage(t *testing.T) {
|
||||||
|
//runtime.GOMAXPROCS(4)
|
||||||
|
|
||||||
|
cell := NewCell()
|
||||||
|
cell.LimitSize = 1024 * 1024 * 1024 * 1
|
||||||
|
|
||||||
|
before := time.Now()
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(4)
|
||||||
|
|
||||||
|
for j := 0; j < 4; j ++ {
|
||||||
|
go func(j int) {
|
||||||
|
start := j * 50 * 10000
|
||||||
|
for i := start; i < start+50*10000; i ++ {
|
||||||
|
key := []byte(strconv.Itoa(i) + "VERY_LONG_STRING")
|
||||||
|
cell.Write(HashKey(key), &Item{
|
||||||
|
Key: key,
|
||||||
|
ValueInt64: int64(i),
|
||||||
|
Type: ItemInt64,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}(j)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
t.Log("items:", len(cell.mapping))
|
||||||
|
t.Log(time.Since(before).Seconds(), "s", "totalBytes:", cell.totalBytes/1024/1024, "M")
|
||||||
|
//time.Sleep(10 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCell_Write(b *testing.B) {
|
||||||
|
runtime.GOMAXPROCS(1)
|
||||||
|
|
||||||
|
cell := NewCell()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i ++ {
|
||||||
|
key := []byte(strconv.Itoa(i) + "_LONG_KEY_LONG_KEY_LONG_KEY_LONG_KEY")
|
||||||
|
cell.Write(HashKey(key), &Item{
|
||||||
|
Key: key,
|
||||||
|
ValueInt64: int64(i),
|
||||||
|
Type: ItemInt64,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Log("items:", len(cell.mapping))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCell_Read(t *testing.T) {
|
||||||
|
cell := NewCell()
|
||||||
|
|
||||||
|
cell.Write(1, &Item{
|
||||||
|
ValueInt64: 1,
|
||||||
|
ExpireAt: time.Now().Unix() + 3600,
|
||||||
|
})
|
||||||
|
cell.Write(2, &Item{
|
||||||
|
ValueInt64: 2,
|
||||||
|
ExpireAt: time.Now().Unix() + 3600,
|
||||||
|
})
|
||||||
|
cell.Write(3, &Item{
|
||||||
|
ValueInt64: 3,
|
||||||
|
ExpireAt: time.Now().Unix() + 3600,
|
||||||
|
})
|
||||||
|
|
||||||
|
{
|
||||||
|
s := []string{}
|
||||||
|
cell.list.Range(func(item *Item) (goNext bool) {
|
||||||
|
s = append(s, fmt.Sprintf("%d", item.ValueInt64))
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
t.Log("before:", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log(cell.Read(1).ValueInt64)
|
||||||
|
|
||||||
|
{
|
||||||
|
s := []string{}
|
||||||
|
cell.list.Range(func(item *Item) (goNext bool) {
|
||||||
|
s = append(s, fmt.Sprintf("%d", item.ValueInt64))
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
t.Log("after:", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log(cell.Read(2).ValueInt64)
|
||||||
|
|
||||||
|
{
|
||||||
|
s := []string{}
|
||||||
|
cell.list.Range(func(item *Item) (goNext bool) {
|
||||||
|
s = append(s, fmt.Sprintf("%d", item.ValueInt64))
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
t.Log("after:", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCell_Recycle(t *testing.T) {
|
||||||
|
cell := NewCell()
|
||||||
|
cell.Write(1, &Item{
|
||||||
|
ValueInt64: 1,
|
||||||
|
ExpireAt: time.Now().Unix() - 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
cell.Write(2, &Item{
|
||||||
|
ValueInt64: 2,
|
||||||
|
ExpireAt: time.Now().Unix() + 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
cell.Recycle()
|
||||||
|
|
||||||
|
{
|
||||||
|
s := []string{}
|
||||||
|
cell.list.Range(func(item *Item) (goNext bool) {
|
||||||
|
s = append(s, fmt.Sprintf("%d", item.ValueInt64))
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
t.Log("after:", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log(cell.list.Len(), cell.totalBytes)
|
||||||
|
}
|
||||||
225
internal/grids/grid.go
Normal file
225
internal/grids/grid.go
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
|
"github.com/iwind/TeaGo/timers"
|
||||||
|
"math"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Memory Cache Grid
|
||||||
|
//
|
||||||
|
// | Grid |
|
||||||
|
// | cell1, cell2, ..., cell1024 |
|
||||||
|
// | item1, item2, ..., item1000000 |
|
||||||
|
type Grid struct {
|
||||||
|
cells []*Cell
|
||||||
|
countCells uint64
|
||||||
|
|
||||||
|
recycleIndex int
|
||||||
|
recycleLooper *timers.Looper
|
||||||
|
recycleInterval int
|
||||||
|
|
||||||
|
gzipLevel int
|
||||||
|
|
||||||
|
limitSize int64
|
||||||
|
limitCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGrid(countCells int, opt ...interface{}) *Grid {
|
||||||
|
grid := &Grid{
|
||||||
|
recycleIndex: -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, o := range opt {
|
||||||
|
switch x := o.(type) {
|
||||||
|
case *CompressOpt:
|
||||||
|
grid.gzipLevel = x.Level
|
||||||
|
case *LimitSizeOpt:
|
||||||
|
grid.limitSize = x.Size
|
||||||
|
case *LimitCountOpt:
|
||||||
|
grid.limitCount = x.Count
|
||||||
|
case *RecycleIntervalOpt:
|
||||||
|
grid.recycleInterval = x.Interval
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cells := []*Cell{}
|
||||||
|
if countCells <= 0 {
|
||||||
|
countCells = 1
|
||||||
|
} else if countCells > 100*10000 {
|
||||||
|
countCells = 100 * 10000
|
||||||
|
}
|
||||||
|
for i := 0; i < countCells; i++ {
|
||||||
|
cell := NewCell()
|
||||||
|
cell.LimitSize = int64(math.Floor(float64(grid.limitSize) / float64(countCells)))
|
||||||
|
cell.LimitCount = int(math.Floor(float64(grid.limitCount)) / float64(countCells))
|
||||||
|
|
||||||
|
cells = append(cells, cell)
|
||||||
|
}
|
||||||
|
grid.cells = cells
|
||||||
|
grid.countCells = uint64(len(cells))
|
||||||
|
|
||||||
|
grid.recycleTimer()
|
||||||
|
return grid
|
||||||
|
}
|
||||||
|
|
||||||
|
// get all cells in the grid
|
||||||
|
func (this *Grid) Cells() []*Cell {
|
||||||
|
return this.cells
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Grid) WriteItem(item *Item) {
|
||||||
|
if this.countCells <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hashKey := item.HashKey()
|
||||||
|
this.cellForHashKey(hashKey).Write(hashKey, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Grid) WriteInt64(key []byte, value int64, lifeSeconds int64) {
|
||||||
|
this.WriteItem(&Item{
|
||||||
|
Key: key,
|
||||||
|
Type: ItemInt64,
|
||||||
|
ValueInt64: value,
|
||||||
|
ExpireAt: time.Now().Unix() + lifeSeconds,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Grid) IncreaseInt64(key []byte, delta int64, lifeSeconds int64) (result int64) {
|
||||||
|
hashKey := HashKey(key)
|
||||||
|
return this.cellForHashKey(hashKey).Increase64(key, time.Now().Unix()+lifeSeconds, hashKey, delta)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Grid) WriteString(key []byte, value string, lifeSeconds int64) {
|
||||||
|
this.WriteBytes(key, []byte(value), lifeSeconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Grid) WriteBytes(key []byte, value []byte, lifeSeconds int64) {
|
||||||
|
isCompressed := false
|
||||||
|
if this.gzipLevel != gzip.NoCompression {
|
||||||
|
buf := bytes.NewBuffer([]byte{})
|
||||||
|
writer, err := gzip.NewWriterLevel(buf, this.gzipLevel)
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
this.WriteItem(&Item{
|
||||||
|
Key: key,
|
||||||
|
Type: ItemBytes,
|
||||||
|
ValueBytes: value,
|
||||||
|
ExpireAt: time.Now().Unix() + lifeSeconds,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = writer.Write([]byte(value))
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
this.WriteItem(&Item{
|
||||||
|
Key: key,
|
||||||
|
Type: ItemBytes,
|
||||||
|
ValueBytes: value,
|
||||||
|
ExpireAt: time.Now().Unix() + lifeSeconds,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = writer.Close()
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
this.WriteItem(&Item{
|
||||||
|
Key: key,
|
||||||
|
Type: ItemBytes,
|
||||||
|
ValueBytes: value,
|
||||||
|
ExpireAt: time.Now().Unix() + lifeSeconds,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value = buf.Bytes()
|
||||||
|
isCompressed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
this.WriteItem(&Item{
|
||||||
|
Key: key,
|
||||||
|
Type: ItemBytes,
|
||||||
|
ValueBytes: value,
|
||||||
|
ExpireAt: time.Now().Unix() + lifeSeconds,
|
||||||
|
IsCompressed: isCompressed,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Grid) WriteInterface(key []byte, value interface{}, lifeSeconds int64) {
|
||||||
|
this.WriteItem(&Item{
|
||||||
|
Key: key,
|
||||||
|
Type: ItemInterface,
|
||||||
|
ValueInterface: value,
|
||||||
|
ExpireAt: time.Now().Unix() + lifeSeconds,
|
||||||
|
IsCompressed: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Grid) Read(key []byte) *Item {
|
||||||
|
if this.countCells <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
hashKey := HashKey(key)
|
||||||
|
return this.cellForHashKey(hashKey).Read(hashKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Grid) Stat() *Stat {
|
||||||
|
stat := &Stat{}
|
||||||
|
for _, cell := range this.cells {
|
||||||
|
cellStat := cell.Stat()
|
||||||
|
stat.CountItems += cellStat.CountItems
|
||||||
|
stat.TotalBytes += cellStat.TotalBytes
|
||||||
|
}
|
||||||
|
return stat
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Grid) Delete(key []byte) {
|
||||||
|
if this.countCells <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hashKey := HashKey(key)
|
||||||
|
this.cellForHashKey(hashKey).Delete(hashKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Grid) Reset() {
|
||||||
|
for _, cell := range this.cells {
|
||||||
|
cell.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Grid) Destroy() {
|
||||||
|
if this.recycleLooper != nil {
|
||||||
|
this.recycleLooper.Stop()
|
||||||
|
this.recycleLooper = nil
|
||||||
|
}
|
||||||
|
this.cells = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Grid) cellForHashKey(hashKey uint64) *Cell {
|
||||||
|
if hashKey < 0 {
|
||||||
|
return this.cells[-hashKey%this.countCells]
|
||||||
|
} else {
|
||||||
|
return this.cells[hashKey%this.countCells]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Grid) recycleTimer() {
|
||||||
|
duration := 1 * time.Minute
|
||||||
|
if this.recycleInterval > 0 {
|
||||||
|
duration = time.Duration(this.recycleInterval) * time.Second
|
||||||
|
}
|
||||||
|
this.recycleLooper = timers.Loop(duration, func(looper *timers.Looper) {
|
||||||
|
if this.countCells == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
this.recycleIndex++
|
||||||
|
if this.recycleIndex > int(this.countCells-1) {
|
||||||
|
this.recycleIndex = 0
|
||||||
|
}
|
||||||
|
this.cells[this.recycleIndex].Recycle()
|
||||||
|
})
|
||||||
|
}
|
||||||
204
internal/grids/grid_test.go
Normal file
204
internal/grids/grid_test.go
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
import (
|
||||||
|
"compress/gzip"
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMemoryGrid_Write(t *testing.T) {
|
||||||
|
grid := NewGrid(5, NewRecycleIntervalOpt(2), NewLimitSizeOpt(10240))
|
||||||
|
t.Log("123456:", grid.Read([]byte("123456")))
|
||||||
|
|
||||||
|
grid.WriteInt64([]byte("abc"), 1, 5)
|
||||||
|
t.Log(grid.Read([]byte("abc")).ValueInt64)
|
||||||
|
|
||||||
|
grid.WriteString([]byte("abc"), "123", 5)
|
||||||
|
t.Log(string(grid.Read([]byte("abc")).Bytes()))
|
||||||
|
|
||||||
|
grid.WriteBytes([]byte("abc"), []byte("123"), 5)
|
||||||
|
t.Log(grid.Read([]byte("abc")).Bytes())
|
||||||
|
|
||||||
|
grid.Delete([]byte("abc"))
|
||||||
|
t.Log(grid.Read([]byte("abc")))
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
grid.WriteInt64([]byte(fmt.Sprintf("%d", i)), 123, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("before recycle:")
|
||||||
|
for index, cell := range grid.cells {
|
||||||
|
t.Log("cell:", index, len(cell.mapping), "items")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
t.Log("after recycle:")
|
||||||
|
for index, cell := range grid.cells {
|
||||||
|
t.Log("cell:", index, len(cell.mapping), "items")
|
||||||
|
}
|
||||||
|
|
||||||
|
grid.Destroy()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryGrid_Write_LimitCount(t *testing.T) {
|
||||||
|
grid := NewGrid(2, NewLimitCountOpt(10))
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
grid.WriteInt64([]byte(strconv.Itoa(i)), int64(i), 30)
|
||||||
|
}
|
||||||
|
t.Log(grid.Stat().CountItems, "items")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryGrid_Compress(t *testing.T) {
|
||||||
|
grid := NewGrid(5, NewCompressOpt(1))
|
||||||
|
grid.WriteString([]byte("hello"), strings.Repeat("abcd", 10240), 30)
|
||||||
|
t.Log(len(string(grid.Read([]byte("hello")).String())))
|
||||||
|
t.Log(len(grid.Read([]byte("hello")).ValueBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMemoryGrid_Performance(b *testing.B) {
|
||||||
|
grid := NewGrid(1024)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
grid.WriteInt64([]byte("key:"+strconv.Itoa(i)), int64(i), 3600)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryGrid_Performance(t *testing.T) {
|
||||||
|
runtime.GOMAXPROCS(1)
|
||||||
|
|
||||||
|
grid := NewGrid(1024)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
s := []byte(strings.Repeat("abcd", 10*1024))
|
||||||
|
|
||||||
|
for i := 0; i < 100000; i++ {
|
||||||
|
grid.WriteBytes([]byte(fmt.Sprintf("key:%d_%d", i, 1)), s, 3600)
|
||||||
|
item := grid.Read([]byte(fmt.Sprintf("key:%d_%d", i, 1)))
|
||||||
|
if item != nil {
|
||||||
|
_ = item.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
countItems := 0
|
||||||
|
for _, cell := range grid.cells {
|
||||||
|
countItems += len(cell.mapping)
|
||||||
|
}
|
||||||
|
t.Log(countItems, "items")
|
||||||
|
|
||||||
|
t.Log(time.Since(now).Seconds()*1000, "ms")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryGrid_Performance_Concurrent(t *testing.T) {
|
||||||
|
//runtime.GOMAXPROCS(1)
|
||||||
|
|
||||||
|
grid := NewGrid(1024)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
s := []byte(strings.Repeat("abcd", 10*1024))
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(runtime.NumCPU())
|
||||||
|
|
||||||
|
for c := 0; c < runtime.NumCPU(); c++ {
|
||||||
|
go func(c int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for i := 0; i < 50000; i++ {
|
||||||
|
grid.WriteBytes([]byte(fmt.Sprintf("key:%d_%d", i, c)), s, 3600)
|
||||||
|
item := grid.Read([]byte(fmt.Sprintf("key:%d_%d", i, c)))
|
||||||
|
if item != nil {
|
||||||
|
_ = item.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
countItems := 0
|
||||||
|
for _, cell := range grid.cells {
|
||||||
|
countItems += len(cell.mapping)
|
||||||
|
}
|
||||||
|
t.Log(countItems, "items")
|
||||||
|
|
||||||
|
t.Log(time.Since(now).Seconds()*1000, "ms")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryGrid_CompressPerformance(t *testing.T) {
|
||||||
|
runtime.GOMAXPROCS(1)
|
||||||
|
|
||||||
|
grid := NewGrid(1024, NewCompressOpt(gzip.BestCompression))
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
data := []byte(strings.Repeat("abcd", 1024))
|
||||||
|
|
||||||
|
for i := 0; i < 100000; i++ {
|
||||||
|
grid.WriteBytes([]byte(fmt.Sprintf("key:%d", i)), data, 3600)
|
||||||
|
item := grid.Read([]byte(fmt.Sprintf("key:%d", i+100)))
|
||||||
|
if item != nil {
|
||||||
|
_ = item.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
countItems := 0
|
||||||
|
for _, cell := range grid.cells {
|
||||||
|
countItems += len(cell.mapping)
|
||||||
|
}
|
||||||
|
t.Log(countItems, "items")
|
||||||
|
|
||||||
|
t.Log(time.Since(now).Seconds()*1000, "ms")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryGrid_IncreaseInt64(t *testing.T) {
|
||||||
|
grid := NewGrid(1024)
|
||||||
|
grid.WriteInt64([]byte("abc"), 123, 10)
|
||||||
|
grid.IncreaseInt64([]byte("abc"), 123, 10)
|
||||||
|
grid.IncreaseInt64([]byte("abc"), 123, 10)
|
||||||
|
item := grid.Read([]byte("abc"))
|
||||||
|
if item == nil {
|
||||||
|
t.Fatal("item == nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if item.ValueInt64 != 369 {
|
||||||
|
t.Fatal("not 369")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryGrid_Destroy(t *testing.T) {
|
||||||
|
grid := NewGrid(1024)
|
||||||
|
grid.WriteInt64([]byte("abc"), 123, 10)
|
||||||
|
t.Log(grid.recycleLooper, grid.cells)
|
||||||
|
grid.Destroy()
|
||||||
|
t.Log(grid.recycleLooper, grid.cells)
|
||||||
|
|
||||||
|
if grid.recycleLooper != nil {
|
||||||
|
t.Fatal("looper != nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryGrid_Recycle(t *testing.T) {
|
||||||
|
cell := NewCell()
|
||||||
|
timestamp := time.Now().Unix()
|
||||||
|
for i := 0; i < 300_0000; i++ {
|
||||||
|
cell.Write(uint64(i), &Item{
|
||||||
|
ExpireAt: timestamp - 30,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
before := time.Now()
|
||||||
|
cell.Recycle()
|
||||||
|
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||||
|
t.Log(len(cell.mapping))
|
||||||
|
|
||||||
|
runtime.GC()
|
||||||
|
printMem(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func printMem(t *testing.T) {
|
||||||
|
mem := &runtime.MemStats{}
|
||||||
|
runtime.ReadMemStats(mem)
|
||||||
|
t.Log(mem.Alloc/1024/1024, "M")
|
||||||
|
}
|
||||||
88
internal/grids/item.go
Normal file
88
internal/grids/item.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"github.com/dchest/siphash"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
|
"sync/atomic"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ItemType = int8
|
||||||
|
|
||||||
|
const (
|
||||||
|
ItemInt64 = 1
|
||||||
|
ItemBytes = 2
|
||||||
|
ItemInterface = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
func HashKey(key []byte) uint64 {
|
||||||
|
return siphash.Hash(0, 0, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Item struct {
|
||||||
|
Key []byte
|
||||||
|
ExpireAt int64
|
||||||
|
Type ItemType
|
||||||
|
ValueInt64 int64
|
||||||
|
ValueBytes []byte
|
||||||
|
ValueInterface interface{}
|
||||||
|
IsCompressed bool
|
||||||
|
|
||||||
|
// linked list
|
||||||
|
Prev *Item
|
||||||
|
Next *Item
|
||||||
|
|
||||||
|
size int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewItem(key []byte, dataType ItemType) *Item {
|
||||||
|
return &Item{
|
||||||
|
Key: key,
|
||||||
|
Type: dataType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Item) HashKey() uint64 {
|
||||||
|
return HashKey(this.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Item) IncreaseInt64(delta int64) {
|
||||||
|
atomic.AddInt64(&this.ValueInt64, delta)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Item) Bytes() []byte {
|
||||||
|
if this.IsCompressed {
|
||||||
|
reader, err := gzip.NewReader(bytes.NewBuffer(this.ValueBytes))
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
return this.ValueBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 256)
|
||||||
|
dataBuf := bytes.NewBuffer([]byte{})
|
||||||
|
for {
|
||||||
|
n, err := reader.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
dataBuf.Write(buf[:n])
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dataBuf.Bytes()
|
||||||
|
}
|
||||||
|
return this.ValueBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Item) String() string {
|
||||||
|
return string(this.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Item) Size() int64 {
|
||||||
|
if this.size == 0 {
|
||||||
|
this.size = int64(int(unsafe.Sizeof(*this)) + len(this.Key) + len(this.ValueBytes))
|
||||||
|
}
|
||||||
|
return this.size
|
||||||
|
}
|
||||||
69
internal/grids/item_test.go
Normal file
69
internal/grids/item_test.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"github.com/dchest/siphash"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestItem_Size(t *testing.T) {
|
||||||
|
item := &Item{
|
||||||
|
ValueInt64: 1024,
|
||||||
|
Key: []byte("123"),
|
||||||
|
ValueBytes: []byte("Hello, World"),
|
||||||
|
}
|
||||||
|
t.Log(item.Size())
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkItem_Size(b *testing.B) {
|
||||||
|
item := &Item{
|
||||||
|
ValueInt64: 1024,
|
||||||
|
Key: []byte("123"),
|
||||||
|
ValueBytes: []byte("Hello, World"),
|
||||||
|
}
|
||||||
|
for i := 0; i < b.N; i ++ {
|
||||||
|
_ = item.Size()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestItem_HashKey(t *testing.T) {
|
||||||
|
t.Log(HashKey([]byte("2")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestItem_siphash(t *testing.T) {
|
||||||
|
result := siphash.Hash(0, 0, []byte("123456"))
|
||||||
|
t.Log(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestItem_unique(t *testing.T) {
|
||||||
|
m := map[uint64]bool{}
|
||||||
|
for i := 0; i < 1000*10000; i ++ {
|
||||||
|
s := "Hello,World,LONG KEY,LONG KEY,LONG KEY,LONG KEY" + strconv.Itoa(i)
|
||||||
|
result := siphash.Hash(0, 0, []byte(s))
|
||||||
|
_, ok := m[result]
|
||||||
|
if ok {
|
||||||
|
t.Log("found same", i)
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
m[result] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log(siphash.Hash(0, 0, []byte("01")))
|
||||||
|
t.Log(siphash.Hash(0, 0, []byte("10")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkItem_HashKeyMd5(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i ++ {
|
||||||
|
h := md5.New()
|
||||||
|
h.Write([]byte("HELLO_KEY_" + strconv.Itoa(i)))
|
||||||
|
_ = h.Sum(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkItem_siphash(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i ++ {
|
||||||
|
_ = siphash.Hash(0, 0, []byte("HELLO_KEY_"+strconv.Itoa(i)))
|
||||||
|
}
|
||||||
|
}
|
||||||
68
internal/grids/list.go
Normal file
68
internal/grids/list.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
type List struct {
|
||||||
|
head *Item
|
||||||
|
end *Item
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewList() *List {
|
||||||
|
return &List{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *List) Add(item *Item) {
|
||||||
|
if item == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if this.end != nil {
|
||||||
|
this.end.Next = item
|
||||||
|
item.Prev = this.end
|
||||||
|
item.Next = nil
|
||||||
|
}
|
||||||
|
this.end = item
|
||||||
|
if this.head == nil {
|
||||||
|
this.head = item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *List) Remove(item *Item) {
|
||||||
|
if item == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if item.Prev != nil {
|
||||||
|
item.Prev.Next = item.Next
|
||||||
|
}
|
||||||
|
if item.Next != nil {
|
||||||
|
item.Next.Prev = item.Prev
|
||||||
|
}
|
||||||
|
if item == this.head {
|
||||||
|
this.head = item.Next
|
||||||
|
}
|
||||||
|
if item == this.end {
|
||||||
|
this.end = item.Prev
|
||||||
|
}
|
||||||
|
|
||||||
|
item.Prev = nil
|
||||||
|
item.Next = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *List) Len() int {
|
||||||
|
l := 0
|
||||||
|
for e := this.head; e != nil; e = e.Next {
|
||||||
|
l ++
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *List) Range(f func(item *Item) (goNext bool)) {
|
||||||
|
for e := this.head; e != nil; e = e.Next {
|
||||||
|
goNext := f(e)
|
||||||
|
if !goNext {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *List) Reset() {
|
||||||
|
this.head = nil
|
||||||
|
this.end = nil
|
||||||
|
}
|
||||||
64
internal/grids/list_test.go
Normal file
64
internal/grids/list_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestList(t *testing.T) {
|
||||||
|
l := &List{}
|
||||||
|
|
||||||
|
var e1 *Item = nil
|
||||||
|
{
|
||||||
|
e := &Item{
|
||||||
|
ValueInt64: 1,
|
||||||
|
}
|
||||||
|
l.Add(e)
|
||||||
|
e1 = e
|
||||||
|
}
|
||||||
|
|
||||||
|
var e2 *Item = nil
|
||||||
|
{
|
||||||
|
e := &Item{
|
||||||
|
ValueInt64: 2,
|
||||||
|
}
|
||||||
|
l.Add(e)
|
||||||
|
e2 = e
|
||||||
|
}
|
||||||
|
|
||||||
|
var e3 *Item = nil
|
||||||
|
{
|
||||||
|
e := &Item{
|
||||||
|
ValueInt64: 3,
|
||||||
|
}
|
||||||
|
l.Add(e)
|
||||||
|
e3 = e
|
||||||
|
}
|
||||||
|
|
||||||
|
var e4 *Item = nil
|
||||||
|
{
|
||||||
|
e := &Item{
|
||||||
|
ValueInt64: 4,
|
||||||
|
}
|
||||||
|
l.Add(e)
|
||||||
|
e4 = e
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Remove(e1)
|
||||||
|
//l.Remove(e2)
|
||||||
|
//l.Remove(e3)
|
||||||
|
l.Remove(e4)
|
||||||
|
|
||||||
|
for e := l.head; e != nil; e = e.Next {
|
||||||
|
t.Log(e.ValueInt64)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("e1, e2, e3, e4, head, end:", e1, e2, e3, e4)
|
||||||
|
if l.head != nil {
|
||||||
|
t.Log("head:", l.head.ValueInt64)
|
||||||
|
} else {
|
||||||
|
t.Log("head: nil")
|
||||||
|
}
|
||||||
|
if l.end != nil {
|
||||||
|
t.Log("end:", l.end.ValueInt64)
|
||||||
|
} else {
|
||||||
|
t.Log("end: nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
11
internal/grids/opt_compress.go
Normal file
11
internal/grids/opt_compress.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
type CompressOpt struct {
|
||||||
|
Level int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCompressOpt(level int) *CompressOpt {
|
||||||
|
return &CompressOpt{
|
||||||
|
Level: level,
|
||||||
|
}
|
||||||
|
}
|
||||||
11
internal/grids/opt_limit_count.go
Normal file
11
internal/grids/opt_limit_count.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
type LimitCountOpt struct {
|
||||||
|
Count int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLimitCountOpt(count int) *LimitCountOpt {
|
||||||
|
return &LimitCountOpt{
|
||||||
|
Count: count,
|
||||||
|
}
|
||||||
|
}
|
||||||
11
internal/grids/opt_limit_size.go
Normal file
11
internal/grids/opt_limit_size.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
type LimitSizeOpt struct {
|
||||||
|
Size int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLimitSizeOpt(size int64) *LimitSizeOpt {
|
||||||
|
return &LimitSizeOpt{
|
||||||
|
Size: size,
|
||||||
|
}
|
||||||
|
}
|
||||||
11
internal/grids/opt_recycle_interval.go
Normal file
11
internal/grids/opt_recycle_interval.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
type RecycleIntervalOpt struct {
|
||||||
|
Interval int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRecycleIntervalOpt(interval int) *RecycleIntervalOpt {
|
||||||
|
return &RecycleIntervalOpt{
|
||||||
|
Interval: interval,
|
||||||
|
}
|
||||||
|
}
|
||||||
6
internal/grids/stat.go
Normal file
6
internal/grids/stat.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package grids
|
||||||
|
|
||||||
|
type Stat struct {
|
||||||
|
TotalBytes int64
|
||||||
|
CountItems int
|
||||||
|
}
|
||||||
@@ -96,7 +96,11 @@ func (this *HTTPRequest) Do() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WAF
|
// WAF
|
||||||
// TODO 需要实现
|
if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn && this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
|
||||||
|
if this.doWAFRequest() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 访问控制
|
// 访问控制
|
||||||
// TODO 需要实现
|
// TODO 需要实现
|
||||||
@@ -253,6 +257,12 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
|
|||||||
this.web.Cache = web.Cache
|
this.web.Cache = web.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// waf
|
||||||
|
if web.FirewallRef != nil && (web.FirewallRef.IsPrior || isTop) {
|
||||||
|
this.web.FirewallRef = web.FirewallRef
|
||||||
|
this.web.FirewallPolicy = web.FirewallPolicy
|
||||||
|
}
|
||||||
|
|
||||||
// 重写规则
|
// 重写规则
|
||||||
if len(web.RewriteRefs) > 0 {
|
if len(web.RewriteRefs) > 0 {
|
||||||
for index, ref := range web.RewriteRefs {
|
for index, ref := range web.RewriteRefs {
|
||||||
|
|||||||
@@ -166,12 +166,26 @@ func (this *HTTPRequest) doReverseProxy() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WAF对出站进行检查
|
// WAF对出站进行检查
|
||||||
// TODO
|
if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn && this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
|
||||||
|
if this.doWAFResponse(resp) {
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO 清除源站错误次数
|
// TODO 清除源站错误次数
|
||||||
|
|
||||||
// 特殊页面
|
// 特殊页面
|
||||||
// TODO
|
if len(this.web.Pages) > 0 && this.doPage(resp.StatusCode) {
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 设置Charset
|
// 设置Charset
|
||||||
// TODO 这里应该可以设置文本类型的列表,以及是否强制覆盖所有文本类型的字符集
|
// TODO 这里应该可以设置文本类型的列表,以及是否强制覆盖所有文本类型的字符集
|
||||||
|
|||||||
51
internal/nodes/http_request_waf.go
Normal file
51
internal/nodes/http_request_waf.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package nodes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 调用WAF
|
||||||
|
func (this *HTTPRequest) doWAFRequest() (blocked bool) {
|
||||||
|
w := sharedWAFManager.FindWAF(this.web.FirewallPolicy.Id)
|
||||||
|
if w == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
goNext, _, ruleSet, err := w.MatchRequest(this.RawReq, this.writer)
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ruleSet != nil {
|
||||||
|
if ruleSet.Action != waf.ActionAllow {
|
||||||
|
// TODO 记录日志
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return !goNext
|
||||||
|
}
|
||||||
|
|
||||||
|
// call response waf
|
||||||
|
func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
|
||||||
|
w := sharedWAFManager.FindWAF(this.web.FirewallPolicy.Id)
|
||||||
|
if w == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
goNext, _, ruleSet, err := w.MatchResponse(this.RawReq, resp, this.writer)
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ruleSet != nil {
|
||||||
|
if ruleSet.Action != waf.ActionAllow {
|
||||||
|
// TODO 记录日志
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return !goNext
|
||||||
|
}
|
||||||
@@ -104,6 +104,7 @@ func (this *Node) syncConfig(isFirstTime bool) error {
|
|||||||
logs.Println("[NODE]reload config ...")
|
logs.Println("[NODE]reload config ...")
|
||||||
nodeconfigs.ResetNodeConfig(nodeConfig)
|
nodeconfigs.ResetNodeConfig(nodeConfig)
|
||||||
caches.SharedManager.UpdatePolicies(nodeConfig.AllCachePolicies())
|
caches.SharedManager.UpdatePolicies(nodeConfig.AllCachePolicies())
|
||||||
|
sharedWAFManager.UpdatePolicies(nodeConfig.AllHTTPFirewallPolicies())
|
||||||
sharedNodeConfig = nodeConfig
|
sharedNodeConfig = nodeConfig
|
||||||
|
|
||||||
if !isFirstTime {
|
if !isFirstTime {
|
||||||
|
|||||||
175
internal/nodes/waf_manager.go
Normal file
175
internal/nodes/waf_manager.go
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
package nodes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/errors"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var sharedWAFManager = NewWAFManager()
|
||||||
|
|
||||||
|
// WAF管理器
|
||||||
|
type WAFManager struct {
|
||||||
|
mapping map[int64]*waf.WAF // policyId => WAF
|
||||||
|
locker sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取新对象
|
||||||
|
func NewWAFManager() *WAFManager {
|
||||||
|
return &WAFManager{
|
||||||
|
mapping: map[int64]*waf.WAF{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新策略
|
||||||
|
func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallPolicy) {
|
||||||
|
this.locker.Lock()
|
||||||
|
defer this.locker.Unlock()
|
||||||
|
|
||||||
|
m := map[int64]*waf.WAF{}
|
||||||
|
for _, p := range policies {
|
||||||
|
w, err := this.convertWAF(p)
|
||||||
|
if err != nil {
|
||||||
|
logs.Println("[WAF]initialize policy '" + strconv.FormatInt(p.Id, 10) + "' failed: " + err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if w == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m[p.Id] = w
|
||||||
|
}
|
||||||
|
this.mapping = m
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找WAF
|
||||||
|
func (this *WAFManager) FindWAF(policyId int64) *waf.WAF {
|
||||||
|
this.locker.RLock()
|
||||||
|
w, _ := this.mapping[policyId]
|
||||||
|
this.locker.RUnlock()
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
// 判断是否包含int64
|
||||||
|
func (this *WAFManager) containsInt64(values []int64, value int64) bool {
|
||||||
|
for _, v := range values {
|
||||||
|
if v == value {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将Policy转换为WAF
|
||||||
|
func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (*waf.WAF, error) {
|
||||||
|
if policy == nil {
|
||||||
|
return nil, errors.New("policy should not be nil")
|
||||||
|
}
|
||||||
|
w := &waf.WAF{
|
||||||
|
Id: strconv.FormatInt(policy.Id, 10),
|
||||||
|
IsOn: policy.IsOn,
|
||||||
|
Name: policy.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
// inbound
|
||||||
|
if policy.Inbound != nil && policy.Inbound.IsOn {
|
||||||
|
for _, group := range policy.Inbound.Groups {
|
||||||
|
g := &waf.RuleGroup{
|
||||||
|
Id: strconv.FormatInt(group.Id, 10),
|
||||||
|
IsOn: group.IsOn,
|
||||||
|
Name: group.Name,
|
||||||
|
Description: group.Description,
|
||||||
|
Code: group.Code,
|
||||||
|
IsInbound: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// rule sets
|
||||||
|
for _, set := range group.Sets {
|
||||||
|
s := &waf.RuleSet{
|
||||||
|
Id: strconv.FormatInt(set.Id, 10),
|
||||||
|
Code: set.Code,
|
||||||
|
IsOn: set.IsOn,
|
||||||
|
Name: set.Name,
|
||||||
|
Description: set.Description,
|
||||||
|
Connector: set.Connector,
|
||||||
|
Action: set.Action,
|
||||||
|
ActionOptions: set.ActionOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
// rules
|
||||||
|
for _, rule := range set.Rules {
|
||||||
|
r := &waf.Rule{
|
||||||
|
Description: rule.Description,
|
||||||
|
Param: rule.Param,
|
||||||
|
Operator: rule.Operator,
|
||||||
|
Value: rule.Value,
|
||||||
|
IsCaseInsensitive: rule.IsCaseInsensitive,
|
||||||
|
CheckpointOptions: rule.CheckpointOptions,
|
||||||
|
}
|
||||||
|
s.Rules = append(s.Rules, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
g.RuleSets = append(g.RuleSets, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Inbound = append(w.Inbound, g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// outbound
|
||||||
|
if policy.Outbound != nil && policy.Outbound.IsOn {
|
||||||
|
for _, group := range policy.Outbound.Groups {
|
||||||
|
g := &waf.RuleGroup{
|
||||||
|
Id: strconv.FormatInt(group.Id, 10),
|
||||||
|
IsOn: group.IsOn,
|
||||||
|
Name: group.Name,
|
||||||
|
Description: group.Description,
|
||||||
|
Code: group.Code,
|
||||||
|
IsInbound: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// rule sets
|
||||||
|
for _, set := range group.Sets {
|
||||||
|
s := &waf.RuleSet{
|
||||||
|
Id: strconv.FormatInt(set.Id, 10),
|
||||||
|
Code: set.Code,
|
||||||
|
IsOn: set.IsOn,
|
||||||
|
Name: set.Name,
|
||||||
|
Description: set.Description,
|
||||||
|
Connector: set.Connector,
|
||||||
|
Action: set.Action,
|
||||||
|
ActionOptions: set.ActionOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
// rules
|
||||||
|
for _, rule := range set.Rules {
|
||||||
|
r := &waf.Rule{
|
||||||
|
Description: rule.Description,
|
||||||
|
Param: rule.Param,
|
||||||
|
Operator: rule.Operator,
|
||||||
|
Value: rule.Value,
|
||||||
|
IsCaseInsensitive: rule.IsCaseInsensitive,
|
||||||
|
CheckpointOptions: rule.CheckpointOptions,
|
||||||
|
}
|
||||||
|
s.Rules = append(s.Rules, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
g.RuleSets = append(g.RuleSets, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Outbound = append(w.Outbound, g)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// action
|
||||||
|
// TODO
|
||||||
|
|
||||||
|
err := w.Init()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return w, nil
|
||||||
|
}
|
||||||
44
internal/nodes/waf_manager_test.go
Normal file
44
internal/nodes/waf_manager_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package nodes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWAFManager_convert(t *testing.T) {
|
||||||
|
p := &firewallconfigs.HTTPFirewallPolicy{
|
||||||
|
Id: 1,
|
||||||
|
IsOn: true,
|
||||||
|
Inbound: &firewallconfigs.HTTPFirewallInboundConfig{
|
||||||
|
IsOn: true,
|
||||||
|
Groups: []*firewallconfigs.HTTPFirewallRuleGroup{
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
Sets: []*firewallconfigs.HTTPFirewallRuleSet{
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 2,
|
||||||
|
Rules: []*firewallconfigs.HTTPFirewallRule{
|
||||||
|
{
|
||||||
|
Id: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w, err := sharedWAFManager.convertWAF(p)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logs.PrintAsJSON(w, t)
|
||||||
|
}
|
||||||
77
internal/utils/get.go
Normal file
77
internal/utils/get.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var RegexpDigitNumber = regexp.MustCompile("^\\d+$")
|
||||||
|
|
||||||
|
func Get(object interface{}, keys []string) interface{} {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return object
|
||||||
|
}
|
||||||
|
|
||||||
|
if object == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
firstKey := keys[0]
|
||||||
|
keys = keys[1:]
|
||||||
|
|
||||||
|
value := reflect.ValueOf(object)
|
||||||
|
|
||||||
|
if !value.IsValid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if value.Kind() == reflect.Ptr {
|
||||||
|
value = value.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if value.Kind() == reflect.Struct {
|
||||||
|
field := value.FieldByName(firstKey)
|
||||||
|
if !field.IsValid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return field.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
return Get(field.Interface(), keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value.Kind() == reflect.Map {
|
||||||
|
mapKey := reflect.ValueOf(firstKey)
|
||||||
|
mapValue := value.MapIndex(mapKey)
|
||||||
|
if !mapValue.IsValid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return mapValue.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
return Get(mapValue.Interface(), keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value.Kind() == reflect.Slice {
|
||||||
|
if RegexpDigitNumber.MatchString(firstKey) {
|
||||||
|
firstKeyInt := types.Int(firstKey)
|
||||||
|
if value.Len() > firstKeyInt {
|
||||||
|
result := value.Index(firstKeyInt).Interface()
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
return Get(result, keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
79
internal/utils/get_test.go
Normal file
79
internal/utils/get_test.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestGetStruct(t *testing.T) {
|
||||||
|
object := struct {
|
||||||
|
Name string
|
||||||
|
Age int
|
||||||
|
Books []string
|
||||||
|
Extend struct {
|
||||||
|
Location struct {
|
||||||
|
City string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}{
|
||||||
|
Name: "lu",
|
||||||
|
Age: 20,
|
||||||
|
Books: []string{"Golang"},
|
||||||
|
Extend: struct {
|
||||||
|
Location struct {
|
||||||
|
City string
|
||||||
|
}
|
||||||
|
}{
|
||||||
|
Location: struct {
|
||||||
|
City string
|
||||||
|
}{
|
||||||
|
City: "Beijing",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if Get(object, []string{"Name"}) != "lu" {
|
||||||
|
t.Fatal("[ERROR]Name != lu")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Get(object, []string{"Age"}) != 20 {
|
||||||
|
t.Fatal("[ERROR]Age != 20")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Get(object, []string{"Books", "0"}) != "Golang" {
|
||||||
|
t.Fatal("[ERROR]books.0 != Golang")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Extend.Location:", Get(object, []string{"Extend", "Location"}))
|
||||||
|
|
||||||
|
if Get(object, []string{"Extend", "Location", "City"}) != "Beijing" {
|
||||||
|
t.Fatal("[ERROR]Extend.Location.City != Beijing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMap(t *testing.T) {
|
||||||
|
object := map[string]interface{}{
|
||||||
|
"Name": "lu",
|
||||||
|
"Age": 20,
|
||||||
|
"Extend": map[string]interface{}{
|
||||||
|
"Location": map[string]interface{}{
|
||||||
|
"City": "Beijing",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if Get(object, []string{"Name"}) != "lu" {
|
||||||
|
t.Fatal("[ERROR]Name != lu")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Get(object, []string{"Age"}) != 20 {
|
||||||
|
t.Fatal("[ERROR]Age != 20")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Get(object, []string{"Books", "0"}) != nil {
|
||||||
|
t.Fatal("[ERROR]books.0 != nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log(Get(object, []string{"Extend", "Location"}))
|
||||||
|
|
||||||
|
if Get(object, []string{"Extend", "Location", "City"}) != "Beijing" {
|
||||||
|
t.Fatal("[ERROR]Extend.Location.City != Beijing")
|
||||||
|
}
|
||||||
|
}
|
||||||
37
internal/utils/string.go
Normal file
37
internal/utils/string.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// convert bytes to string
|
||||||
|
func UnsafeBytesToString(bs []byte) string {
|
||||||
|
return *(*string)(unsafe.Pointer(&bs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert string to bytes
|
||||||
|
func UnsafeStringToBytes(s string) []byte {
|
||||||
|
return *(*[]byte)(unsafe.Pointer(&s))
|
||||||
|
}
|
||||||
|
|
||||||
|
// format address
|
||||||
|
func FormatAddress(addr string) string {
|
||||||
|
if strings.HasSuffix(addr, "unix:") {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
addr = strings.Replace(addr, " ", "", -1)
|
||||||
|
addr = strings.Replace(addr, "\t", "", -1)
|
||||||
|
addr = strings.Replace(addr, ":", ":", -1)
|
||||||
|
addr = strings.TrimSpace(addr)
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// format address list
|
||||||
|
func FormatAddressList(addrList []string) []string {
|
||||||
|
result := []string{}
|
||||||
|
for _, addr := range addrList {
|
||||||
|
result = append(result, FormatAddress(addr))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
56
internal/utils/string_test.go
Normal file
56
internal/utils/string_test.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBytesToString(t *testing.T) {
|
||||||
|
t.Log(UnsafeBytesToString([]byte("Hello,World")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStringToBytes(t *testing.T) {
|
||||||
|
t.Log(string(UnsafeStringToBytes("Hello,World")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBytesToString(b *testing.B) {
|
||||||
|
data := []byte("Hello,World")
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = UnsafeBytesToString(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBytesToString2(b *testing.B) {
|
||||||
|
data := []byte("Hello,World")
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = string(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkStringToBytes(b *testing.B) {
|
||||||
|
s := strings.Repeat("Hello,World", 1024)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = UnsafeStringToBytes(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkStringToBytes2(b *testing.B) {
|
||||||
|
s := strings.Repeat("Hello,World", 1024)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = []byte(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatAddress(t *testing.T) {
|
||||||
|
t.Log(FormatAddress("127.0.0.1:1234"))
|
||||||
|
t.Log(FormatAddress("127.0.0.1 : 1234"))
|
||||||
|
t.Log(FormatAddress("127.0.0.1:1234"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatAddressList(t *testing.T) {
|
||||||
|
t.Log(FormatAddressList([]string{
|
||||||
|
"127.0.0.1:1234",
|
||||||
|
"127.0.0.1 : 1234",
|
||||||
|
"127.0.0.1:1234",
|
||||||
|
}))
|
||||||
|
}
|
||||||
48
internal/waf/README.md
Normal file
48
internal/waf/README.md
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
# WAF
|
||||||
|
A basic WAF for TeaWeb.
|
||||||
|
|
||||||
|
## Config Constructions
|
||||||
|
~~~
|
||||||
|
WAF
|
||||||
|
Inbound
|
||||||
|
Rule Groups
|
||||||
|
Rule Sets
|
||||||
|
Rules
|
||||||
|
Checkpoint Param <Operator> Value
|
||||||
|
Outbound
|
||||||
|
Rule Groups
|
||||||
|
...
|
||||||
|
~~~
|
||||||
|
|
||||||
|
## Apply WAF
|
||||||
|
~~~
|
||||||
|
Request --> WAF --> Backends
|
||||||
|
/
|
||||||
|
Response <-- WAF <----
|
||||||
|
~~~
|
||||||
|
|
||||||
|
## Coding
|
||||||
|
~~~go
|
||||||
|
waf := teawaf.NewWAF()
|
||||||
|
|
||||||
|
// add rule groups here
|
||||||
|
|
||||||
|
err := waf.Init()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
waf.Start()
|
||||||
|
|
||||||
|
// match http request
|
||||||
|
// (req *http.Request, responseWriter http.ResponseWriter)
|
||||||
|
goNext, ruleSet, _ := waf.MatchRequest(req, responseWriter)
|
||||||
|
if ruleSet != nil {
|
||||||
|
log.Println("meet rule set:", ruleSet.Name, "action:", ruleSet.Action)
|
||||||
|
}
|
||||||
|
if !goNext {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// stop the waf
|
||||||
|
// waf.Stop()
|
||||||
|
~~~
|
||||||
14
internal/waf/action_allow.go
Normal file
14
internal/waf/action_allow.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AllowAction struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *AllowAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||||
|
// do nothing
|
||||||
|
return true
|
||||||
|
}
|
||||||
94
internal/waf/action_block.go
Normal file
94
internal/waf/action_block.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"github.com/iwind/TeaGo/Tea"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// url client configure
|
||||||
|
var urlPrefixReg = regexp.MustCompile("^(?i)(http|https)://")
|
||||||
|
var httpClient = utils.SharedHttpClient(5 * time.Second)
|
||||||
|
|
||||||
|
type BlockAction struct {
|
||||||
|
StatusCode int `yaml:"statusCode" json:"statusCode"`
|
||||||
|
Body string `yaml:"body" json:"body"` // supports HTML
|
||||||
|
URL string `yaml:"url" json:"url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||||
|
if writer != nil {
|
||||||
|
// if status code eq 444, we close the connection
|
||||||
|
if this.StatusCode == 444 {
|
||||||
|
hijack, ok := writer.(http.Hijacker)
|
||||||
|
if ok {
|
||||||
|
conn, _, _ := hijack.Hijack()
|
||||||
|
if conn != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// output response
|
||||||
|
if this.StatusCode > 0 {
|
||||||
|
writer.WriteHeader(this.StatusCode)
|
||||||
|
} else {
|
||||||
|
writer.WriteHeader(http.StatusForbidden)
|
||||||
|
}
|
||||||
|
if len(this.URL) > 0 {
|
||||||
|
if urlPrefixReg.MatchString(this.URL) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, this.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
for _, v1 := range v {
|
||||||
|
writer.Header().Add(k, v1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
_, _ = io.CopyBuffer(writer, resp.Body, buf)
|
||||||
|
} else {
|
||||||
|
path := this.URL
|
||||||
|
if !filepath.IsAbs(this.URL) {
|
||||||
|
path = Tea.Root + string(os.PathSeparator) + path
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := ioutil.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, _ = writer.Write(data)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(this.Body) > 0 {
|
||||||
|
_, _ = writer.Write([]byte(this.Body))
|
||||||
|
} else {
|
||||||
|
_, _ = writer.Write([]byte("The request is blocked by " + teaconst.ProductName))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
39
internal/waf/action_captcha.go
Normal file
39
internal/waf/action_captcha.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var captchaSalt = stringutil.Rand(32)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CaptchaSeconds = 600 // 10 minutes
|
||||||
|
)
|
||||||
|
|
||||||
|
type CaptchaAction struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *CaptchaAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||||
|
// TEAWEB_CAPTCHA:
|
||||||
|
cookie, err := request.Cookie("TEAWEB_WAF_CAPTCHA")
|
||||||
|
if err == nil && cookie != nil && len(cookie.Value) > 32 {
|
||||||
|
m := cookie.Value[:32]
|
||||||
|
timestamp := cookie.Value[32:]
|
||||||
|
if stringutil.Md5(captchaSalt+timestamp) == m && time.Now().Unix() < types.Int64(timestamp) { // verify md5
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
refURL := request.URL.String()
|
||||||
|
if len(request.Referer()) > 0 {
|
||||||
|
refURL = request.Referer()
|
||||||
|
}
|
||||||
|
http.Redirect(writer, request.Raw(), "/WAFCAPTCHA?url="+url.QueryEscape(refURL), http.StatusTemporaryRedirect)
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
12
internal/waf/action_definition.go
Normal file
12
internal/waf/action_definition.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import "reflect"
|
||||||
|
|
||||||
|
// action definition
|
||||||
|
type ActionDefinition struct {
|
||||||
|
Name string
|
||||||
|
Code ActionString
|
||||||
|
Description string
|
||||||
|
Instance ActionInterface
|
||||||
|
Type reflect.Type
|
||||||
|
}
|
||||||
34
internal/waf/action_go_group.go
Normal file
34
internal/waf/action_go_group.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GoGroupAction struct {
|
||||||
|
GroupId string `yaml:"groupId" json:"groupId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||||
|
group := waf.FindRuleGroup(this.GroupId)
|
||||||
|
if group == nil || !group.IsOn {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
b, set, err := group.MatchRequest(request)
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !b {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
actionObject := FindActionInstance(set.Action, set.ActionOptions)
|
||||||
|
if actionObject == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return actionObject.Perform(waf, request, writer)
|
||||||
|
}
|
||||||
37
internal/waf/action_go_set.go
Normal file
37
internal/waf/action_go_set.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GoSetAction struct {
|
||||||
|
GroupId string `yaml:"groupId" json:"groupId"`
|
||||||
|
SetId string `yaml:"setId" json:"setId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||||
|
group := waf.FindRuleGroup(this.GroupId)
|
||||||
|
if group == nil || !group.IsOn {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
set := group.FindRuleSet(this.SetId)
|
||||||
|
if set == nil || !set.IsOn {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := set.MatchRequest(request)
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if !b {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
actionObject := FindActionInstance(set.Action, set.ActionOptions)
|
||||||
|
if actionObject == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return actionObject.Perform(waf, request, writer)
|
||||||
|
}
|
||||||
5
internal/waf/action_instance.go
Normal file
5
internal/waf/action_instance.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
type Action struct {
|
||||||
|
|
||||||
|
}
|
||||||
13
internal/waf/action_log.go
Normal file
13
internal/waf/action_log.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LogAction struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *LogAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
21
internal/waf/action_type.go
Normal file
21
internal/waf/action_type.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ActionString = string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ActionLog = "log" // allow and log
|
||||||
|
ActionBlock = "block" // block
|
||||||
|
ActionCaptcha = "captcha" // block and show captcha
|
||||||
|
ActionAllow = "allow" // allow
|
||||||
|
ActionGoGroup = "go_group" // go to next rule group
|
||||||
|
ActionGoSet = "go_set" // go to next rule set
|
||||||
|
)
|
||||||
|
|
||||||
|
type ActionInterface interface {
|
||||||
|
Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool)
|
||||||
|
}
|
||||||
82
internal/waf/action_utils.go
Normal file
82
internal/waf/action_utils.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/iwind/TeaGo/maps"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
var AllActions = []*ActionDefinition{
|
||||||
|
{
|
||||||
|
Name: "阻止",
|
||||||
|
Code: ActionBlock,
|
||||||
|
Instance: new(BlockAction),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "允许通过",
|
||||||
|
Code: ActionAllow,
|
||||||
|
Instance: new(AllowAction),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "允许并记录日志",
|
||||||
|
Code: ActionLog,
|
||||||
|
Instance: new(LogAction),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Captcha验证码",
|
||||||
|
Code: ActionCaptcha,
|
||||||
|
Instance: new(CaptchaAction),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "跳到下一个规则分组",
|
||||||
|
Code: ActionGoGroup,
|
||||||
|
Instance: new(GoGroupAction),
|
||||||
|
Type: reflect.TypeOf(new(GoGroupAction)).Elem(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "跳到下一个规则集",
|
||||||
|
Code: ActionGoSet,
|
||||||
|
Instance: new(GoSetAction),
|
||||||
|
Type: reflect.TypeOf(new(GoSetAction)).Elem(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
|
||||||
|
for _, def := range AllActions {
|
||||||
|
if def.Code == action {
|
||||||
|
if def.Type != nil {
|
||||||
|
// create new instance
|
||||||
|
ptrValue := reflect.New(def.Type)
|
||||||
|
instance := ptrValue.Interface().(ActionInterface)
|
||||||
|
|
||||||
|
if len(options) > 0 {
|
||||||
|
count := def.Type.NumField()
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
field := def.Type.Field(i)
|
||||||
|
tag, ok := field.Tag.Lookup("yaml")
|
||||||
|
if ok {
|
||||||
|
v, ok := options[tag]
|
||||||
|
if ok && reflect.TypeOf(v) == field.Type {
|
||||||
|
ptrValue.Elem().FieldByName(field.Name).Set(reflect.ValueOf(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return instance
|
||||||
|
}
|
||||||
|
|
||||||
|
// return shared instance
|
||||||
|
return def.Instance
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindActionName(action ActionString) string {
|
||||||
|
for _, def := range AllActions {
|
||||||
|
if def.Code == action {
|
||||||
|
return def.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
29
internal/waf/action_utils_test.go
Normal file
29
internal/waf/action_utils_test.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/iwind/TeaGo/assert"
|
||||||
|
"github.com/iwind/TeaGo/maps"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFindActionInstance(t *testing.T) {
|
||||||
|
a := assert.NewAssertion(t)
|
||||||
|
|
||||||
|
t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil))
|
||||||
|
t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil))
|
||||||
|
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
|
||||||
|
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
|
||||||
|
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
|
||||||
|
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
|
||||||
|
t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b",}))
|
||||||
|
|
||||||
|
a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkFindActionInstance(b *testing.B) {
|
||||||
|
runtime.GOMAXPROCS(1)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
FindActionInstance(ActionGoSet, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
84
internal/waf/captcha_validator.go
Normal file
84
internal/waf/captcha_validator.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"github.com/dchest/captcha"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
|
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var captchaValidator = &CaptchaValidator{}
|
||||||
|
|
||||||
|
type CaptchaValidator struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *CaptchaValidator) Run(request *requests.Request, writer http.ResponseWriter) {
|
||||||
|
if request.Method == http.MethodPost && len(request.FormValue("TEAWEB_WAF_CAPTCHA_ID")) > 0 {
|
||||||
|
this.validate(request, writer)
|
||||||
|
} else {
|
||||||
|
this.show(request, writer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *CaptchaValidator) show(request *requests.Request, writer http.ResponseWriter) {
|
||||||
|
// show captcha
|
||||||
|
captchaId := captcha.NewLen(6)
|
||||||
|
buf := bytes.NewBuffer([]byte{})
|
||||||
|
err := captcha.WriteImage(buf, captchaId, 200, 100)
|
||||||
|
if err != nil {
|
||||||
|
logs.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = writer.Write([]byte(`<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Verify Yourself</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<form method="POST">
|
||||||
|
<input type="hidden" name="TEAWEB_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
|
||||||
|
<img src="data:image/png;base64, ` + base64.StdEncoding.EncodeToString(buf.Bytes()) + `"/>` + `
|
||||||
|
<div>
|
||||||
|
<p>Input verify code above:</p>
|
||||||
|
<input type="text" name="TEAWEB_WAF_CAPTCHA_CODE" maxlength="6" size="18" autocomplete="off" z-index="1" style="font-size:16px;line-height:24px; letter-spacing: 15px; padding-left: 4px"/>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<button type="submit" onclick="window.location = '/webhook'" style="line-height:24px;margin-top:10px">Verify Yourself</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>`))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *CaptchaValidator) validate(request *requests.Request, writer http.ResponseWriter) (allow bool) {
|
||||||
|
captchaId := request.FormValue("TEAWEB_WAF_CAPTCHA_ID")
|
||||||
|
if len(captchaId) > 0 {
|
||||||
|
captchaCode := request.FormValue("TEAWEB_WAF_CAPTCHA_CODE")
|
||||||
|
if captcha.VerifyString(captchaId, captchaCode) {
|
||||||
|
// set cookie
|
||||||
|
timestamp := fmt.Sprintf("%d", time.Now().Unix()+CaptchaSeconds)
|
||||||
|
m := stringutil.Md5(captchaSalt + timestamp)
|
||||||
|
http.SetCookie(writer, &http.Cookie{
|
||||||
|
Name: "TEAWEB_WAF_CAPTCHA",
|
||||||
|
Value: m + timestamp,
|
||||||
|
MaxAge: CaptchaSeconds,
|
||||||
|
Path: "/", // all of dirs
|
||||||
|
})
|
||||||
|
|
||||||
|
rawURL := request.URL.Query().Get("url")
|
||||||
|
http.Redirect(writer, request.Raw(), rawURL, http.StatusSeeOther)
|
||||||
|
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
http.Redirect(writer, request.Raw(), request.URL.String(), http.StatusSeeOther)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
246
internal/waf/checkpoints/cc.go
Normal file
246
internal/waf/checkpoints/cc.go
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/grids"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"github.com/iwind/TeaGo/maps"
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
"net"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ${cc.arg}
|
||||||
|
// TODO implement more traffic rules
|
||||||
|
type CCCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
|
||||||
|
grid *grids.Grid
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *CCCheckpoint) Init() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *CCCheckpoint) Start() {
|
||||||
|
if this.grid != nil {
|
||||||
|
this.grid.Destroy()
|
||||||
|
}
|
||||||
|
this.grid = grids.NewGrid(32, grids.NewLimitCountOpt(1000_0000))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = 0
|
||||||
|
|
||||||
|
if this.grid == nil {
|
||||||
|
this.once.Do(func() {
|
||||||
|
this.Start()
|
||||||
|
})
|
||||||
|
if this.grid == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
periodString, ok := options["period"]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
period := types.Int64(periodString)
|
||||||
|
if period < 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
v, _ := options["userType"]
|
||||||
|
userType := types.String(v)
|
||||||
|
|
||||||
|
v, _ = options["userField"]
|
||||||
|
userField := types.String(v)
|
||||||
|
|
||||||
|
v, _ = options["userIndex"]
|
||||||
|
userIndex := types.Int(v)
|
||||||
|
|
||||||
|
if param == "requests" { // requests
|
||||||
|
var key = ""
|
||||||
|
switch userType {
|
||||||
|
case "ip":
|
||||||
|
key = this.ip(req)
|
||||||
|
case "cookie":
|
||||||
|
if len(userField) == 0 {
|
||||||
|
key = this.ip(req)
|
||||||
|
} else {
|
||||||
|
cookie, _ := req.Cookie(userField)
|
||||||
|
if cookie != nil {
|
||||||
|
v := cookie.Value
|
||||||
|
if userIndex > 0 && len(v) > userIndex {
|
||||||
|
v = v[userIndex:]
|
||||||
|
}
|
||||||
|
key = "USER@" + userType + "@" + userField + "@" + v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "get":
|
||||||
|
if len(userField) == 0 {
|
||||||
|
key = this.ip(req)
|
||||||
|
} else {
|
||||||
|
v := req.URL.Query().Get(userField)
|
||||||
|
if userIndex > 0 && len(v) > userIndex {
|
||||||
|
v = v[userIndex:]
|
||||||
|
}
|
||||||
|
key = "USER@" + userType + "@" + userField + "@" + v
|
||||||
|
}
|
||||||
|
case "post":
|
||||||
|
if len(userField) == 0 {
|
||||||
|
key = this.ip(req)
|
||||||
|
} else {
|
||||||
|
v := req.PostFormValue(userField)
|
||||||
|
if userIndex > 0 && len(v) > userIndex {
|
||||||
|
v = v[userIndex:]
|
||||||
|
}
|
||||||
|
key = "USER@" + userType + "@" + userField + "@" + v
|
||||||
|
}
|
||||||
|
case "header":
|
||||||
|
if len(userField) == 0 {
|
||||||
|
key = this.ip(req)
|
||||||
|
} else {
|
||||||
|
v := req.Header.Get(userField)
|
||||||
|
if userIndex > 0 && len(v) > userIndex {
|
||||||
|
v = v[userIndex:]
|
||||||
|
}
|
||||||
|
key = "USER@" + userType + "@" + userField + "@" + v
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
key = this.ip(req)
|
||||||
|
}
|
||||||
|
if len(key) == 0 {
|
||||||
|
key = this.ip(req)
|
||||||
|
}
|
||||||
|
value = this.grid.IncreaseInt64([]byte(key), 1, period)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *CCCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *CCCheckpoint) ParamOptions() *ParamOptions {
|
||||||
|
option := NewParamOptions()
|
||||||
|
option.AddParam("请求数", "requests")
|
||||||
|
return option
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *CCCheckpoint) Options() []OptionInterface {
|
||||||
|
options := []OptionInterface{}
|
||||||
|
|
||||||
|
// period
|
||||||
|
{
|
||||||
|
option := NewFieldOption("统计周期", "period")
|
||||||
|
option.Value = "60"
|
||||||
|
option.RightLabel = "秒"
|
||||||
|
option.Size = 8
|
||||||
|
option.MaxLength = 8
|
||||||
|
option.Validate = func(value string) (ok bool, message string) {
|
||||||
|
if regexp.MustCompile("^\\d+$").MatchString(value) {
|
||||||
|
ok = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
message = "周期需要是一个整数数字"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
options = append(options, option)
|
||||||
|
}
|
||||||
|
|
||||||
|
// type
|
||||||
|
{
|
||||||
|
option := NewOptionsOption("用户识别读取来源", "userType")
|
||||||
|
option.Size = 10
|
||||||
|
option.SetOptions([]maps.Map{
|
||||||
|
{
|
||||||
|
"name": "IP",
|
||||||
|
"value": "ip",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Cookie",
|
||||||
|
"value": "cookie",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "URL参数",
|
||||||
|
"value": "get",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "POST参数",
|
||||||
|
"value": "post",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "HTTP Header",
|
||||||
|
"value": "header",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
options = append(options, option)
|
||||||
|
}
|
||||||
|
|
||||||
|
// user field
|
||||||
|
{
|
||||||
|
option := NewFieldOption("用户识别字段", "userField")
|
||||||
|
option.Comment = "识别用户的唯一性字段,在用户读取来源不是IP时使用"
|
||||||
|
options = append(options, option)
|
||||||
|
}
|
||||||
|
|
||||||
|
// user value index
|
||||||
|
{
|
||||||
|
option := NewFieldOption("字段读取位置", "userIndex")
|
||||||
|
option.Size = 5
|
||||||
|
option.MaxLength = 5
|
||||||
|
option.Comment = "读取用户识别字段的位置,从0开始,比如user12345的数字ID 12345的位置就是5,在用户读取来源不是IP时使用"
|
||||||
|
options = append(options, option)
|
||||||
|
}
|
||||||
|
|
||||||
|
return options
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *CCCheckpoint) Stop() {
|
||||||
|
if this.grid != nil {
|
||||||
|
this.grid.Destroy()
|
||||||
|
this.grid = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *CCCheckpoint) ip(req *requests.Request) string {
|
||||||
|
// X-Forwarded-For
|
||||||
|
forwardedFor := req.Header.Get("X-Forwarded-For")
|
||||||
|
if len(forwardedFor) > 0 {
|
||||||
|
commaIndex := strings.Index(forwardedFor, ",")
|
||||||
|
if commaIndex > 0 {
|
||||||
|
return forwardedFor[:commaIndex]
|
||||||
|
}
|
||||||
|
return forwardedFor
|
||||||
|
}
|
||||||
|
|
||||||
|
// Real-IP
|
||||||
|
{
|
||||||
|
realIP, ok := req.Header["X-Real-IP"]
|
||||||
|
if ok && len(realIP) > 0 {
|
||||||
|
return realIP[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Real-Ip
|
||||||
|
{
|
||||||
|
realIP, ok := req.Header["X-Real-Ip"]
|
||||||
|
if ok && len(realIP) > 0 {
|
||||||
|
return realIP[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remote-Addr
|
||||||
|
host, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||||
|
if err == nil {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
return req.RemoteAddr
|
||||||
|
}
|
||||||
42
internal/waf/checkpoints/cc_test.go
Normal file
42
internal/waf/checkpoints/cc_test.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCCCheckpoint_RequestValue(t *testing.T) {
|
||||||
|
raw, err := http.NewRequest(http.MethodGet, "http://teaos.cn/", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := requests.NewRequest(raw)
|
||||||
|
req.RemoteAddr = "127.0.0.1"
|
||||||
|
|
||||||
|
checkpoint := new(CCCheckpoint)
|
||||||
|
checkpoint.Init()
|
||||||
|
checkpoint.Start()
|
||||||
|
|
||||||
|
options := map[string]string{
|
||||||
|
"period": "5",
|
||||||
|
}
|
||||||
|
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||||
|
|
||||||
|
req.RemoteAddr = "127.0.0.2"
|
||||||
|
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||||
|
|
||||||
|
req.RemoteAddr = "127.0.0.1"
|
||||||
|
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||||
|
|
||||||
|
req.RemoteAddr = "127.0.0.2"
|
||||||
|
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||||
|
|
||||||
|
req.RemoteAddr = "127.0.0.2"
|
||||||
|
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||||
|
|
||||||
|
req.RemoteAddr = "127.0.0.2"
|
||||||
|
t.Log(checkpoint.RequestValue(req, "requests", options))
|
||||||
|
}
|
||||||
28
internal/waf/checkpoints/checkpoint.go
Normal file
28
internal/waf/checkpoints/checkpoint.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
type Checkpoint struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Checkpoint) Init() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Checkpoint) IsRequest() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Checkpoint) ParamOptions() *ParamOptions {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Checkpoint) Options() []OptionInterface {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Checkpoint) Start() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Checkpoint) Stop() {
|
||||||
|
|
||||||
|
}
|
||||||
10
internal/waf/checkpoints/checkpoint_definition.go
Normal file
10
internal/waf/checkpoints/checkpoint_definition.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
// check point definition
|
||||||
|
type CheckpointDefinition struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Prefix string
|
||||||
|
HasParams bool // has sub params
|
||||||
|
Instance CheckpointInterface
|
||||||
|
}
|
||||||
32
internal/waf/checkpoints/checkpoint_interface.go
Normal file
32
internal/waf/checkpoints/checkpoint_interface.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check Point
|
||||||
|
type CheckpointInterface interface {
|
||||||
|
// initialize
|
||||||
|
Init()
|
||||||
|
|
||||||
|
// is request?
|
||||||
|
IsRequest() bool
|
||||||
|
|
||||||
|
// get request value
|
||||||
|
RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error)
|
||||||
|
|
||||||
|
// get response value
|
||||||
|
ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error)
|
||||||
|
|
||||||
|
// param option list
|
||||||
|
ParamOptions() *ParamOptions
|
||||||
|
|
||||||
|
// options
|
||||||
|
Options() []OptionInterface
|
||||||
|
|
||||||
|
// start
|
||||||
|
Start()
|
||||||
|
|
||||||
|
// stop
|
||||||
|
Stop()
|
||||||
|
}
|
||||||
5
internal/waf/checkpoints/option.go
Normal file
5
internal/waf/checkpoints/option.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
type OptionInterface interface {
|
||||||
|
Type() string
|
||||||
|
}
|
||||||
26
internal/waf/checkpoints/option_field.go
Normal file
26
internal/waf/checkpoints/option_field.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
// attach option
|
||||||
|
type FieldOption struct {
|
||||||
|
Name string
|
||||||
|
Code string
|
||||||
|
Value string // default value
|
||||||
|
IsRequired bool
|
||||||
|
Size int
|
||||||
|
Comment string
|
||||||
|
Placeholder string
|
||||||
|
RightLabel string
|
||||||
|
MaxLength int
|
||||||
|
Validate func(value string) (ok bool, message string)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFieldOption(name string, code string) *FieldOption {
|
||||||
|
return &FieldOption{
|
||||||
|
Name: name,
|
||||||
|
Code: code,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *FieldOption) Type() string {
|
||||||
|
return "field"
|
||||||
|
}
|
||||||
30
internal/waf/checkpoints/option_options.go
Normal file
30
internal/waf/checkpoints/option_options.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import "github.com/iwind/TeaGo/maps"
|
||||||
|
|
||||||
|
type OptionsOption struct {
|
||||||
|
Name string
|
||||||
|
Code string
|
||||||
|
Value string // default value
|
||||||
|
IsRequired bool
|
||||||
|
Size int
|
||||||
|
Comment string
|
||||||
|
RightLabel string
|
||||||
|
Validate func(value string) (ok bool, message string)
|
||||||
|
Options []maps.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOptionsOption(name string, code string) *OptionsOption {
|
||||||
|
return &OptionsOption{
|
||||||
|
Name: name,
|
||||||
|
Code: code,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *OptionsOption) Type() string {
|
||||||
|
return "options"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *OptionsOption) SetOptions(options []maps.Map) {
|
||||||
|
this.Options = options
|
||||||
|
}
|
||||||
21
internal/waf/checkpoints/param_option.go
Normal file
21
internal/waf/checkpoints/param_option.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
type KeyValue struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ParamOptions struct {
|
||||||
|
Options []*KeyValue `json:"options"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewParamOptions() *ParamOptions {
|
||||||
|
return &ParamOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ParamOptions) AddParam(name string, value string) {
|
||||||
|
this.Options = append(this.Options, &KeyValue{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
})
|
||||||
|
}
|
||||||
46
internal/waf/checkpoints/request_all.go
Normal file
46
internal/waf/checkpoints/request_all.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ${requestAll}
|
||||||
|
type RequestAllCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
valueBytes := []byte{}
|
||||||
|
if len(req.RequestURI) > 0 {
|
||||||
|
valueBytes = append(valueBytes, req.RequestURI...)
|
||||||
|
} else if req.URL != nil {
|
||||||
|
valueBytes = append(valueBytes, req.URL.RequestURI()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Body != nil {
|
||||||
|
valueBytes = append(valueBytes, ' ')
|
||||||
|
|
||||||
|
if len(req.BodyData) == 0 {
|
||||||
|
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||||
|
if err != nil {
|
||||||
|
return "", err, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req.BodyData = data
|
||||||
|
req.RestoreBody(data)
|
||||||
|
}
|
||||||
|
valueBytes = append(valueBytes, req.BodyData...)
|
||||||
|
}
|
||||||
|
|
||||||
|
value = valueBytes
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestAllCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = ""
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
70
internal/waf/checkpoints/request_all_test.go
Normal file
70
internal/waf/checkpoints/request_all_test.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestAllCheckpoint_RequestValue(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "http://teaos.cn/hello/world", bytes.NewBuffer([]byte("123456")))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint := new(RequestAllCheckpoint)
|
||||||
|
v, sysErr, userErr := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
|
||||||
|
if sysErr != nil {
|
||||||
|
t.Fatal(sysErr)
|
||||||
|
}
|
||||||
|
if userErr != nil {
|
||||||
|
t.Fatal(userErr)
|
||||||
|
}
|
||||||
|
t.Log(v)
|
||||||
|
t.Log(types.String(v))
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestAllCheckpoint_RequestValue_Max(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(strings.Repeat("123456", 10240000))))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint := new(RequestBodyCheckpoint)
|
||||||
|
value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("value bytes:", len(types.String(value)))
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("raw bytes:", len(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRequestAllCheckpoint_RequestValue(b *testing.B) {
|
||||||
|
runtime.GOMAXPROCS(1)
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "http://teaos.cn/hello/world", bytes.NewBuffer(bytes.Repeat([]byte("HELLO"), 1024)))
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint := new(RequestAllCheckpoint)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _ = checkpoint.RequestValue(requests.NewRequest(req), "", nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
20
internal/waf/checkpoints/request_arg.go
Normal file
20
internal/waf/checkpoints/request_arg.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestArgCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestArgCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
return req.URL.Query().Get(param), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
21
internal/waf/checkpoints/request_arg_test.go
Normal file
21
internal/waf/checkpoints/request_arg_test.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestArgParam_RequestValue(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/?name=lu", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := requests.NewRequest(rawReq)
|
||||||
|
|
||||||
|
checkpoint := new(RequestArgCheckpoint)
|
||||||
|
t.Log(checkpoint.RequestValue(req, "name", nil))
|
||||||
|
t.Log(checkpoint.ResponseValue(req, nil, "name", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "name2", nil))
|
||||||
|
}
|
||||||
21
internal/waf/checkpoints/request_args.go
Normal file
21
internal/waf/checkpoints/request_args.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestArgsCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestArgsCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = req.URL.RawQuery
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestArgsCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
36
internal/waf/checkpoints/request_body.go
Normal file
36
internal/waf/checkpoints/request_body.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ${requestBody}
|
||||||
|
type RequestBodyCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestBodyCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if req.Body == nil {
|
||||||
|
value = ""
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.BodyData) == 0 {
|
||||||
|
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||||
|
if err != nil {
|
||||||
|
return "", err, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req.BodyData = data
|
||||||
|
req.RestoreBody(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return req.BodyData, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
47
internal/waf/checkpoints/request_body_test.go
Normal file
47
internal/waf/checkpoints/request_body_test.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestBodyCheckpoint_RequestValue(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456")))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint := new(RequestBodyCheckpoint)
|
||||||
|
t.Log(checkpoint.RequestValue(requests.NewRequest(req), "", nil))
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(strings.Repeat("123456", 10240000))))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint := new(RequestBodyCheckpoint)
|
||||||
|
value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("value bytes:", len(types.String(value)))
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("raw bytes:", len(body))
|
||||||
|
}
|
||||||
21
internal/waf/checkpoints/request_content_type.go
Normal file
21
internal/waf/checkpoints/request_content_type.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestContentTypeCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestContentTypeCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = req.Header.Get("Content-Type")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestContentTypeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
27
internal/waf/checkpoints/request_cookie.go
Normal file
27
internal/waf/checkpoints/request_cookie.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestCookieCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
cookie, err := req.Cookie(param)
|
||||||
|
if err != nil {
|
||||||
|
value = ""
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
value = cookie.Value
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestCookieCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
27
internal/waf/checkpoints/request_cookies.go
Normal file
27
internal/waf/checkpoints/request_cookies.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestCookiesCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestCookiesCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
var cookies = []string{}
|
||||||
|
for _, cookie := range req.Cookies() {
|
||||||
|
cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value))
|
||||||
|
}
|
||||||
|
value = strings.Join(cookies, "&")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestCookiesCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
39
internal/waf/checkpoints/request_form_arg.go
Normal file
39
internal/waf/checkpoints/request_form_arg.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ${requestForm.arg}
|
||||||
|
type RequestFormArgCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestFormArgCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if req.Body == nil {
|
||||||
|
value = ""
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.BodyData) == 0 {
|
||||||
|
data, err := req.ReadBody(32 * 1024 * 1024) // read 32m bytes
|
||||||
|
if err != nil {
|
||||||
|
return "", err, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req.BodyData = data
|
||||||
|
req.RestoreBody(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO improve performance
|
||||||
|
values, _ := url.ParseQuery(string(req.BodyData))
|
||||||
|
return values.Get(param), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestFormArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
32
internal/waf/checkpoints/request_form_arg_test.go
Normal file
32
internal/waf/checkpoints/request_form_arg_test.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("name=lu&age=20&encoded="+url.QueryEscape("<strong>ENCODED STRING</strong>"))))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := requests.NewRequest(rawReq)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
checkpoint := new(RequestFormArgCheckpoint)
|
||||||
|
t.Log(checkpoint.RequestValue(req, "name", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "age", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "Hello", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "encoded", nil))
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(string(body))
|
||||||
|
}
|
||||||
27
internal/waf/checkpoints/request_header.go
Normal file
27
internal/waf/checkpoints/request_header.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestHeaderCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
v, found := req.Header[param]
|
||||||
|
if !found {
|
||||||
|
value = ""
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value = strings.Join(v, ";")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
30
internal/waf/checkpoints/request_headers.go
Normal file
30
internal/waf/checkpoints/request_headers.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestHeadersCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
var headers = []string{}
|
||||||
|
for k, v := range req.Header {
|
||||||
|
for _, subV := range v {
|
||||||
|
headers = append(headers, k+": "+subV)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(headers)
|
||||||
|
value = strings.Join(headers, "\n")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestHeadersCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
21
internal/waf/checkpoints/request_host.go
Normal file
21
internal/waf/checkpoints/request_host.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestHostCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestHostCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = req.Host
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestHostCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
20
internal/waf/checkpoints/request_host_test.go
Normal file
20
internal/waf/checkpoints/request_host_test.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestHostCheckpoint_RequestValue(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodGet, "https://teaos.cn/?name=lu", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := requests.NewRequest(rawReq)
|
||||||
|
req.Header.Set("Host", "cloud.teaos.cn")
|
||||||
|
|
||||||
|
checkpoint := new(RequestHostCheckpoint)
|
||||||
|
t.Log(checkpoint.RequestValue(req, "", nil))
|
||||||
|
}
|
||||||
44
internal/waf/checkpoints/request_json_arg.go
Normal file
44
internal/waf/checkpoints/request_json_arg.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ${requestJSON.arg}
|
||||||
|
type RequestJSONArgCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if len(req.BodyData) == 0 {
|
||||||
|
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
|
||||||
|
if err != nil {
|
||||||
|
return "", err, nil
|
||||||
|
}
|
||||||
|
req.BodyData = data
|
||||||
|
defer req.RestoreBody(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO improve performance
|
||||||
|
var m interface{} = nil
|
||||||
|
err := json.Unmarshal(req.BodyData, &m)
|
||||||
|
if err != nil || m == nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
value = utils.Get(m, strings.Split(param, "."))
|
||||||
|
if value != nil {
|
||||||
|
return value, nil, err
|
||||||
|
}
|
||||||
|
return "", nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestJSONArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
99
internal/waf/checkpoints/request_json_arg_test.go
Normal file
99
internal/waf/checkpoints/request_json_arg_test.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(`
|
||||||
|
{
|
||||||
|
"name": "lu",
|
||||||
|
"age": 20,
|
||||||
|
"books": [ "PHP", "Golang", "Python" ]
|
||||||
|
}
|
||||||
|
`)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := requests.NewRequest(rawReq)
|
||||||
|
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
checkpoint := new(RequestJSONArgCheckpoint)
|
||||||
|
t.Log(checkpoint.RequestValue(req, "name", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "age", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "Hello", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "books", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "books.1", nil))
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(`
|
||||||
|
[{
|
||||||
|
"name": "lu",
|
||||||
|
"age": 20,
|
||||||
|
"books": [ "PHP", "Golang", "Python" ]
|
||||||
|
}]
|
||||||
|
`)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := requests.NewRequest(rawReq)
|
||||||
|
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
checkpoint := new(RequestJSONArgCheckpoint)
|
||||||
|
t.Log(checkpoint.RequestValue(req, "0.name", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "0.age", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "0.Hello", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "0.books", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(`
|
||||||
|
[{
|
||||||
|
"name": "lu",
|
||||||
|
"age": 20,
|
||||||
|
"books": [ "PHP", "Golang", "Python" ]
|
||||||
|
}]
|
||||||
|
`)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := requests.NewRequest(rawReq)
|
||||||
|
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
checkpoint := new(RequestJSONArgCheckpoint)
|
||||||
|
t.Log(checkpoint.RequestValue(req, "0.name", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "0.age", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "0.Hello", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "0.books", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(string(body))
|
||||||
|
}
|
||||||
21
internal/waf/checkpoints/request_length.go
Normal file
21
internal/waf/checkpoints/request_length.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestLengthCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestLengthCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = req.ContentLength
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
21
internal/waf/checkpoints/request_method.go
Normal file
21
internal/waf/checkpoints/request_method.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestMethodCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestMethodCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = req.Method
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestMethodCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
20
internal/waf/checkpoints/request_path.go
Normal file
20
internal/waf/checkpoints/request_path.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestPathCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestPathCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
return req.URL.Path, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestPathCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
18
internal/waf/checkpoints/request_path_test.go
Normal file
18
internal/waf/checkpoints/request_path_test.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestPathCheckpoint_RequestValue(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/index?name=lu", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := requests.NewRequest(rawReq)
|
||||||
|
checkpoint := new(RequestPathCheckpoint)
|
||||||
|
t.Log(checkpoint.RequestValue(req, "", nil))
|
||||||
|
}
|
||||||
21
internal/waf/checkpoints/request_proto.go
Normal file
21
internal/waf/checkpoints/request_proto.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestProtoCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestProtoCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = req.Proto
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestProtoCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
27
internal/waf/checkpoints/request_raw_remote_addr.go
Normal file
27
internal/waf/checkpoints/request_raw_remote_addr.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestRawRemoteAddrCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
host, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||||
|
if err == nil {
|
||||||
|
value = host
|
||||||
|
} else {
|
||||||
|
value = req.RemoteAddr
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
21
internal/waf/checkpoints/request_referer.go
Normal file
21
internal/waf/checkpoints/request_referer.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestRefererCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestRefererCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = req.Referer()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestRefererCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
59
internal/waf/checkpoints/request_remote_addr.go
Normal file
59
internal/waf/checkpoints/request_remote_addr.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestRemoteAddrCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
// X-Forwarded-For
|
||||||
|
forwardedFor := req.Header.Get("X-Forwarded-For")
|
||||||
|
if len(forwardedFor) > 0 {
|
||||||
|
commaIndex := strings.Index(forwardedFor, ",")
|
||||||
|
if commaIndex > 0 {
|
||||||
|
value = forwardedFor[:commaIndex]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value = forwardedFor
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Real-IP
|
||||||
|
{
|
||||||
|
realIP, ok := req.Header["X-Real-IP"]
|
||||||
|
if ok && len(realIP) > 0 {
|
||||||
|
value = realIP[0]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Real-Ip
|
||||||
|
{
|
||||||
|
realIP, ok := req.Header["X-Real-Ip"]
|
||||||
|
if ok && len(realIP) > 0 {
|
||||||
|
value = realIP[0]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remote-Addr
|
||||||
|
host, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||||
|
if err == nil {
|
||||||
|
value = host
|
||||||
|
} else {
|
||||||
|
value = req.RemoteAddr
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
28
internal/waf/checkpoints/request_remote_port.go
Normal file
28
internal/waf/checkpoints/request_remote_port.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestRemotePortCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
_, port, err := net.SplitHostPort(req.RemoteAddr)
|
||||||
|
if err == nil {
|
||||||
|
value = types.Int(port)
|
||||||
|
} else {
|
||||||
|
value = 0
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestRemotePortCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
26
internal/waf/checkpoints/request_remote_user.go
Normal file
26
internal/waf/checkpoints/request_remote_user.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestRemoteUserCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
username, _, ok := req.BasicAuth()
|
||||||
|
if !ok {
|
||||||
|
value = ""
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value = username
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestRemoteUserCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
21
internal/waf/checkpoints/request_scheme.go
Normal file
21
internal/waf/checkpoints/request_scheme.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestSchemeCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestSchemeCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = req.URL.Scheme
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestSchemeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
18
internal/waf/checkpoints/request_scheme_test.go
Normal file
18
internal/waf/checkpoints/request_scheme_test.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestSchemeCheckpoint_RequestValue(t *testing.T) {
|
||||||
|
rawReq, err := http.NewRequest(http.MethodGet, "https://teaos.cn/?name=lu", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := requests.NewRequest(rawReq)
|
||||||
|
checkpoint := new(RequestSchemeCheckpoint)
|
||||||
|
t.Log(checkpoint.RequestValue(req, "", nil))
|
||||||
|
}
|
||||||
130
internal/waf/checkpoints/request_upload.go
Normal file
130
internal/waf/checkpoints/request_upload.go
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"github.com/iwind/TeaGo/lists"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ${requestUpload.arg}
|
||||||
|
type RequestUploadCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = ""
|
||||||
|
if param == "minSize" || param == "maxSize" {
|
||||||
|
value = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Method != http.MethodPost {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Body == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.MultipartForm == nil {
|
||||||
|
if len(req.BodyData) == 0 {
|
||||||
|
data, err := req.ReadBody(32 * 1024 * 1024)
|
||||||
|
if err != nil {
|
||||||
|
sysErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req.BodyData = data
|
||||||
|
defer req.RestoreBody(data)
|
||||||
|
}
|
||||||
|
oldBody := req.Body
|
||||||
|
req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyData))
|
||||||
|
|
||||||
|
err := req.ParseMultipartForm(32 * 1024 * 1024)
|
||||||
|
|
||||||
|
// 还原
|
||||||
|
req.Body = oldBody
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
userErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.MultipartForm == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if param == "field" { // field
|
||||||
|
fields := []string{}
|
||||||
|
for field := range req.MultipartForm.File {
|
||||||
|
fields = append(fields, field)
|
||||||
|
}
|
||||||
|
value = strings.Join(fields, ",")
|
||||||
|
} else if param == "minSize" { // minSize
|
||||||
|
minSize := int64(0)
|
||||||
|
for _, files := range req.MultipartForm.File {
|
||||||
|
for _, file := range files {
|
||||||
|
if minSize == 0 || minSize > file.Size {
|
||||||
|
minSize = file.Size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
value = minSize
|
||||||
|
} else if param == "maxSize" { // maxSize
|
||||||
|
maxSize := int64(0)
|
||||||
|
for _, files := range req.MultipartForm.File {
|
||||||
|
for _, file := range files {
|
||||||
|
if maxSize < file.Size {
|
||||||
|
maxSize = file.Size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
value = maxSize
|
||||||
|
} else if param == "name" { // name
|
||||||
|
names := []string{}
|
||||||
|
for _, files := range req.MultipartForm.File {
|
||||||
|
for _, file := range files {
|
||||||
|
if !lists.ContainsString(names, file.Filename) {
|
||||||
|
names = append(names, file.Filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
value = strings.Join(names, ",")
|
||||||
|
} else if param == "ext" { // ext
|
||||||
|
extensions := []string{}
|
||||||
|
for _, files := range req.MultipartForm.File {
|
||||||
|
for _, file := range files {
|
||||||
|
if len(file.Filename) > 0 {
|
||||||
|
exit := strings.ToLower(filepath.Ext(file.Filename))
|
||||||
|
if !lists.ContainsString(extensions, exit) {
|
||||||
|
extensions = append(extensions, exit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
value = strings.Join(extensions, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestUploadCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestUploadCheckpoint) ParamOptions() *ParamOptions {
|
||||||
|
option := NewParamOptions()
|
||||||
|
option.AddParam("最小文件尺寸", "minSize")
|
||||||
|
option.AddParam("最大文件尺寸", "maxSize")
|
||||||
|
option.AddParam("扩展名(如.txt)", "ext")
|
||||||
|
option.AddParam("原始文件名", "name")
|
||||||
|
option.AddParam("表单字段名", "field")
|
||||||
|
return option
|
||||||
|
}
|
||||||
81
internal/waf/checkpoints/request_upload_test.go
Normal file
81
internal/waf/checkpoints/request_upload_test.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"io/ioutil"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestUploadCheckpoint_RequestValue(t *testing.T) {
|
||||||
|
body := bytes.NewBuffer([]byte{})
|
||||||
|
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
|
||||||
|
{
|
||||||
|
part, err := writer.CreateFormField("name")
|
||||||
|
if err == nil {
|
||||||
|
part.Write([]byte("lu"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
part, err := writer.CreateFormField("age")
|
||||||
|
if err == nil {
|
||||||
|
part.Write([]byte("20"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
part, err := writer.CreateFormFile("myFile", "hello.txt")
|
||||||
|
if err == nil {
|
||||||
|
part.Write([]byte("Hello, World!"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
part, err := writer.CreateFormFile("myFile2", "hello.PHP")
|
||||||
|
if err == nil {
|
||||||
|
part.Write([]byte("Hello, World, PHP!"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
part, err := writer.CreateFormFile("myFile3", "hello.asp")
|
||||||
|
if err == nil {
|
||||||
|
part.Write([]byte("Hello, World, ASP Pages!"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
part, err := writer.CreateFormFile("myFile4", "hello.asp")
|
||||||
|
if err == nil {
|
||||||
|
part.Write([]byte("Hello, World, ASP Pages!"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.Close()
|
||||||
|
|
||||||
|
rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn/", body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal()
|
||||||
|
}
|
||||||
|
|
||||||
|
req := requests.NewRequest(rawReq)
|
||||||
|
req.Header.Add("Content-Type", writer.FormDataContentType())
|
||||||
|
|
||||||
|
checkpoint := new(RequestUploadCheckpoint)
|
||||||
|
t.Log(checkpoint.RequestValue(req, "field", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "minSize", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "maxSize", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "name", nil))
|
||||||
|
t.Log(checkpoint.RequestValue(req, "ext", nil))
|
||||||
|
|
||||||
|
data, err := ioutil.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(string(data))
|
||||||
|
}
|
||||||
25
internal/waf/checkpoints/request_uri.go
Normal file
25
internal/waf/checkpoints/request_uri.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestURICheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestURICheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if len(req.RequestURI) > 0 {
|
||||||
|
value = req.RequestURI
|
||||||
|
} else if req.URL != nil {
|
||||||
|
value = req.URL.RequestURI()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestURICheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
21
internal/waf/checkpoints/request_user_agent.go
Normal file
21
internal/waf/checkpoints/request_user_agent.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestUserAgentCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestUserAgentCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = req.UserAgent()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RequestUserAgentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
41
internal/waf/checkpoints/response_body.go
Normal file
41
internal/waf/checkpoints/response_body.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"io/ioutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ${responseBody}
|
||||||
|
type ResponseBodyCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ResponseBodyCheckpoint) IsRequest() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ResponseBodyCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = ""
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ResponseBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = ""
|
||||||
|
if resp != nil && resp.Body != nil {
|
||||||
|
if len(resp.BodyData) > 0 {
|
||||||
|
value = string(resp.BodyData)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
sysErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.BodyData = body
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
value = body
|
||||||
|
resp.Body = ioutil.NopCloser(bytes.NewBuffer(body))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
29
internal/waf/checkpoints/response_body_test.go
Normal file
29
internal/waf/checkpoints/response_body_test.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResponseBodyCheckpoint_ResponseValue(t *testing.T) {
|
||||||
|
resp := requests.NewResponse(new(http.Response))
|
||||||
|
resp.StatusCode = 200
|
||||||
|
resp.Header = http.Header{}
|
||||||
|
resp.Header.Set("Hello", "World")
|
||||||
|
resp.Body = ioutil.NopCloser(bytes.NewBuffer([]byte("Hello, World")))
|
||||||
|
|
||||||
|
checkpoint := new(ResponseBodyCheckpoint)
|
||||||
|
t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
|
||||||
|
t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
|
||||||
|
t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
|
||||||
|
t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
|
||||||
|
|
||||||
|
data, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("after read:", string(data))
|
||||||
|
}
|
||||||
27
internal/waf/checkpoints/response_bytes_sent.go
Normal file
27
internal/waf/checkpoints/response_bytes_sent.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ${bytesSent}
|
||||||
|
type ResponseBytesSentCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ResponseBytesSentCheckpoint) IsRequest() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ResponseBytesSentCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ResponseBytesSentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = 0
|
||||||
|
if resp != nil {
|
||||||
|
value = resp.ContentLength
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
28
internal/waf/checkpoints/response_header.go
Normal file
28
internal/waf/checkpoints/response_header.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ${responseHeader.arg}
|
||||||
|
type ResponseHeaderCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ResponseHeaderCheckpoint) IsRequest() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ResponseHeaderCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = ""
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ResponseHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if resp != nil && resp.Header != nil {
|
||||||
|
value = resp.Header.Get(param)
|
||||||
|
} else {
|
||||||
|
value = ""
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
17
internal/waf/checkpoints/response_header_test.go
Normal file
17
internal/waf/checkpoints/response_header_test.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResponseHeaderCheckpoint_ResponseValue(t *testing.T) {
|
||||||
|
resp := requests.NewResponse(new(http.Response))
|
||||||
|
resp.StatusCode = 200
|
||||||
|
resp.Header = http.Header{}
|
||||||
|
resp.Header.Set("Hello", "World")
|
||||||
|
|
||||||
|
checkpoint := new(ResponseHeaderCheckpoint)
|
||||||
|
t.Log(checkpoint.ResponseValue(nil, resp, "Hello", nil))
|
||||||
|
}
|
||||||
26
internal/waf/checkpoints/response_status.go
Normal file
26
internal/waf/checkpoints/response_status.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ${bytesSent}
|
||||||
|
type ResponseStatusCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ResponseStatusCheckpoint) IsRequest() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ResponseStatusCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
value = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ResponseStatusCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if resp != nil {
|
||||||
|
value = resp.StatusCode
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
15
internal/waf/checkpoints/response_status_test.go
Normal file
15
internal/waf/checkpoints/response_status_test.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestResponseStatusCheckpoint_ResponseValue(t *testing.T) {
|
||||||
|
resp := requests.NewResponse(new(http.Response))
|
||||||
|
resp.StatusCode = 200
|
||||||
|
|
||||||
|
checkpoint := new(ResponseStatusCheckpoint)
|
||||||
|
t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
|
||||||
|
}
|
||||||
21
internal/waf/checkpoints/sample_request.go
Normal file
21
internal/waf/checkpoints/sample_request.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
// just a sample checkpoint, copy and change it for your new checkpoint
|
||||||
|
type SampleRequestCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *SampleRequestCheckpoint) RequestValue(req *requests.Request, param string, options map[string]string) (value interface{}, sysErr error, userErr error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *SampleRequestCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]string) (value interface{}, sysErr error, userErr error) {
|
||||||
|
if this.IsRequest() {
|
||||||
|
return this.RequestValue(req, param, options)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
22
internal/waf/checkpoints/sample_response.go
Normal file
22
internal/waf/checkpoints/sample_response.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
// just a sample checkpoint, copy and change it for your new checkpoint
|
||||||
|
type SampleResponseCheckpoint struct {
|
||||||
|
Checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *SampleResponseCheckpoint) IsRequest() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *SampleResponseCheckpoint) RequestValue(req *requests.Request, param string, options map[string]string) (value interface{}, sysErr error, userErr error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *SampleResponseCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]string) (value interface{}, sysErr error, userErr error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
235
internal/waf/checkpoints/utils.go
Normal file
235
internal/waf/checkpoints/utils.go
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
// all check points list
|
||||||
|
var AllCheckpoints = []*CheckpointDefinition{
|
||||||
|
{
|
||||||
|
Name: "客户端地址(IP)",
|
||||||
|
Prefix: "remoteAddr",
|
||||||
|
Description: "试图通过分析X-Forwarded-For等Header获取的客户端地址,比如192.168.1.100",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestRemoteAddrCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "客户端源地址(IP)",
|
||||||
|
Prefix: "rawRemoteAddr",
|
||||||
|
Description: "直接连接的客户端地址,比如192.168.1.100",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestRawRemoteAddrCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "客户端端口",
|
||||||
|
Prefix: "remotePort",
|
||||||
|
Description: "直接连接的客户端地址端口",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestRemotePortCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "客户端用户名",
|
||||||
|
Prefix: "remoteUser",
|
||||||
|
Description: "通过BasicAuth登录的客户端用户名",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestRemoteUserCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "请求URI",
|
||||||
|
Prefix: "requestURI",
|
||||||
|
Description: "包含URL参数的请求URI,比如/hello/world?lang=go",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestURICheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "请求路径",
|
||||||
|
Prefix: "requestPath",
|
||||||
|
Description: "不包含URL参数的请求路径,比如/hello/world",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestPathCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "请求内容长度",
|
||||||
|
Prefix: "requestLength",
|
||||||
|
Description: "请求Header中的Content-Length",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestLengthCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "请求体内容",
|
||||||
|
Prefix: "requestBody",
|
||||||
|
Description: "通常在POST或者PUT等操作时会附带请求体,最大限制32M",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestBodyCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "请求URI和请求体组合",
|
||||||
|
Prefix: "requestAll",
|
||||||
|
Description: "${requestURI}和${requestBody}组合",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestAllCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "请求表单参数",
|
||||||
|
Prefix: "requestForm",
|
||||||
|
Description: "获取POST或者其他方法发送的表单参数,最大请求体限制32M",
|
||||||
|
HasParams: true,
|
||||||
|
Instance: new(RequestFormArgCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "上传文件",
|
||||||
|
Prefix: "requestUpload",
|
||||||
|
Description: "获取POST上传的文件信息,最大请求体限制32M",
|
||||||
|
HasParams: true,
|
||||||
|
Instance: new(RequestUploadCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "请求JSON参数",
|
||||||
|
Prefix: "requestJSON",
|
||||||
|
Description: "获取POST或者其他方法发送的JSON,最大请求体限制32M,使用点(.)符号表示多级数据",
|
||||||
|
HasParams: true,
|
||||||
|
Instance: new(RequestJSONArgCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "请求方法",
|
||||||
|
Prefix: "requestMethod",
|
||||||
|
Description: "比如GET、POST",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestMethodCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "请求协议",
|
||||||
|
Prefix: "scheme",
|
||||||
|
Description: "比如http或https",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestSchemeCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "HTTP协议版本",
|
||||||
|
Prefix: "proto",
|
||||||
|
Description: "比如HTTP/1.1",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestProtoCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "主机名",
|
||||||
|
Prefix: "host",
|
||||||
|
Description: "比如teaos.cn",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestHostCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "请求来源URL",
|
||||||
|
Prefix: "referer",
|
||||||
|
Description: "请求Header中的Referer值",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestRefererCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "客户端信息",
|
||||||
|
Prefix: "userAgent",
|
||||||
|
Description: "比如Mozilla/5.0 AppleWebKit/537.36 (KHTML, like Gecko) Chrome/73.0.3683.103",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestUserAgentCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "内容类型",
|
||||||
|
Prefix: "contentType",
|
||||||
|
Description: "请求Header的Content-Type",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestContentTypeCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "所有cookie组合字符串",
|
||||||
|
Prefix: "cookies",
|
||||||
|
Description: "比如sid=IxZVPFhE&city=beijing&uid=18237",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestCookiesCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "单个cookie值",
|
||||||
|
Prefix: "cookie",
|
||||||
|
Description: "单个cookie值",
|
||||||
|
HasParams: true,
|
||||||
|
Instance: new(RequestCookieCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "所有URL参数组合",
|
||||||
|
Prefix: "args",
|
||||||
|
Description: "比如name=lu&age=20",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestArgsCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "单个URL参数值",
|
||||||
|
Prefix: "arg",
|
||||||
|
Description: "单个URL参数值",
|
||||||
|
HasParams: true,
|
||||||
|
Instance: new(RequestArgCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "所有Header信息",
|
||||||
|
Prefix: "headers",
|
||||||
|
Description: "使用\n隔开的Header信息字符串",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(RequestHeadersCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "单个Header值",
|
||||||
|
Prefix: "header",
|
||||||
|
Description: "单个Header值",
|
||||||
|
HasParams: true,
|
||||||
|
Instance: new(RequestHeaderCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "CC统计",
|
||||||
|
Prefix: "cc",
|
||||||
|
Description: "统计某段时间段内的请求信息",
|
||||||
|
HasParams: true,
|
||||||
|
Instance: new(CCCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "响应状态码",
|
||||||
|
Prefix: "status",
|
||||||
|
Description: "响应状态码,比如200、404、500",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(ResponseStatusCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "响应Header",
|
||||||
|
Prefix: "responseHeader",
|
||||||
|
Description: "响应Header值",
|
||||||
|
HasParams: true,
|
||||||
|
Instance: new(ResponseHeaderCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "响应内容",
|
||||||
|
Prefix: "responseBody",
|
||||||
|
Description: "响应内容字符串",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(ResponseBodyCheckpoint),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "响应内容长度",
|
||||||
|
Prefix: "bytesSent",
|
||||||
|
Description: "响应内容长度,通过响应的Header Content-Length获取",
|
||||||
|
HasParams: false,
|
||||||
|
Instance: new(ResponseBytesSentCheckpoint),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// find a check point
|
||||||
|
func FindCheckpoint(prefix string) CheckpointInterface {
|
||||||
|
for _, def := range AllCheckpoints {
|
||||||
|
if def.Prefix == prefix {
|
||||||
|
return def.Instance
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// find a check point definition
|
||||||
|
func FindCheckpointDefinition(prefix string) *CheckpointDefinition {
|
||||||
|
for _, def := range AllCheckpoints {
|
||||||
|
if def.Prefix == prefix {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
31
internal/waf/checkpoints/utils_test.go
Normal file
31
internal/waf/checkpoints/utils_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package checkpoints
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFindCheckpointDefinition_Markdown(t *testing.T) {
|
||||||
|
result := []string{}
|
||||||
|
for _, def := range AllCheckpoints {
|
||||||
|
row := "## " + def.Name + "\n* 前缀:`${" + def.Prefix + "}`\n* 描述:" + def.Description
|
||||||
|
if def.HasParams {
|
||||||
|
row += "\n* 是否有子参数:YES"
|
||||||
|
|
||||||
|
paramOptions := def.Instance.ParamOptions()
|
||||||
|
if paramOptions != nil && len(paramOptions.Options) > 0 {
|
||||||
|
row += "\n* 可选子参数"
|
||||||
|
for _, option := range paramOptions.Options {
|
||||||
|
row += "\n * `" + option.Name + "`:值为 `" + option.Value + "`"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
row += "\n* 是否有子参数:NO"
|
||||||
|
}
|
||||||
|
row += "\n"
|
||||||
|
result = append(result, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Print(strings.Join(result, "\n") + "\n")
|
||||||
|
}
|
||||||
154
internal/waf/ip_table.go
Normal file
154
internal/waf/ip_table.go
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
|
||||||
|
"github.com/iwind/TeaGo/lists"
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
stringutil "github.com/iwind/TeaGo/utils/string"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IPAction = string
|
||||||
|
|
||||||
|
var RegexpDigitNumber = regexp.MustCompile("^\\d+$")
|
||||||
|
|
||||||
|
const (
|
||||||
|
IPActionAccept IPAction = "accept"
|
||||||
|
IPActionReject IPAction = "reject"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ip table
|
||||||
|
type IPTable struct {
|
||||||
|
Id string `yaml:"id" json:"id"`
|
||||||
|
On bool `yaml:"on" json:"on"`
|
||||||
|
IP string `yaml:"ip" json:"ip"` // single ip, cidr, ip range, TODO support *
|
||||||
|
Port string `yaml:"port" json:"port"` // single port, range, *
|
||||||
|
Action IPAction `yaml:"action" json:"action"` // accept, reject
|
||||||
|
TimeFrom int64 `yaml:"timeFrom" json:"timeFrom"` // from timestamp
|
||||||
|
TimeTo int64 `yaml:"timeTo" json:"timeTo"` // zero means forever
|
||||||
|
Remark string `yaml:"remark" json:"remark"`
|
||||||
|
|
||||||
|
// port
|
||||||
|
minPort int
|
||||||
|
maxPort int
|
||||||
|
|
||||||
|
minPortWildcard bool
|
||||||
|
maxPortWildcard bool
|
||||||
|
|
||||||
|
ports []int
|
||||||
|
|
||||||
|
// ip
|
||||||
|
ipRange *shared.IPRangeConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIPTable() *IPTable {
|
||||||
|
return &IPTable{
|
||||||
|
On: true,
|
||||||
|
Id: stringutil.Rand(16),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *IPTable) Init() error {
|
||||||
|
// parse port
|
||||||
|
if RegexpDigitNumber.MatchString(this.Port) {
|
||||||
|
this.minPort = types.Int(this.Port)
|
||||||
|
this.maxPort = types.Int(this.Port)
|
||||||
|
} else if regexp.MustCompile(`[:-]`).MatchString(this.Port) {
|
||||||
|
pieces := regexp.MustCompile(`[:-]`).Split(this.Port, 2)
|
||||||
|
if pieces[0] == "*" {
|
||||||
|
this.minPortWildcard = true
|
||||||
|
} else {
|
||||||
|
this.minPort = types.Int(pieces[0])
|
||||||
|
}
|
||||||
|
if pieces[1] == "*" {
|
||||||
|
this.maxPortWildcard = true
|
||||||
|
} else {
|
||||||
|
this.maxPort = types.Int(pieces[1])
|
||||||
|
}
|
||||||
|
} else if strings.Contains(this.Port, ",") {
|
||||||
|
pieces := strings.Split(this.Port, ",")
|
||||||
|
for _, piece := range pieces {
|
||||||
|
piece = strings.TrimSpace(piece)
|
||||||
|
if len(piece) > 0 {
|
||||||
|
this.ports = append(this.ports, types.Int(piece))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if this.Port == "*" {
|
||||||
|
this.minPortWildcard = true
|
||||||
|
this.maxPortWildcard = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse ip
|
||||||
|
if len(this.IP) > 0 {
|
||||||
|
ipRange, err := shared.ParseIPRange(this.IP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
this.ipRange = ipRange
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check ip
|
||||||
|
func (this *IPTable) Match(ip string, port int) (isMatched bool) {
|
||||||
|
if !this.On {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if this.TimeFrom > 0 && now < this.TimeFrom {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if this.TimeTo > 0 && now > this.TimeTo {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !this.matchPort(port) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !this.matchIP(ip) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *IPTable) matchPort(port int) bool {
|
||||||
|
if port == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if this.minPortWildcard {
|
||||||
|
if this.maxPortWildcard {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if this.maxPort >= port {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if this.maxPortWildcard {
|
||||||
|
if this.minPortWildcard {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if this.minPort <= port {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (this.minPort > 0 || this.maxPort > 0) && this.minPort <= port && this.maxPort >= port {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(this.ports) > 0 {
|
||||||
|
return lists.ContainsInt(this.ports, port)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *IPTable) matchIP(ip string) bool {
|
||||||
|
if this.ipRange == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return this.ipRange.Contains(ip)
|
||||||
|
}
|
||||||
142
internal/waf/ip_table_test.go
Normal file
142
internal/waf/ip_table_test.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/iwind/TeaGo/assert"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIPTable_MatchIP(t *testing.T) {
|
||||||
|
a := assert.NewAssertion(t)
|
||||||
|
|
||||||
|
{
|
||||||
|
table := NewIPTable()
|
||||||
|
err := table.Init()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
a.IsFalse(table.Match("192.168.1.100", 8080))
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
table := NewIPTable()
|
||||||
|
table.IP = "*"
|
||||||
|
table.Port = "8080"
|
||||||
|
err := table.Init()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("port:", table.minPort, table.maxPort)
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||||
|
a.IsFalse(table.Match("192.168.1.100", 8081))
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
table := NewIPTable()
|
||||||
|
table.IP = "*"
|
||||||
|
table.Port = "8080-8082"
|
||||||
|
err := table.Init()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("port:", table.minPort, table.maxPort)
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8081))
|
||||||
|
a.IsFalse(table.Match("192.168.1.100", 8083))
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
table := NewIPTable()
|
||||||
|
table.IP = "*"
|
||||||
|
table.Port = "*-8082"
|
||||||
|
err := table.Init()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("port:", table.minPort, table.maxPort)
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8079))
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8081))
|
||||||
|
a.IsFalse(table.Match("192.168.1.100", 8083))
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
table := NewIPTable()
|
||||||
|
table.IP = "*"
|
||||||
|
table.Port = "8080-*"
|
||||||
|
err := table.Init()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("port:", table.minPort, table.maxPort)
|
||||||
|
a.IsFalse(table.Match("192.168.1.100", 8079))
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8081))
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8083))
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
table := NewIPTable()
|
||||||
|
table.IP = "*"
|
||||||
|
table.Port = "*"
|
||||||
|
err := table.Init()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("port:", table.minPort, table.maxPort)
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8079))
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8081))
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8083))
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
table := NewIPTable()
|
||||||
|
table.IP = "192.168.1.100"
|
||||||
|
table.Port = "*"
|
||||||
|
err := table.Init()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("port:", table.minPort, table.maxPort)
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
table := NewIPTable()
|
||||||
|
table.IP = "192.168.1.99-192.168.1.101"
|
||||||
|
table.Port = "*"
|
||||||
|
err := table.Init()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("port:", table.minPort, table.maxPort)
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
table := NewIPTable()
|
||||||
|
table.IP = "192.168.1.99/24"
|
||||||
|
table.Port = "*"
|
||||||
|
err := table.Init()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("ip:", table.ipRange)
|
||||||
|
a.IsTrue(table.Match("192.168.1.100", 8080))
|
||||||
|
a.IsFalse(table.Match("192.168.2.100", 8080))
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
table := NewIPTable()
|
||||||
|
table.IP = "192.168.1.99/24"
|
||||||
|
table.TimeTo = time.Now().Unix() - 10
|
||||||
|
table.Port = "*"
|
||||||
|
err := table.Init()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
a.IsFalse(table.Match("192.168.1.100", 8080))
|
||||||
|
a.IsFalse(table.Match("192.168.2.100", 8080))
|
||||||
|
}
|
||||||
|
}
|
||||||
35
internal/waf/requests/request.go
Normal file
35
internal/waf/requests/request.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package requests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Request struct {
|
||||||
|
*http.Request
|
||||||
|
BodyData []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRequest(raw *http.Request) *Request {
|
||||||
|
return &Request{
|
||||||
|
Request: raw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Request) Raw() *http.Request {
|
||||||
|
return this.Request
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Request) ReadBody(max int64) (data []byte, err error) {
|
||||||
|
data, err = ioutil.ReadAll(io.LimitReader(this.Request.Body, max))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Request) RestoreBody(data []byte) {
|
||||||
|
rawReader := bytes.NewBuffer(data)
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
io.CopyBuffer(rawReader, this.Request.Body, buf)
|
||||||
|
this.Request.Body = ioutil.NopCloser(rawReader)
|
||||||
|
}
|
||||||
15
internal/waf/requests/response.go
Normal file
15
internal/waf/requests/response.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package requests
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type Response struct {
|
||||||
|
*http.Response
|
||||||
|
|
||||||
|
BodyData []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewResponse(resp *http.Response) *Response {
|
||||||
|
return &Response{
|
||||||
|
Response: resp,
|
||||||
|
}
|
||||||
|
}
|
||||||
584
internal/waf/rule.go
Normal file
584
internal/waf/rule.go
Normal file
@@ -0,0 +1,584 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
|
||||||
|
"github.com/iwind/TeaGo/lists"
|
||||||
|
"github.com/iwind/TeaGo/maps"
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
"github.com/iwind/TeaGo/utils/string"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var singleParamRegexp = regexp.MustCompile("^\\${[\\w.-]+}$")
|
||||||
|
|
||||||
|
// rule
|
||||||
|
type Rule struct {
|
||||||
|
Description string `yaml:"description" json:"description"`
|
||||||
|
Param string `yaml:"param" json:"param"` // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName}
|
||||||
|
Operator RuleOperator `yaml:"operator" json:"operator"` // such as contains, gt, ...
|
||||||
|
Value string `yaml:"value" json:"value"` // compared value
|
||||||
|
IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"`
|
||||||
|
CheckpointOptions map[string]interface{} `yaml:"checkpointOptions" json:"checkpointOptions"`
|
||||||
|
|
||||||
|
checkpointFinder func(prefix string) checkpoints.CheckpointInterface
|
||||||
|
|
||||||
|
singleParam string // real param after prefix
|
||||||
|
singleCheckpoint checkpoints.CheckpointInterface // if is single check point
|
||||||
|
|
||||||
|
multipleCheckpoints map[string]checkpoints.CheckpointInterface
|
||||||
|
|
||||||
|
isIP bool
|
||||||
|
ipValue net.IP
|
||||||
|
|
||||||
|
floatValue float64
|
||||||
|
reg *regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRule() *Rule {
|
||||||
|
return &Rule{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) Init() error {
|
||||||
|
// operator
|
||||||
|
switch this.Operator {
|
||||||
|
case RuleOperatorGt:
|
||||||
|
this.floatValue = types.Float64(this.Value)
|
||||||
|
case RuleOperatorGte:
|
||||||
|
this.floatValue = types.Float64(this.Value)
|
||||||
|
case RuleOperatorLt:
|
||||||
|
this.floatValue = types.Float64(this.Value)
|
||||||
|
case RuleOperatorLte:
|
||||||
|
this.floatValue = types.Float64(this.Value)
|
||||||
|
case RuleOperatorEq:
|
||||||
|
this.floatValue = types.Float64(this.Value)
|
||||||
|
case RuleOperatorNeq:
|
||||||
|
this.floatValue = types.Float64(this.Value)
|
||||||
|
case RuleOperatorMatch:
|
||||||
|
v := this.Value
|
||||||
|
if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") {
|
||||||
|
v = "(?i)" + v
|
||||||
|
}
|
||||||
|
|
||||||
|
v = this.unescape(v)
|
||||||
|
|
||||||
|
reg, err := regexp.Compile(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
this.reg = reg
|
||||||
|
case RuleOperatorNotMatch:
|
||||||
|
v := this.Value
|
||||||
|
if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") {
|
||||||
|
v = "(?i)" + v
|
||||||
|
}
|
||||||
|
|
||||||
|
v = this.unescape(v)
|
||||||
|
|
||||||
|
reg, err := regexp.Compile(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
this.reg = reg
|
||||||
|
case RuleOperatorEqIP, RuleOperatorGtIP, RuleOperatorGteIP, RuleOperatorLtIP, RuleOperatorLteIP:
|
||||||
|
this.ipValue = net.ParseIP(this.Value)
|
||||||
|
this.isIP = this.ipValue != nil
|
||||||
|
|
||||||
|
if !this.isIP {
|
||||||
|
return errors.New("value should be a valid ip")
|
||||||
|
}
|
||||||
|
case RuleOperatorIPRange, RuleOperatorNotIPRange:
|
||||||
|
if strings.Contains(this.Value, ",") {
|
||||||
|
ipList := strings.SplitN(this.Value, ",", 2)
|
||||||
|
ipString1 := strings.TrimSpace(ipList[0])
|
||||||
|
ipString2 := strings.TrimSpace(ipList[1])
|
||||||
|
|
||||||
|
if len(ipString1) > 0 {
|
||||||
|
ip1 := net.ParseIP(ipString1)
|
||||||
|
if ip1 == nil {
|
||||||
|
return errors.New("start ip is invalid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ipString2) > 0 {
|
||||||
|
ip2 := net.ParseIP(ipString2)
|
||||||
|
if ip2 == nil {
|
||||||
|
return errors.New("end ip is invalid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if strings.Contains(this.Value, "/") {
|
||||||
|
_, _, err := net.ParseCIDR(this.Value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return errors.New("invalid ip range")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
if singleParamRegexp.MatchString(this.Param) {
|
||||||
|
param := this.Param[2 : len(this.Param)-1]
|
||||||
|
pieces := strings.SplitN(param, ".", 2)
|
||||||
|
prefix := pieces[0]
|
||||||
|
if len(pieces) == 1 {
|
||||||
|
this.singleParam = ""
|
||||||
|
} else {
|
||||||
|
this.singleParam = pieces[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if this.checkpointFinder != nil {
|
||||||
|
checkpoint := this.checkpointFinder(prefix)
|
||||||
|
if checkpoint == nil {
|
||||||
|
return errors.New("no check point '" + prefix + "' found")
|
||||||
|
}
|
||||||
|
this.singleCheckpoint = checkpoint
|
||||||
|
} else {
|
||||||
|
checkpoint := checkpoints.FindCheckpoint(prefix)
|
||||||
|
if checkpoint == nil {
|
||||||
|
return errors.New("no check point '" + prefix + "' found")
|
||||||
|
}
|
||||||
|
checkpoint.Init()
|
||||||
|
this.singleCheckpoint = checkpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
this.multipleCheckpoints = map[string]checkpoints.CheckpointInterface{}
|
||||||
|
var err error = nil
|
||||||
|
configutils.ParseVariables(this.Param, func(varName string) (value string) {
|
||||||
|
pieces := strings.SplitN(varName, ".", 2)
|
||||||
|
prefix := pieces[0]
|
||||||
|
if this.checkpointFinder != nil {
|
||||||
|
checkpoint := this.checkpointFinder(prefix)
|
||||||
|
if checkpoint == nil {
|
||||||
|
err = errors.New("no check point '" + prefix + "' found")
|
||||||
|
} else {
|
||||||
|
this.multipleCheckpoints[prefix] = checkpoint
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
checkpoint := checkpoints.FindCheckpoint(prefix)
|
||||||
|
if checkpoint == nil {
|
||||||
|
err = errors.New("no check point '" + prefix + "' found")
|
||||||
|
} else {
|
||||||
|
checkpoint.Init()
|
||||||
|
this.multipleCheckpoints[prefix] = checkpoint
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) MatchRequest(req *requests.Request) (b bool, err error) {
|
||||||
|
if this.singleCheckpoint != nil {
|
||||||
|
value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return this.Test(value), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
value := configutils.ParseVariables(this.Param, func(varName string) (value string) {
|
||||||
|
pieces := strings.SplitN(varName, ".", 2)
|
||||||
|
prefix := pieces[0]
|
||||||
|
point, ok := this.multipleCheckpoints[prefix]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pieces) == 1 {
|
||||||
|
value1, err1, _ := point.RequestValue(req, "", this.CheckpointOptions)
|
||||||
|
if err1 != nil {
|
||||||
|
err = err1
|
||||||
|
}
|
||||||
|
return types.String(value1)
|
||||||
|
}
|
||||||
|
|
||||||
|
value1, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions)
|
||||||
|
if err1 != nil {
|
||||||
|
err = err1
|
||||||
|
}
|
||||||
|
return types.String(value1)
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.Test(value), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, err error) {
|
||||||
|
if this.singleCheckpoint != nil {
|
||||||
|
// if is request param
|
||||||
|
if this.singleCheckpoint.IsRequest() {
|
||||||
|
value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return this.Test(value), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// response param
|
||||||
|
value, err, _ := this.singleCheckpoint.ResponseValue(req, resp, this.singleParam, this.CheckpointOptions)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return this.Test(value), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
value := configutils.ParseVariables(this.Param, func(varName string) (value string) {
|
||||||
|
pieces := strings.SplitN(varName, ".", 2)
|
||||||
|
prefix := pieces[0]
|
||||||
|
point, ok := this.multipleCheckpoints[prefix]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pieces) == 1 {
|
||||||
|
if point.IsRequest() {
|
||||||
|
value1, err1, _ := point.RequestValue(req, "", this.CheckpointOptions)
|
||||||
|
if err1 != nil {
|
||||||
|
err = err1
|
||||||
|
}
|
||||||
|
return types.String(value1)
|
||||||
|
} else {
|
||||||
|
value1, err1, _ := point.ResponseValue(req, resp, "", this.CheckpointOptions)
|
||||||
|
if err1 != nil {
|
||||||
|
err = err1
|
||||||
|
}
|
||||||
|
return types.String(value1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if point.IsRequest() {
|
||||||
|
value1, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions)
|
||||||
|
if err1 != nil {
|
||||||
|
err = err1
|
||||||
|
}
|
||||||
|
return types.String(value1)
|
||||||
|
} else {
|
||||||
|
value1, err1, _ := point.ResponseValue(req, resp, pieces[1], this.CheckpointOptions)
|
||||||
|
if err1 != nil {
|
||||||
|
err = err1
|
||||||
|
}
|
||||||
|
return types.String(value1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.Test(value), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) Test(value interface{}) bool {
|
||||||
|
// operator
|
||||||
|
switch this.Operator {
|
||||||
|
case RuleOperatorGt:
|
||||||
|
return types.Float64(value) > this.floatValue
|
||||||
|
case RuleOperatorGte:
|
||||||
|
return types.Float64(value) >= this.floatValue
|
||||||
|
case RuleOperatorLt:
|
||||||
|
return types.Float64(value) < this.floatValue
|
||||||
|
case RuleOperatorLte:
|
||||||
|
return types.Float64(value) <= this.floatValue
|
||||||
|
case RuleOperatorEq:
|
||||||
|
return types.Float64(value) == this.floatValue
|
||||||
|
case RuleOperatorNeq:
|
||||||
|
return types.Float64(value) != this.floatValue
|
||||||
|
case RuleOperatorEqString:
|
||||||
|
if this.IsCaseInsensitive {
|
||||||
|
return strings.ToLower(types.String(value)) == strings.ToLower(this.Value)
|
||||||
|
} else {
|
||||||
|
return types.String(value) == this.Value
|
||||||
|
}
|
||||||
|
case RuleOperatorNeqString:
|
||||||
|
if this.IsCaseInsensitive {
|
||||||
|
return strings.ToLower(types.String(value)) != strings.ToLower(this.Value)
|
||||||
|
} else {
|
||||||
|
return types.String(value) != this.Value
|
||||||
|
}
|
||||||
|
case RuleOperatorMatch:
|
||||||
|
if value == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// strings
|
||||||
|
stringList, ok := value.([]string)
|
||||||
|
if ok {
|
||||||
|
for _, s := range stringList {
|
||||||
|
if utils.MatchStringCache(this.reg, s) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// bytes
|
||||||
|
byteSlice, ok := value.([]byte)
|
||||||
|
if ok {
|
||||||
|
return utils.MatchBytesCache(this.reg, byteSlice)
|
||||||
|
}
|
||||||
|
|
||||||
|
// string
|
||||||
|
return utils.MatchStringCache(this.reg, types.String(value))
|
||||||
|
case RuleOperatorNotMatch:
|
||||||
|
if value == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
stringList, ok := value.([]string)
|
||||||
|
if ok {
|
||||||
|
for _, s := range stringList {
|
||||||
|
if utils.MatchStringCache(this.reg, s) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// bytes
|
||||||
|
byteSlice, ok := value.([]byte)
|
||||||
|
if ok {
|
||||||
|
return !utils.MatchBytesCache(this.reg, byteSlice)
|
||||||
|
}
|
||||||
|
|
||||||
|
return !utils.MatchStringCache(this.reg, types.String(value))
|
||||||
|
case RuleOperatorContains:
|
||||||
|
if types.IsSlice(value) {
|
||||||
|
ok := false
|
||||||
|
lists.Each(value, func(k int, v interface{}) {
|
||||||
|
if types.String(v) == this.Value {
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
if types.IsMap(value) {
|
||||||
|
lowerValue := ""
|
||||||
|
if this.IsCaseInsensitive {
|
||||||
|
lowerValue = strings.ToLower(this.Value)
|
||||||
|
}
|
||||||
|
for _, v := range maps.NewMap(value) {
|
||||||
|
if this.IsCaseInsensitive {
|
||||||
|
if strings.ToLower(types.String(v)) == lowerValue {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if types.String(v) == this.Value {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if this.IsCaseInsensitive {
|
||||||
|
return strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
|
||||||
|
} else {
|
||||||
|
return strings.Contains(types.String(value), this.Value)
|
||||||
|
}
|
||||||
|
case RuleOperatorNotContains:
|
||||||
|
if this.IsCaseInsensitive {
|
||||||
|
return !strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
|
||||||
|
} else {
|
||||||
|
return !strings.Contains(types.String(value), this.Value)
|
||||||
|
}
|
||||||
|
case RuleOperatorPrefix:
|
||||||
|
if this.IsCaseInsensitive {
|
||||||
|
return strings.HasPrefix(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
|
||||||
|
} else {
|
||||||
|
return strings.HasPrefix(types.String(value), this.Value)
|
||||||
|
}
|
||||||
|
case RuleOperatorSuffix:
|
||||||
|
if this.IsCaseInsensitive {
|
||||||
|
return strings.HasSuffix(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
|
||||||
|
} else {
|
||||||
|
return strings.HasSuffix(types.String(value), this.Value)
|
||||||
|
}
|
||||||
|
case RuleOperatorHasKey:
|
||||||
|
if types.IsSlice(value) {
|
||||||
|
index := types.Int(this.Value)
|
||||||
|
if index < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(value).Len() > index
|
||||||
|
} else if types.IsMap(value) {
|
||||||
|
m := maps.NewMap(value)
|
||||||
|
if this.IsCaseInsensitive {
|
||||||
|
lowerValue := strings.ToLower(this.Value)
|
||||||
|
for k := range m {
|
||||||
|
if strings.ToLower(k) == lowerValue {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return m.Has(this.Value)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
case RuleOperatorVersionGt:
|
||||||
|
return stringutil.VersionCompare(this.Value, types.String(value)) > 0
|
||||||
|
case RuleOperatorVersionLt:
|
||||||
|
return stringutil.VersionCompare(this.Value, types.String(value)) < 0
|
||||||
|
case RuleOperatorVersionRange:
|
||||||
|
if strings.Contains(this.Value, ",") {
|
||||||
|
versions := strings.SplitN(this.Value, ",", 2)
|
||||||
|
version1 := strings.TrimSpace(versions[0])
|
||||||
|
version2 := strings.TrimSpace(versions[1])
|
||||||
|
if len(version1) > 0 && stringutil.VersionCompare(types.String(value), version1) < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(version2) > 0 && stringutil.VersionCompare(types.String(value), version2) > 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
} else {
|
||||||
|
return stringutil.VersionCompare(types.String(value), this.Value) >= 0
|
||||||
|
}
|
||||||
|
case RuleOperatorEqIP:
|
||||||
|
ip := net.ParseIP(types.String(value))
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return this.isIP && bytes.Compare(this.ipValue, ip) == 0
|
||||||
|
case RuleOperatorGtIP:
|
||||||
|
ip := net.ParseIP(types.String(value))
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return this.isIP && bytes.Compare(ip, this.ipValue) > 0
|
||||||
|
case RuleOperatorGteIP:
|
||||||
|
ip := net.ParseIP(types.String(value))
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return this.isIP && bytes.Compare(ip, this.ipValue) >= 0
|
||||||
|
case RuleOperatorLtIP:
|
||||||
|
ip := net.ParseIP(types.String(value))
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return this.isIP && bytes.Compare(ip, this.ipValue) < 0
|
||||||
|
case RuleOperatorLteIP:
|
||||||
|
ip := net.ParseIP(types.String(value))
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return this.isIP && bytes.Compare(ip, this.ipValue) <= 0
|
||||||
|
case RuleOperatorIPRange:
|
||||||
|
return this.containsIP(value)
|
||||||
|
case RuleOperatorNotIPRange:
|
||||||
|
return !this.containsIP(value)
|
||||||
|
case RuleOperatorIPMod:
|
||||||
|
pieces := strings.SplitN(this.Value, ",", 2)
|
||||||
|
if len(pieces) == 1 {
|
||||||
|
rem := types.Int64(pieces[0])
|
||||||
|
return this.ipToInt64(net.ParseIP(types.String(value)))%10 == rem
|
||||||
|
}
|
||||||
|
div := types.Int64(pieces[0])
|
||||||
|
if div == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
rem := types.Int64(pieces[1])
|
||||||
|
return this.ipToInt64(net.ParseIP(types.String(value)))%div == rem
|
||||||
|
case RuleOperatorIPMod10:
|
||||||
|
return this.ipToInt64(net.ParseIP(types.String(value)))%10 == types.Int64(this.Value)
|
||||||
|
case RuleOperatorIPMod100:
|
||||||
|
return this.ipToInt64(net.ParseIP(types.String(value)))%100 == types.Int64(this.Value)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) IsSingleCheckpoint() bool {
|
||||||
|
return this.singleCheckpoint != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) SetCheckpointFinder(finder func(prefix string) checkpoints.CheckpointInterface) {
|
||||||
|
this.checkpointFinder = finder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) unescape(v string) string {
|
||||||
|
//replace urlencoded characters
|
||||||
|
v = strings.Replace(v, `\s`, `(\s|%09|%0A|\+)`, -1)
|
||||||
|
v = strings.Replace(v, `\(`, `(\(|%28)`, -1)
|
||||||
|
v = strings.Replace(v, `=`, `(=|%3D)`, -1)
|
||||||
|
v = strings.Replace(v, `<`, `(<|%3C)`, -1)
|
||||||
|
v = strings.Replace(v, `\*`, `(\*|%2A)`, -1)
|
||||||
|
v = strings.Replace(v, `\\`, `(\\|%2F)`, -1)
|
||||||
|
v = strings.Replace(v, `!`, `(!|%21)`, -1)
|
||||||
|
v = strings.Replace(v, `/`, `(/|%2F)`, -1)
|
||||||
|
v = strings.Replace(v, `;`, `(;|%3B)`, -1)
|
||||||
|
v = strings.Replace(v, `\+`, `(\+|%20)`, -1)
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) containsIP(value interface{}) bool {
|
||||||
|
ip := net.ParseIP(types.String(value))
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查IP范围格式
|
||||||
|
if strings.Contains(this.Value, ",") {
|
||||||
|
ipList := strings.SplitN(this.Value, ",", 2)
|
||||||
|
ipString1 := strings.TrimSpace(ipList[0])
|
||||||
|
ipString2 := strings.TrimSpace(ipList[1])
|
||||||
|
|
||||||
|
if len(ipString1) > 0 {
|
||||||
|
ip1 := net.ParseIP(ipString1)
|
||||||
|
if ip1 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Compare(ip, ip1) < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ipString2) > 0 {
|
||||||
|
ip2 := net.ParseIP(ipString2)
|
||||||
|
if ip2 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Compare(ip, ip2) > 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
} else if strings.Contains(this.Value, "/") {
|
||||||
|
_, ipNet, err := net.ParseCIDR(this.Value)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return ipNet.Contains(ip)
|
||||||
|
} else {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) ipToInt64(ip net.IP) int64 {
|
||||||
|
if len(ip) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if len(ip) == 16 {
|
||||||
|
return int64(binary.BigEndian.Uint32(ip[12:16]))
|
||||||
|
}
|
||||||
|
return int64(binary.BigEndian.Uint32(ip))
|
||||||
|
}
|
||||||
147
internal/waf/rule_group.go
Normal file
147
internal/waf/rule_group.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package waf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
// rule group
|
||||||
|
type RuleGroup struct {
|
||||||
|
Id string `yaml:"id" json:"id"`
|
||||||
|
IsOn bool `yaml:"isOn" json:"isOn"`
|
||||||
|
Name string `yaml:"name" json:"name"` // such as SQL Injection
|
||||||
|
Description string `yaml:"description" json:"description"`
|
||||||
|
Code string `yaml:"code" json:"code"` // identify the group
|
||||||
|
RuleSets []*RuleSet `yaml:"ruleSets" json:"ruleSets"`
|
||||||
|
IsInbound bool `yaml:"isInbound" json:"isInbound"`
|
||||||
|
|
||||||
|
hasRuleSets bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRuleGroup() *RuleGroup {
|
||||||
|
return &RuleGroup{
|
||||||
|
IsOn: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RuleGroup) Init() error {
|
||||||
|
this.hasRuleSets = len(this.RuleSets) > 0
|
||||||
|
|
||||||
|
if this.hasRuleSets {
|
||||||
|
for _, set := range this.RuleSets {
|
||||||
|
err := set.Init()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RuleGroup) AddRuleSet(ruleSet *RuleSet) {
|
||||||
|
this.RuleSets = append(this.RuleSets, ruleSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RuleGroup) FindRuleSet(id string) *RuleSet {
|
||||||
|
if len(id) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, ruleSet := range this.RuleSets {
|
||||||
|
if ruleSet.Id == id {
|
||||||
|
return ruleSet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RuleGroup) FindRuleSetWithCode(code string) *RuleSet {
|
||||||
|
if len(code) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, ruleSet := range this.RuleSets {
|
||||||
|
if ruleSet.Code == code {
|
||||||
|
return ruleSet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RuleGroup) RemoveRuleSet(id string) {
|
||||||
|
if len(id) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result := []*RuleSet{}
|
||||||
|
for _, ruleSet := range this.RuleSets {
|
||||||
|
if ruleSet.Id == id {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, ruleSet)
|
||||||
|
}
|
||||||
|
this.RuleSets = result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RuleGroup) MatchRequest(req *requests.Request) (b bool, set *RuleSet, err error) {
|
||||||
|
if !this.hasRuleSets {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, set := range this.RuleSets {
|
||||||
|
if !set.IsOn {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b, err = set.MatchRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
if b {
|
||||||
|
return true, set, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RuleGroup) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, set *RuleSet, err error) {
|
||||||
|
if !this.hasRuleSets {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, set := range this.RuleSets {
|
||||||
|
if !set.IsOn {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b, err = set.MatchResponse(req, resp)
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, err
|
||||||
|
}
|
||||||
|
if b {
|
||||||
|
return true, set, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *RuleGroup) MoveRuleSet(fromIndex int, toIndex int) {
|
||||||
|
if fromIndex < 0 || fromIndex >= len(this.RuleSets) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if toIndex < 0 || toIndex >= len(this.RuleSets) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if fromIndex == toIndex {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
location := this.RuleSets[fromIndex]
|
||||||
|
result := []*RuleSet{}
|
||||||
|
for i := 0; i < len(this.RuleSets); i++ {
|
||||||
|
if i == fromIndex {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if fromIndex > toIndex && i == toIndex {
|
||||||
|
result = append(result, location)
|
||||||
|
}
|
||||||
|
result = append(result, this.RuleSets[i])
|
||||||
|
if fromIndex < toIndex && i == toIndex {
|
||||||
|
result = append(result, location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this.RuleSets = result
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user