实现WAF

This commit is contained in:
GoEdgeLab
2020-10-08 15:06:42 +08:00
parent b4cfc33875
commit 4245c73c47
110 changed files with 8179 additions and 3 deletions

2
go.mod
View File

@@ -7,6 +7,7 @@ replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon
require (
github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect
github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000
github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f
github.com/dchest/siphash v1.2.1
github.com/go-ole/go-ole v1.2.4 // indirect
github.com/go-yaml/yaml v2.1.0+incompatible
@@ -14,4 +15,5 @@ require (
github.com/shirou/gopsutil v2.20.9+incompatible
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7
google.golang.org/grpc v1.32.0
gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776
)

2
go.sum
View File

@@ -14,6 +14,8 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f h1:q/DpyjJjZs94bziQ7YkBmIlpqbVP7yw179rnzoNVX1M=
github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f/go.mod h1:QGrK8vMWWHQYQ3QU9bw9Y9OPNfxccGzfb41qjvVeXtY=
github.com/dchest/siphash v1.2.1 h1:4cLinnzVJDKxTCl9B01807Yiy+W7ZzVHj/KIroQRvT4=
github.com/dchest/siphash v1.2.1/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4=
github.com/dgryski/go-rendezvous v0.0.0-20200624174652-8d2f3be8b2d9/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=

2
internal/grids/README.md Normal file
View File

@@ -0,0 +1,2 @@
# Memory Grid
Cache items in memory, using partitions and LRU.

186
internal/grids/cell.go Normal file
View File

@@ -0,0 +1,186 @@
package grids
import (
"math"
"sync"
"time"
)
type Cell struct {
LimitSize int64
LimitCount int
mapping map[uint64]*Item // key => item
list *List // { item1, item2, ... }
totalBytes int64
locker sync.RWMutex
}
func NewCell() *Cell {
return &Cell{
mapping: map[uint64]*Item{},
list: NewList(),
}
}
func (this *Cell) Write(hashKey uint64, item *Item) {
if item == nil {
return
}
this.locker.Lock()
oldItem, ok := this.mapping[hashKey]
if ok {
this.list.Remove(oldItem)
if this.LimitSize > 0 {
this.totalBytes -= oldItem.Size()
}
}
// limit count
if this.LimitCount > 0 && len(this.mapping) >= this.LimitCount {
this.locker.Unlock()
return
}
// trim memory
size := item.Size()
shouldTrim := false
if this.LimitSize > 0 && this.LimitSize < this.totalBytes+size {
this.Trim()
shouldTrim = true
}
// compare again
if shouldTrim {
if this.LimitSize > 0 && this.LimitSize < this.totalBytes+size {
this.locker.Unlock()
return
}
}
this.totalBytes += size
this.list.Add(item)
this.mapping[hashKey] = item
this.locker.Unlock()
}
func (this *Cell) Increase64(key []byte, expireAt int64, hashKey uint64, delta int64) (result int64) {
this.locker.Lock()
item, ok := this.mapping[hashKey]
if ok {
// reset to zero if expired
if item.ExpireAt < time.Now().Unix() {
item.ValueInt64 = 0
item.ExpireAt = expireAt
}
item.IncreaseInt64(delta)
result = item.ValueInt64
} else {
item := NewItem(key, ItemInt64)
item.ValueInt64 = delta
item.ExpireAt = expireAt
this.mapping[hashKey] = item
result = delta
}
this.locker.Unlock()
return
}
func (this *Cell) Read(hashKey uint64) *Item {
this.locker.Lock()
item, ok := this.mapping[hashKey]
if ok {
this.list.Remove(item)
this.list.Add(item)
this.locker.Unlock()
if item.ExpireAt < time.Now().Unix() {
return nil
}
return item
}
this.locker.Unlock()
return nil
}
func (this *Cell) Stat() *CellStat {
this.locker.RLock()
defer this.locker.RUnlock()
return &CellStat{
TotalBytes: this.totalBytes,
CountItems: len(this.mapping),
}
}
// trim NOT ACTIVE items
// should called in locker context
func (this *Cell) Trim() {
l := len(this.mapping)
if l == 0 {
return
}
inactiveSize := int(math.Ceil(float64(l) / 10)) // trim 10% items
this.list.Range(func(item *Item) (goNext bool) {
inactiveSize--
delete(this.mapping, item.HashKey())
this.list.Remove(item)
this.totalBytes -= item.Size()
return inactiveSize > 0
})
}
func (this *Cell) Delete(hashKey uint64) {
this.locker.Lock()
item, ok := this.mapping[hashKey]
if ok {
delete(this.mapping, hashKey)
this.list.Remove(item)
this.totalBytes -= item.Size()
}
this.locker.Unlock()
}
// range all items in the cell
func (this *Cell) Range(f func(item *Item)) {
this.locker.Lock()
for _, item := range this.mapping {
f(item)
}
this.locker.Unlock()
}
func (this *Cell) Recycle() {
this.locker.Lock()
if len(this.mapping) == 0 {
this.locker.Unlock()
return
}
timestamp := time.Now().Unix()
for key, item := range this.mapping {
if item.ExpireAt < timestamp {
delete(this.mapping, key)
this.list.Remove(item)
this.totalBytes -= item.Size()
}
}
this.locker.Unlock()
}
func (this *Cell) Reset() {
this.locker.Lock()
this.list.Reset()
this.mapping = map[uint64]*Item{}
this.totalBytes = 0
this.locker.Unlock()
}

View File

@@ -0,0 +1,6 @@
package grids
type CellStat struct {
TotalBytes int64
CountItems int
}

214
internal/grids/cell_test.go Normal file
View File

@@ -0,0 +1,214 @@
package grids
import (
"fmt"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
)
func TestCell_List(t *testing.T) {
cell := NewCell()
cell.Write(1, &Item{
ValueInt64: 1,
})
cell.Write(2, &Item{
ValueInt64: 2,
})
cell.Write(3, &Item{
ValueInt64: 3,
})
{
t.Log("====")
l := cell.list
for e := l.head; e != nil; e = e.Next {
t.Log("element:", e.ValueInt64)
}
}
cell.Write(1, &Item{
ValueInt64: 1,
})
cell.Write(3, &Item{
ValueInt64: 3,
})
cell.Write(3, &Item{
ValueInt64: 3,
})
cell.Write(2, &Item{
ValueInt64: 2,
})
cell.Delete(2)
{
t.Log("====")
l := cell.list
for e := l.head; e != nil; e = e.Next {
t.Log("element:", e.ValueInt64)
}
}
for _, m := range cell.mapping {
t.Log(m.ValueInt64)
}
}
func TestCell_LimitSize(t *testing.T) {
cell := NewCell()
cell.LimitSize = 1024
for i := int64(0); i < 100; i ++ {
key := []byte(fmt.Sprintf("%d", i))
cell.Write(HashKey(key), &Item{
Key: key,
ValueInt64: i,
Type: ItemInt64,
})
}
t.Log("totalBytes:", cell.totalBytes)
{
t.Log("====")
l := cell.list
s := []string{}
for e := l.head; e != nil; e = e.Next {
s = append(s, fmt.Sprintf("%d", e.ValueInt64))
}
t.Log("{ " + strings.Join(s, ", ") + " }")
}
t.Log("mapping:", len(cell.mapping))
s := []string{}
for _, item := range cell.mapping {
s = append(s, fmt.Sprintf("%d", item.ValueInt64))
}
t.Log("{ " + strings.Join(s, ", ") + " }")
}
func TestCell_MemoryUsage(t *testing.T) {
//runtime.GOMAXPROCS(4)
cell := NewCell()
cell.LimitSize = 1024 * 1024 * 1024 * 1
before := time.Now()
wg := sync.WaitGroup{}
wg.Add(4)
for j := 0; j < 4; j ++ {
go func(j int) {
start := j * 50 * 10000
for i := start; i < start+50*10000; i ++ {
key := []byte(strconv.Itoa(i) + "VERY_LONG_STRING")
cell.Write(HashKey(key), &Item{
Key: key,
ValueInt64: int64(i),
Type: ItemInt64,
})
}
wg.Done()
}(j)
}
wg.Wait()
t.Log("items:", len(cell.mapping))
t.Log(time.Since(before).Seconds(), "s", "totalBytes:", cell.totalBytes/1024/1024, "M")
//time.Sleep(10 * time.Second)
}
func BenchmarkCell_Write(b *testing.B) {
runtime.GOMAXPROCS(1)
cell := NewCell()
for i := 0; i < b.N; i ++ {
key := []byte(strconv.Itoa(i) + "_LONG_KEY_LONG_KEY_LONG_KEY_LONG_KEY")
cell.Write(HashKey(key), &Item{
Key: key,
ValueInt64: int64(i),
Type: ItemInt64,
})
}
b.Log("items:", len(cell.mapping))
}
func TestCell_Read(t *testing.T) {
cell := NewCell()
cell.Write(1, &Item{
ValueInt64: 1,
ExpireAt: time.Now().Unix() + 3600,
})
cell.Write(2, &Item{
ValueInt64: 2,
ExpireAt: time.Now().Unix() + 3600,
})
cell.Write(3, &Item{
ValueInt64: 3,
ExpireAt: time.Now().Unix() + 3600,
})
{
s := []string{}
cell.list.Range(func(item *Item) (goNext bool) {
s = append(s, fmt.Sprintf("%d", item.ValueInt64))
return true
})
t.Log("before:", s)
}
t.Log(cell.Read(1).ValueInt64)
{
s := []string{}
cell.list.Range(func(item *Item) (goNext bool) {
s = append(s, fmt.Sprintf("%d", item.ValueInt64))
return true
})
t.Log("after:", s)
}
t.Log(cell.Read(2).ValueInt64)
{
s := []string{}
cell.list.Range(func(item *Item) (goNext bool) {
s = append(s, fmt.Sprintf("%d", item.ValueInt64))
return true
})
t.Log("after:", s)
}
}
func TestCell_Recycle(t *testing.T) {
cell := NewCell()
cell.Write(1, &Item{
ValueInt64: 1,
ExpireAt: time.Now().Unix() - 1,
})
cell.Write(2, &Item{
ValueInt64: 2,
ExpireAt: time.Now().Unix() + 1,
})
cell.Recycle()
{
s := []string{}
cell.list.Range(func(item *Item) (goNext bool) {
s = append(s, fmt.Sprintf("%d", item.ValueInt64))
return true
})
t.Log("after:", s)
}
t.Log(cell.list.Len(), cell.totalBytes)
}

225
internal/grids/grid.go Normal file
View File

@@ -0,0 +1,225 @@
package grids
import (
"bytes"
"compress/gzip"
"github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/timers"
"math"
"time"
)
// Memory Cache Grid
//
// | Grid |
// | cell1, cell2, ..., cell1024 |
// | item1, item2, ..., item1000000 |
type Grid struct {
cells []*Cell
countCells uint64
recycleIndex int
recycleLooper *timers.Looper
recycleInterval int
gzipLevel int
limitSize int64
limitCount int
}
func NewGrid(countCells int, opt ...interface{}) *Grid {
grid := &Grid{
recycleIndex: -1,
}
for _, o := range opt {
switch x := o.(type) {
case *CompressOpt:
grid.gzipLevel = x.Level
case *LimitSizeOpt:
grid.limitSize = x.Size
case *LimitCountOpt:
grid.limitCount = x.Count
case *RecycleIntervalOpt:
grid.recycleInterval = x.Interval
}
}
cells := []*Cell{}
if countCells <= 0 {
countCells = 1
} else if countCells > 100*10000 {
countCells = 100 * 10000
}
for i := 0; i < countCells; i++ {
cell := NewCell()
cell.LimitSize = int64(math.Floor(float64(grid.limitSize) / float64(countCells)))
cell.LimitCount = int(math.Floor(float64(grid.limitCount)) / float64(countCells))
cells = append(cells, cell)
}
grid.cells = cells
grid.countCells = uint64(len(cells))
grid.recycleTimer()
return grid
}
// get all cells in the grid
func (this *Grid) Cells() []*Cell {
return this.cells
}
func (this *Grid) WriteItem(item *Item) {
if this.countCells <= 0 {
return
}
hashKey := item.HashKey()
this.cellForHashKey(hashKey).Write(hashKey, item)
}
func (this *Grid) WriteInt64(key []byte, value int64, lifeSeconds int64) {
this.WriteItem(&Item{
Key: key,
Type: ItemInt64,
ValueInt64: value,
ExpireAt: time.Now().Unix() + lifeSeconds,
})
}
func (this *Grid) IncreaseInt64(key []byte, delta int64, lifeSeconds int64) (result int64) {
hashKey := HashKey(key)
return this.cellForHashKey(hashKey).Increase64(key, time.Now().Unix()+lifeSeconds, hashKey, delta)
}
func (this *Grid) WriteString(key []byte, value string, lifeSeconds int64) {
this.WriteBytes(key, []byte(value), lifeSeconds)
}
func (this *Grid) WriteBytes(key []byte, value []byte, lifeSeconds int64) {
isCompressed := false
if this.gzipLevel != gzip.NoCompression {
buf := bytes.NewBuffer([]byte{})
writer, err := gzip.NewWriterLevel(buf, this.gzipLevel)
if err != nil {
logs.Error(err)
this.WriteItem(&Item{
Key: key,
Type: ItemBytes,
ValueBytes: value,
ExpireAt: time.Now().Unix() + lifeSeconds,
})
return
}
_, err = writer.Write([]byte(value))
if err != nil {
logs.Error(err)
this.WriteItem(&Item{
Key: key,
Type: ItemBytes,
ValueBytes: value,
ExpireAt: time.Now().Unix() + lifeSeconds,
})
return
}
err = writer.Close()
if err != nil {
logs.Error(err)
this.WriteItem(&Item{
Key: key,
Type: ItemBytes,
ValueBytes: value,
ExpireAt: time.Now().Unix() + lifeSeconds,
})
return
}
value = buf.Bytes()
isCompressed = true
}
this.WriteItem(&Item{
Key: key,
Type: ItemBytes,
ValueBytes: value,
ExpireAt: time.Now().Unix() + lifeSeconds,
IsCompressed: isCompressed,
})
}
func (this *Grid) WriteInterface(key []byte, value interface{}, lifeSeconds int64) {
this.WriteItem(&Item{
Key: key,
Type: ItemInterface,
ValueInterface: value,
ExpireAt: time.Now().Unix() + lifeSeconds,
IsCompressed: false,
})
}
func (this *Grid) Read(key []byte) *Item {
if this.countCells <= 0 {
return nil
}
hashKey := HashKey(key)
return this.cellForHashKey(hashKey).Read(hashKey)
}
func (this *Grid) Stat() *Stat {
stat := &Stat{}
for _, cell := range this.cells {
cellStat := cell.Stat()
stat.CountItems += cellStat.CountItems
stat.TotalBytes += cellStat.TotalBytes
}
return stat
}
func (this *Grid) Delete(key []byte) {
if this.countCells <= 0 {
return
}
hashKey := HashKey(key)
this.cellForHashKey(hashKey).Delete(hashKey)
}
func (this *Grid) Reset() {
for _, cell := range this.cells {
cell.Reset()
}
}
func (this *Grid) Destroy() {
if this.recycleLooper != nil {
this.recycleLooper.Stop()
this.recycleLooper = nil
}
this.cells = nil
}
func (this *Grid) cellForHashKey(hashKey uint64) *Cell {
if hashKey < 0 {
return this.cells[-hashKey%this.countCells]
} else {
return this.cells[hashKey%this.countCells]
}
}
func (this *Grid) recycleTimer() {
duration := 1 * time.Minute
if this.recycleInterval > 0 {
duration = time.Duration(this.recycleInterval) * time.Second
}
this.recycleLooper = timers.Loop(duration, func(looper *timers.Looper) {
if this.countCells == 0 {
return
}
this.recycleIndex++
if this.recycleIndex > int(this.countCells-1) {
this.recycleIndex = 0
}
this.cells[this.recycleIndex].Recycle()
})
}

204
internal/grids/grid_test.go Normal file
View File

@@ -0,0 +1,204 @@
package grids
import (
"compress/gzip"
"fmt"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
)
func TestMemoryGrid_Write(t *testing.T) {
grid := NewGrid(5, NewRecycleIntervalOpt(2), NewLimitSizeOpt(10240))
t.Log("123456:", grid.Read([]byte("123456")))
grid.WriteInt64([]byte("abc"), 1, 5)
t.Log(grid.Read([]byte("abc")).ValueInt64)
grid.WriteString([]byte("abc"), "123", 5)
t.Log(string(grid.Read([]byte("abc")).Bytes()))
grid.WriteBytes([]byte("abc"), []byte("123"), 5)
t.Log(grid.Read([]byte("abc")).Bytes())
grid.Delete([]byte("abc"))
t.Log(grid.Read([]byte("abc")))
for i := 0; i < 100; i++ {
grid.WriteInt64([]byte(fmt.Sprintf("%d", i)), 123, 1)
}
t.Log("before recycle:")
for index, cell := range grid.cells {
t.Log("cell:", index, len(cell.mapping), "items")
}
time.Sleep(3 * time.Second)
t.Log("after recycle:")
for index, cell := range grid.cells {
t.Log("cell:", index, len(cell.mapping), "items")
}
grid.Destroy()
}
func TestMemoryGrid_Write_LimitCount(t *testing.T) {
grid := NewGrid(2, NewLimitCountOpt(10))
for i := 0; i < 100; i++ {
grid.WriteInt64([]byte(strconv.Itoa(i)), int64(i), 30)
}
t.Log(grid.Stat().CountItems, "items")
}
func TestMemoryGrid_Compress(t *testing.T) {
grid := NewGrid(5, NewCompressOpt(1))
grid.WriteString([]byte("hello"), strings.Repeat("abcd", 10240), 30)
t.Log(len(string(grid.Read([]byte("hello")).String())))
t.Log(len(grid.Read([]byte("hello")).ValueBytes))
}
func BenchmarkMemoryGrid_Performance(b *testing.B) {
grid := NewGrid(1024)
for i := 0; i < b.N; i++ {
grid.WriteInt64([]byte("key:"+strconv.Itoa(i)), int64(i), 3600)
}
}
func TestMemoryGrid_Performance(t *testing.T) {
runtime.GOMAXPROCS(1)
grid := NewGrid(1024)
now := time.Now()
s := []byte(strings.Repeat("abcd", 10*1024))
for i := 0; i < 100000; i++ {
grid.WriteBytes([]byte(fmt.Sprintf("key:%d_%d", i, 1)), s, 3600)
item := grid.Read([]byte(fmt.Sprintf("key:%d_%d", i, 1)))
if item != nil {
_ = item.String()
}
}
countItems := 0
for _, cell := range grid.cells {
countItems += len(cell.mapping)
}
t.Log(countItems, "items")
t.Log(time.Since(now).Seconds()*1000, "ms")
}
func TestMemoryGrid_Performance_Concurrent(t *testing.T) {
//runtime.GOMAXPROCS(1)
grid := NewGrid(1024)
now := time.Now()
s := []byte(strings.Repeat("abcd", 10*1024))
wg := sync.WaitGroup{}
wg.Add(runtime.NumCPU())
for c := 0; c < runtime.NumCPU(); c++ {
go func(c int) {
defer wg.Done()
for i := 0; i < 50000; i++ {
grid.WriteBytes([]byte(fmt.Sprintf("key:%d_%d", i, c)), s, 3600)
item := grid.Read([]byte(fmt.Sprintf("key:%d_%d", i, c)))
if item != nil {
_ = item.String()
}
}
}(c)
}
wg.Wait()
countItems := 0
for _, cell := range grid.cells {
countItems += len(cell.mapping)
}
t.Log(countItems, "items")
t.Log(time.Since(now).Seconds()*1000, "ms")
}
func TestMemoryGrid_CompressPerformance(t *testing.T) {
runtime.GOMAXPROCS(1)
grid := NewGrid(1024, NewCompressOpt(gzip.BestCompression))
now := time.Now()
data := []byte(strings.Repeat("abcd", 1024))
for i := 0; i < 100000; i++ {
grid.WriteBytes([]byte(fmt.Sprintf("key:%d", i)), data, 3600)
item := grid.Read([]byte(fmt.Sprintf("key:%d", i+100)))
if item != nil {
_ = item.String()
}
}
countItems := 0
for _, cell := range grid.cells {
countItems += len(cell.mapping)
}
t.Log(countItems, "items")
t.Log(time.Since(now).Seconds()*1000, "ms")
}
func TestMemoryGrid_IncreaseInt64(t *testing.T) {
grid := NewGrid(1024)
grid.WriteInt64([]byte("abc"), 123, 10)
grid.IncreaseInt64([]byte("abc"), 123, 10)
grid.IncreaseInt64([]byte("abc"), 123, 10)
item := grid.Read([]byte("abc"))
if item == nil {
t.Fatal("item == nil")
}
if item.ValueInt64 != 369 {
t.Fatal("not 369")
}
}
func TestMemoryGrid_Destroy(t *testing.T) {
grid := NewGrid(1024)
grid.WriteInt64([]byte("abc"), 123, 10)
t.Log(grid.recycleLooper, grid.cells)
grid.Destroy()
t.Log(grid.recycleLooper, grid.cells)
if grid.recycleLooper != nil {
t.Fatal("looper != nil")
}
}
func TestMemoryGrid_Recycle(t *testing.T) {
cell := NewCell()
timestamp := time.Now().Unix()
for i := 0; i < 300_0000; i++ {
cell.Write(uint64(i), &Item{
ExpireAt: timestamp - 30,
})
}
before := time.Now()
cell.Recycle()
t.Log(time.Since(before).Seconds()*1000, "ms")
t.Log(len(cell.mapping))
runtime.GC()
printMem(t)
}
func printMem(t *testing.T) {
mem := &runtime.MemStats{}
runtime.ReadMemStats(mem)
t.Log(mem.Alloc/1024/1024, "M")
}

88
internal/grids/item.go Normal file
View File

@@ -0,0 +1,88 @@
package grids
import (
"bytes"
"compress/gzip"
"github.com/dchest/siphash"
"github.com/iwind/TeaGo/logs"
"sync/atomic"
"unsafe"
)
type ItemType = int8
const (
ItemInt64 = 1
ItemBytes = 2
ItemInterface = 3
)
func HashKey(key []byte) uint64 {
return siphash.Hash(0, 0, key)
}
type Item struct {
Key []byte
ExpireAt int64
Type ItemType
ValueInt64 int64
ValueBytes []byte
ValueInterface interface{}
IsCompressed bool
// linked list
Prev *Item
Next *Item
size int64
}
func NewItem(key []byte, dataType ItemType) *Item {
return &Item{
Key: key,
Type: dataType,
}
}
func (this *Item) HashKey() uint64 {
return HashKey(this.Key)
}
func (this *Item) IncreaseInt64(delta int64) {
atomic.AddInt64(&this.ValueInt64, delta)
}
func (this *Item) Bytes() []byte {
if this.IsCompressed {
reader, err := gzip.NewReader(bytes.NewBuffer(this.ValueBytes))
if err != nil {
logs.Error(err)
return this.ValueBytes
}
buf := make([]byte, 256)
dataBuf := bytes.NewBuffer([]byte{})
for {
n, err := reader.Read(buf)
if n > 0 {
dataBuf.Write(buf[:n])
}
if err != nil {
break
}
}
return dataBuf.Bytes()
}
return this.ValueBytes
}
func (this *Item) String() string {
return string(this.Bytes())
}
func (this *Item) Size() int64 {
if this.size == 0 {
this.size = int64(int(unsafe.Sizeof(*this)) + len(this.Key) + len(this.ValueBytes))
}
return this.size
}

View File

@@ -0,0 +1,69 @@
package grids
import (
"crypto/md5"
"github.com/dchest/siphash"
"strconv"
"testing"
)
func TestItem_Size(t *testing.T) {
item := &Item{
ValueInt64: 1024,
Key: []byte("123"),
ValueBytes: []byte("Hello, World"),
}
t.Log(item.Size())
}
func BenchmarkItem_Size(b *testing.B) {
item := &Item{
ValueInt64: 1024,
Key: []byte("123"),
ValueBytes: []byte("Hello, World"),
}
for i := 0; i < b.N; i ++ {
_ = item.Size()
}
}
func TestItem_HashKey(t *testing.T) {
t.Log(HashKey([]byte("2")))
}
func TestItem_siphash(t *testing.T) {
result := siphash.Hash(0, 0, []byte("123456"))
t.Log(result)
}
func TestItem_unique(t *testing.T) {
m := map[uint64]bool{}
for i := 0; i < 1000*10000; i ++ {
s := "Hello,World,LONG KEY,LONG KEY,LONG KEY,LONG KEY" + strconv.Itoa(i)
result := siphash.Hash(0, 0, []byte(s))
_, ok := m[result]
if ok {
t.Log("found same", i)
break
} else {
m[result] = true
}
}
t.Log(siphash.Hash(0, 0, []byte("01")))
t.Log(siphash.Hash(0, 0, []byte("10")))
}
func BenchmarkItem_HashKeyMd5(b *testing.B) {
for i := 0; i < b.N; i ++ {
h := md5.New()
h.Write([]byte("HELLO_KEY_" + strconv.Itoa(i)))
_ = h.Sum(nil)
}
}
func BenchmarkItem_siphash(b *testing.B) {
for i := 0; i < b.N; i ++ {
_ = siphash.Hash(0, 0, []byte("HELLO_KEY_"+strconv.Itoa(i)))
}
}

68
internal/grids/list.go Normal file
View File

@@ -0,0 +1,68 @@
package grids
type List struct {
head *Item
end *Item
}
func NewList() *List {
return &List{}
}
func (this *List) Add(item *Item) {
if item == nil {
return
}
if this.end != nil {
this.end.Next = item
item.Prev = this.end
item.Next = nil
}
this.end = item
if this.head == nil {
this.head = item
}
}
func (this *List) Remove(item *Item) {
if item == nil {
return
}
if item.Prev != nil {
item.Prev.Next = item.Next
}
if item.Next != nil {
item.Next.Prev = item.Prev
}
if item == this.head {
this.head = item.Next
}
if item == this.end {
this.end = item.Prev
}
item.Prev = nil
item.Next = nil
}
func (this *List) Len() int {
l := 0
for e := this.head; e != nil; e = e.Next {
l ++
}
return l
}
func (this *List) Range(f func(item *Item) (goNext bool)) {
for e := this.head; e != nil; e = e.Next {
goNext := f(e)
if !goNext {
break
}
}
}
func (this *List) Reset() {
this.head = nil
this.end = nil
}

View File

@@ -0,0 +1,64 @@
package grids
import "testing"
func TestList(t *testing.T) {
l := &List{}
var e1 *Item = nil
{
e := &Item{
ValueInt64: 1,
}
l.Add(e)
e1 = e
}
var e2 *Item = nil
{
e := &Item{
ValueInt64: 2,
}
l.Add(e)
e2 = e
}
var e3 *Item = nil
{
e := &Item{
ValueInt64: 3,
}
l.Add(e)
e3 = e
}
var e4 *Item = nil
{
e := &Item{
ValueInt64: 4,
}
l.Add(e)
e4 = e
}
l.Remove(e1)
//l.Remove(e2)
//l.Remove(e3)
l.Remove(e4)
for e := l.head; e != nil; e = e.Next {
t.Log(e.ValueInt64)
}
t.Log("e1, e2, e3, e4, head, end:", e1, e2, e3, e4)
if l.head != nil {
t.Log("head:", l.head.ValueInt64)
} else {
t.Log("head: nil")
}
if l.end != nil {
t.Log("end:", l.end.ValueInt64)
} else {
t.Log("end: nil")
}
}

View File

@@ -0,0 +1,11 @@
package grids
type CompressOpt struct {
Level int
}
func NewCompressOpt(level int) *CompressOpt {
return &CompressOpt{
Level: level,
}
}

View File

@@ -0,0 +1,11 @@
package grids
type LimitCountOpt struct {
Count int
}
func NewLimitCountOpt(count int) *LimitCountOpt {
return &LimitCountOpt{
Count: count,
}
}

View File

@@ -0,0 +1,11 @@
package grids
type LimitSizeOpt struct {
Size int64
}
func NewLimitSizeOpt(size int64) *LimitSizeOpt {
return &LimitSizeOpt{
Size: size,
}
}

View File

@@ -0,0 +1,11 @@
package grids
type RecycleIntervalOpt struct {
Interval int
}
func NewRecycleIntervalOpt(interval int) *RecycleIntervalOpt {
return &RecycleIntervalOpt{
Interval: interval,
}
}

6
internal/grids/stat.go Normal file
View File

@@ -0,0 +1,6 @@
package grids
type Stat struct {
TotalBytes int64
CountItems int
}

View File

@@ -96,7 +96,11 @@ func (this *HTTPRequest) Do() {
}
// WAF
// TODO 需要实现
if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn && this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
if this.doWAFRequest() {
return
}
}
// 访问控制
// TODO 需要实现
@@ -253,6 +257,12 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
this.web.Cache = web.Cache
}
// waf
if web.FirewallRef != nil && (web.FirewallRef.IsPrior || isTop) {
this.web.FirewallRef = web.FirewallRef
this.web.FirewallPolicy = web.FirewallPolicy
}
// 重写规则
if len(web.RewriteRefs) > 0 {
for index, ref := range web.RewriteRefs {

View File

@@ -166,12 +166,26 @@ func (this *HTTPRequest) doReverseProxy() {
}
// WAF对出站进行检查
// TODO
if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn && this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
if this.doWAFResponse(resp) {
err = resp.Body.Close()
if err != nil {
logs.Error(err)
}
return
}
}
// TODO 清除源站错误次数
// 特殊页面
// TODO
if len(this.web.Pages) > 0 && this.doPage(resp.StatusCode) {
err = resp.Body.Close()
if err != nil {
logs.Error(err)
}
return
}
// 设置Charset
// TODO 这里应该可以设置文本类型的列表,以及是否强制覆盖所有文本类型的字符集

View File

@@ -0,0 +1,51 @@
package nodes
import (
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/logs"
"net/http"
)
// 调用WAF
func (this *HTTPRequest) doWAFRequest() (blocked bool) {
w := sharedWAFManager.FindWAF(this.web.FirewallPolicy.Id)
if w == nil {
return
}
goNext, _, ruleSet, err := w.MatchRequest(this.RawReq, this.writer)
if err != nil {
logs.Error(err)
return
}
if ruleSet != nil {
if ruleSet.Action != waf.ActionAllow {
// TODO 记录日志
}
}
return !goNext
}
// call response waf
func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
w := sharedWAFManager.FindWAF(this.web.FirewallPolicy.Id)
if w == nil {
return
}
goNext, _, ruleSet, err := w.MatchResponse(this.RawReq, resp, this.writer)
if err != nil {
logs.Error(err)
return
}
if ruleSet != nil {
if ruleSet.Action != waf.ActionAllow {
// TODO 记录日志
}
}
return !goNext
}

View File

@@ -104,6 +104,7 @@ func (this *Node) syncConfig(isFirstTime bool) error {
logs.Println("[NODE]reload config ...")
nodeconfigs.ResetNodeConfig(nodeConfig)
caches.SharedManager.UpdatePolicies(nodeConfig.AllCachePolicies())
sharedWAFManager.UpdatePolicies(nodeConfig.AllHTTPFirewallPolicies())
sharedNodeConfig = nodeConfig
if !isFirstTime {

View File

@@ -0,0 +1,175 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/errors"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/logs"
"strconv"
"sync"
)
var sharedWAFManager = NewWAFManager()
// WAF管理器
type WAFManager struct {
mapping map[int64]*waf.WAF // policyId => WAF
locker sync.RWMutex
}
// 获取新对象
func NewWAFManager() *WAFManager {
return &WAFManager{
mapping: map[int64]*waf.WAF{},
}
}
// 更新策略
func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallPolicy) {
this.locker.Lock()
defer this.locker.Unlock()
m := map[int64]*waf.WAF{}
for _, p := range policies {
w, err := this.convertWAF(p)
if err != nil {
logs.Println("[WAF]initialize policy '" + strconv.FormatInt(p.Id, 10) + "' failed: " + err.Error())
continue
}
if w == nil {
continue
}
m[p.Id] = w
}
this.mapping = m
}
// 查找WAF
func (this *WAFManager) FindWAF(policyId int64) *waf.WAF {
this.locker.RLock()
w, _ := this.mapping[policyId]
this.locker.RUnlock()
return w
}
// 判断是否包含int64
func (this *WAFManager) containsInt64(values []int64, value int64) bool {
for _, v := range values {
if v == value {
return true
}
}
return false
}
// 将Policy转换为WAF
func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (*waf.WAF, error) {
if policy == nil {
return nil, errors.New("policy should not be nil")
}
w := &waf.WAF{
Id: strconv.FormatInt(policy.Id, 10),
IsOn: policy.IsOn,
Name: policy.Name,
}
// inbound
if policy.Inbound != nil && policy.Inbound.IsOn {
for _, group := range policy.Inbound.Groups {
g := &waf.RuleGroup{
Id: strconv.FormatInt(group.Id, 10),
IsOn: group.IsOn,
Name: group.Name,
Description: group.Description,
Code: group.Code,
IsInbound: true,
}
// rule sets
for _, set := range group.Sets {
s := &waf.RuleSet{
Id: strconv.FormatInt(set.Id, 10),
Code: set.Code,
IsOn: set.IsOn,
Name: set.Name,
Description: set.Description,
Connector: set.Connector,
Action: set.Action,
ActionOptions: set.ActionOptions,
}
// rules
for _, rule := range set.Rules {
r := &waf.Rule{
Description: rule.Description,
Param: rule.Param,
Operator: rule.Operator,
Value: rule.Value,
IsCaseInsensitive: rule.IsCaseInsensitive,
CheckpointOptions: rule.CheckpointOptions,
}
s.Rules = append(s.Rules, r)
}
g.RuleSets = append(g.RuleSets, s)
}
w.Inbound = append(w.Inbound, g)
}
}
// outbound
if policy.Outbound != nil && policy.Outbound.IsOn {
for _, group := range policy.Outbound.Groups {
g := &waf.RuleGroup{
Id: strconv.FormatInt(group.Id, 10),
IsOn: group.IsOn,
Name: group.Name,
Description: group.Description,
Code: group.Code,
IsInbound: true,
}
// rule sets
for _, set := range group.Sets {
s := &waf.RuleSet{
Id: strconv.FormatInt(set.Id, 10),
Code: set.Code,
IsOn: set.IsOn,
Name: set.Name,
Description: set.Description,
Connector: set.Connector,
Action: set.Action,
ActionOptions: set.ActionOptions,
}
// rules
for _, rule := range set.Rules {
r := &waf.Rule{
Description: rule.Description,
Param: rule.Param,
Operator: rule.Operator,
Value: rule.Value,
IsCaseInsensitive: rule.IsCaseInsensitive,
CheckpointOptions: rule.CheckpointOptions,
}
s.Rules = append(s.Rules, r)
}
g.RuleSets = append(g.RuleSets, s)
}
w.Outbound = append(w.Outbound, g)
}
}
// action
// TODO
err := w.Init()
if err != nil {
return nil, err
}
return w, nil
}

View File

@@ -0,0 +1,44 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/iwind/TeaGo/logs"
"testing"
)
func TestWAFManager_convert(t *testing.T) {
p := &firewallconfigs.HTTPFirewallPolicy{
Id: 1,
IsOn: true,
Inbound: &firewallconfigs.HTTPFirewallInboundConfig{
IsOn: true,
Groups: []*firewallconfigs.HTTPFirewallRuleGroup{
{
Id: 1,
Sets: []*firewallconfigs.HTTPFirewallRuleSet{
{
Id: 1,
},
{
Id: 2,
Rules: []*firewallconfigs.HTTPFirewallRule{
{
Id: 1,
},
{
Id: 2,
},
},
},
},
},
},
},
}
w, err := sharedWAFManager.convertWAF(p)
if err != nil {
t.Fatal(err)
}
logs.PrintAsJSON(w, t)
}

77
internal/utils/get.go Normal file
View File

@@ -0,0 +1,77 @@
package utils
import (
"github.com/iwind/TeaGo/types"
"reflect"
"regexp"
)
var RegexpDigitNumber = regexp.MustCompile("^\\d+$")
func Get(object interface{}, keys []string) interface{} {
if len(keys) == 0 {
return object
}
if object == nil {
return nil
}
firstKey := keys[0]
keys = keys[1:]
value := reflect.ValueOf(object)
if !value.IsValid() {
return nil
}
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
if value.Kind() == reflect.Struct {
field := value.FieldByName(firstKey)
if !field.IsValid() {
return nil
}
if len(keys) == 0 {
return field.Interface()
}
return Get(field.Interface(), keys)
}
if value.Kind() == reflect.Map {
mapKey := reflect.ValueOf(firstKey)
mapValue := value.MapIndex(mapKey)
if !mapValue.IsValid() {
return nil
}
if len(keys) == 0 {
return mapValue.Interface()
}
return Get(mapValue.Interface(), keys)
}
if value.Kind() == reflect.Slice {
if RegexpDigitNumber.MatchString(firstKey) {
firstKeyInt := types.Int(firstKey)
if value.Len() > firstKeyInt {
result := value.Index(firstKeyInt).Interface()
if len(keys) == 0 {
return result
}
return Get(result, keys)
}
}
return nil
}
return nil
}

View File

@@ -0,0 +1,79 @@
package utils
import "testing"
func TestGetStruct(t *testing.T) {
object := struct {
Name string
Age int
Books []string
Extend struct {
Location struct {
City string
}
}
}{
Name: "lu",
Age: 20,
Books: []string{"Golang"},
Extend: struct {
Location struct {
City string
}
}{
Location: struct {
City string
}{
City: "Beijing",
},
},
}
if Get(object, []string{"Name"}) != "lu" {
t.Fatal("[ERROR]Name != lu")
}
if Get(object, []string{"Age"}) != 20 {
t.Fatal("[ERROR]Age != 20")
}
if Get(object, []string{"Books", "0"}) != "Golang" {
t.Fatal("[ERROR]books.0 != Golang")
}
t.Log("Extend.Location:", Get(object, []string{"Extend", "Location"}))
if Get(object, []string{"Extend", "Location", "City"}) != "Beijing" {
t.Fatal("[ERROR]Extend.Location.City != Beijing")
}
}
func TestGetMap(t *testing.T) {
object := map[string]interface{}{
"Name": "lu",
"Age": 20,
"Extend": map[string]interface{}{
"Location": map[string]interface{}{
"City": "Beijing",
},
},
}
if Get(object, []string{"Name"}) != "lu" {
t.Fatal("[ERROR]Name != lu")
}
if Get(object, []string{"Age"}) != 20 {
t.Fatal("[ERROR]Age != 20")
}
if Get(object, []string{"Books", "0"}) != nil {
t.Fatal("[ERROR]books.0 != nil")
}
t.Log(Get(object, []string{"Extend", "Location"}))
if Get(object, []string{"Extend", "Location", "City"}) != "Beijing" {
t.Fatal("[ERROR]Extend.Location.City != Beijing")
}
}

37
internal/utils/string.go Normal file
View File

@@ -0,0 +1,37 @@
package utils
import (
"strings"
"unsafe"
)
// convert bytes to string
func UnsafeBytesToString(bs []byte) string {
return *(*string)(unsafe.Pointer(&bs))
}
// convert string to bytes
func UnsafeStringToBytes(s string) []byte {
return *(*[]byte)(unsafe.Pointer(&s))
}
// format address
func FormatAddress(addr string) string {
if strings.HasSuffix(addr, "unix:") {
return addr
}
addr = strings.Replace(addr, " ", "", -1)
addr = strings.Replace(addr, "\t", "", -1)
addr = strings.Replace(addr, "", ":", -1)
addr = strings.TrimSpace(addr)
return addr
}
// format address list
func FormatAddressList(addrList []string) []string {
result := []string{}
for _, addr := range addrList {
result = append(result, FormatAddress(addr))
}
return result
}

View File

@@ -0,0 +1,56 @@
package utils
import (
"strings"
"testing"
)
func TestBytesToString(t *testing.T) {
t.Log(UnsafeBytesToString([]byte("Hello,World")))
}
func TestStringToBytes(t *testing.T) {
t.Log(string(UnsafeStringToBytes("Hello,World")))
}
func BenchmarkBytesToString(b *testing.B) {
data := []byte("Hello,World")
for i := 0; i < b.N; i++ {
_ = UnsafeBytesToString(data)
}
}
func BenchmarkBytesToString2(b *testing.B) {
data := []byte("Hello,World")
for i := 0; i < b.N; i++ {
_ = string(data)
}
}
func BenchmarkStringToBytes(b *testing.B) {
s := strings.Repeat("Hello,World", 1024)
for i := 0; i < b.N; i++ {
_ = UnsafeStringToBytes(s)
}
}
func BenchmarkStringToBytes2(b *testing.B) {
s := strings.Repeat("Hello,World", 1024)
for i := 0; i < b.N; i++ {
_ = []byte(s)
}
}
func TestFormatAddress(t *testing.T) {
t.Log(FormatAddress("127.0.0.1:1234"))
t.Log(FormatAddress("127.0.0.1 : 1234"))
t.Log(FormatAddress("127.0.0.11234"))
}
func TestFormatAddressList(t *testing.T) {
t.Log(FormatAddressList([]string{
"127.0.0.1:1234",
"127.0.0.1 : 1234",
"127.0.0.11234",
}))
}

48
internal/waf/README.md Normal file
View File

@@ -0,0 +1,48 @@
# WAF
A basic WAF for TeaWeb.
## Config Constructions
~~~
WAF
Inbound
Rule Groups
Rule Sets
Rules
Checkpoint Param <Operator> Value
Outbound
Rule Groups
...
~~~
## Apply WAF
~~~
Request --> WAF --> Backends
/
Response <-- WAF <----
~~~
## Coding
~~~go
waf := teawaf.NewWAF()
// add rule groups here
err := waf.Init()
if err != nil {
return
}
waf.Start()
// match http request
// (req *http.Request, responseWriter http.ResponseWriter)
goNext, ruleSet, _ := waf.MatchRequest(req, responseWriter)
if ruleSet != nil {
log.Println("meet rule set:", ruleSet.Name, "action:", ruleSet.Action)
}
if !goNext {
return
}
// stop the waf
// waf.Stop()
~~~

View File

@@ -0,0 +1,14 @@
package waf
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
)
type AllowAction struct {
}
func (this *AllowAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
// do nothing
return true
}

View File

@@ -0,0 +1,94 @@
package waf
import (
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/logs"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"regexp"
"time"
)
// url client configure
var urlPrefixReg = regexp.MustCompile("^(?i)(http|https)://")
var httpClient = utils.SharedHttpClient(5 * time.Second)
type BlockAction struct {
StatusCode int `yaml:"statusCode" json:"statusCode"`
Body string `yaml:"body" json:"body"` // supports HTML
URL string `yaml:"url" json:"url"`
}
func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
if writer != nil {
// if status code eq 444, we close the connection
if this.StatusCode == 444 {
hijack, ok := writer.(http.Hijacker)
if ok {
conn, _, _ := hijack.Hijack()
if conn != nil {
_ = conn.Close()
return
}
}
}
// output response
if this.StatusCode > 0 {
writer.WriteHeader(this.StatusCode)
} else {
writer.WriteHeader(http.StatusForbidden)
}
if len(this.URL) > 0 {
if urlPrefixReg.MatchString(this.URL) {
req, err := http.NewRequest(http.MethodGet, this.URL, nil)
if err != nil {
logs.Error(err)
return false
}
resp, err := httpClient.Do(req)
if err != nil {
logs.Error(err)
return false
}
defer func() {
_ = resp.Body.Close()
}()
for k, v := range resp.Header {
for _, v1 := range v {
writer.Header().Add(k, v1)
}
}
buf := make([]byte, 1024)
_, _ = io.CopyBuffer(writer, resp.Body, buf)
} else {
path := this.URL
if !filepath.IsAbs(this.URL) {
path = Tea.Root + string(os.PathSeparator) + path
}
data, err := ioutil.ReadFile(path)
if err != nil {
logs.Error(err)
return false
}
_, _ = writer.Write(data)
}
return false
}
if len(this.Body) > 0 {
_, _ = writer.Write([]byte(this.Body))
} else {
_, _ = writer.Write([]byte("The request is blocked by " + teaconst.ProductName))
}
}
return false
}

View File

@@ -0,0 +1,39 @@
package waf
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"net/http"
"net/url"
"time"
)
var captchaSalt = stringutil.Rand(32)
const (
CaptchaSeconds = 600 // 10 minutes
)
type CaptchaAction struct {
}
func (this *CaptchaAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
// TEAWEB_CAPTCHA:
cookie, err := request.Cookie("TEAWEB_WAF_CAPTCHA")
if err == nil && cookie != nil && len(cookie.Value) > 32 {
m := cookie.Value[:32]
timestamp := cookie.Value[32:]
if stringutil.Md5(captchaSalt+timestamp) == m && time.Now().Unix() < types.Int64(timestamp) { // verify md5
return true
}
}
refURL := request.URL.String()
if len(request.Referer()) > 0 {
refURL = request.Referer()
}
http.Redirect(writer, request.Raw(), "/WAFCAPTCHA?url="+url.QueryEscape(refURL), http.StatusTemporaryRedirect)
return false
}

View File

@@ -0,0 +1,12 @@
package waf
import "reflect"
// action definition
type ActionDefinition struct {
Name string
Code ActionString
Description string
Instance ActionInterface
Type reflect.Type
}

View File

@@ -0,0 +1,34 @@
package waf
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/logs"
"net/http"
)
type GoGroupAction struct {
GroupId string `yaml:"groupId" json:"groupId"`
}
func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
group := waf.FindRuleGroup(this.GroupId)
if group == nil || !group.IsOn {
return true
}
b, set, err := group.MatchRequest(request)
if err != nil {
logs.Error(err)
return true
}
if !b {
return true
}
actionObject := FindActionInstance(set.Action, set.ActionOptions)
if actionObject == nil {
return true
}
return actionObject.Perform(waf, request, writer)
}

View File

@@ -0,0 +1,37 @@
package waf
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/logs"
"net/http"
)
type GoSetAction struct {
GroupId string `yaml:"groupId" json:"groupId"`
SetId string `yaml:"setId" json:"setId"`
}
func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
group := waf.FindRuleGroup(this.GroupId)
if group == nil || !group.IsOn {
return true
}
set := group.FindRuleSet(this.SetId)
if set == nil || !set.IsOn {
return true
}
b, err := set.MatchRequest(request)
if err != nil {
logs.Error(err)
return true
}
if !b {
return true
}
actionObject := FindActionInstance(set.Action, set.ActionOptions)
if actionObject == nil {
return true
}
return actionObject.Perform(waf, request, writer)
}

View File

@@ -0,0 +1,5 @@
package waf
type Action struct {
}

View File

@@ -0,0 +1,13 @@
package waf
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
)
type LogAction struct {
}
func (this *LogAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) {
return true
}

View File

@@ -0,0 +1,21 @@
package waf
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
)
type ActionString = string
const (
ActionLog = "log" // allow and log
ActionBlock = "block" // block
ActionCaptcha = "captcha" // block and show captcha
ActionAllow = "allow" // allow
ActionGoGroup = "go_group" // go to next rule group
ActionGoSet = "go_set" // go to next rule set
)
type ActionInterface interface {
Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool)
}

View File

@@ -0,0 +1,82 @@
package waf
import (
"github.com/iwind/TeaGo/maps"
"reflect"
)
var AllActions = []*ActionDefinition{
{
Name: "阻止",
Code: ActionBlock,
Instance: new(BlockAction),
},
{
Name: "允许通过",
Code: ActionAllow,
Instance: new(AllowAction),
},
{
Name: "允许并记录日志",
Code: ActionLog,
Instance: new(LogAction),
},
{
Name: "Captcha验证码",
Code: ActionCaptcha,
Instance: new(CaptchaAction),
},
{
Name: "跳到下一个规则分组",
Code: ActionGoGroup,
Instance: new(GoGroupAction),
Type: reflect.TypeOf(new(GoGroupAction)).Elem(),
},
{
Name: "跳到下一个规则集",
Code: ActionGoSet,
Instance: new(GoSetAction),
Type: reflect.TypeOf(new(GoSetAction)).Elem(),
},
}
func FindActionInstance(action ActionString, options maps.Map) ActionInterface {
for _, def := range AllActions {
if def.Code == action {
if def.Type != nil {
// create new instance
ptrValue := reflect.New(def.Type)
instance := ptrValue.Interface().(ActionInterface)
if len(options) > 0 {
count := def.Type.NumField()
for i := 0; i < count; i++ {
field := def.Type.Field(i)
tag, ok := field.Tag.Lookup("yaml")
if ok {
v, ok := options[tag]
if ok && reflect.TypeOf(v) == field.Type {
ptrValue.Elem().FieldByName(field.Name).Set(reflect.ValueOf(v))
}
}
}
}
return instance
}
// return shared instance
return def.Instance
}
}
return nil
}
func FindActionName(action ActionString) string {
for _, def := range AllActions {
if def.Code == action {
return def.Name
}
}
return ""
}

View File

@@ -0,0 +1,29 @@
package waf
import (
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/maps"
"runtime"
"testing"
)
func TestFindActionInstance(t *testing.T) {
a := assert.NewAssertion(t)
t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil))
t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil))
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil))
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil))
t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b",}))
a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil))
}
func BenchmarkFindActionInstance(b *testing.B) {
runtime.GOMAXPROCS(1)
for i := 0; i < b.N; i++ {
FindActionInstance(ActionGoSet, nil)
}
}

View File

@@ -0,0 +1,84 @@
package waf
import (
"bytes"
"encoding/base64"
"fmt"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/dchest/captcha"
"github.com/iwind/TeaGo/logs"
stringutil "github.com/iwind/TeaGo/utils/string"
"net/http"
"time"
)
var captchaValidator = &CaptchaValidator{}
type CaptchaValidator struct {
}
func (this *CaptchaValidator) Run(request *requests.Request, writer http.ResponseWriter) {
if request.Method == http.MethodPost && len(request.FormValue("TEAWEB_WAF_CAPTCHA_ID")) > 0 {
this.validate(request, writer)
} else {
this.show(request, writer)
}
}
func (this *CaptchaValidator) show(request *requests.Request, writer http.ResponseWriter) {
// show captcha
captchaId := captcha.NewLen(6)
buf := bytes.NewBuffer([]byte{})
err := captcha.WriteImage(buf, captchaId, 200, 100)
if err != nil {
logs.Error(err)
return
}
_, _ = writer.Write([]byte(`<!DOCTYPE html>
<html>
<head>
<title>Verify Yourself</title>
</head>
<body>
<form method="POST">
<input type="hidden" name="TEAWEB_WAF_CAPTCHA_ID" value="` + captchaId + `"/>
<img src="data:image/png;base64, ` + base64.StdEncoding.EncodeToString(buf.Bytes()) + `"/>` + `
<div>
<p>Input verify code above:</p>
<input type="text" name="TEAWEB_WAF_CAPTCHA_CODE" maxlength="6" size="18" autocomplete="off" z-index="1" style="font-size:16px;line-height:24px; letter-spacing: 15px; padding-left: 4px"/>
</div>
<div>
<button type="submit" onclick="window.location = '/webhook'" style="line-height:24px;margin-top:10px">Verify Yourself</button>
</div>
</form>
</body>
</html>`))
}
func (this *CaptchaValidator) validate(request *requests.Request, writer http.ResponseWriter) (allow bool) {
captchaId := request.FormValue("TEAWEB_WAF_CAPTCHA_ID")
if len(captchaId) > 0 {
captchaCode := request.FormValue("TEAWEB_WAF_CAPTCHA_CODE")
if captcha.VerifyString(captchaId, captchaCode) {
// set cookie
timestamp := fmt.Sprintf("%d", time.Now().Unix()+CaptchaSeconds)
m := stringutil.Md5(captchaSalt + timestamp)
http.SetCookie(writer, &http.Cookie{
Name: "TEAWEB_WAF_CAPTCHA",
Value: m + timestamp,
MaxAge: CaptchaSeconds,
Path: "/", // all of dirs
})
rawURL := request.URL.Query().Get("url")
http.Redirect(writer, request.Raw(), rawURL, http.StatusSeeOther)
return false
} else {
http.Redirect(writer, request.Raw(), request.URL.String(), http.StatusSeeOther)
}
}
return true
}

View File

@@ -0,0 +1,246 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/grids"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"net"
"regexp"
"strings"
"sync"
)
// ${cc.arg}
// TODO implement more traffic rules
type CCCheckpoint struct {
Checkpoint
grid *grids.Grid
once sync.Once
}
func (this *CCCheckpoint) Init() {
}
func (this *CCCheckpoint) Start() {
if this.grid != nil {
this.grid.Destroy()
}
this.grid = grids.NewGrid(32, grids.NewLimitCountOpt(1000_0000))
}
func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = 0
if this.grid == nil {
this.once.Do(func() {
this.Start()
})
if this.grid == nil {
return
}
}
periodString, ok := options["period"]
if !ok {
return
}
period := types.Int64(periodString)
if period < 1 {
return
}
v, _ := options["userType"]
userType := types.String(v)
v, _ = options["userField"]
userField := types.String(v)
v, _ = options["userIndex"]
userIndex := types.Int(v)
if param == "requests" { // requests
var key = ""
switch userType {
case "ip":
key = this.ip(req)
case "cookie":
if len(userField) == 0 {
key = this.ip(req)
} else {
cookie, _ := req.Cookie(userField)
if cookie != nil {
v := cookie.Value
if userIndex > 0 && len(v) > userIndex {
v = v[userIndex:]
}
key = "USER@" + userType + "@" + userField + "@" + v
}
}
case "get":
if len(userField) == 0 {
key = this.ip(req)
} else {
v := req.URL.Query().Get(userField)
if userIndex > 0 && len(v) > userIndex {
v = v[userIndex:]
}
key = "USER@" + userType + "@" + userField + "@" + v
}
case "post":
if len(userField) == 0 {
key = this.ip(req)
} else {
v := req.PostFormValue(userField)
if userIndex > 0 && len(v) > userIndex {
v = v[userIndex:]
}
key = "USER@" + userType + "@" + userField + "@" + v
}
case "header":
if len(userField) == 0 {
key = this.ip(req)
} else {
v := req.Header.Get(userField)
if userIndex > 0 && len(v) > userIndex {
v = v[userIndex:]
}
key = "USER@" + userType + "@" + userField + "@" + v
}
default:
key = this.ip(req)
}
if len(key) == 0 {
key = this.ip(req)
}
value = this.grid.IncreaseInt64([]byte(key), 1, period)
}
return
}
func (this *CCCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}
func (this *CCCheckpoint) ParamOptions() *ParamOptions {
option := NewParamOptions()
option.AddParam("请求数", "requests")
return option
}
func (this *CCCheckpoint) Options() []OptionInterface {
options := []OptionInterface{}
// period
{
option := NewFieldOption("统计周期", "period")
option.Value = "60"
option.RightLabel = "秒"
option.Size = 8
option.MaxLength = 8
option.Validate = func(value string) (ok bool, message string) {
if regexp.MustCompile("^\\d+$").MatchString(value) {
ok = true
return
}
message = "周期需要是一个整数数字"
return
}
options = append(options, option)
}
// type
{
option := NewOptionsOption("用户识别读取来源", "userType")
option.Size = 10
option.SetOptions([]maps.Map{
{
"name": "IP",
"value": "ip",
},
{
"name": "Cookie",
"value": "cookie",
},
{
"name": "URL参数",
"value": "get",
},
{
"name": "POST参数",
"value": "post",
},
{
"name": "HTTP Header",
"value": "header",
},
})
options = append(options, option)
}
// user field
{
option := NewFieldOption("用户识别字段", "userField")
option.Comment = "识别用户的唯一性字段在用户读取来源不是IP时使用"
options = append(options, option)
}
// user value index
{
option := NewFieldOption("字段读取位置", "userIndex")
option.Size = 5
option.MaxLength = 5
option.Comment = "读取用户识别字段的位置从0开始比如user12345的数字ID 12345的位置就是5在用户读取来源不是IP时使用"
options = append(options, option)
}
return options
}
func (this *CCCheckpoint) Stop() {
if this.grid != nil {
this.grid.Destroy()
this.grid = nil
}
}
func (this *CCCheckpoint) ip(req *requests.Request) string {
// X-Forwarded-For
forwardedFor := req.Header.Get("X-Forwarded-For")
if len(forwardedFor) > 0 {
commaIndex := strings.Index(forwardedFor, ",")
if commaIndex > 0 {
return forwardedFor[:commaIndex]
}
return forwardedFor
}
// Real-IP
{
realIP, ok := req.Header["X-Real-IP"]
if ok && len(realIP) > 0 {
return realIP[0]
}
}
// Real-Ip
{
realIP, ok := req.Header["X-Real-Ip"]
if ok && len(realIP) > 0 {
return realIP[0]
}
}
// Remote-Addr
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
return host
}
return req.RemoteAddr
}

View File

@@ -0,0 +1,42 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
"testing"
)
func TestCCCheckpoint_RequestValue(t *testing.T) {
raw, err := http.NewRequest(http.MethodGet, "http://teaos.cn/", nil)
if err != nil {
t.Fatal(err)
}
req := requests.NewRequest(raw)
req.RemoteAddr = "127.0.0.1"
checkpoint := new(CCCheckpoint)
checkpoint.Init()
checkpoint.Start()
options := map[string]string{
"period": "5",
}
t.Log(checkpoint.RequestValue(req, "requests", options))
t.Log(checkpoint.RequestValue(req, "requests", options))
req.RemoteAddr = "127.0.0.2"
t.Log(checkpoint.RequestValue(req, "requests", options))
req.RemoteAddr = "127.0.0.1"
t.Log(checkpoint.RequestValue(req, "requests", options))
req.RemoteAddr = "127.0.0.2"
t.Log(checkpoint.RequestValue(req, "requests", options))
req.RemoteAddr = "127.0.0.2"
t.Log(checkpoint.RequestValue(req, "requests", options))
req.RemoteAddr = "127.0.0.2"
t.Log(checkpoint.RequestValue(req, "requests", options))
}

View File

@@ -0,0 +1,28 @@
package checkpoints
type Checkpoint struct {
}
func (this *Checkpoint) Init() {
}
func (this *Checkpoint) IsRequest() bool {
return true
}
func (this *Checkpoint) ParamOptions() *ParamOptions {
return nil
}
func (this *Checkpoint) Options() []OptionInterface {
return nil
}
func (this *Checkpoint) Start() {
}
func (this *Checkpoint) Stop() {
}

View File

@@ -0,0 +1,10 @@
package checkpoints
// check point definition
type CheckpointDefinition struct {
Name string
Description string
Prefix string
HasParams bool // has sub params
Instance CheckpointInterface
}

View File

@@ -0,0 +1,32 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
// Check Point
type CheckpointInterface interface {
// initialize
Init()
// is request?
IsRequest() bool
// get request value
RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error)
// get response value
ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error)
// param option list
ParamOptions() *ParamOptions
// options
Options() []OptionInterface
// start
Start()
// stop
Stop()
}

View File

@@ -0,0 +1,5 @@
package checkpoints
type OptionInterface interface {
Type() string
}

View File

@@ -0,0 +1,26 @@
package checkpoints
// attach option
type FieldOption struct {
Name string
Code string
Value string // default value
IsRequired bool
Size int
Comment string
Placeholder string
RightLabel string
MaxLength int
Validate func(value string) (ok bool, message string)
}
func NewFieldOption(name string, code string) *FieldOption {
return &FieldOption{
Name: name,
Code: code,
}
}
func (this *FieldOption) Type() string {
return "field"
}

View File

@@ -0,0 +1,30 @@
package checkpoints
import "github.com/iwind/TeaGo/maps"
type OptionsOption struct {
Name string
Code string
Value string // default value
IsRequired bool
Size int
Comment string
RightLabel string
Validate func(value string) (ok bool, message string)
Options []maps.Map
}
func NewOptionsOption(name string, code string) *OptionsOption {
return &OptionsOption{
Name: name,
Code: code,
}
}
func (this *OptionsOption) Type() string {
return "options"
}
func (this *OptionsOption) SetOptions(options []maps.Map) {
this.Options = options
}

View File

@@ -0,0 +1,21 @@
package checkpoints
type KeyValue struct {
Name string `json:"name"`
Value string `json:"value"`
}
type ParamOptions struct {
Options []*KeyValue `json:"options"`
}
func NewParamOptions() *ParamOptions {
return &ParamOptions{}
}
func (this *ParamOptions) AddParam(name string, value string) {
this.Options = append(this.Options, &KeyValue{
Name: name,
Value: value,
})
}

View File

@@ -0,0 +1,46 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
// ${requestAll}
type RequestAllCheckpoint struct {
Checkpoint
}
func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
valueBytes := []byte{}
if len(req.RequestURI) > 0 {
valueBytes = append(valueBytes, req.RequestURI...)
} else if req.URL != nil {
valueBytes = append(valueBytes, req.URL.RequestURI()...)
}
if req.Body != nil {
valueBytes = append(valueBytes, ' ')
if len(req.BodyData) == 0 {
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
if err != nil {
return "", err, nil
}
req.BodyData = data
req.RestoreBody(data)
}
valueBytes = append(valueBytes, req.BodyData...)
}
value = valueBytes
return
}
func (this *RequestAllCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = ""
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,70 @@
package checkpoints
import (
"bytes"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/types"
"io/ioutil"
"net/http"
"runtime"
"strings"
"testing"
)
func TestRequestAllCheckpoint_RequestValue(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "http://teaos.cn/hello/world", bytes.NewBuffer([]byte("123456")))
if err != nil {
t.Fatal(err)
}
checkpoint := new(RequestAllCheckpoint)
v, sysErr, userErr := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
if sysErr != nil {
t.Fatal(sysErr)
}
if userErr != nil {
t.Fatal(userErr)
}
t.Log(v)
t.Log(types.String(v))
body, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Fatal(err)
}
t.Log(string(body))
}
func TestRequestAllCheckpoint_RequestValue_Max(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(strings.Repeat("123456", 10240000))))
if err != nil {
t.Fatal(err)
}
checkpoint := new(RequestBodyCheckpoint)
value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
if err != nil {
t.Fatal(err)
}
t.Log("value bytes:", len(types.String(value)))
body, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Fatal(err)
}
t.Log("raw bytes:", len(body))
}
func BenchmarkRequestAllCheckpoint_RequestValue(b *testing.B) {
runtime.GOMAXPROCS(1)
req, err := http.NewRequest(http.MethodPost, "http://teaos.cn/hello/world", bytes.NewBuffer(bytes.Repeat([]byte("HELLO"), 1024)))
if err != nil {
b.Fatal(err)
}
checkpoint := new(RequestAllCheckpoint)
for i := 0; i < b.N; i++ {
_, _, _ = checkpoint.RequestValue(requests.NewRequest(req), "", nil)
}
}

View File

@@ -0,0 +1,20 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestArgCheckpoint struct {
Checkpoint
}
func (this *RequestArgCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
return req.URL.Query().Get(param), nil, nil
}
func (this *RequestArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,21 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
"testing"
)
func TestArgParam_RequestValue(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/?name=lu", nil)
if err != nil {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
checkpoint := new(RequestArgCheckpoint)
t.Log(checkpoint.RequestValue(req, "name", nil))
t.Log(checkpoint.ResponseValue(req, nil, "name", nil))
t.Log(checkpoint.RequestValue(req, "name2", nil))
}

View File

@@ -0,0 +1,21 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestArgsCheckpoint struct {
Checkpoint
}
func (this *RequestArgsCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = req.URL.RawQuery
return
}
func (this *RequestArgsCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,36 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
// ${requestBody}
type RequestBodyCheckpoint struct {
Checkpoint
}
func (this *RequestBodyCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if req.Body == nil {
value = ""
return
}
if len(req.BodyData) == 0 {
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
if err != nil {
return "", err, nil
}
req.BodyData = data
req.RestoreBody(data)
}
return req.BodyData, nil, nil
}
func (this *RequestBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,47 @@
package checkpoints
import (
"bytes"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/types"
"io/ioutil"
"net/http"
"strings"
"testing"
)
func TestRequestBodyCheckpoint_RequestValue(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456")))
if err != nil {
t.Fatal(err)
}
checkpoint := new(RequestBodyCheckpoint)
t.Log(checkpoint.RequestValue(requests.NewRequest(req), "", nil))
body, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Fatal(err)
}
t.Log(string(body))
}
func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(strings.Repeat("123456", 10240000))))
if err != nil {
t.Fatal(err)
}
checkpoint := new(RequestBodyCheckpoint)
value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil)
if err != nil {
t.Fatal(err)
}
t.Log("value bytes:", len(types.String(value)))
body, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Fatal(err)
}
t.Log("raw bytes:", len(body))
}

View File

@@ -0,0 +1,21 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestContentTypeCheckpoint struct {
Checkpoint
}
func (this *RequestContentTypeCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = req.Header.Get("Content-Type")
return
}
func (this *RequestContentTypeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,27 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestCookieCheckpoint struct {
Checkpoint
}
func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
cookie, err := req.Cookie(param)
if err != nil {
value = ""
return
}
value = cookie.Value
return
}
func (this *RequestCookieCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,27 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/url"
"strings"
)
type RequestCookiesCheckpoint struct {
Checkpoint
}
func (this *RequestCookiesCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
var cookies = []string{}
for _, cookie := range req.Cookies() {
cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value))
}
value = strings.Join(cookies, "&")
return
}
func (this *RequestCookiesCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,39 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/url"
)
// ${requestForm.arg}
type RequestFormArgCheckpoint struct {
Checkpoint
}
func (this *RequestFormArgCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if req.Body == nil {
value = ""
return
}
if len(req.BodyData) == 0 {
data, err := req.ReadBody(32 * 1024 * 1024) // read 32m bytes
if err != nil {
return "", err, nil
}
req.BodyData = data
req.RestoreBody(data)
}
// TODO improve performance
values, _ := url.ParseQuery(string(req.BodyData))
return values.Get(param), nil, nil
}
func (this *RequestFormArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,32 @@
package checkpoints
import (
"bytes"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"io/ioutil"
"net/http"
"net/url"
"testing"
)
func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("name=lu&age=20&encoded="+url.QueryEscape("<strong>ENCODED STRING</strong>"))))
if err != nil {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
checkpoint := new(RequestFormArgCheckpoint)
t.Log(checkpoint.RequestValue(req, "name", nil))
t.Log(checkpoint.RequestValue(req, "age", nil))
t.Log(checkpoint.RequestValue(req, "Hello", nil))
t.Log(checkpoint.RequestValue(req, "encoded", nil))
body, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Fatal(err)
}
t.Log(string(body))
}

View File

@@ -0,0 +1,27 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"strings"
)
type RequestHeaderCheckpoint struct {
Checkpoint
}
func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
v, found := req.Header[param]
if !found {
value = ""
return
}
value = strings.Join(v, ";")
return
}
func (this *RequestHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,30 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"sort"
"strings"
)
type RequestHeadersCheckpoint struct {
Checkpoint
}
func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
var headers = []string{}
for k, v := range req.Header {
for _, subV := range v {
headers = append(headers, k+": "+subV)
}
}
sort.Strings(headers)
value = strings.Join(headers, "\n")
return
}
func (this *RequestHeadersCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,21 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestHostCheckpoint struct {
Checkpoint
}
func (this *RequestHostCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = req.Host
return
}
func (this *RequestHostCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,20 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
"testing"
)
func TestRequestHostCheckpoint_RequestValue(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodGet, "https://teaos.cn/?name=lu", nil)
if err != nil {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
req.Header.Set("Host", "cloud.teaos.cn")
checkpoint := new(RequestHostCheckpoint)
t.Log(checkpoint.RequestValue(req, "", nil))
}

View File

@@ -0,0 +1,44 @@
package checkpoints
import (
"encoding/json"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"strings"
)
// ${requestJSON.arg}
type RequestJSONArgCheckpoint struct {
Checkpoint
}
func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if len(req.BodyData) == 0 {
data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes
if err != nil {
return "", err, nil
}
req.BodyData = data
defer req.RestoreBody(data)
}
// TODO improve performance
var m interface{} = nil
err := json.Unmarshal(req.BodyData, &m)
if err != nil || m == nil {
return "", nil, err
}
value = utils.Get(m, strings.Split(param, "."))
if value != nil {
return value, nil, err
}
return "", nil, nil
}
func (this *RequestJSONArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,99 @@
package checkpoints
import (
"bytes"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"io/ioutil"
"net/http"
"testing"
)
func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(`
{
"name": "lu",
"age": 20,
"books": [ "PHP", "Golang", "Python" ]
}
`)))
if err != nil {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
checkpoint := new(RequestJSONArgCheckpoint)
t.Log(checkpoint.RequestValue(req, "name", nil))
t.Log(checkpoint.RequestValue(req, "age", nil))
t.Log(checkpoint.RequestValue(req, "Hello", nil))
t.Log(checkpoint.RequestValue(req, "", nil))
t.Log(checkpoint.RequestValue(req, "books", nil))
t.Log(checkpoint.RequestValue(req, "books.1", nil))
body, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Fatal(err)
}
t.Log(string(body))
}
func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(`
[{
"name": "lu",
"age": 20,
"books": [ "PHP", "Golang", "Python" ]
}]
`)))
if err != nil {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
checkpoint := new(RequestJSONArgCheckpoint)
t.Log(checkpoint.RequestValue(req, "0.name", nil))
t.Log(checkpoint.RequestValue(req, "0.age", nil))
t.Log(checkpoint.RequestValue(req, "0.Hello", nil))
t.Log(checkpoint.RequestValue(req, "", nil))
t.Log(checkpoint.RequestValue(req, "0.books", nil))
t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
body, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Fatal(err)
}
t.Log(string(body))
}
func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(`
[{
"name": "lu",
"age": 20,
"books": [ "PHP", "Golang", "Python" ]
}]
`)))
if err != nil {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
//req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
checkpoint := new(RequestJSONArgCheckpoint)
t.Log(checkpoint.RequestValue(req, "0.name", nil))
t.Log(checkpoint.RequestValue(req, "0.age", nil))
t.Log(checkpoint.RequestValue(req, "0.Hello", nil))
t.Log(checkpoint.RequestValue(req, "", nil))
t.Log(checkpoint.RequestValue(req, "0.books", nil))
t.Log(checkpoint.RequestValue(req, "0.books.1", nil))
body, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Fatal(err)
}
t.Log(string(body))
}

View File

@@ -0,0 +1,21 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestLengthCheckpoint struct {
Checkpoint
}
func (this *RequestLengthCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = req.ContentLength
return
}
func (this *RequestLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,21 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestMethodCheckpoint struct {
Checkpoint
}
func (this *RequestMethodCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = req.Method
return
}
func (this *RequestMethodCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,20 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestPathCheckpoint struct {
Checkpoint
}
func (this *RequestPathCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
return req.URL.Path, nil, nil
}
func (this *RequestPathCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,18 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
"testing"
)
func TestRequestPathCheckpoint_RequestValue(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/index?name=lu", nil)
if err != nil {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
checkpoint := new(RequestPathCheckpoint)
t.Log(checkpoint.RequestValue(req, "", nil))
}

View File

@@ -0,0 +1,21 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestProtoCheckpoint struct {
Checkpoint
}
func (this *RequestProtoCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = req.Proto
return
}
func (this *RequestProtoCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,27 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net"
)
type RequestRawRemoteAddrCheckpoint struct {
Checkpoint
}
func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
value = host
} else {
value = req.RemoteAddr
}
return
}
func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,21 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestRefererCheckpoint struct {
Checkpoint
}
func (this *RequestRefererCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = req.Referer()
return
}
func (this *RequestRefererCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,59 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net"
"strings"
)
type RequestRemoteAddrCheckpoint struct {
Checkpoint
}
func (this *RequestRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
// X-Forwarded-For
forwardedFor := req.Header.Get("X-Forwarded-For")
if len(forwardedFor) > 0 {
commaIndex := strings.Index(forwardedFor, ",")
if commaIndex > 0 {
value = forwardedFor[:commaIndex]
return
}
value = forwardedFor
return
}
// Real-IP
{
realIP, ok := req.Header["X-Real-IP"]
if ok && len(realIP) > 0 {
value = realIP[0]
return
}
}
// Real-Ip
{
realIP, ok := req.Header["X-Real-Ip"]
if ok && len(realIP) > 0 {
value = realIP[0]
return
}
}
// Remote-Addr
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
value = host
} else {
value = req.RemoteAddr
}
return
}
func (this *RequestRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,28 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/types"
"net"
)
type RequestRemotePortCheckpoint struct {
Checkpoint
}
func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
_, port, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
value = types.Int(port)
} else {
value = 0
}
return
}
func (this *RequestRemotePortCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,26 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestRemoteUserCheckpoint struct {
Checkpoint
}
func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
username, _, ok := req.BasicAuth()
if !ok {
value = ""
return
}
value = username
return
}
func (this *RequestRemoteUserCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,21 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestSchemeCheckpoint struct {
Checkpoint
}
func (this *RequestSchemeCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = req.URL.Scheme
return
}
func (this *RequestSchemeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,18 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
"testing"
)
func TestRequestSchemeCheckpoint_RequestValue(t *testing.T) {
rawReq, err := http.NewRequest(http.MethodGet, "https://teaos.cn/?name=lu", nil)
if err != nil {
t.Fatal(err)
}
req := requests.NewRequest(rawReq)
checkpoint := new(RequestSchemeCheckpoint)
t.Log(checkpoint.RequestValue(req, "", nil))
}

View File

@@ -0,0 +1,130 @@
package checkpoints
import (
"bytes"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/lists"
"io/ioutil"
"net/http"
"path/filepath"
"strings"
)
// ${requestUpload.arg}
type RequestUploadCheckpoint struct {
Checkpoint
}
func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = ""
if param == "minSize" || param == "maxSize" {
value = 0
}
if req.Method != http.MethodPost {
return
}
if req.Body == nil {
return
}
if req.MultipartForm == nil {
if len(req.BodyData) == 0 {
data, err := req.ReadBody(32 * 1024 * 1024)
if err != nil {
sysErr = err
return
}
req.BodyData = data
defer req.RestoreBody(data)
}
oldBody := req.Body
req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyData))
err := req.ParseMultipartForm(32 * 1024 * 1024)
// 还原
req.Body = oldBody
if err != nil {
userErr = err
return
}
if req.MultipartForm == nil {
return
}
}
if param == "field" { // field
fields := []string{}
for field := range req.MultipartForm.File {
fields = append(fields, field)
}
value = strings.Join(fields, ",")
} else if param == "minSize" { // minSize
minSize := int64(0)
for _, files := range req.MultipartForm.File {
for _, file := range files {
if minSize == 0 || minSize > file.Size {
minSize = file.Size
}
}
}
value = minSize
} else if param == "maxSize" { // maxSize
maxSize := int64(0)
for _, files := range req.MultipartForm.File {
for _, file := range files {
if maxSize < file.Size {
maxSize = file.Size
}
}
}
value = maxSize
} else if param == "name" { // name
names := []string{}
for _, files := range req.MultipartForm.File {
for _, file := range files {
if !lists.ContainsString(names, file.Filename) {
names = append(names, file.Filename)
}
}
}
value = strings.Join(names, ",")
} else if param == "ext" { // ext
extensions := []string{}
for _, files := range req.MultipartForm.File {
for _, file := range files {
if len(file.Filename) > 0 {
exit := strings.ToLower(filepath.Ext(file.Filename))
if !lists.ContainsString(extensions, exit) {
extensions = append(extensions, exit)
}
}
}
}
value = strings.Join(extensions, ",")
}
return
}
func (this *RequestUploadCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}
func (this *RequestUploadCheckpoint) ParamOptions() *ParamOptions {
option := NewParamOptions()
option.AddParam("最小文件尺寸", "minSize")
option.AddParam("最大文件尺寸", "maxSize")
option.AddParam("扩展名(如.txt)", "ext")
option.AddParam("原始文件名", "name")
option.AddParam("表单字段名", "field")
return option
}

View File

@@ -0,0 +1,81 @@
package checkpoints
import (
"bytes"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"io/ioutil"
"mime/multipart"
"net/http"
"testing"
)
func TestRequestUploadCheckpoint_RequestValue(t *testing.T) {
body := bytes.NewBuffer([]byte{})
writer := multipart.NewWriter(body)
{
part, err := writer.CreateFormField("name")
if err == nil {
part.Write([]byte("lu"))
}
}
{
part, err := writer.CreateFormField("age")
if err == nil {
part.Write([]byte("20"))
}
}
{
part, err := writer.CreateFormFile("myFile", "hello.txt")
if err == nil {
part.Write([]byte("Hello, World!"))
}
}
{
part, err := writer.CreateFormFile("myFile2", "hello.PHP")
if err == nil {
part.Write([]byte("Hello, World, PHP!"))
}
}
{
part, err := writer.CreateFormFile("myFile3", "hello.asp")
if err == nil {
part.Write([]byte("Hello, World, ASP Pages!"))
}
}
{
part, err := writer.CreateFormFile("myFile4", "hello.asp")
if err == nil {
part.Write([]byte("Hello, World, ASP Pages!"))
}
}
writer.Close()
rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn/", body)
if err != nil {
t.Fatal()
}
req := requests.NewRequest(rawReq)
req.Header.Add("Content-Type", writer.FormDataContentType())
checkpoint := new(RequestUploadCheckpoint)
t.Log(checkpoint.RequestValue(req, "field", nil))
t.Log(checkpoint.RequestValue(req, "minSize", nil))
t.Log(checkpoint.RequestValue(req, "maxSize", nil))
t.Log(checkpoint.RequestValue(req, "name", nil))
t.Log(checkpoint.RequestValue(req, "ext", nil))
data, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Fatal(err)
}
t.Log(string(data))
}

View File

@@ -0,0 +1,25 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestURICheckpoint struct {
Checkpoint
}
func (this *RequestURICheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if len(req.RequestURI) > 0 {
value = req.RequestURI
} else if req.URL != nil {
value = req.URL.RequestURI()
}
return
}
func (this *RequestURICheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,21 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
type RequestUserAgentCheckpoint struct {
Checkpoint
}
func (this *RequestUserAgentCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = req.UserAgent()
return
}
func (this *RequestUserAgentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,41 @@
package checkpoints
import (
"bytes"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"io/ioutil"
)
// ${responseBody}
type ResponseBodyCheckpoint struct {
Checkpoint
}
func (this *ResponseBodyCheckpoint) IsRequest() bool {
return false
}
func (this *ResponseBodyCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = ""
return
}
func (this *ResponseBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = ""
if resp != nil && resp.Body != nil {
if len(resp.BodyData) > 0 {
value = string(resp.BodyData)
return
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
sysErr = err
return
}
resp.BodyData = body
_ = resp.Body.Close()
value = body
resp.Body = ioutil.NopCloser(bytes.NewBuffer(body))
}
return
}

View File

@@ -0,0 +1,29 @@
package checkpoints
import (
"bytes"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"io/ioutil"
"net/http"
"testing"
)
func TestResponseBodyCheckpoint_ResponseValue(t *testing.T) {
resp := requests.NewResponse(new(http.Response))
resp.StatusCode = 200
resp.Header = http.Header{}
resp.Header.Set("Hello", "World")
resp.Body = ioutil.NopCloser(bytes.NewBuffer([]byte("Hello, World")))
checkpoint := new(ResponseBodyCheckpoint)
t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
t.Log("after read:", string(data))
}

View File

@@ -0,0 +1,27 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
// ${bytesSent}
type ResponseBytesSentCheckpoint struct {
Checkpoint
}
func (this *ResponseBytesSentCheckpoint) IsRequest() bool {
return false
}
func (this *ResponseBytesSentCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = 0
return
}
func (this *ResponseBytesSentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = 0
if resp != nil {
value = resp.ContentLength
}
return
}

View File

@@ -0,0 +1,28 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
// ${responseHeader.arg}
type ResponseHeaderCheckpoint struct {
Checkpoint
}
func (this *ResponseHeaderCheckpoint) IsRequest() bool {
return false
}
func (this *ResponseHeaderCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = ""
return
}
func (this *ResponseHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if resp != nil && resp.Header != nil {
value = resp.Header.Get(param)
} else {
value = ""
}
return
}

View File

@@ -0,0 +1,17 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
"testing"
)
func TestResponseHeaderCheckpoint_ResponseValue(t *testing.T) {
resp := requests.NewResponse(new(http.Response))
resp.StatusCode = 200
resp.Header = http.Header{}
resp.Header.Set("Hello", "World")
checkpoint := new(ResponseHeaderCheckpoint)
t.Log(checkpoint.ResponseValue(nil, resp, "Hello", nil))
}

View File

@@ -0,0 +1,26 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
// ${bytesSent}
type ResponseStatusCheckpoint struct {
Checkpoint
}
func (this *ResponseStatusCheckpoint) IsRequest() bool {
return false
}
func (this *ResponseStatusCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
value = 0
return
}
func (this *ResponseStatusCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) {
if resp != nil {
value = resp.StatusCode
}
return
}

View File

@@ -0,0 +1,15 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"net/http"
"testing"
)
func TestResponseStatusCheckpoint_ResponseValue(t *testing.T) {
resp := requests.NewResponse(new(http.Response))
resp.StatusCode = 200
checkpoint := new(ResponseStatusCheckpoint)
t.Log(checkpoint.ResponseValue(nil, resp, "", nil))
}

View File

@@ -0,0 +1,21 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
// just a sample checkpoint, copy and change it for your new checkpoint
type SampleRequestCheckpoint struct {
Checkpoint
}
func (this *SampleRequestCheckpoint) RequestValue(req *requests.Request, param string, options map[string]string) (value interface{}, sysErr error, userErr error) {
return
}
func (this *SampleRequestCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]string) (value interface{}, sysErr error, userErr error) {
if this.IsRequest() {
return this.RequestValue(req, param, options)
}
return
}

View File

@@ -0,0 +1,22 @@
package checkpoints
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
// just a sample checkpoint, copy and change it for your new checkpoint
type SampleResponseCheckpoint struct {
Checkpoint
}
func (this *SampleResponseCheckpoint) IsRequest() bool {
return false
}
func (this *SampleResponseCheckpoint) RequestValue(req *requests.Request, param string, options map[string]string) (value interface{}, sysErr error, userErr error) {
return
}
func (this *SampleResponseCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]string) (value interface{}, sysErr error, userErr error) {
return
}

View File

@@ -0,0 +1,235 @@
package checkpoints
// all check points list
var AllCheckpoints = []*CheckpointDefinition{
{
Name: "客户端地址IP",
Prefix: "remoteAddr",
Description: "试图通过分析X-Forwarded-For等Header获取的客户端地址比如192.168.1.100",
HasParams: false,
Instance: new(RequestRemoteAddrCheckpoint),
},
{
Name: "客户端源地址IP",
Prefix: "rawRemoteAddr",
Description: "直接连接的客户端地址比如192.168.1.100",
HasParams: false,
Instance: new(RequestRawRemoteAddrCheckpoint),
},
{
Name: "客户端端口",
Prefix: "remotePort",
Description: "直接连接的客户端地址端口",
HasParams: false,
Instance: new(RequestRemotePortCheckpoint),
},
{
Name: "客户端用户名",
Prefix: "remoteUser",
Description: "通过BasicAuth登录的客户端用户名",
HasParams: false,
Instance: new(RequestRemoteUserCheckpoint),
},
{
Name: "请求URI",
Prefix: "requestURI",
Description: "包含URL参数的请求URI比如/hello/world?lang=go",
HasParams: false,
Instance: new(RequestURICheckpoint),
},
{
Name: "请求路径",
Prefix: "requestPath",
Description: "不包含URL参数的请求路径比如/hello/world",
HasParams: false,
Instance: new(RequestPathCheckpoint),
},
{
Name: "请求内容长度",
Prefix: "requestLength",
Description: "请求Header中的Content-Length",
HasParams: false,
Instance: new(RequestLengthCheckpoint),
},
{
Name: "请求体内容",
Prefix: "requestBody",
Description: "通常在POST或者PUT等操作时会附带请求体最大限制32M",
HasParams: false,
Instance: new(RequestBodyCheckpoint),
},
{
Name: "请求URI和请求体组合",
Prefix: "requestAll",
Description: "${requestURI}和${requestBody}组合",
HasParams: false,
Instance: new(RequestAllCheckpoint),
},
{
Name: "请求表单参数",
Prefix: "requestForm",
Description: "获取POST或者其他方法发送的表单参数最大请求体限制32M",
HasParams: true,
Instance: new(RequestFormArgCheckpoint),
},
{
Name: "上传文件",
Prefix: "requestUpload",
Description: "获取POST上传的文件信息最大请求体限制32M",
HasParams: true,
Instance: new(RequestUploadCheckpoint),
},
{
Name: "请求JSON参数",
Prefix: "requestJSON",
Description: "获取POST或者其他方法发送的JSON最大请求体限制32M使用点.)符号表示多级数据",
HasParams: true,
Instance: new(RequestJSONArgCheckpoint),
},
{
Name: "请求方法",
Prefix: "requestMethod",
Description: "比如GET、POST",
HasParams: false,
Instance: new(RequestMethodCheckpoint),
},
{
Name: "请求协议",
Prefix: "scheme",
Description: "比如http或https",
HasParams: false,
Instance: new(RequestSchemeCheckpoint),
},
{
Name: "HTTP协议版本",
Prefix: "proto",
Description: "比如HTTP/1.1",
HasParams: false,
Instance: new(RequestProtoCheckpoint),
},
{
Name: "主机名",
Prefix: "host",
Description: "比如teaos.cn",
HasParams: false,
Instance: new(RequestHostCheckpoint),
},
{
Name: "请求来源URL",
Prefix: "referer",
Description: "请求Header中的Referer值",
HasParams: false,
Instance: new(RequestRefererCheckpoint),
},
{
Name: "客户端信息",
Prefix: "userAgent",
Description: "比如Mozilla/5.0 AppleWebKit/537.36 (KHTML, like Gecko) Chrome/73.0.3683.103",
HasParams: false,
Instance: new(RequestUserAgentCheckpoint),
},
{
Name: "内容类型",
Prefix: "contentType",
Description: "请求Header的Content-Type",
HasParams: false,
Instance: new(RequestContentTypeCheckpoint),
},
{
Name: "所有cookie组合字符串",
Prefix: "cookies",
Description: "比如sid=IxZVPFhE&city=beijing&uid=18237",
HasParams: false,
Instance: new(RequestCookiesCheckpoint),
},
{
Name: "单个cookie值",
Prefix: "cookie",
Description: "单个cookie值",
HasParams: true,
Instance: new(RequestCookieCheckpoint),
},
{
Name: "所有URL参数组合",
Prefix: "args",
Description: "比如name=lu&age=20",
HasParams: false,
Instance: new(RequestArgsCheckpoint),
},
{
Name: "单个URL参数值",
Prefix: "arg",
Description: "单个URL参数值",
HasParams: true,
Instance: new(RequestArgCheckpoint),
},
{
Name: "所有Header信息",
Prefix: "headers",
Description: "使用\n隔开的Header信息字符串",
HasParams: false,
Instance: new(RequestHeadersCheckpoint),
},
{
Name: "单个Header值",
Prefix: "header",
Description: "单个Header值",
HasParams: true,
Instance: new(RequestHeaderCheckpoint),
},
{
Name: "CC统计",
Prefix: "cc",
Description: "统计某段时间段内的请求信息",
HasParams: true,
Instance: new(CCCheckpoint),
},
{
Name: "响应状态码",
Prefix: "status",
Description: "响应状态码比如200、404、500",
HasParams: false,
Instance: new(ResponseStatusCheckpoint),
},
{
Name: "响应Header",
Prefix: "responseHeader",
Description: "响应Header值",
HasParams: true,
Instance: new(ResponseHeaderCheckpoint),
},
{
Name: "响应内容",
Prefix: "responseBody",
Description: "响应内容字符串",
HasParams: false,
Instance: new(ResponseBodyCheckpoint),
},
{
Name: "响应内容长度",
Prefix: "bytesSent",
Description: "响应内容长度通过响应的Header Content-Length获取",
HasParams: false,
Instance: new(ResponseBytesSentCheckpoint),
},
}
// find a check point
func FindCheckpoint(prefix string) CheckpointInterface {
for _, def := range AllCheckpoints {
if def.Prefix == prefix {
return def.Instance
}
}
return nil
}
// find a check point definition
func FindCheckpointDefinition(prefix string) *CheckpointDefinition {
for _, def := range AllCheckpoints {
if def.Prefix == prefix {
return def
}
}
return nil
}

View File

@@ -0,0 +1,31 @@
package checkpoints
import (
"fmt"
"strings"
"testing"
)
func TestFindCheckpointDefinition_Markdown(t *testing.T) {
result := []string{}
for _, def := range AllCheckpoints {
row := "## " + def.Name + "\n* 前缀:`${" + def.Prefix + "}`\n* 描述:" + def.Description
if def.HasParams {
row += "\n* 是否有子参数YES"
paramOptions := def.Instance.ParamOptions()
if paramOptions != nil && len(paramOptions.Options) > 0 {
row += "\n* 可选子参数"
for _, option := range paramOptions.Options {
row += "\n * `" + option.Name + "`:值为 `" + option.Value + "`"
}
}
} else {
row += "\n* 是否有子参数NO"
}
row += "\n"
result = append(result, row)
}
fmt.Print(strings.Join(result, "\n") + "\n")
}

154
internal/waf/ip_table.go Normal file
View File

@@ -0,0 +1,154 @@
package waf
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string"
"regexp"
"strings"
"time"
)
type IPAction = string
var RegexpDigitNumber = regexp.MustCompile("^\\d+$")
const (
IPActionAccept IPAction = "accept"
IPActionReject IPAction = "reject"
)
// ip table
type IPTable struct {
Id string `yaml:"id" json:"id"`
On bool `yaml:"on" json:"on"`
IP string `yaml:"ip" json:"ip"` // single ip, cidr, ip range, TODO support *
Port string `yaml:"port" json:"port"` // single port, range, *
Action IPAction `yaml:"action" json:"action"` // accept, reject
TimeFrom int64 `yaml:"timeFrom" json:"timeFrom"` // from timestamp
TimeTo int64 `yaml:"timeTo" json:"timeTo"` // zero means forever
Remark string `yaml:"remark" json:"remark"`
// port
minPort int
maxPort int
minPortWildcard bool
maxPortWildcard bool
ports []int
// ip
ipRange *shared.IPRangeConfig
}
func NewIPTable() *IPTable {
return &IPTable{
On: true,
Id: stringutil.Rand(16),
}
}
func (this *IPTable) Init() error {
// parse port
if RegexpDigitNumber.MatchString(this.Port) {
this.minPort = types.Int(this.Port)
this.maxPort = types.Int(this.Port)
} else if regexp.MustCompile(`[:-]`).MatchString(this.Port) {
pieces := regexp.MustCompile(`[:-]`).Split(this.Port, 2)
if pieces[0] == "*" {
this.minPortWildcard = true
} else {
this.minPort = types.Int(pieces[0])
}
if pieces[1] == "*" {
this.maxPortWildcard = true
} else {
this.maxPort = types.Int(pieces[1])
}
} else if strings.Contains(this.Port, ",") {
pieces := strings.Split(this.Port, ",")
for _, piece := range pieces {
piece = strings.TrimSpace(piece)
if len(piece) > 0 {
this.ports = append(this.ports, types.Int(piece))
}
}
} else if this.Port == "*" {
this.minPortWildcard = true
this.maxPortWildcard = true
}
// parse ip
if len(this.IP) > 0 {
ipRange, err := shared.ParseIPRange(this.IP)
if err != nil {
return err
}
this.ipRange = ipRange
}
return nil
}
// check ip
func (this *IPTable) Match(ip string, port int) (isMatched bool) {
if !this.On {
return
}
now := time.Now().Unix()
if this.TimeFrom > 0 && now < this.TimeFrom {
return
}
if this.TimeTo > 0 && now > this.TimeTo {
return
}
if !this.matchPort(port) {
return
}
if !this.matchIP(ip) {
return
}
return true
}
func (this *IPTable) matchPort(port int) bool {
if port == 0 {
return false
}
if this.minPortWildcard {
if this.maxPortWildcard {
return true
}
if this.maxPort >= port {
return true
}
}
if this.maxPortWildcard {
if this.minPortWildcard {
return true
}
if this.minPort <= port {
return true
}
}
if (this.minPort > 0 || this.maxPort > 0) && this.minPort <= port && this.maxPort >= port {
return true
}
if len(this.ports) > 0 {
return lists.ContainsInt(this.ports, port)
}
return false
}
func (this *IPTable) matchIP(ip string) bool {
if this.ipRange == nil {
return false
}
return this.ipRange.Contains(ip)
}

View File

@@ -0,0 +1,142 @@
package waf
import (
"github.com/iwind/TeaGo/assert"
"testing"
"time"
)
func TestIPTable_MatchIP(t *testing.T) {
a := assert.NewAssertion(t)
{
table := NewIPTable()
err := table.Init()
if err != nil {
t.Fatal(err)
}
a.IsFalse(table.Match("192.168.1.100", 8080))
}
{
table := NewIPTable()
table.IP = "*"
table.Port = "8080"
err := table.Init()
if err != nil {
t.Fatal(err)
}
t.Log("port:", table.minPort, table.maxPort)
a.IsTrue(table.Match("192.168.1.100", 8080))
a.IsFalse(table.Match("192.168.1.100", 8081))
}
{
table := NewIPTable()
table.IP = "*"
table.Port = "8080-8082"
err := table.Init()
if err != nil {
t.Fatal(err)
}
t.Log("port:", table.minPort, table.maxPort)
a.IsTrue(table.Match("192.168.1.100", 8080))
a.IsTrue(table.Match("192.168.1.100", 8081))
a.IsFalse(table.Match("192.168.1.100", 8083))
}
{
table := NewIPTable()
table.IP = "*"
table.Port = "*-8082"
err := table.Init()
if err != nil {
t.Fatal(err)
}
t.Log("port:", table.minPort, table.maxPort)
a.IsTrue(table.Match("192.168.1.100", 8079))
a.IsTrue(table.Match("192.168.1.100", 8080))
a.IsTrue(table.Match("192.168.1.100", 8081))
a.IsFalse(table.Match("192.168.1.100", 8083))
}
{
table := NewIPTable()
table.IP = "*"
table.Port = "8080-*"
err := table.Init()
if err != nil {
t.Fatal(err)
}
t.Log("port:", table.minPort, table.maxPort)
a.IsFalse(table.Match("192.168.1.100", 8079))
a.IsTrue(table.Match("192.168.1.100", 8080))
a.IsTrue(table.Match("192.168.1.100", 8081))
a.IsTrue(table.Match("192.168.1.100", 8083))
}
{
table := NewIPTable()
table.IP = "*"
table.Port = "*"
err := table.Init()
if err != nil {
t.Fatal(err)
}
t.Log("port:", table.minPort, table.maxPort)
a.IsTrue(table.Match("192.168.1.100", 8079))
a.IsTrue(table.Match("192.168.1.100", 8080))
a.IsTrue(table.Match("192.168.1.100", 8081))
a.IsTrue(table.Match("192.168.1.100", 8083))
}
{
table := NewIPTable()
table.IP = "192.168.1.100"
table.Port = "*"
err := table.Init()
if err != nil {
t.Fatal(err)
}
t.Log("port:", table.minPort, table.maxPort)
a.IsTrue(table.Match("192.168.1.100", 8080))
}
{
table := NewIPTable()
table.IP = "192.168.1.99-192.168.1.101"
table.Port = "*"
err := table.Init()
if err != nil {
t.Fatal(err)
}
t.Log("port:", table.minPort, table.maxPort)
a.IsTrue(table.Match("192.168.1.100", 8080))
}
{
table := NewIPTable()
table.IP = "192.168.1.99/24"
table.Port = "*"
err := table.Init()
if err != nil {
t.Fatal(err)
}
t.Log("ip:", table.ipRange)
a.IsTrue(table.Match("192.168.1.100", 8080))
a.IsFalse(table.Match("192.168.2.100", 8080))
}
{
table := NewIPTable()
table.IP = "192.168.1.99/24"
table.TimeTo = time.Now().Unix() - 10
table.Port = "*"
err := table.Init()
if err != nil {
t.Fatal(err)
}
a.IsFalse(table.Match("192.168.1.100", 8080))
a.IsFalse(table.Match("192.168.2.100", 8080))
}
}

View File

@@ -0,0 +1,35 @@
package requests
import (
"bytes"
"io"
"io/ioutil"
"net/http"
)
type Request struct {
*http.Request
BodyData []byte
}
func NewRequest(raw *http.Request) *Request {
return &Request{
Request: raw,
}
}
func (this *Request) Raw() *http.Request {
return this.Request
}
func (this *Request) ReadBody(max int64) (data []byte, err error) {
data, err = ioutil.ReadAll(io.LimitReader(this.Request.Body, max))
return
}
func (this *Request) RestoreBody(data []byte) {
rawReader := bytes.NewBuffer(data)
buf := make([]byte, 1024)
io.CopyBuffer(rawReader, this.Request.Body, buf)
this.Request.Body = ioutil.NopCloser(rawReader)
}

View File

@@ -0,0 +1,15 @@
package requests
import "net/http"
type Response struct {
*http.Response
BodyData []byte
}
func NewResponse(resp *http.Response) *Response {
return &Response{
Response: resp,
}
}

584
internal/waf/rule.go Normal file
View File

@@ -0,0 +1,584 @@
package waf
import (
"bytes"
"encoding/binary"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
"github.com/iwind/TeaGo/utils/string"
"net"
"reflect"
"regexp"
"strings"
)
var singleParamRegexp = regexp.MustCompile("^\\${[\\w.-]+}$")
// rule
type Rule struct {
Description string `yaml:"description" json:"description"`
Param string `yaml:"param" json:"param"` // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName}
Operator RuleOperator `yaml:"operator" json:"operator"` // such as contains, gt, ...
Value string `yaml:"value" json:"value"` // compared value
IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"`
CheckpointOptions map[string]interface{} `yaml:"checkpointOptions" json:"checkpointOptions"`
checkpointFinder func(prefix string) checkpoints.CheckpointInterface
singleParam string // real param after prefix
singleCheckpoint checkpoints.CheckpointInterface // if is single check point
multipleCheckpoints map[string]checkpoints.CheckpointInterface
isIP bool
ipValue net.IP
floatValue float64
reg *regexp.Regexp
}
func NewRule() *Rule {
return &Rule{}
}
func (this *Rule) Init() error {
// operator
switch this.Operator {
case RuleOperatorGt:
this.floatValue = types.Float64(this.Value)
case RuleOperatorGte:
this.floatValue = types.Float64(this.Value)
case RuleOperatorLt:
this.floatValue = types.Float64(this.Value)
case RuleOperatorLte:
this.floatValue = types.Float64(this.Value)
case RuleOperatorEq:
this.floatValue = types.Float64(this.Value)
case RuleOperatorNeq:
this.floatValue = types.Float64(this.Value)
case RuleOperatorMatch:
v := this.Value
if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") {
v = "(?i)" + v
}
v = this.unescape(v)
reg, err := regexp.Compile(v)
if err != nil {
return err
}
this.reg = reg
case RuleOperatorNotMatch:
v := this.Value
if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") {
v = "(?i)" + v
}
v = this.unescape(v)
reg, err := regexp.Compile(v)
if err != nil {
return err
}
this.reg = reg
case RuleOperatorEqIP, RuleOperatorGtIP, RuleOperatorGteIP, RuleOperatorLtIP, RuleOperatorLteIP:
this.ipValue = net.ParseIP(this.Value)
this.isIP = this.ipValue != nil
if !this.isIP {
return errors.New("value should be a valid ip")
}
case RuleOperatorIPRange, RuleOperatorNotIPRange:
if strings.Contains(this.Value, ",") {
ipList := strings.SplitN(this.Value, ",", 2)
ipString1 := strings.TrimSpace(ipList[0])
ipString2 := strings.TrimSpace(ipList[1])
if len(ipString1) > 0 {
ip1 := net.ParseIP(ipString1)
if ip1 == nil {
return errors.New("start ip is invalid")
}
}
if len(ipString2) > 0 {
ip2 := net.ParseIP(ipString2)
if ip2 == nil {
return errors.New("end ip is invalid")
}
}
} else if strings.Contains(this.Value, "/") {
_, _, err := net.ParseCIDR(this.Value)
if err != nil {
return err
}
} else {
return errors.New("invalid ip range")
}
}
if singleParamRegexp.MatchString(this.Param) {
param := this.Param[2 : len(this.Param)-1]
pieces := strings.SplitN(param, ".", 2)
prefix := pieces[0]
if len(pieces) == 1 {
this.singleParam = ""
} else {
this.singleParam = pieces[1]
}
if this.checkpointFinder != nil {
checkpoint := this.checkpointFinder(prefix)
if checkpoint == nil {
return errors.New("no check point '" + prefix + "' found")
}
this.singleCheckpoint = checkpoint
} else {
checkpoint := checkpoints.FindCheckpoint(prefix)
if checkpoint == nil {
return errors.New("no check point '" + prefix + "' found")
}
checkpoint.Init()
this.singleCheckpoint = checkpoint
}
return nil
}
this.multipleCheckpoints = map[string]checkpoints.CheckpointInterface{}
var err error = nil
configutils.ParseVariables(this.Param, func(varName string) (value string) {
pieces := strings.SplitN(varName, ".", 2)
prefix := pieces[0]
if this.checkpointFinder != nil {
checkpoint := this.checkpointFinder(prefix)
if checkpoint == nil {
err = errors.New("no check point '" + prefix + "' found")
} else {
this.multipleCheckpoints[prefix] = checkpoint
}
} else {
checkpoint := checkpoints.FindCheckpoint(prefix)
if checkpoint == nil {
err = errors.New("no check point '" + prefix + "' found")
} else {
checkpoint.Init()
this.multipleCheckpoints[prefix] = checkpoint
}
}
return ""
})
return err
}
func (this *Rule) MatchRequest(req *requests.Request) (b bool, err error) {
if this.singleCheckpoint != nil {
value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions)
if err != nil {
return false, err
}
return this.Test(value), nil
}
value := configutils.ParseVariables(this.Param, func(varName string) (value string) {
pieces := strings.SplitN(varName, ".", 2)
prefix := pieces[0]
point, ok := this.multipleCheckpoints[prefix]
if !ok {
return ""
}
if len(pieces) == 1 {
value1, err1, _ := point.RequestValue(req, "", this.CheckpointOptions)
if err1 != nil {
err = err1
}
return types.String(value1)
}
value1, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions)
if err1 != nil {
err = err1
}
return types.String(value1)
})
if err != nil {
return false, err
}
return this.Test(value), nil
}
func (this *Rule) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, err error) {
if this.singleCheckpoint != nil {
// if is request param
if this.singleCheckpoint.IsRequest() {
value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions)
if err != nil {
return false, err
}
return this.Test(value), nil
}
// response param
value, err, _ := this.singleCheckpoint.ResponseValue(req, resp, this.singleParam, this.CheckpointOptions)
if err != nil {
return false, err
}
return this.Test(value), nil
}
value := configutils.ParseVariables(this.Param, func(varName string) (value string) {
pieces := strings.SplitN(varName, ".", 2)
prefix := pieces[0]
point, ok := this.multipleCheckpoints[prefix]
if !ok {
return ""
}
if len(pieces) == 1 {
if point.IsRequest() {
value1, err1, _ := point.RequestValue(req, "", this.CheckpointOptions)
if err1 != nil {
err = err1
}
return types.String(value1)
} else {
value1, err1, _ := point.ResponseValue(req, resp, "", this.CheckpointOptions)
if err1 != nil {
err = err1
}
return types.String(value1)
}
}
if point.IsRequest() {
value1, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions)
if err1 != nil {
err = err1
}
return types.String(value1)
} else {
value1, err1, _ := point.ResponseValue(req, resp, pieces[1], this.CheckpointOptions)
if err1 != nil {
err = err1
}
return types.String(value1)
}
})
if err != nil {
return false, err
}
return this.Test(value), nil
}
func (this *Rule) Test(value interface{}) bool {
// operator
switch this.Operator {
case RuleOperatorGt:
return types.Float64(value) > this.floatValue
case RuleOperatorGte:
return types.Float64(value) >= this.floatValue
case RuleOperatorLt:
return types.Float64(value) < this.floatValue
case RuleOperatorLte:
return types.Float64(value) <= this.floatValue
case RuleOperatorEq:
return types.Float64(value) == this.floatValue
case RuleOperatorNeq:
return types.Float64(value) != this.floatValue
case RuleOperatorEqString:
if this.IsCaseInsensitive {
return strings.ToLower(types.String(value)) == strings.ToLower(this.Value)
} else {
return types.String(value) == this.Value
}
case RuleOperatorNeqString:
if this.IsCaseInsensitive {
return strings.ToLower(types.String(value)) != strings.ToLower(this.Value)
} else {
return types.String(value) != this.Value
}
case RuleOperatorMatch:
if value == nil {
return false
}
// strings
stringList, ok := value.([]string)
if ok {
for _, s := range stringList {
if utils.MatchStringCache(this.reg, s) {
return true
}
}
return false
}
// bytes
byteSlice, ok := value.([]byte)
if ok {
return utils.MatchBytesCache(this.reg, byteSlice)
}
// string
return utils.MatchStringCache(this.reg, types.String(value))
case RuleOperatorNotMatch:
if value == nil {
return true
}
stringList, ok := value.([]string)
if ok {
for _, s := range stringList {
if utils.MatchStringCache(this.reg, s) {
return false
}
}
return true
}
// bytes
byteSlice, ok := value.([]byte)
if ok {
return !utils.MatchBytesCache(this.reg, byteSlice)
}
return !utils.MatchStringCache(this.reg, types.String(value))
case RuleOperatorContains:
if types.IsSlice(value) {
ok := false
lists.Each(value, func(k int, v interface{}) {
if types.String(v) == this.Value {
ok = true
}
})
return ok
}
if types.IsMap(value) {
lowerValue := ""
if this.IsCaseInsensitive {
lowerValue = strings.ToLower(this.Value)
}
for _, v := range maps.NewMap(value) {
if this.IsCaseInsensitive {
if strings.ToLower(types.String(v)) == lowerValue {
return true
}
} else {
if types.String(v) == this.Value {
return true
}
}
}
return false
}
if this.IsCaseInsensitive {
return strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
} else {
return strings.Contains(types.String(value), this.Value)
}
case RuleOperatorNotContains:
if this.IsCaseInsensitive {
return !strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
} else {
return !strings.Contains(types.String(value), this.Value)
}
case RuleOperatorPrefix:
if this.IsCaseInsensitive {
return strings.HasPrefix(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
} else {
return strings.HasPrefix(types.String(value), this.Value)
}
case RuleOperatorSuffix:
if this.IsCaseInsensitive {
return strings.HasSuffix(strings.ToLower(types.String(value)), strings.ToLower(this.Value))
} else {
return strings.HasSuffix(types.String(value), this.Value)
}
case RuleOperatorHasKey:
if types.IsSlice(value) {
index := types.Int(this.Value)
if index < 0 {
return false
}
return reflect.ValueOf(value).Len() > index
} else if types.IsMap(value) {
m := maps.NewMap(value)
if this.IsCaseInsensitive {
lowerValue := strings.ToLower(this.Value)
for k := range m {
if strings.ToLower(k) == lowerValue {
return true
}
}
} else {
return m.Has(this.Value)
}
} else {
return false
}
case RuleOperatorVersionGt:
return stringutil.VersionCompare(this.Value, types.String(value)) > 0
case RuleOperatorVersionLt:
return stringutil.VersionCompare(this.Value, types.String(value)) < 0
case RuleOperatorVersionRange:
if strings.Contains(this.Value, ",") {
versions := strings.SplitN(this.Value, ",", 2)
version1 := strings.TrimSpace(versions[0])
version2 := strings.TrimSpace(versions[1])
if len(version1) > 0 && stringutil.VersionCompare(types.String(value), version1) < 0 {
return false
}
if len(version2) > 0 && stringutil.VersionCompare(types.String(value), version2) > 0 {
return false
}
return true
} else {
return stringutil.VersionCompare(types.String(value), this.Value) >= 0
}
case RuleOperatorEqIP:
ip := net.ParseIP(types.String(value))
if ip == nil {
return false
}
return this.isIP && bytes.Compare(this.ipValue, ip) == 0
case RuleOperatorGtIP:
ip := net.ParseIP(types.String(value))
if ip == nil {
return false
}
return this.isIP && bytes.Compare(ip, this.ipValue) > 0
case RuleOperatorGteIP:
ip := net.ParseIP(types.String(value))
if ip == nil {
return false
}
return this.isIP && bytes.Compare(ip, this.ipValue) >= 0
case RuleOperatorLtIP:
ip := net.ParseIP(types.String(value))
if ip == nil {
return false
}
return this.isIP && bytes.Compare(ip, this.ipValue) < 0
case RuleOperatorLteIP:
ip := net.ParseIP(types.String(value))
if ip == nil {
return false
}
return this.isIP && bytes.Compare(ip, this.ipValue) <= 0
case RuleOperatorIPRange:
return this.containsIP(value)
case RuleOperatorNotIPRange:
return !this.containsIP(value)
case RuleOperatorIPMod:
pieces := strings.SplitN(this.Value, ",", 2)
if len(pieces) == 1 {
rem := types.Int64(pieces[0])
return this.ipToInt64(net.ParseIP(types.String(value)))%10 == rem
}
div := types.Int64(pieces[0])
if div == 0 {
return false
}
rem := types.Int64(pieces[1])
return this.ipToInt64(net.ParseIP(types.String(value)))%div == rem
case RuleOperatorIPMod10:
return this.ipToInt64(net.ParseIP(types.String(value)))%10 == types.Int64(this.Value)
case RuleOperatorIPMod100:
return this.ipToInt64(net.ParseIP(types.String(value)))%100 == types.Int64(this.Value)
}
return false
}
func (this *Rule) IsSingleCheckpoint() bool {
return this.singleCheckpoint != nil
}
func (this *Rule) SetCheckpointFinder(finder func(prefix string) checkpoints.CheckpointInterface) {
this.checkpointFinder = finder
}
func (this *Rule) unescape(v string) string {
//replace urlencoded characters
v = strings.Replace(v, `\s`, `(\s|%09|%0A|\+)`, -1)
v = strings.Replace(v, `\(`, `(\(|%28)`, -1)
v = strings.Replace(v, `=`, `(=|%3D)`, -1)
v = strings.Replace(v, `<`, `(<|%3C)`, -1)
v = strings.Replace(v, `\*`, `(\*|%2A)`, -1)
v = strings.Replace(v, `\\`, `(\\|%2F)`, -1)
v = strings.Replace(v, `!`, `(!|%21)`, -1)
v = strings.Replace(v, `/`, `(/|%2F)`, -1)
v = strings.Replace(v, `;`, `(;|%3B)`, -1)
v = strings.Replace(v, `\+`, `(\+|%20)`, -1)
return v
}
func (this *Rule) containsIP(value interface{}) bool {
ip := net.ParseIP(types.String(value))
if ip == nil {
return false
}
// 检查IP范围格式
if strings.Contains(this.Value, ",") {
ipList := strings.SplitN(this.Value, ",", 2)
ipString1 := strings.TrimSpace(ipList[0])
ipString2 := strings.TrimSpace(ipList[1])
if len(ipString1) > 0 {
ip1 := net.ParseIP(ipString1)
if ip1 == nil {
return false
}
if bytes.Compare(ip, ip1) < 0 {
return false
}
}
if len(ipString2) > 0 {
ip2 := net.ParseIP(ipString2)
if ip2 == nil {
return false
}
if bytes.Compare(ip, ip2) > 0 {
return false
}
}
return true
} else if strings.Contains(this.Value, "/") {
_, ipNet, err := net.ParseCIDR(this.Value)
if err != nil {
return false
}
return ipNet.Contains(ip)
} else {
return false
}
}
func (this *Rule) ipToInt64(ip net.IP) int64 {
if len(ip) == 0 {
return 0
}
if len(ip) == 16 {
return int64(binary.BigEndian.Uint32(ip[12:16]))
}
return int64(binary.BigEndian.Uint32(ip))
}

147
internal/waf/rule_group.go Normal file
View File

@@ -0,0 +1,147 @@
package waf
import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
)
// rule group
type RuleGroup struct {
Id string `yaml:"id" json:"id"`
IsOn bool `yaml:"isOn" json:"isOn"`
Name string `yaml:"name" json:"name"` // such as SQL Injection
Description string `yaml:"description" json:"description"`
Code string `yaml:"code" json:"code"` // identify the group
RuleSets []*RuleSet `yaml:"ruleSets" json:"ruleSets"`
IsInbound bool `yaml:"isInbound" json:"isInbound"`
hasRuleSets bool
}
func NewRuleGroup() *RuleGroup {
return &RuleGroup{
IsOn: true,
}
}
func (this *RuleGroup) Init() error {
this.hasRuleSets = len(this.RuleSets) > 0
if this.hasRuleSets {
for _, set := range this.RuleSets {
err := set.Init()
if err != nil {
return err
}
}
}
return nil
}
func (this *RuleGroup) AddRuleSet(ruleSet *RuleSet) {
this.RuleSets = append(this.RuleSets, ruleSet)
}
func (this *RuleGroup) FindRuleSet(id string) *RuleSet {
if len(id) == 0 {
return nil
}
for _, ruleSet := range this.RuleSets {
if ruleSet.Id == id {
return ruleSet
}
}
return nil
}
func (this *RuleGroup) FindRuleSetWithCode(code string) *RuleSet {
if len(code) == 0 {
return nil
}
for _, ruleSet := range this.RuleSets {
if ruleSet.Code == code {
return ruleSet
}
}
return nil
}
func (this *RuleGroup) RemoveRuleSet(id string) {
if len(id) == 0 {
return
}
result := []*RuleSet{}
for _, ruleSet := range this.RuleSets {
if ruleSet.Id == id {
continue
}
result = append(result, ruleSet)
}
this.RuleSets = result
}
func (this *RuleGroup) MatchRequest(req *requests.Request) (b bool, set *RuleSet, err error) {
if !this.hasRuleSets {
return
}
for _, set := range this.RuleSets {
if !set.IsOn {
continue
}
b, err = set.MatchRequest(req)
if err != nil {
return false, nil, err
}
if b {
return true, set, nil
}
}
return
}
func (this *RuleGroup) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, set *RuleSet, err error) {
if !this.hasRuleSets {
return
}
for _, set := range this.RuleSets {
if !set.IsOn {
continue
}
b, err = set.MatchResponse(req, resp)
if err != nil {
return false, nil, err
}
if b {
return true, set, nil
}
}
return
}
func (this *RuleGroup) MoveRuleSet(fromIndex int, toIndex int) {
if fromIndex < 0 || fromIndex >= len(this.RuleSets) {
return
}
if toIndex < 0 || toIndex >= len(this.RuleSets) {
return
}
if fromIndex == toIndex {
return
}
location := this.RuleSets[fromIndex]
result := []*RuleSet{}
for i := 0; i < len(this.RuleSets); i++ {
if i == fromIndex {
continue
}
if fromIndex > toIndex && i == toIndex {
result = append(result, location)
}
result = append(result, this.RuleSets[i])
if fromIndex < toIndex && i == toIndex {
result = append(result, location)
}
}
this.RuleSets = result
}

Some files were not shown because too many files have changed in this diff Show More