diff --git a/go.mod b/go.mod index 99f78e9..a848987 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon require ( github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000 + github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f github.com/dchest/siphash v1.2.1 github.com/go-ole/go-ole v1.2.4 // indirect github.com/go-yaml/yaml v2.1.0+incompatible @@ -14,4 +15,5 @@ require ( github.com/shirou/gopsutil v2.20.9+incompatible golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7 google.golang.org/grpc v1.32.0 + gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 ) diff --git a/go.sum b/go.sum index 5329e7c..1cde562 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f h1:q/DpyjJjZs94bziQ7YkBmIlpqbVP7yw179rnzoNVX1M= +github.com/dchest/captcha v0.0.0-20200903113550-03f5f0333e1f/go.mod h1:QGrK8vMWWHQYQ3QU9bw9Y9OPNfxccGzfb41qjvVeXtY= github.com/dchest/siphash v1.2.1 h1:4cLinnzVJDKxTCl9B01807Yiy+W7ZzVHj/KIroQRvT4= github.com/dchest/siphash v1.2.1/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4= github.com/dgryski/go-rendezvous v0.0.0-20200624174652-8d2f3be8b2d9/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= diff --git a/internal/grids/README.md b/internal/grids/README.md new file mode 100644 index 0000000..c9cca8f --- /dev/null +++ b/internal/grids/README.md @@ -0,0 +1,2 @@ +# Memory Grid +Cache items in memory, using partitions and LRU. \ No newline at end of file diff --git a/internal/grids/cell.go b/internal/grids/cell.go new file mode 100644 index 0000000..c60388d --- /dev/null +++ b/internal/grids/cell.go @@ -0,0 +1,186 @@ +package grids + +import ( + "math" + "sync" + "time" +) + +type Cell struct { + LimitSize int64 + LimitCount int + + mapping map[uint64]*Item // key => item + list *List // { item1, item2, ... } + totalBytes int64 + locker sync.RWMutex +} + +func NewCell() *Cell { + return &Cell{ + mapping: map[uint64]*Item{}, + list: NewList(), + } +} + +func (this *Cell) Write(hashKey uint64, item *Item) { + if item == nil { + return + } + this.locker.Lock() + + oldItem, ok := this.mapping[hashKey] + if ok { + this.list.Remove(oldItem) + + if this.LimitSize > 0 { + this.totalBytes -= oldItem.Size() + } + } + + // limit count + if this.LimitCount > 0 && len(this.mapping) >= this.LimitCount { + this.locker.Unlock() + return + } + + // trim memory + size := item.Size() + shouldTrim := false + if this.LimitSize > 0 && this.LimitSize < this.totalBytes+size { + this.Trim() + shouldTrim = true + } + + // compare again + if shouldTrim { + if this.LimitSize > 0 && this.LimitSize < this.totalBytes+size { + this.locker.Unlock() + return + } + } + + this.totalBytes += size + + this.list.Add(item) + this.mapping[hashKey] = item + + this.locker.Unlock() +} + +func (this *Cell) Increase64(key []byte, expireAt int64, hashKey uint64, delta int64) (result int64) { + this.locker.Lock() + item, ok := this.mapping[hashKey] + if ok { + // reset to zero if expired + if item.ExpireAt < time.Now().Unix() { + item.ValueInt64 = 0 + item.ExpireAt = expireAt + } + item.IncreaseInt64(delta) + result = item.ValueInt64 + } else { + item := NewItem(key, ItemInt64) + item.ValueInt64 = delta + item.ExpireAt = expireAt + this.mapping[hashKey] = item + result = delta + } + this.locker.Unlock() + return +} + +func (this *Cell) Read(hashKey uint64) *Item { + this.locker.Lock() + + item, ok := this.mapping[hashKey] + if ok { + this.list.Remove(item) + this.list.Add(item) + + this.locker.Unlock() + + if item.ExpireAt < time.Now().Unix() { + return nil + } + return item + } + + this.locker.Unlock() + return nil +} + +func (this *Cell) Stat() *CellStat { + this.locker.RLock() + defer this.locker.RUnlock() + + return &CellStat{ + TotalBytes: this.totalBytes, + CountItems: len(this.mapping), + } +} + +// trim NOT ACTIVE items +// should called in locker context +func (this *Cell) Trim() { + l := len(this.mapping) + if l == 0 { + return + } + + inactiveSize := int(math.Ceil(float64(l) / 10)) // trim 10% items + this.list.Range(func(item *Item) (goNext bool) { + inactiveSize-- + delete(this.mapping, item.HashKey()) + this.list.Remove(item) + this.totalBytes -= item.Size() + return inactiveSize > 0 + }) +} + +func (this *Cell) Delete(hashKey uint64) { + this.locker.Lock() + item, ok := this.mapping[hashKey] + if ok { + delete(this.mapping, hashKey) + this.list.Remove(item) + this.totalBytes -= item.Size() + } + this.locker.Unlock() +} + +// range all items in the cell +func (this *Cell) Range(f func(item *Item)) { + this.locker.Lock() + for _, item := range this.mapping { + f(item) + } + this.locker.Unlock() +} + +func (this *Cell) Recycle() { + this.locker.Lock() + if len(this.mapping) == 0 { + this.locker.Unlock() + return + } + + timestamp := time.Now().Unix() + for key, item := range this.mapping { + if item.ExpireAt < timestamp { + delete(this.mapping, key) + this.list.Remove(item) + this.totalBytes -= item.Size() + } + } + + this.locker.Unlock() +} + +func (this *Cell) Reset() { + this.locker.Lock() + this.list.Reset() + this.mapping = map[uint64]*Item{} + this.totalBytes = 0 + this.locker.Unlock() +} diff --git a/internal/grids/cell_stat.go b/internal/grids/cell_stat.go new file mode 100644 index 0000000..aa37efa --- /dev/null +++ b/internal/grids/cell_stat.go @@ -0,0 +1,6 @@ +package grids + +type CellStat struct { + TotalBytes int64 + CountItems int +} diff --git a/internal/grids/cell_test.go b/internal/grids/cell_test.go new file mode 100644 index 0000000..e499e97 --- /dev/null +++ b/internal/grids/cell_test.go @@ -0,0 +1,214 @@ +package grids + +import ( + "fmt" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +func TestCell_List(t *testing.T) { + cell := NewCell() + cell.Write(1, &Item{ + ValueInt64: 1, + }) + cell.Write(2, &Item{ + ValueInt64: 2, + }) + cell.Write(3, &Item{ + ValueInt64: 3, + }) + + { + t.Log("====") + l := cell.list + for e := l.head; e != nil; e = e.Next { + t.Log("element:", e.ValueInt64) + } + } + + cell.Write(1, &Item{ + ValueInt64: 1, + }) + cell.Write(3, &Item{ + ValueInt64: 3, + }) + cell.Write(3, &Item{ + ValueInt64: 3, + }) + cell.Write(2, &Item{ + ValueInt64: 2, + }) + cell.Delete(2) + + { + t.Log("====") + l := cell.list + for e := l.head; e != nil; e = e.Next { + t.Log("element:", e.ValueInt64) + } + } + + for _, m := range cell.mapping { + t.Log(m.ValueInt64) + } +} + +func TestCell_LimitSize(t *testing.T) { + cell := NewCell() + cell.LimitSize = 1024 + + for i := int64(0); i < 100; i ++ { + key := []byte(fmt.Sprintf("%d", i)) + cell.Write(HashKey(key), &Item{ + Key: key, + ValueInt64: i, + Type: ItemInt64, + }) + } + + t.Log("totalBytes:", cell.totalBytes) + + { + t.Log("====") + l := cell.list + s := []string{} + for e := l.head; e != nil; e = e.Next { + s = append(s, fmt.Sprintf("%d", e.ValueInt64)) + } + t.Log("{ " + strings.Join(s, ", ") + " }") + } + + t.Log("mapping:", len(cell.mapping)) + s := []string{} + for _, item := range cell.mapping { + s = append(s, fmt.Sprintf("%d", item.ValueInt64)) + } + t.Log("{ " + strings.Join(s, ", ") + " }") +} + +func TestCell_MemoryUsage(t *testing.T) { + //runtime.GOMAXPROCS(4) + + cell := NewCell() + cell.LimitSize = 1024 * 1024 * 1024 * 1 + + before := time.Now() + + wg := sync.WaitGroup{} + wg.Add(4) + + for j := 0; j < 4; j ++ { + go func(j int) { + start := j * 50 * 10000 + for i := start; i < start+50*10000; i ++ { + key := []byte(strconv.Itoa(i) + "VERY_LONG_STRING") + cell.Write(HashKey(key), &Item{ + Key: key, + ValueInt64: int64(i), + Type: ItemInt64, + }) + } + wg.Done() + }(j) + } + + wg.Wait() + t.Log("items:", len(cell.mapping)) + t.Log(time.Since(before).Seconds(), "s", "totalBytes:", cell.totalBytes/1024/1024, "M") + //time.Sleep(10 * time.Second) +} + +func BenchmarkCell_Write(b *testing.B) { + runtime.GOMAXPROCS(1) + + cell := NewCell() + + for i := 0; i < b.N; i ++ { + key := []byte(strconv.Itoa(i) + "_LONG_KEY_LONG_KEY_LONG_KEY_LONG_KEY") + cell.Write(HashKey(key), &Item{ + Key: key, + ValueInt64: int64(i), + Type: ItemInt64, + }) + } + + b.Log("items:", len(cell.mapping)) +} + +func TestCell_Read(t *testing.T) { + cell := NewCell() + + cell.Write(1, &Item{ + ValueInt64: 1, + ExpireAt: time.Now().Unix() + 3600, + }) + cell.Write(2, &Item{ + ValueInt64: 2, + ExpireAt: time.Now().Unix() + 3600, + }) + cell.Write(3, &Item{ + ValueInt64: 3, + ExpireAt: time.Now().Unix() + 3600, + }) + + { + s := []string{} + cell.list.Range(func(item *Item) (goNext bool) { + s = append(s, fmt.Sprintf("%d", item.ValueInt64)) + return true + }) + t.Log("before:", s) + } + + t.Log(cell.Read(1).ValueInt64) + + { + s := []string{} + cell.list.Range(func(item *Item) (goNext bool) { + s = append(s, fmt.Sprintf("%d", item.ValueInt64)) + return true + }) + t.Log("after:", s) + } + + t.Log(cell.Read(2).ValueInt64) + + { + s := []string{} + cell.list.Range(func(item *Item) (goNext bool) { + s = append(s, fmt.Sprintf("%d", item.ValueInt64)) + return true + }) + t.Log("after:", s) + } +} + +func TestCell_Recycle(t *testing.T) { + cell := NewCell() + cell.Write(1, &Item{ + ValueInt64: 1, + ExpireAt: time.Now().Unix() - 1, + }) + + cell.Write(2, &Item{ + ValueInt64: 2, + ExpireAt: time.Now().Unix() + 1, + }) + + cell.Recycle() + + { + s := []string{} + cell.list.Range(func(item *Item) (goNext bool) { + s = append(s, fmt.Sprintf("%d", item.ValueInt64)) + return true + }) + t.Log("after:", s) + } + + t.Log(cell.list.Len(), cell.totalBytes) +} diff --git a/internal/grids/grid.go b/internal/grids/grid.go new file mode 100644 index 0000000..37fa5c3 --- /dev/null +++ b/internal/grids/grid.go @@ -0,0 +1,225 @@ +package grids + +import ( + "bytes" + "compress/gzip" + "github.com/iwind/TeaGo/logs" + "github.com/iwind/TeaGo/timers" + "math" + "time" +) + +// Memory Cache Grid +// +// | Grid | +// | cell1, cell2, ..., cell1024 | +// | item1, item2, ..., item1000000 | +type Grid struct { + cells []*Cell + countCells uint64 + + recycleIndex int + recycleLooper *timers.Looper + recycleInterval int + + gzipLevel int + + limitSize int64 + limitCount int +} + +func NewGrid(countCells int, opt ...interface{}) *Grid { + grid := &Grid{ + recycleIndex: -1, + } + + for _, o := range opt { + switch x := o.(type) { + case *CompressOpt: + grid.gzipLevel = x.Level + case *LimitSizeOpt: + grid.limitSize = x.Size + case *LimitCountOpt: + grid.limitCount = x.Count + case *RecycleIntervalOpt: + grid.recycleInterval = x.Interval + } + } + + cells := []*Cell{} + if countCells <= 0 { + countCells = 1 + } else if countCells > 100*10000 { + countCells = 100 * 10000 + } + for i := 0; i < countCells; i++ { + cell := NewCell() + cell.LimitSize = int64(math.Floor(float64(grid.limitSize) / float64(countCells))) + cell.LimitCount = int(math.Floor(float64(grid.limitCount)) / float64(countCells)) + + cells = append(cells, cell) + } + grid.cells = cells + grid.countCells = uint64(len(cells)) + + grid.recycleTimer() + return grid +} + +// get all cells in the grid +func (this *Grid) Cells() []*Cell { + return this.cells +} + +func (this *Grid) WriteItem(item *Item) { + if this.countCells <= 0 { + return + } + hashKey := item.HashKey() + this.cellForHashKey(hashKey).Write(hashKey, item) +} + +func (this *Grid) WriteInt64(key []byte, value int64, lifeSeconds int64) { + this.WriteItem(&Item{ + Key: key, + Type: ItemInt64, + ValueInt64: value, + ExpireAt: time.Now().Unix() + lifeSeconds, + }) +} + +func (this *Grid) IncreaseInt64(key []byte, delta int64, lifeSeconds int64) (result int64) { + hashKey := HashKey(key) + return this.cellForHashKey(hashKey).Increase64(key, time.Now().Unix()+lifeSeconds, hashKey, delta) +} + +func (this *Grid) WriteString(key []byte, value string, lifeSeconds int64) { + this.WriteBytes(key, []byte(value), lifeSeconds) +} + +func (this *Grid) WriteBytes(key []byte, value []byte, lifeSeconds int64) { + isCompressed := false + if this.gzipLevel != gzip.NoCompression { + buf := bytes.NewBuffer([]byte{}) + writer, err := gzip.NewWriterLevel(buf, this.gzipLevel) + if err != nil { + logs.Error(err) + this.WriteItem(&Item{ + Key: key, + Type: ItemBytes, + ValueBytes: value, + ExpireAt: time.Now().Unix() + lifeSeconds, + }) + return + } + + _, err = writer.Write([]byte(value)) + if err != nil { + logs.Error(err) + this.WriteItem(&Item{ + Key: key, + Type: ItemBytes, + ValueBytes: value, + ExpireAt: time.Now().Unix() + lifeSeconds, + }) + return + } + + err = writer.Close() + if err != nil { + logs.Error(err) + this.WriteItem(&Item{ + Key: key, + Type: ItemBytes, + ValueBytes: value, + ExpireAt: time.Now().Unix() + lifeSeconds, + }) + return + } + value = buf.Bytes() + isCompressed = true + } + + this.WriteItem(&Item{ + Key: key, + Type: ItemBytes, + ValueBytes: value, + ExpireAt: time.Now().Unix() + lifeSeconds, + IsCompressed: isCompressed, + }) +} + +func (this *Grid) WriteInterface(key []byte, value interface{}, lifeSeconds int64) { + this.WriteItem(&Item{ + Key: key, + Type: ItemInterface, + ValueInterface: value, + ExpireAt: time.Now().Unix() + lifeSeconds, + IsCompressed: false, + }) +} + +func (this *Grid) Read(key []byte) *Item { + if this.countCells <= 0 { + return nil + } + hashKey := HashKey(key) + return this.cellForHashKey(hashKey).Read(hashKey) +} + +func (this *Grid) Stat() *Stat { + stat := &Stat{} + for _, cell := range this.cells { + cellStat := cell.Stat() + stat.CountItems += cellStat.CountItems + stat.TotalBytes += cellStat.TotalBytes + } + return stat +} + +func (this *Grid) Delete(key []byte) { + if this.countCells <= 0 { + return + } + hashKey := HashKey(key) + this.cellForHashKey(hashKey).Delete(hashKey) +} + +func (this *Grid) Reset() { + for _, cell := range this.cells { + cell.Reset() + } +} + +func (this *Grid) Destroy() { + if this.recycleLooper != nil { + this.recycleLooper.Stop() + this.recycleLooper = nil + } + this.cells = nil +} + +func (this *Grid) cellForHashKey(hashKey uint64) *Cell { + if hashKey < 0 { + return this.cells[-hashKey%this.countCells] + } else { + return this.cells[hashKey%this.countCells] + } +} + +func (this *Grid) recycleTimer() { + duration := 1 * time.Minute + if this.recycleInterval > 0 { + duration = time.Duration(this.recycleInterval) * time.Second + } + this.recycleLooper = timers.Loop(duration, func(looper *timers.Looper) { + if this.countCells == 0 { + return + } + this.recycleIndex++ + if this.recycleIndex > int(this.countCells-1) { + this.recycleIndex = 0 + } + this.cells[this.recycleIndex].Recycle() + }) +} diff --git a/internal/grids/grid_test.go b/internal/grids/grid_test.go new file mode 100644 index 0000000..1aa1bea --- /dev/null +++ b/internal/grids/grid_test.go @@ -0,0 +1,204 @@ +package grids + +import ( + "compress/gzip" + "fmt" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +func TestMemoryGrid_Write(t *testing.T) { + grid := NewGrid(5, NewRecycleIntervalOpt(2), NewLimitSizeOpt(10240)) + t.Log("123456:", grid.Read([]byte("123456"))) + + grid.WriteInt64([]byte("abc"), 1, 5) + t.Log(grid.Read([]byte("abc")).ValueInt64) + + grid.WriteString([]byte("abc"), "123", 5) + t.Log(string(grid.Read([]byte("abc")).Bytes())) + + grid.WriteBytes([]byte("abc"), []byte("123"), 5) + t.Log(grid.Read([]byte("abc")).Bytes()) + + grid.Delete([]byte("abc")) + t.Log(grid.Read([]byte("abc"))) + + for i := 0; i < 100; i++ { + grid.WriteInt64([]byte(fmt.Sprintf("%d", i)), 123, 1) + } + + t.Log("before recycle:") + for index, cell := range grid.cells { + t.Log("cell:", index, len(cell.mapping), "items") + } + + time.Sleep(3 * time.Second) + t.Log("after recycle:") + for index, cell := range grid.cells { + t.Log("cell:", index, len(cell.mapping), "items") + } + + grid.Destroy() +} + +func TestMemoryGrid_Write_LimitCount(t *testing.T) { + grid := NewGrid(2, NewLimitCountOpt(10)) + for i := 0; i < 100; i++ { + grid.WriteInt64([]byte(strconv.Itoa(i)), int64(i), 30) + } + t.Log(grid.Stat().CountItems, "items") +} + +func TestMemoryGrid_Compress(t *testing.T) { + grid := NewGrid(5, NewCompressOpt(1)) + grid.WriteString([]byte("hello"), strings.Repeat("abcd", 10240), 30) + t.Log(len(string(grid.Read([]byte("hello")).String()))) + t.Log(len(grid.Read([]byte("hello")).ValueBytes)) +} + +func BenchmarkMemoryGrid_Performance(b *testing.B) { + grid := NewGrid(1024) + for i := 0; i < b.N; i++ { + grid.WriteInt64([]byte("key:"+strconv.Itoa(i)), int64(i), 3600) + } +} + +func TestMemoryGrid_Performance(t *testing.T) { + runtime.GOMAXPROCS(1) + + grid := NewGrid(1024) + + now := time.Now() + + s := []byte(strings.Repeat("abcd", 10*1024)) + + for i := 0; i < 100000; i++ { + grid.WriteBytes([]byte(fmt.Sprintf("key:%d_%d", i, 1)), s, 3600) + item := grid.Read([]byte(fmt.Sprintf("key:%d_%d", i, 1))) + if item != nil { + _ = item.String() + } + } + + countItems := 0 + for _, cell := range grid.cells { + countItems += len(cell.mapping) + } + t.Log(countItems, "items") + + t.Log(time.Since(now).Seconds()*1000, "ms") +} + +func TestMemoryGrid_Performance_Concurrent(t *testing.T) { + //runtime.GOMAXPROCS(1) + + grid := NewGrid(1024) + + now := time.Now() + + s := []byte(strings.Repeat("abcd", 10*1024)) + + wg := sync.WaitGroup{} + wg.Add(runtime.NumCPU()) + + for c := 0; c < runtime.NumCPU(); c++ { + go func(c int) { + defer wg.Done() + for i := 0; i < 50000; i++ { + grid.WriteBytes([]byte(fmt.Sprintf("key:%d_%d", i, c)), s, 3600) + item := grid.Read([]byte(fmt.Sprintf("key:%d_%d", i, c))) + if item != nil { + _ = item.String() + } + } + }(c) + } + + wg.Wait() + countItems := 0 + for _, cell := range grid.cells { + countItems += len(cell.mapping) + } + t.Log(countItems, "items") + + t.Log(time.Since(now).Seconds()*1000, "ms") +} + +func TestMemoryGrid_CompressPerformance(t *testing.T) { + runtime.GOMAXPROCS(1) + + grid := NewGrid(1024, NewCompressOpt(gzip.BestCompression)) + + now := time.Now() + data := []byte(strings.Repeat("abcd", 1024)) + + for i := 0; i < 100000; i++ { + grid.WriteBytes([]byte(fmt.Sprintf("key:%d", i)), data, 3600) + item := grid.Read([]byte(fmt.Sprintf("key:%d", i+100))) + if item != nil { + _ = item.String() + } + } + + countItems := 0 + for _, cell := range grid.cells { + countItems += len(cell.mapping) + } + t.Log(countItems, "items") + + t.Log(time.Since(now).Seconds()*1000, "ms") +} + +func TestMemoryGrid_IncreaseInt64(t *testing.T) { + grid := NewGrid(1024) + grid.WriteInt64([]byte("abc"), 123, 10) + grid.IncreaseInt64([]byte("abc"), 123, 10) + grid.IncreaseInt64([]byte("abc"), 123, 10) + item := grid.Read([]byte("abc")) + if item == nil { + t.Fatal("item == nil") + } + + if item.ValueInt64 != 369 { + t.Fatal("not 369") + } +} + +func TestMemoryGrid_Destroy(t *testing.T) { + grid := NewGrid(1024) + grid.WriteInt64([]byte("abc"), 123, 10) + t.Log(grid.recycleLooper, grid.cells) + grid.Destroy() + t.Log(grid.recycleLooper, grid.cells) + + if grid.recycleLooper != nil { + t.Fatal("looper != nil") + } +} + +func TestMemoryGrid_Recycle(t *testing.T) { + cell := NewCell() + timestamp := time.Now().Unix() + for i := 0; i < 300_0000; i++ { + cell.Write(uint64(i), &Item{ + ExpireAt: timestamp - 30, + }) + } + before := time.Now() + cell.Recycle() + t.Log(time.Since(before).Seconds()*1000, "ms") + t.Log(len(cell.mapping)) + + runtime.GC() + printMem(t) +} + +func printMem(t *testing.T) { + mem := &runtime.MemStats{} + runtime.ReadMemStats(mem) + t.Log(mem.Alloc/1024/1024, "M") +} diff --git a/internal/grids/item.go b/internal/grids/item.go new file mode 100644 index 0000000..99bb6b0 --- /dev/null +++ b/internal/grids/item.go @@ -0,0 +1,88 @@ +package grids + +import ( + "bytes" + "compress/gzip" + "github.com/dchest/siphash" + "github.com/iwind/TeaGo/logs" + "sync/atomic" + "unsafe" +) + +type ItemType = int8 + +const ( + ItemInt64 = 1 + ItemBytes = 2 + ItemInterface = 3 +) + +func HashKey(key []byte) uint64 { + return siphash.Hash(0, 0, key) +} + +type Item struct { + Key []byte + ExpireAt int64 + Type ItemType + ValueInt64 int64 + ValueBytes []byte + ValueInterface interface{} + IsCompressed bool + + // linked list + Prev *Item + Next *Item + + size int64 +} + +func NewItem(key []byte, dataType ItemType) *Item { + return &Item{ + Key: key, + Type: dataType, + } +} + +func (this *Item) HashKey() uint64 { + return HashKey(this.Key) +} + +func (this *Item) IncreaseInt64(delta int64) { + atomic.AddInt64(&this.ValueInt64, delta) +} + +func (this *Item) Bytes() []byte { + if this.IsCompressed { + reader, err := gzip.NewReader(bytes.NewBuffer(this.ValueBytes)) + if err != nil { + logs.Error(err) + return this.ValueBytes + } + + buf := make([]byte, 256) + dataBuf := bytes.NewBuffer([]byte{}) + for { + n, err := reader.Read(buf) + if n > 0 { + dataBuf.Write(buf[:n]) + } + if err != nil { + break + } + } + return dataBuf.Bytes() + } + return this.ValueBytes +} + +func (this *Item) String() string { + return string(this.Bytes()) +} + +func (this *Item) Size() int64 { + if this.size == 0 { + this.size = int64(int(unsafe.Sizeof(*this)) + len(this.Key) + len(this.ValueBytes)) + } + return this.size +} diff --git a/internal/grids/item_test.go b/internal/grids/item_test.go new file mode 100644 index 0000000..5fc3c32 --- /dev/null +++ b/internal/grids/item_test.go @@ -0,0 +1,69 @@ +package grids + +import ( + "crypto/md5" + "github.com/dchest/siphash" + "strconv" + "testing" +) + +func TestItem_Size(t *testing.T) { + item := &Item{ + ValueInt64: 1024, + Key: []byte("123"), + ValueBytes: []byte("Hello, World"), + } + t.Log(item.Size()) +} + +func BenchmarkItem_Size(b *testing.B) { + item := &Item{ + ValueInt64: 1024, + Key: []byte("123"), + ValueBytes: []byte("Hello, World"), + } + for i := 0; i < b.N; i ++ { + _ = item.Size() + } +} + +func TestItem_HashKey(t *testing.T) { + t.Log(HashKey([]byte("2"))) +} + +func TestItem_siphash(t *testing.T) { + result := siphash.Hash(0, 0, []byte("123456")) + t.Log(result) +} + +func TestItem_unique(t *testing.T) { + m := map[uint64]bool{} + for i := 0; i < 1000*10000; i ++ { + s := "Hello,World,LONG KEY,LONG KEY,LONG KEY,LONG KEY" + strconv.Itoa(i) + result := siphash.Hash(0, 0, []byte(s)) + _, ok := m[result] + if ok { + t.Log("found same", i) + break + } else { + m[result] = true + } + } + + t.Log(siphash.Hash(0, 0, []byte("01"))) + t.Log(siphash.Hash(0, 0, []byte("10"))) +} + +func BenchmarkItem_HashKeyMd5(b *testing.B) { + for i := 0; i < b.N; i ++ { + h := md5.New() + h.Write([]byte("HELLO_KEY_" + strconv.Itoa(i))) + _ = h.Sum(nil) + } +} + +func BenchmarkItem_siphash(b *testing.B) { + for i := 0; i < b.N; i ++ { + _ = siphash.Hash(0, 0, []byte("HELLO_KEY_"+strconv.Itoa(i))) + } +} diff --git a/internal/grids/list.go b/internal/grids/list.go new file mode 100644 index 0000000..35c7d6a --- /dev/null +++ b/internal/grids/list.go @@ -0,0 +1,68 @@ +package grids + +type List struct { + head *Item + end *Item +} + +func NewList() *List { + return &List{} +} + +func (this *List) Add(item *Item) { + if item == nil { + return + } + if this.end != nil { + this.end.Next = item + item.Prev = this.end + item.Next = nil + } + this.end = item + if this.head == nil { + this.head = item + } +} + +func (this *List) Remove(item *Item) { + if item == nil { + return + } + if item.Prev != nil { + item.Prev.Next = item.Next + } + if item.Next != nil { + item.Next.Prev = item.Prev + } + if item == this.head { + this.head = item.Next + } + if item == this.end { + this.end = item.Prev + } + + item.Prev = nil + item.Next = nil +} + +func (this *List) Len() int { + l := 0 + for e := this.head; e != nil; e = e.Next { + l ++ + } + return l +} + +func (this *List) Range(f func(item *Item) (goNext bool)) { + for e := this.head; e != nil; e = e.Next { + goNext := f(e) + if !goNext { + break + } + } +} + +func (this *List) Reset() { + this.head = nil + this.end = nil +} diff --git a/internal/grids/list_test.go b/internal/grids/list_test.go new file mode 100644 index 0000000..d7ed42c --- /dev/null +++ b/internal/grids/list_test.go @@ -0,0 +1,64 @@ +package grids + +import "testing" + +func TestList(t *testing.T) { + l := &List{} + + var e1 *Item = nil + { + e := &Item{ + ValueInt64: 1, + } + l.Add(e) + e1 = e + } + + var e2 *Item = nil + { + e := &Item{ + ValueInt64: 2, + } + l.Add(e) + e2 = e + } + + var e3 *Item = nil + { + e := &Item{ + ValueInt64: 3, + } + l.Add(e) + e3 = e + } + + var e4 *Item = nil + { + e := &Item{ + ValueInt64: 4, + } + l.Add(e) + e4 = e + } + + l.Remove(e1) + //l.Remove(e2) + //l.Remove(e3) + l.Remove(e4) + + for e := l.head; e != nil; e = e.Next { + t.Log(e.ValueInt64) + } + + t.Log("e1, e2, e3, e4, head, end:", e1, e2, e3, e4) + if l.head != nil { + t.Log("head:", l.head.ValueInt64) + } else { + t.Log("head: nil") + } + if l.end != nil { + t.Log("end:", l.end.ValueInt64) + } else { + t.Log("end: nil") + } +} diff --git a/internal/grids/opt_compress.go b/internal/grids/opt_compress.go new file mode 100644 index 0000000..3cdc8bf --- /dev/null +++ b/internal/grids/opt_compress.go @@ -0,0 +1,11 @@ +package grids + +type CompressOpt struct { + Level int +} + +func NewCompressOpt(level int) *CompressOpt { + return &CompressOpt{ + Level: level, + } +} diff --git a/internal/grids/opt_limit_count.go b/internal/grids/opt_limit_count.go new file mode 100644 index 0000000..2858b8d --- /dev/null +++ b/internal/grids/opt_limit_count.go @@ -0,0 +1,11 @@ +package grids + +type LimitCountOpt struct { + Count int +} + +func NewLimitCountOpt(count int) *LimitCountOpt { + return &LimitCountOpt{ + Count: count, + } +} diff --git a/internal/grids/opt_limit_size.go b/internal/grids/opt_limit_size.go new file mode 100644 index 0000000..ab513ef --- /dev/null +++ b/internal/grids/opt_limit_size.go @@ -0,0 +1,11 @@ +package grids + +type LimitSizeOpt struct { + Size int64 +} + +func NewLimitSizeOpt(size int64) *LimitSizeOpt { + return &LimitSizeOpt{ + Size: size, + } +} diff --git a/internal/grids/opt_recycle_interval.go b/internal/grids/opt_recycle_interval.go new file mode 100644 index 0000000..73b792b --- /dev/null +++ b/internal/grids/opt_recycle_interval.go @@ -0,0 +1,11 @@ +package grids + +type RecycleIntervalOpt struct { + Interval int +} + +func NewRecycleIntervalOpt(interval int) *RecycleIntervalOpt { + return &RecycleIntervalOpt{ + Interval: interval, + } +} diff --git a/internal/grids/stat.go b/internal/grids/stat.go new file mode 100644 index 0000000..7202811 --- /dev/null +++ b/internal/grids/stat.go @@ -0,0 +1,6 @@ +package grids + +type Stat struct { + TotalBytes int64 + CountItems int +} diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index 5dc4c3c..ef007d8 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -96,7 +96,11 @@ func (this *HTTPRequest) Do() { } // WAF - // TODO 需要实现 + if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn && this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn { + if this.doWAFRequest() { + return + } + } // 访问控制 // TODO 需要实现 @@ -253,6 +257,12 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo this.web.Cache = web.Cache } + // waf + if web.FirewallRef != nil && (web.FirewallRef.IsPrior || isTop) { + this.web.FirewallRef = web.FirewallRef + this.web.FirewallPolicy = web.FirewallPolicy + } + // 重写规则 if len(web.RewriteRefs) > 0 { for index, ref := range web.RewriteRefs { diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index 84672ac..0e83f58 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -166,12 +166,26 @@ func (this *HTTPRequest) doReverseProxy() { } // WAF对出站进行检查 - // TODO + if this.web.FirewallRef != nil && this.web.FirewallRef.IsOn && this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn { + if this.doWAFResponse(resp) { + err = resp.Body.Close() + if err != nil { + logs.Error(err) + } + return + } + } // TODO 清除源站错误次数 // 特殊页面 - // TODO + if len(this.web.Pages) > 0 && this.doPage(resp.StatusCode) { + err = resp.Body.Close() + if err != nil { + logs.Error(err) + } + return + } // 设置Charset // TODO 这里应该可以设置文本类型的列表,以及是否强制覆盖所有文本类型的字符集 diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go new file mode 100644 index 0000000..af5a7d6 --- /dev/null +++ b/internal/nodes/http_request_waf.go @@ -0,0 +1,51 @@ +package nodes + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf" + "github.com/iwind/TeaGo/logs" + "net/http" +) + +// 调用WAF +func (this *HTTPRequest) doWAFRequest() (blocked bool) { + w := sharedWAFManager.FindWAF(this.web.FirewallPolicy.Id) + if w == nil { + return + } + + goNext, _, ruleSet, err := w.MatchRequest(this.RawReq, this.writer) + if err != nil { + logs.Error(err) + return + } + + if ruleSet != nil { + if ruleSet.Action != waf.ActionAllow { + // TODO 记录日志 + } + } + + return !goNext +} + +// call response waf +func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) { + w := sharedWAFManager.FindWAF(this.web.FirewallPolicy.Id) + if w == nil { + return + } + + goNext, _, ruleSet, err := w.MatchResponse(this.RawReq, resp, this.writer) + if err != nil { + logs.Error(err) + return + } + + if ruleSet != nil { + if ruleSet.Action != waf.ActionAllow { + // TODO 记录日志 + } + } + + return !goNext +} diff --git a/internal/nodes/node.go b/internal/nodes/node.go index 990f761..8bfb0bd 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -104,6 +104,7 @@ func (this *Node) syncConfig(isFirstTime bool) error { logs.Println("[NODE]reload config ...") nodeconfigs.ResetNodeConfig(nodeConfig) caches.SharedManager.UpdatePolicies(nodeConfig.AllCachePolicies()) + sharedWAFManager.UpdatePolicies(nodeConfig.AllHTTPFirewallPolicies()) sharedNodeConfig = nodeConfig if !isFirstTime { diff --git a/internal/nodes/waf_manager.go b/internal/nodes/waf_manager.go new file mode 100644 index 0000000..5e1782a --- /dev/null +++ b/internal/nodes/waf_manager.go @@ -0,0 +1,175 @@ +package nodes + +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" + "github.com/TeaOSLab/EdgeNode/internal/errors" + "github.com/TeaOSLab/EdgeNode/internal/waf" + "github.com/iwind/TeaGo/logs" + "strconv" + "sync" +) + +var sharedWAFManager = NewWAFManager() + +// WAF管理器 +type WAFManager struct { + mapping map[int64]*waf.WAF // policyId => WAF + locker sync.RWMutex +} + +// 获取新对象 +func NewWAFManager() *WAFManager { + return &WAFManager{ + mapping: map[int64]*waf.WAF{}, + } +} + +// 更新策略 +func (this *WAFManager) UpdatePolicies(policies []*firewallconfigs.HTTPFirewallPolicy) { + this.locker.Lock() + defer this.locker.Unlock() + + m := map[int64]*waf.WAF{} + for _, p := range policies { + w, err := this.convertWAF(p) + if err != nil { + logs.Println("[WAF]initialize policy '" + strconv.FormatInt(p.Id, 10) + "' failed: " + err.Error()) + continue + } + if w == nil { + continue + } + m[p.Id] = w + } + this.mapping = m +} + +// 查找WAF +func (this *WAFManager) FindWAF(policyId int64) *waf.WAF { + this.locker.RLock() + w, _ := this.mapping[policyId] + this.locker.RUnlock() + return w +} + +// 判断是否包含int64 +func (this *WAFManager) containsInt64(values []int64, value int64) bool { + for _, v := range values { + if v == value { + return true + } + } + return false +} + +// 将Policy转换为WAF +func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) (*waf.WAF, error) { + if policy == nil { + return nil, errors.New("policy should not be nil") + } + w := &waf.WAF{ + Id: strconv.FormatInt(policy.Id, 10), + IsOn: policy.IsOn, + Name: policy.Name, + } + + // inbound + if policy.Inbound != nil && policy.Inbound.IsOn { + for _, group := range policy.Inbound.Groups { + g := &waf.RuleGroup{ + Id: strconv.FormatInt(group.Id, 10), + IsOn: group.IsOn, + Name: group.Name, + Description: group.Description, + Code: group.Code, + IsInbound: true, + } + + // rule sets + for _, set := range group.Sets { + s := &waf.RuleSet{ + Id: strconv.FormatInt(set.Id, 10), + Code: set.Code, + IsOn: set.IsOn, + Name: set.Name, + Description: set.Description, + Connector: set.Connector, + Action: set.Action, + ActionOptions: set.ActionOptions, + } + + // rules + for _, rule := range set.Rules { + r := &waf.Rule{ + Description: rule.Description, + Param: rule.Param, + Operator: rule.Operator, + Value: rule.Value, + IsCaseInsensitive: rule.IsCaseInsensitive, + CheckpointOptions: rule.CheckpointOptions, + } + s.Rules = append(s.Rules, r) + } + + g.RuleSets = append(g.RuleSets, s) + } + + w.Inbound = append(w.Inbound, g) + } + } + + // outbound + if policy.Outbound != nil && policy.Outbound.IsOn { + for _, group := range policy.Outbound.Groups { + g := &waf.RuleGroup{ + Id: strconv.FormatInt(group.Id, 10), + IsOn: group.IsOn, + Name: group.Name, + Description: group.Description, + Code: group.Code, + IsInbound: true, + } + + // rule sets + for _, set := range group.Sets { + s := &waf.RuleSet{ + Id: strconv.FormatInt(set.Id, 10), + Code: set.Code, + IsOn: set.IsOn, + Name: set.Name, + Description: set.Description, + Connector: set.Connector, + Action: set.Action, + ActionOptions: set.ActionOptions, + } + + // rules + for _, rule := range set.Rules { + r := &waf.Rule{ + Description: rule.Description, + Param: rule.Param, + Operator: rule.Operator, + Value: rule.Value, + IsCaseInsensitive: rule.IsCaseInsensitive, + CheckpointOptions: rule.CheckpointOptions, + } + s.Rules = append(s.Rules, r) + } + + g.RuleSets = append(g.RuleSets, s) + } + + w.Outbound = append(w.Outbound, g) + } + } + + // action + // TODO + + err := w.Init() + if err != nil { + return nil, err + } + + return w, nil +} diff --git a/internal/nodes/waf_manager_test.go b/internal/nodes/waf_manager_test.go new file mode 100644 index 0000000..9aa9f0e --- /dev/null +++ b/internal/nodes/waf_manager_test.go @@ -0,0 +1,44 @@ +package nodes + +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" + "github.com/iwind/TeaGo/logs" + "testing" +) + +func TestWAFManager_convert(t *testing.T) { + p := &firewallconfigs.HTTPFirewallPolicy{ + Id: 1, + IsOn: true, + Inbound: &firewallconfigs.HTTPFirewallInboundConfig{ + IsOn: true, + Groups: []*firewallconfigs.HTTPFirewallRuleGroup{ + { + Id: 1, + Sets: []*firewallconfigs.HTTPFirewallRuleSet{ + { + Id: 1, + }, + { + Id: 2, + Rules: []*firewallconfigs.HTTPFirewallRule{ + { + Id: 1, + }, + { + Id: 2, + }, + }, + }, + }, + }, + }, + }, + } + w, err := sharedWAFManager.convertWAF(p) + if err != nil { + t.Fatal(err) + } + + logs.PrintAsJSON(w, t) +} diff --git a/internal/utils/get.go b/internal/utils/get.go new file mode 100644 index 0000000..ecac7ce --- /dev/null +++ b/internal/utils/get.go @@ -0,0 +1,77 @@ +package utils + +import ( + "github.com/iwind/TeaGo/types" + "reflect" + "regexp" +) + +var RegexpDigitNumber = regexp.MustCompile("^\\d+$") + +func Get(object interface{}, keys []string) interface{} { + if len(keys) == 0 { + return object + } + + if object == nil { + return nil + } + + firstKey := keys[0] + keys = keys[1:] + + value := reflect.ValueOf(object) + + if !value.IsValid() { + return nil + } + + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + if value.Kind() == reflect.Struct { + field := value.FieldByName(firstKey) + if !field.IsValid() { + return nil + } + + if len(keys) == 0 { + return field.Interface() + } + + return Get(field.Interface(), keys) + } + + if value.Kind() == reflect.Map { + mapKey := reflect.ValueOf(firstKey) + mapValue := value.MapIndex(mapKey) + if !mapValue.IsValid() { + return nil + } + + if len(keys) == 0 { + return mapValue.Interface() + } + + return Get(mapValue.Interface(), keys) + } + + if value.Kind() == reflect.Slice { + if RegexpDigitNumber.MatchString(firstKey) { + firstKeyInt := types.Int(firstKey) + if value.Len() > firstKeyInt { + result := value.Index(firstKeyInt).Interface() + if len(keys) == 0 { + return result + } + + return Get(result, keys) + } + } + + return nil + } + + return nil +} diff --git a/internal/utils/get_test.go b/internal/utils/get_test.go new file mode 100644 index 0000000..791cdd4 --- /dev/null +++ b/internal/utils/get_test.go @@ -0,0 +1,79 @@ +package utils + +import "testing" + +func TestGetStruct(t *testing.T) { + object := struct { + Name string + Age int + Books []string + Extend struct { + Location struct { + City string + } + } + }{ + Name: "lu", + Age: 20, + Books: []string{"Golang"}, + Extend: struct { + Location struct { + City string + } + }{ + Location: struct { + City string + }{ + City: "Beijing", + }, + }, + } + + if Get(object, []string{"Name"}) != "lu" { + t.Fatal("[ERROR]Name != lu") + } + + if Get(object, []string{"Age"}) != 20 { + t.Fatal("[ERROR]Age != 20") + } + + if Get(object, []string{"Books", "0"}) != "Golang" { + t.Fatal("[ERROR]books.0 != Golang") + } + + t.Log("Extend.Location:", Get(object, []string{"Extend", "Location"})) + + if Get(object, []string{"Extend", "Location", "City"}) != "Beijing" { + t.Fatal("[ERROR]Extend.Location.City != Beijing") + } +} + +func TestGetMap(t *testing.T) { + object := map[string]interface{}{ + "Name": "lu", + "Age": 20, + "Extend": map[string]interface{}{ + "Location": map[string]interface{}{ + "City": "Beijing", + }, + }, + } + + if Get(object, []string{"Name"}) != "lu" { + t.Fatal("[ERROR]Name != lu") + } + + if Get(object, []string{"Age"}) != 20 { + t.Fatal("[ERROR]Age != 20") + } + + if Get(object, []string{"Books", "0"}) != nil { + t.Fatal("[ERROR]books.0 != nil") + } + + t.Log(Get(object, []string{"Extend", "Location"})) + + if Get(object, []string{"Extend", "Location", "City"}) != "Beijing" { + t.Fatal("[ERROR]Extend.Location.City != Beijing") + } +} diff --git a/internal/utils/string.go b/internal/utils/string.go new file mode 100644 index 0000000..3f7f4ca --- /dev/null +++ b/internal/utils/string.go @@ -0,0 +1,37 @@ +package utils + +import ( + "strings" + "unsafe" +) + +// convert bytes to string +func UnsafeBytesToString(bs []byte) string { + return *(*string)(unsafe.Pointer(&bs)) +} + +// convert string to bytes +func UnsafeStringToBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer(&s)) +} + +// format address +func FormatAddress(addr string) string { + if strings.HasSuffix(addr, "unix:") { + return addr + } + addr = strings.Replace(addr, " ", "", -1) + addr = strings.Replace(addr, "\t", "", -1) + addr = strings.Replace(addr, ":", ":", -1) + addr = strings.TrimSpace(addr) + return addr +} + +// format address list +func FormatAddressList(addrList []string) []string { + result := []string{} + for _, addr := range addrList { + result = append(result, FormatAddress(addr)) + } + return result +} diff --git a/internal/utils/string_test.go b/internal/utils/string_test.go new file mode 100644 index 0000000..14e0ba6 --- /dev/null +++ b/internal/utils/string_test.go @@ -0,0 +1,56 @@ +package utils + +import ( + "strings" + "testing" +) + +func TestBytesToString(t *testing.T) { + t.Log(UnsafeBytesToString([]byte("Hello,World"))) +} + +func TestStringToBytes(t *testing.T) { + t.Log(string(UnsafeStringToBytes("Hello,World"))) +} + +func BenchmarkBytesToString(b *testing.B) { + data := []byte("Hello,World") + for i := 0; i < b.N; i++ { + _ = UnsafeBytesToString(data) + } +} + +func BenchmarkBytesToString2(b *testing.B) { + data := []byte("Hello,World") + for i := 0; i < b.N; i++ { + _ = string(data) + } +} + +func BenchmarkStringToBytes(b *testing.B) { + s := strings.Repeat("Hello,World", 1024) + for i := 0; i < b.N; i++ { + _ = UnsafeStringToBytes(s) + } +} + +func BenchmarkStringToBytes2(b *testing.B) { + s := strings.Repeat("Hello,World", 1024) + for i := 0; i < b.N; i++ { + _ = []byte(s) + } +} + +func TestFormatAddress(t *testing.T) { + t.Log(FormatAddress("127.0.0.1:1234")) + t.Log(FormatAddress("127.0.0.1 : 1234")) + t.Log(FormatAddress("127.0.0.1:1234")) +} + +func TestFormatAddressList(t *testing.T) { + t.Log(FormatAddressList([]string{ + "127.0.0.1:1234", + "127.0.0.1 : 1234", + "127.0.0.1:1234", + })) +} diff --git a/internal/waf/README.md b/internal/waf/README.md new file mode 100644 index 0000000..6e7dafa --- /dev/null +++ b/internal/waf/README.md @@ -0,0 +1,48 @@ +# WAF +A basic WAF for TeaWeb. + +## Config Constructions +~~~ +WAF + Inbound + Rule Groups + Rule Sets + Rules + Checkpoint Param Value + Outbound + Rule Groups + ... +~~~ + +## Apply WAF +~~~ +Request --> WAF --> Backends + / +Response <-- WAF <---- +~~~ + +## Coding +~~~go +waf := teawaf.NewWAF() + +// add rule groups here + +err := waf.Init() +if err != nil { + return +} +waf.Start() + +// match http request +// (req *http.Request, responseWriter http.ResponseWriter) +goNext, ruleSet, _ := waf.MatchRequest(req, responseWriter) +if ruleSet != nil { + log.Println("meet rule set:", ruleSet.Name, "action:", ruleSet.Action) +} +if !goNext { + return +} + +// stop the waf +// waf.Stop() +~~~ \ No newline at end of file diff --git a/internal/waf/action_allow.go b/internal/waf/action_allow.go new file mode 100644 index 0000000..35421ca --- /dev/null +++ b/internal/waf/action_allow.go @@ -0,0 +1,14 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" +) + +type AllowAction struct { +} + +func (this *AllowAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) { + // do nothing + return true +} diff --git a/internal/waf/action_block.go b/internal/waf/action_block.go new file mode 100644 index 0000000..e6ba70f --- /dev/null +++ b/internal/waf/action_block.go @@ -0,0 +1,94 @@ +package waf + +import ( + teaconst "github.com/TeaOSLab/EdgeNode/internal/const" + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/logs" + "io" + "io/ioutil" + "net/http" + "os" + "path/filepath" + "regexp" + "time" +) + +// url client configure +var urlPrefixReg = regexp.MustCompile("^(?i)(http|https)://") +var httpClient = utils.SharedHttpClient(5 * time.Second) + +type BlockAction struct { + StatusCode int `yaml:"statusCode" json:"statusCode"` + Body string `yaml:"body" json:"body"` // supports HTML + URL string `yaml:"url" json:"url"` +} + +func (this *BlockAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) { + if writer != nil { + // if status code eq 444, we close the connection + if this.StatusCode == 444 { + hijack, ok := writer.(http.Hijacker) + if ok { + conn, _, _ := hijack.Hijack() + if conn != nil { + _ = conn.Close() + return + } + } + } + + // output response + if this.StatusCode > 0 { + writer.WriteHeader(this.StatusCode) + } else { + writer.WriteHeader(http.StatusForbidden) + } + if len(this.URL) > 0 { + if urlPrefixReg.MatchString(this.URL) { + req, err := http.NewRequest(http.MethodGet, this.URL, nil) + if err != nil { + logs.Error(err) + return false + } + resp, err := httpClient.Do(req) + if err != nil { + logs.Error(err) + return false + } + defer func() { + _ = resp.Body.Close() + }() + + for k, v := range resp.Header { + for _, v1 := range v { + writer.Header().Add(k, v1) + } + } + + buf := make([]byte, 1024) + _, _ = io.CopyBuffer(writer, resp.Body, buf) + } else { + path := this.URL + if !filepath.IsAbs(this.URL) { + path = Tea.Root + string(os.PathSeparator) + path + } + + data, err := ioutil.ReadFile(path) + if err != nil { + logs.Error(err) + return false + } + _, _ = writer.Write(data) + } + return false + } + if len(this.Body) > 0 { + _, _ = writer.Write([]byte(this.Body)) + } else { + _, _ = writer.Write([]byte("The request is blocked by " + teaconst.ProductName)) + } + } + return false +} diff --git a/internal/waf/action_captcha.go b/internal/waf/action_captcha.go new file mode 100644 index 0000000..120869d --- /dev/null +++ b/internal/waf/action_captcha.go @@ -0,0 +1,39 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/types" + stringutil "github.com/iwind/TeaGo/utils/string" + "net/http" + "net/url" + "time" +) + +var captchaSalt = stringutil.Rand(32) + +const ( + CaptchaSeconds = 600 // 10 minutes +) + +type CaptchaAction struct { +} + +func (this *CaptchaAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) { + // TEAWEB_CAPTCHA: + cookie, err := request.Cookie("TEAWEB_WAF_CAPTCHA") + if err == nil && cookie != nil && len(cookie.Value) > 32 { + m := cookie.Value[:32] + timestamp := cookie.Value[32:] + if stringutil.Md5(captchaSalt+timestamp) == m && time.Now().Unix() < types.Int64(timestamp) { // verify md5 + return true + } + } + + refURL := request.URL.String() + if len(request.Referer()) > 0 { + refURL = request.Referer() + } + http.Redirect(writer, request.Raw(), "/WAFCAPTCHA?url="+url.QueryEscape(refURL), http.StatusTemporaryRedirect) + + return false +} diff --git a/internal/waf/action_definition.go b/internal/waf/action_definition.go new file mode 100644 index 0000000..e268742 --- /dev/null +++ b/internal/waf/action_definition.go @@ -0,0 +1,12 @@ +package waf + +import "reflect" + +// action definition +type ActionDefinition struct { + Name string + Code ActionString + Description string + Instance ActionInterface + Type reflect.Type +} diff --git a/internal/waf/action_go_group.go b/internal/waf/action_go_group.go new file mode 100644 index 0000000..446bd0a --- /dev/null +++ b/internal/waf/action_go_group.go @@ -0,0 +1,34 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/logs" + "net/http" +) + +type GoGroupAction struct { + GroupId string `yaml:"groupId" json:"groupId"` +} + +func (this *GoGroupAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) { + group := waf.FindRuleGroup(this.GroupId) + if group == nil || !group.IsOn { + return true + } + + b, set, err := group.MatchRequest(request) + if err != nil { + logs.Error(err) + return true + } + + if !b { + return true + } + + actionObject := FindActionInstance(set.Action, set.ActionOptions) + if actionObject == nil { + return true + } + return actionObject.Perform(waf, request, writer) +} diff --git a/internal/waf/action_go_set.go b/internal/waf/action_go_set.go new file mode 100644 index 0000000..ad8b049 --- /dev/null +++ b/internal/waf/action_go_set.go @@ -0,0 +1,37 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/logs" + "net/http" +) + +type GoSetAction struct { + GroupId string `yaml:"groupId" json:"groupId"` + SetId string `yaml:"setId" json:"setId"` +} + +func (this *GoSetAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) { + group := waf.FindRuleGroup(this.GroupId) + if group == nil || !group.IsOn { + return true + } + set := group.FindRuleSet(this.SetId) + if set == nil || !set.IsOn { + return true + } + + b, err := set.MatchRequest(request) + if err != nil { + logs.Error(err) + return true + } + if !b { + return true + } + actionObject := FindActionInstance(set.Action, set.ActionOptions) + if actionObject == nil { + return true + } + return actionObject.Perform(waf, request, writer) +} diff --git a/internal/waf/action_instance.go b/internal/waf/action_instance.go new file mode 100644 index 0000000..7aeb04d --- /dev/null +++ b/internal/waf/action_instance.go @@ -0,0 +1,5 @@ +package waf + +type Action struct { + +} diff --git a/internal/waf/action_log.go b/internal/waf/action_log.go new file mode 100644 index 0000000..8b8efcd --- /dev/null +++ b/internal/waf/action_log.go @@ -0,0 +1,13 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" +) + +type LogAction struct { +} + +func (this *LogAction) Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) { + return true +} diff --git a/internal/waf/action_type.go b/internal/waf/action_type.go new file mode 100644 index 0000000..ddf7b8c --- /dev/null +++ b/internal/waf/action_type.go @@ -0,0 +1,21 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" +) + +type ActionString = string + +const ( + ActionLog = "log" // allow and log + ActionBlock = "block" // block + ActionCaptcha = "captcha" // block and show captcha + ActionAllow = "allow" // allow + ActionGoGroup = "go_group" // go to next rule group + ActionGoSet = "go_set" // go to next rule set +) + +type ActionInterface interface { + Perform(waf *WAF, request *requests.Request, writer http.ResponseWriter) (allow bool) +} diff --git a/internal/waf/action_utils.go b/internal/waf/action_utils.go new file mode 100644 index 0000000..d2178e9 --- /dev/null +++ b/internal/waf/action_utils.go @@ -0,0 +1,82 @@ +package waf + +import ( + "github.com/iwind/TeaGo/maps" + "reflect" +) + +var AllActions = []*ActionDefinition{ + { + Name: "阻止", + Code: ActionBlock, + Instance: new(BlockAction), + }, + { + Name: "允许通过", + Code: ActionAllow, + Instance: new(AllowAction), + }, + { + Name: "允许并记录日志", + Code: ActionLog, + Instance: new(LogAction), + }, + { + Name: "Captcha验证码", + Code: ActionCaptcha, + Instance: new(CaptchaAction), + }, + { + Name: "跳到下一个规则分组", + Code: ActionGoGroup, + Instance: new(GoGroupAction), + Type: reflect.TypeOf(new(GoGroupAction)).Elem(), + }, + { + Name: "跳到下一个规则集", + Code: ActionGoSet, + Instance: new(GoSetAction), + Type: reflect.TypeOf(new(GoSetAction)).Elem(), + }, +} + +func FindActionInstance(action ActionString, options maps.Map) ActionInterface { + for _, def := range AllActions { + if def.Code == action { + if def.Type != nil { + // create new instance + ptrValue := reflect.New(def.Type) + instance := ptrValue.Interface().(ActionInterface) + + if len(options) > 0 { + count := def.Type.NumField() + for i := 0; i < count; i++ { + field := def.Type.Field(i) + tag, ok := field.Tag.Lookup("yaml") + if ok { + v, ok := options[tag] + if ok && reflect.TypeOf(v) == field.Type { + ptrValue.Elem().FieldByName(field.Name).Set(reflect.ValueOf(v)) + } + } + } + } + + return instance + } + + // return shared instance + return def.Instance + } + } + return nil +} + +func FindActionName(action ActionString) string { + for _, def := range AllActions { + if def.Code == action { + return def.Name + } + } + return "" +} diff --git a/internal/waf/action_utils_test.go b/internal/waf/action_utils_test.go new file mode 100644 index 0000000..e219f55 --- /dev/null +++ b/internal/waf/action_utils_test.go @@ -0,0 +1,29 @@ +package waf + +import ( + "github.com/iwind/TeaGo/assert" + "github.com/iwind/TeaGo/maps" + "runtime" + "testing" +) + +func TestFindActionInstance(t *testing.T) { + a := assert.NewAssertion(t) + + t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil)) + t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil)) + t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil)) + t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil)) + t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil)) + t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil)) + t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b",})) + + a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil)) +} + +func BenchmarkFindActionInstance(b *testing.B) { + runtime.GOMAXPROCS(1) + for i := 0; i < b.N; i++ { + FindActionInstance(ActionGoSet, nil) + } +} diff --git a/internal/waf/captcha_validator.go b/internal/waf/captcha_validator.go new file mode 100644 index 0000000..a4b889e --- /dev/null +++ b/internal/waf/captcha_validator.go @@ -0,0 +1,84 @@ +package waf + +import ( + "bytes" + "encoding/base64" + "fmt" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/dchest/captcha" + "github.com/iwind/TeaGo/logs" + stringutil "github.com/iwind/TeaGo/utils/string" + "net/http" + "time" +) + +var captchaValidator = &CaptchaValidator{} + +type CaptchaValidator struct { +} + +func (this *CaptchaValidator) Run(request *requests.Request, writer http.ResponseWriter) { + if request.Method == http.MethodPost && len(request.FormValue("TEAWEB_WAF_CAPTCHA_ID")) > 0 { + this.validate(request, writer) + } else { + this.show(request, writer) + } +} + +func (this *CaptchaValidator) show(request *requests.Request, writer http.ResponseWriter) { + // show captcha + captchaId := captcha.NewLen(6) + buf := bytes.NewBuffer([]byte{}) + err := captcha.WriteImage(buf, captchaId, 200, 100) + if err != nil { + logs.Error(err) + return + } + + _, _ = writer.Write([]byte(` + + + Verify Yourself + + +
+ + ` + ` +
+

Input verify code above:

+ +
+
+ +
+
+ +`)) +} + +func (this *CaptchaValidator) validate(request *requests.Request, writer http.ResponseWriter) (allow bool) { + captchaId := request.FormValue("TEAWEB_WAF_CAPTCHA_ID") + if len(captchaId) > 0 { + captchaCode := request.FormValue("TEAWEB_WAF_CAPTCHA_CODE") + if captcha.VerifyString(captchaId, captchaCode) { + // set cookie + timestamp := fmt.Sprintf("%d", time.Now().Unix()+CaptchaSeconds) + m := stringutil.Md5(captchaSalt + timestamp) + http.SetCookie(writer, &http.Cookie{ + Name: "TEAWEB_WAF_CAPTCHA", + Value: m + timestamp, + MaxAge: CaptchaSeconds, + Path: "/", // all of dirs + }) + + rawURL := request.URL.Query().Get("url") + http.Redirect(writer, request.Raw(), rawURL, http.StatusSeeOther) + + return false + } else { + http.Redirect(writer, request.Raw(), request.URL.String(), http.StatusSeeOther) + } + } + + return true +} diff --git a/internal/waf/checkpoints/cc.go b/internal/waf/checkpoints/cc.go new file mode 100644 index 0000000..89fe6e5 --- /dev/null +++ b/internal/waf/checkpoints/cc.go @@ -0,0 +1,246 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/grids" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/maps" + "github.com/iwind/TeaGo/types" + "net" + "regexp" + "strings" + "sync" +) + +// ${cc.arg} +// TODO implement more traffic rules +type CCCheckpoint struct { + Checkpoint + + grid *grids.Grid + once sync.Once +} + +func (this *CCCheckpoint) Init() { + +} + +func (this *CCCheckpoint) Start() { + if this.grid != nil { + this.grid.Destroy() + } + this.grid = grids.NewGrid(32, grids.NewLimitCountOpt(1000_0000)) +} + +func (this *CCCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = 0 + + if this.grid == nil { + this.once.Do(func() { + this.Start() + }) + if this.grid == nil { + return + } + } + + periodString, ok := options["period"] + if !ok { + return + } + period := types.Int64(periodString) + if period < 1 { + return + } + + v, _ := options["userType"] + userType := types.String(v) + + v, _ = options["userField"] + userField := types.String(v) + + v, _ = options["userIndex"] + userIndex := types.Int(v) + + if param == "requests" { // requests + var key = "" + switch userType { + case "ip": + key = this.ip(req) + case "cookie": + if len(userField) == 0 { + key = this.ip(req) + } else { + cookie, _ := req.Cookie(userField) + if cookie != nil { + v := cookie.Value + if userIndex > 0 && len(v) > userIndex { + v = v[userIndex:] + } + key = "USER@" + userType + "@" + userField + "@" + v + } + } + case "get": + if len(userField) == 0 { + key = this.ip(req) + } else { + v := req.URL.Query().Get(userField) + if userIndex > 0 && len(v) > userIndex { + v = v[userIndex:] + } + key = "USER@" + userType + "@" + userField + "@" + v + } + case "post": + if len(userField) == 0 { + key = this.ip(req) + } else { + v := req.PostFormValue(userField) + if userIndex > 0 && len(v) > userIndex { + v = v[userIndex:] + } + key = "USER@" + userType + "@" + userField + "@" + v + } + case "header": + if len(userField) == 0 { + key = this.ip(req) + } else { + v := req.Header.Get(userField) + if userIndex > 0 && len(v) > userIndex { + v = v[userIndex:] + } + key = "USER@" + userType + "@" + userField + "@" + v + } + default: + key = this.ip(req) + } + if len(key) == 0 { + key = this.ip(req) + } + value = this.grid.IncreaseInt64([]byte(key), 1, period) + } + + return +} + +func (this *CCCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} + +func (this *CCCheckpoint) ParamOptions() *ParamOptions { + option := NewParamOptions() + option.AddParam("请求数", "requests") + return option +} + +func (this *CCCheckpoint) Options() []OptionInterface { + options := []OptionInterface{} + + // period + { + option := NewFieldOption("统计周期", "period") + option.Value = "60" + option.RightLabel = "秒" + option.Size = 8 + option.MaxLength = 8 + option.Validate = func(value string) (ok bool, message string) { + if regexp.MustCompile("^\\d+$").MatchString(value) { + ok = true + return + } + message = "周期需要是一个整数数字" + return + } + options = append(options, option) + } + + // type + { + option := NewOptionsOption("用户识别读取来源", "userType") + option.Size = 10 + option.SetOptions([]maps.Map{ + { + "name": "IP", + "value": "ip", + }, + { + "name": "Cookie", + "value": "cookie", + }, + { + "name": "URL参数", + "value": "get", + }, + { + "name": "POST参数", + "value": "post", + }, + { + "name": "HTTP Header", + "value": "header", + }, + }) + options = append(options, option) + } + + // user field + { + option := NewFieldOption("用户识别字段", "userField") + option.Comment = "识别用户的唯一性字段,在用户读取来源不是IP时使用" + options = append(options, option) + } + + // user value index + { + option := NewFieldOption("字段读取位置", "userIndex") + option.Size = 5 + option.MaxLength = 5 + option.Comment = "读取用户识别字段的位置,从0开始,比如user12345的数字ID 12345的位置就是5,在用户读取来源不是IP时使用" + options = append(options, option) + } + + return options +} + +func (this *CCCheckpoint) Stop() { + if this.grid != nil { + this.grid.Destroy() + this.grid = nil + } +} + +func (this *CCCheckpoint) ip(req *requests.Request) string { + // X-Forwarded-For + forwardedFor := req.Header.Get("X-Forwarded-For") + if len(forwardedFor) > 0 { + commaIndex := strings.Index(forwardedFor, ",") + if commaIndex > 0 { + return forwardedFor[:commaIndex] + } + return forwardedFor + } + + // Real-IP + { + realIP, ok := req.Header["X-Real-IP"] + if ok && len(realIP) > 0 { + return realIP[0] + } + } + + // Real-Ip + { + realIP, ok := req.Header["X-Real-Ip"] + if ok && len(realIP) > 0 { + return realIP[0] + } + } + + // Remote-Addr + host, _, err := net.SplitHostPort(req.RemoteAddr) + if err == nil { + return host + } + return req.RemoteAddr +} diff --git a/internal/waf/checkpoints/cc_test.go b/internal/waf/checkpoints/cc_test.go new file mode 100644 index 0000000..6245798 --- /dev/null +++ b/internal/waf/checkpoints/cc_test.go @@ -0,0 +1,42 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" + "testing" +) + +func TestCCCheckpoint_RequestValue(t *testing.T) { + raw, err := http.NewRequest(http.MethodGet, "http://teaos.cn/", nil) + if err != nil { + t.Fatal(err) + } + + req := requests.NewRequest(raw) + req.RemoteAddr = "127.0.0.1" + + checkpoint := new(CCCheckpoint) + checkpoint.Init() + checkpoint.Start() + + options := map[string]string{ + "period": "5", + } + t.Log(checkpoint.RequestValue(req, "requests", options)) + t.Log(checkpoint.RequestValue(req, "requests", options)) + + req.RemoteAddr = "127.0.0.2" + t.Log(checkpoint.RequestValue(req, "requests", options)) + + req.RemoteAddr = "127.0.0.1" + t.Log(checkpoint.RequestValue(req, "requests", options)) + + req.RemoteAddr = "127.0.0.2" + t.Log(checkpoint.RequestValue(req, "requests", options)) + + req.RemoteAddr = "127.0.0.2" + t.Log(checkpoint.RequestValue(req, "requests", options)) + + req.RemoteAddr = "127.0.0.2" + t.Log(checkpoint.RequestValue(req, "requests", options)) +} diff --git a/internal/waf/checkpoints/checkpoint.go b/internal/waf/checkpoints/checkpoint.go new file mode 100644 index 0000000..3b1475b --- /dev/null +++ b/internal/waf/checkpoints/checkpoint.go @@ -0,0 +1,28 @@ +package checkpoints + +type Checkpoint struct { +} + +func (this *Checkpoint) Init() { + +} + +func (this *Checkpoint) IsRequest() bool { + return true +} + +func (this *Checkpoint) ParamOptions() *ParamOptions { + return nil +} + +func (this *Checkpoint) Options() []OptionInterface { + return nil +} + +func (this *Checkpoint) Start() { + +} + +func (this *Checkpoint) Stop() { + +} diff --git a/internal/waf/checkpoints/checkpoint_definition.go b/internal/waf/checkpoints/checkpoint_definition.go new file mode 100644 index 0000000..c65b9a4 --- /dev/null +++ b/internal/waf/checkpoints/checkpoint_definition.go @@ -0,0 +1,10 @@ +package checkpoints + +// check point definition +type CheckpointDefinition struct { + Name string + Description string + Prefix string + HasParams bool // has sub params + Instance CheckpointInterface +} diff --git a/internal/waf/checkpoints/checkpoint_interface.go b/internal/waf/checkpoints/checkpoint_interface.go new file mode 100644 index 0000000..db958f2 --- /dev/null +++ b/internal/waf/checkpoints/checkpoint_interface.go @@ -0,0 +1,32 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +// Check Point +type CheckpointInterface interface { + // initialize + Init() + + // is request? + IsRequest() bool + + // get request value + RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) + + // get response value + ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) + + // param option list + ParamOptions() *ParamOptions + + // options + Options() []OptionInterface + + // start + Start() + + // stop + Stop() +} diff --git a/internal/waf/checkpoints/option.go b/internal/waf/checkpoints/option.go new file mode 100644 index 0000000..04850a8 --- /dev/null +++ b/internal/waf/checkpoints/option.go @@ -0,0 +1,5 @@ +package checkpoints + +type OptionInterface interface { + Type() string +} diff --git a/internal/waf/checkpoints/option_field.go b/internal/waf/checkpoints/option_field.go new file mode 100644 index 0000000..c476b20 --- /dev/null +++ b/internal/waf/checkpoints/option_field.go @@ -0,0 +1,26 @@ +package checkpoints + +// attach option +type FieldOption struct { + Name string + Code string + Value string // default value + IsRequired bool + Size int + Comment string + Placeholder string + RightLabel string + MaxLength int + Validate func(value string) (ok bool, message string) +} + +func NewFieldOption(name string, code string) *FieldOption { + return &FieldOption{ + Name: name, + Code: code, + } +} + +func (this *FieldOption) Type() string { + return "field" +} diff --git a/internal/waf/checkpoints/option_options.go b/internal/waf/checkpoints/option_options.go new file mode 100644 index 0000000..dcfa6dc --- /dev/null +++ b/internal/waf/checkpoints/option_options.go @@ -0,0 +1,30 @@ +package checkpoints + +import "github.com/iwind/TeaGo/maps" + +type OptionsOption struct { + Name string + Code string + Value string // default value + IsRequired bool + Size int + Comment string + RightLabel string + Validate func(value string) (ok bool, message string) + Options []maps.Map +} + +func NewOptionsOption(name string, code string) *OptionsOption { + return &OptionsOption{ + Name: name, + Code: code, + } +} + +func (this *OptionsOption) Type() string { + return "options" +} + +func (this *OptionsOption) SetOptions(options []maps.Map) { + this.Options = options +} diff --git a/internal/waf/checkpoints/param_option.go b/internal/waf/checkpoints/param_option.go new file mode 100644 index 0000000..00124e8 --- /dev/null +++ b/internal/waf/checkpoints/param_option.go @@ -0,0 +1,21 @@ +package checkpoints + +type KeyValue struct { + Name string `json:"name"` + Value string `json:"value"` +} + +type ParamOptions struct { + Options []*KeyValue `json:"options"` +} + +func NewParamOptions() *ParamOptions { + return &ParamOptions{} +} + +func (this *ParamOptions) AddParam(name string, value string) { + this.Options = append(this.Options, &KeyValue{ + Name: name, + Value: value, + }) +} diff --git a/internal/waf/checkpoints/request_all.go b/internal/waf/checkpoints/request_all.go new file mode 100644 index 0000000..95d9c68 --- /dev/null +++ b/internal/waf/checkpoints/request_all.go @@ -0,0 +1,46 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +// ${requestAll} +type RequestAllCheckpoint struct { + Checkpoint +} + +func (this *RequestAllCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + valueBytes := []byte{} + if len(req.RequestURI) > 0 { + valueBytes = append(valueBytes, req.RequestURI...) + } else if req.URL != nil { + valueBytes = append(valueBytes, req.URL.RequestURI()...) + } + + if req.Body != nil { + valueBytes = append(valueBytes, ' ') + + if len(req.BodyData) == 0 { + data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes + if err != nil { + return "", err, nil + } + + req.BodyData = data + req.RestoreBody(data) + } + valueBytes = append(valueBytes, req.BodyData...) + } + + value = valueBytes + + return +} + +func (this *RequestAllCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = "" + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_all_test.go b/internal/waf/checkpoints/request_all_test.go new file mode 100644 index 0000000..d8a12a0 --- /dev/null +++ b/internal/waf/checkpoints/request_all_test.go @@ -0,0 +1,70 @@ +package checkpoints + +import ( + "bytes" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/types" + "io/ioutil" + "net/http" + "runtime" + "strings" + "testing" +) + +func TestRequestAllCheckpoint_RequestValue(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "http://teaos.cn/hello/world", bytes.NewBuffer([]byte("123456"))) + if err != nil { + t.Fatal(err) + } + + checkpoint := new(RequestAllCheckpoint) + v, sysErr, userErr := checkpoint.RequestValue(requests.NewRequest(req), "", nil) + if sysErr != nil { + t.Fatal(sysErr) + } + if userErr != nil { + t.Fatal(userErr) + } + t.Log(v) + t.Log(types.String(v)) + + body, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatal(err) + } + t.Log(string(body)) +} + +func TestRequestAllCheckpoint_RequestValue_Max(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(strings.Repeat("123456", 10240000)))) + if err != nil { + t.Fatal(err) + } + + checkpoint := new(RequestBodyCheckpoint) + value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil) + if err != nil { + t.Fatal(err) + } + t.Log("value bytes:", len(types.String(value))) + + body, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatal(err) + } + t.Log("raw bytes:", len(body)) +} + +func BenchmarkRequestAllCheckpoint_RequestValue(b *testing.B) { + runtime.GOMAXPROCS(1) + + req, err := http.NewRequest(http.MethodPost, "http://teaos.cn/hello/world", bytes.NewBuffer(bytes.Repeat([]byte("HELLO"), 1024))) + if err != nil { + b.Fatal(err) + } + + checkpoint := new(RequestAllCheckpoint) + for i := 0; i < b.N; i++ { + _, _, _ = checkpoint.RequestValue(requests.NewRequest(req), "", nil) + } +} diff --git a/internal/waf/checkpoints/request_arg.go b/internal/waf/checkpoints/request_arg.go new file mode 100644 index 0000000..826c8d3 --- /dev/null +++ b/internal/waf/checkpoints/request_arg.go @@ -0,0 +1,20 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestArgCheckpoint struct { + Checkpoint +} + +func (this *RequestArgCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + return req.URL.Query().Get(param), nil, nil +} + +func (this *RequestArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_arg_test.go b/internal/waf/checkpoints/request_arg_test.go new file mode 100644 index 0000000..a7cdaf3 --- /dev/null +++ b/internal/waf/checkpoints/request_arg_test.go @@ -0,0 +1,21 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" + "testing" +) + +func TestArgParam_RequestValue(t *testing.T) { + rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/?name=lu", nil) + if err != nil { + t.Fatal(err) + } + + req := requests.NewRequest(rawReq) + + checkpoint := new(RequestArgCheckpoint) + t.Log(checkpoint.RequestValue(req, "name", nil)) + t.Log(checkpoint.ResponseValue(req, nil, "name", nil)) + t.Log(checkpoint.RequestValue(req, "name2", nil)) +} diff --git a/internal/waf/checkpoints/request_args.go b/internal/waf/checkpoints/request_args.go new file mode 100644 index 0000000..6b5f75d --- /dev/null +++ b/internal/waf/checkpoints/request_args.go @@ -0,0 +1,21 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestArgsCheckpoint struct { + Checkpoint +} + +func (this *RequestArgsCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = req.URL.RawQuery + return +} + +func (this *RequestArgsCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_body.go b/internal/waf/checkpoints/request_body.go new file mode 100644 index 0000000..5bd3bc3 --- /dev/null +++ b/internal/waf/checkpoints/request_body.go @@ -0,0 +1,36 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +// ${requestBody} +type RequestBodyCheckpoint struct { + Checkpoint +} + +func (this *RequestBodyCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if req.Body == nil { + value = "" + return + } + + if len(req.BodyData) == 0 { + data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes + if err != nil { + return "", err, nil + } + + req.BodyData = data + req.RestoreBody(data) + } + + return req.BodyData, nil, nil +} + +func (this *RequestBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_body_test.go b/internal/waf/checkpoints/request_body_test.go new file mode 100644 index 0000000..b1c982d --- /dev/null +++ b/internal/waf/checkpoints/request_body_test.go @@ -0,0 +1,47 @@ +package checkpoints + +import ( + "bytes" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/types" + "io/ioutil" + "net/http" + "strings" + "testing" +) + +func TestRequestBodyCheckpoint_RequestValue(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("123456"))) + if err != nil { + t.Fatal(err) + } + + checkpoint := new(RequestBodyCheckpoint) + t.Log(checkpoint.RequestValue(requests.NewRequest(req), "", nil)) + + body, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatal(err) + } + t.Log(string(body)) +} + +func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(strings.Repeat("123456", 10240000)))) + if err != nil { + t.Fatal(err) + } + + checkpoint := new(RequestBodyCheckpoint) + value, err, _ := checkpoint.RequestValue(requests.NewRequest(req), "", nil) + if err != nil { + t.Fatal(err) + } + t.Log("value bytes:", len(types.String(value))) + + body, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatal(err) + } + t.Log("raw bytes:", len(body)) +} diff --git a/internal/waf/checkpoints/request_content_type.go b/internal/waf/checkpoints/request_content_type.go new file mode 100644 index 0000000..73c1d65 --- /dev/null +++ b/internal/waf/checkpoints/request_content_type.go @@ -0,0 +1,21 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestContentTypeCheckpoint struct { + Checkpoint +} + +func (this *RequestContentTypeCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = req.Header.Get("Content-Type") + return +} + +func (this *RequestContentTypeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_cookie.go b/internal/waf/checkpoints/request_cookie.go new file mode 100644 index 0000000..a6c6887 --- /dev/null +++ b/internal/waf/checkpoints/request_cookie.go @@ -0,0 +1,27 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestCookieCheckpoint struct { + Checkpoint +} + +func (this *RequestCookieCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + cookie, err := req.Cookie(param) + if err != nil { + value = "" + return + } + + value = cookie.Value + return +} + +func (this *RequestCookieCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_cookies.go b/internal/waf/checkpoints/request_cookies.go new file mode 100644 index 0000000..1a12e2e --- /dev/null +++ b/internal/waf/checkpoints/request_cookies.go @@ -0,0 +1,27 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/url" + "strings" +) + +type RequestCookiesCheckpoint struct { + Checkpoint +} + +func (this *RequestCookiesCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + var cookies = []string{} + for _, cookie := range req.Cookies() { + cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value)) + } + value = strings.Join(cookies, "&") + return +} + +func (this *RequestCookiesCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_form_arg.go b/internal/waf/checkpoints/request_form_arg.go new file mode 100644 index 0000000..3d6535e --- /dev/null +++ b/internal/waf/checkpoints/request_form_arg.go @@ -0,0 +1,39 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/url" +) + +// ${requestForm.arg} +type RequestFormArgCheckpoint struct { + Checkpoint +} + +func (this *RequestFormArgCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if req.Body == nil { + value = "" + return + } + + if len(req.BodyData) == 0 { + data, err := req.ReadBody(32 * 1024 * 1024) // read 32m bytes + if err != nil { + return "", err, nil + } + + req.BodyData = data + req.RestoreBody(data) + } + + // TODO improve performance + values, _ := url.ParseQuery(string(req.BodyData)) + return values.Get(param), nil, nil +} + +func (this *RequestFormArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_form_arg_test.go b/internal/waf/checkpoints/request_form_arg_test.go new file mode 100644 index 0000000..01c0396 --- /dev/null +++ b/internal/waf/checkpoints/request_form_arg_test.go @@ -0,0 +1,32 @@ +package checkpoints + +import ( + "bytes" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "io/ioutil" + "net/http" + "net/url" + "testing" +) + +func TestRequestFormArgCheckpoint_RequestValue(t *testing.T) { + rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte("name=lu&age=20&encoded="+url.QueryEscape("ENCODED STRING")))) + if err != nil { + t.Fatal(err) + } + + req := requests.NewRequest(rawReq) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + checkpoint := new(RequestFormArgCheckpoint) + t.Log(checkpoint.RequestValue(req, "name", nil)) + t.Log(checkpoint.RequestValue(req, "age", nil)) + t.Log(checkpoint.RequestValue(req, "Hello", nil)) + t.Log(checkpoint.RequestValue(req, "encoded", nil)) + + body, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatal(err) + } + t.Log(string(body)) +} diff --git a/internal/waf/checkpoints/request_header.go b/internal/waf/checkpoints/request_header.go new file mode 100644 index 0000000..553b515 --- /dev/null +++ b/internal/waf/checkpoints/request_header.go @@ -0,0 +1,27 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "strings" +) + +type RequestHeaderCheckpoint struct { + Checkpoint +} + +func (this *RequestHeaderCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + v, found := req.Header[param] + if !found { + value = "" + return + } + value = strings.Join(v, ";") + return +} + +func (this *RequestHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_headers.go b/internal/waf/checkpoints/request_headers.go new file mode 100644 index 0000000..c6acdbf --- /dev/null +++ b/internal/waf/checkpoints/request_headers.go @@ -0,0 +1,30 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "sort" + "strings" +) + +type RequestHeadersCheckpoint struct { + Checkpoint +} + +func (this *RequestHeadersCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + var headers = []string{} + for k, v := range req.Header { + for _, subV := range v { + headers = append(headers, k+": "+subV) + } + } + sort.Strings(headers) + value = strings.Join(headers, "\n") + return +} + +func (this *RequestHeadersCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_host.go b/internal/waf/checkpoints/request_host.go new file mode 100644 index 0000000..68f8e26 --- /dev/null +++ b/internal/waf/checkpoints/request_host.go @@ -0,0 +1,21 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestHostCheckpoint struct { + Checkpoint +} + +func (this *RequestHostCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = req.Host + return +} + +func (this *RequestHostCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_host_test.go b/internal/waf/checkpoints/request_host_test.go new file mode 100644 index 0000000..fc1b449 --- /dev/null +++ b/internal/waf/checkpoints/request_host_test.go @@ -0,0 +1,20 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" + "testing" +) + +func TestRequestHostCheckpoint_RequestValue(t *testing.T) { + rawReq, err := http.NewRequest(http.MethodGet, "https://teaos.cn/?name=lu", nil) + if err != nil { + t.Fatal(err) + } + + req := requests.NewRequest(rawReq) + req.Header.Set("Host", "cloud.teaos.cn") + + checkpoint := new(RequestHostCheckpoint) + t.Log(checkpoint.RequestValue(req, "", nil)) +} diff --git a/internal/waf/checkpoints/request_json_arg.go b/internal/waf/checkpoints/request_json_arg.go new file mode 100644 index 0000000..5841e55 --- /dev/null +++ b/internal/waf/checkpoints/request_json_arg.go @@ -0,0 +1,44 @@ +package checkpoints + +import ( + "encoding/json" + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "strings" +) + +// ${requestJSON.arg} +type RequestJSONArgCheckpoint struct { + Checkpoint +} + +func (this *RequestJSONArgCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if len(req.BodyData) == 0 { + data, err := req.ReadBody(int64(32 * 1024 * 1024)) // read 32m bytes + if err != nil { + return "", err, nil + } + req.BodyData = data + defer req.RestoreBody(data) + } + + // TODO improve performance + var m interface{} = nil + err := json.Unmarshal(req.BodyData, &m) + if err != nil || m == nil { + return "", nil, err + } + + value = utils.Get(m, strings.Split(param, ".")) + if value != nil { + return value, nil, err + } + return "", nil, nil +} + +func (this *RequestJSONArgCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_json_arg_test.go b/internal/waf/checkpoints/request_json_arg_test.go new file mode 100644 index 0000000..00708be --- /dev/null +++ b/internal/waf/checkpoints/request_json_arg_test.go @@ -0,0 +1,99 @@ +package checkpoints + +import ( + "bytes" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "io/ioutil" + "net/http" + "testing" +) + +func TestRequestJSONArgCheckpoint_RequestValue_Map(t *testing.T) { + rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(` +{ + "name": "lu", + "age": 20, + "books": [ "PHP", "Golang", "Python" ] +} +`))) + if err != nil { + t.Fatal(err) + } + + req := requests.NewRequest(rawReq) + //req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + checkpoint := new(RequestJSONArgCheckpoint) + t.Log(checkpoint.RequestValue(req, "name", nil)) + t.Log(checkpoint.RequestValue(req, "age", nil)) + t.Log(checkpoint.RequestValue(req, "Hello", nil)) + t.Log(checkpoint.RequestValue(req, "", nil)) + t.Log(checkpoint.RequestValue(req, "books", nil)) + t.Log(checkpoint.RequestValue(req, "books.1", nil)) + + body, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatal(err) + } + t.Log(string(body)) +} + +func TestRequestJSONArgCheckpoint_RequestValue_Array(t *testing.T) { + rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(` +[{ + "name": "lu", + "age": 20, + "books": [ "PHP", "Golang", "Python" ] +}] +`))) + if err != nil { + t.Fatal(err) + } + + req := requests.NewRequest(rawReq) + //req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + checkpoint := new(RequestJSONArgCheckpoint) + t.Log(checkpoint.RequestValue(req, "0.name", nil)) + t.Log(checkpoint.RequestValue(req, "0.age", nil)) + t.Log(checkpoint.RequestValue(req, "0.Hello", nil)) + t.Log(checkpoint.RequestValue(req, "", nil)) + t.Log(checkpoint.RequestValue(req, "0.books", nil)) + t.Log(checkpoint.RequestValue(req, "0.books.1", nil)) + + body, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatal(err) + } + t.Log(string(body)) +} + +func TestRequestJSONArgCheckpoint_RequestValue_Error(t *testing.T) { + rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn", bytes.NewBuffer([]byte(` +[{ + "name": "lu", + "age": 20, + "books": [ "PHP", "Golang", "Python" ] +}] +`))) + if err != nil { + t.Fatal(err) + } + + req := requests.NewRequest(rawReq) + //req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + checkpoint := new(RequestJSONArgCheckpoint) + t.Log(checkpoint.RequestValue(req, "0.name", nil)) + t.Log(checkpoint.RequestValue(req, "0.age", nil)) + t.Log(checkpoint.RequestValue(req, "0.Hello", nil)) + t.Log(checkpoint.RequestValue(req, "", nil)) + t.Log(checkpoint.RequestValue(req, "0.books", nil)) + t.Log(checkpoint.RequestValue(req, "0.books.1", nil)) + + body, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatal(err) + } + t.Log(string(body)) +} diff --git a/internal/waf/checkpoints/request_length.go b/internal/waf/checkpoints/request_length.go new file mode 100644 index 0000000..24f8718 --- /dev/null +++ b/internal/waf/checkpoints/request_length.go @@ -0,0 +1,21 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestLengthCheckpoint struct { + Checkpoint +} + +func (this *RequestLengthCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = req.ContentLength + return +} + +func (this *RequestLengthCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_method.go b/internal/waf/checkpoints/request_method.go new file mode 100644 index 0000000..a2580be --- /dev/null +++ b/internal/waf/checkpoints/request_method.go @@ -0,0 +1,21 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestMethodCheckpoint struct { + Checkpoint +} + +func (this *RequestMethodCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = req.Method + return +} + +func (this *RequestMethodCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_path.go b/internal/waf/checkpoints/request_path.go new file mode 100644 index 0000000..776a6dc --- /dev/null +++ b/internal/waf/checkpoints/request_path.go @@ -0,0 +1,20 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestPathCheckpoint struct { + Checkpoint +} + +func (this *RequestPathCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + return req.URL.Path, nil, nil +} + +func (this *RequestPathCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_path_test.go b/internal/waf/checkpoints/request_path_test.go new file mode 100644 index 0000000..e100602 --- /dev/null +++ b/internal/waf/checkpoints/request_path_test.go @@ -0,0 +1,18 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" + "testing" +) + +func TestRequestPathCheckpoint_RequestValue(t *testing.T) { + rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/index?name=lu", nil) + if err != nil { + t.Fatal(err) + } + + req := requests.NewRequest(rawReq) + checkpoint := new(RequestPathCheckpoint) + t.Log(checkpoint.RequestValue(req, "", nil)) +} diff --git a/internal/waf/checkpoints/request_proto.go b/internal/waf/checkpoints/request_proto.go new file mode 100644 index 0000000..0c144f4 --- /dev/null +++ b/internal/waf/checkpoints/request_proto.go @@ -0,0 +1,21 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestProtoCheckpoint struct { + Checkpoint +} + +func (this *RequestProtoCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = req.Proto + return +} + +func (this *RequestProtoCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_raw_remote_addr.go b/internal/waf/checkpoints/request_raw_remote_addr.go new file mode 100644 index 0000000..9cab823 --- /dev/null +++ b/internal/waf/checkpoints/request_raw_remote_addr.go @@ -0,0 +1,27 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net" +) + +type RequestRawRemoteAddrCheckpoint struct { + Checkpoint +} + +func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + host, _, err := net.SplitHostPort(req.RemoteAddr) + if err == nil { + value = host + } else { + value = req.RemoteAddr + } + return +} + +func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_referer.go b/internal/waf/checkpoints/request_referer.go new file mode 100644 index 0000000..593784f --- /dev/null +++ b/internal/waf/checkpoints/request_referer.go @@ -0,0 +1,21 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestRefererCheckpoint struct { + Checkpoint +} + +func (this *RequestRefererCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = req.Referer() + return +} + +func (this *RequestRefererCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_remote_addr.go b/internal/waf/checkpoints/request_remote_addr.go new file mode 100644 index 0000000..c5a271c --- /dev/null +++ b/internal/waf/checkpoints/request_remote_addr.go @@ -0,0 +1,59 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net" + "strings" +) + +type RequestRemoteAddrCheckpoint struct { + Checkpoint +} + +func (this *RequestRemoteAddrCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + // X-Forwarded-For + forwardedFor := req.Header.Get("X-Forwarded-For") + if len(forwardedFor) > 0 { + commaIndex := strings.Index(forwardedFor, ",") + if commaIndex > 0 { + value = forwardedFor[:commaIndex] + return + } + value = forwardedFor + return + } + + // Real-IP + { + realIP, ok := req.Header["X-Real-IP"] + if ok && len(realIP) > 0 { + value = realIP[0] + return + } + } + + // Real-Ip + { + realIP, ok := req.Header["X-Real-Ip"] + if ok && len(realIP) > 0 { + value = realIP[0] + return + } + } + + // Remote-Addr + host, _, err := net.SplitHostPort(req.RemoteAddr) + if err == nil { + value = host + } else { + value = req.RemoteAddr + } + return +} + +func (this *RequestRemoteAddrCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_remote_port.go b/internal/waf/checkpoints/request_remote_port.go new file mode 100644 index 0000000..b5b9b1e --- /dev/null +++ b/internal/waf/checkpoints/request_remote_port.go @@ -0,0 +1,28 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/types" + "net" +) + +type RequestRemotePortCheckpoint struct { + Checkpoint +} + +func (this *RequestRemotePortCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + _, port, err := net.SplitHostPort(req.RemoteAddr) + if err == nil { + value = types.Int(port) + } else { + value = 0 + } + return +} + +func (this *RequestRemotePortCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_remote_user.go b/internal/waf/checkpoints/request_remote_user.go new file mode 100644 index 0000000..0e95c8f --- /dev/null +++ b/internal/waf/checkpoints/request_remote_user.go @@ -0,0 +1,26 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestRemoteUserCheckpoint struct { + Checkpoint +} + +func (this *RequestRemoteUserCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + username, _, ok := req.BasicAuth() + if !ok { + value = "" + return + } + value = username + return +} + +func (this *RequestRemoteUserCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_scheme.go b/internal/waf/checkpoints/request_scheme.go new file mode 100644 index 0000000..63f2855 --- /dev/null +++ b/internal/waf/checkpoints/request_scheme.go @@ -0,0 +1,21 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestSchemeCheckpoint struct { + Checkpoint +} + +func (this *RequestSchemeCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = req.URL.Scheme + return +} + +func (this *RequestSchemeCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_scheme_test.go b/internal/waf/checkpoints/request_scheme_test.go new file mode 100644 index 0000000..461cf23 --- /dev/null +++ b/internal/waf/checkpoints/request_scheme_test.go @@ -0,0 +1,18 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" + "testing" +) + +func TestRequestSchemeCheckpoint_RequestValue(t *testing.T) { + rawReq, err := http.NewRequest(http.MethodGet, "https://teaos.cn/?name=lu", nil) + if err != nil { + t.Fatal(err) + } + + req := requests.NewRequest(rawReq) + checkpoint := new(RequestSchemeCheckpoint) + t.Log(checkpoint.RequestValue(req, "", nil)) +} diff --git a/internal/waf/checkpoints/request_upload.go b/internal/waf/checkpoints/request_upload.go new file mode 100644 index 0000000..38ba070 --- /dev/null +++ b/internal/waf/checkpoints/request_upload.go @@ -0,0 +1,130 @@ +package checkpoints + +import ( + "bytes" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/lists" + "io/ioutil" + "net/http" + "path/filepath" + "strings" +) + +// ${requestUpload.arg} +type RequestUploadCheckpoint struct { + Checkpoint +} + +func (this *RequestUploadCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = "" + if param == "minSize" || param == "maxSize" { + value = 0 + } + + if req.Method != http.MethodPost { + return + } + + if req.Body == nil { + return + } + + if req.MultipartForm == nil { + if len(req.BodyData) == 0 { + data, err := req.ReadBody(32 * 1024 * 1024) + if err != nil { + sysErr = err + return + } + + req.BodyData = data + defer req.RestoreBody(data) + } + oldBody := req.Body + req.Body = ioutil.NopCloser(bytes.NewBuffer(req.BodyData)) + + err := req.ParseMultipartForm(32 * 1024 * 1024) + + // 还原 + req.Body = oldBody + + if err != nil { + userErr = err + return + } + + if req.MultipartForm == nil { + return + } + } + + if param == "field" { // field + fields := []string{} + for field := range req.MultipartForm.File { + fields = append(fields, field) + } + value = strings.Join(fields, ",") + } else if param == "minSize" { // minSize + minSize := int64(0) + for _, files := range req.MultipartForm.File { + for _, file := range files { + if minSize == 0 || minSize > file.Size { + minSize = file.Size + } + } + } + value = minSize + } else if param == "maxSize" { // maxSize + maxSize := int64(0) + for _, files := range req.MultipartForm.File { + for _, file := range files { + if maxSize < file.Size { + maxSize = file.Size + } + } + } + value = maxSize + } else if param == "name" { // name + names := []string{} + for _, files := range req.MultipartForm.File { + for _, file := range files { + if !lists.ContainsString(names, file.Filename) { + names = append(names, file.Filename) + } + } + } + value = strings.Join(names, ",") + } else if param == "ext" { // ext + extensions := []string{} + for _, files := range req.MultipartForm.File { + for _, file := range files { + if len(file.Filename) > 0 { + exit := strings.ToLower(filepath.Ext(file.Filename)) + if !lists.ContainsString(extensions, exit) { + extensions = append(extensions, exit) + } + } + } + } + value = strings.Join(extensions, ",") + } + + return +} + +func (this *RequestUploadCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} + +func (this *RequestUploadCheckpoint) ParamOptions() *ParamOptions { + option := NewParamOptions() + option.AddParam("最小文件尺寸", "minSize") + option.AddParam("最大文件尺寸", "maxSize") + option.AddParam("扩展名(如.txt)", "ext") + option.AddParam("原始文件名", "name") + option.AddParam("表单字段名", "field") + return option +} diff --git a/internal/waf/checkpoints/request_upload_test.go b/internal/waf/checkpoints/request_upload_test.go new file mode 100644 index 0000000..cc1ab64 --- /dev/null +++ b/internal/waf/checkpoints/request_upload_test.go @@ -0,0 +1,81 @@ +package checkpoints + +import ( + "bytes" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "io/ioutil" + "mime/multipart" + "net/http" + "testing" +) + +func TestRequestUploadCheckpoint_RequestValue(t *testing.T) { + body := bytes.NewBuffer([]byte{}) + + writer := multipart.NewWriter(body) + + { + part, err := writer.CreateFormField("name") + if err == nil { + part.Write([]byte("lu")) + } + } + + { + part, err := writer.CreateFormField("age") + if err == nil { + part.Write([]byte("20")) + } + } + + { + part, err := writer.CreateFormFile("myFile", "hello.txt") + if err == nil { + part.Write([]byte("Hello, World!")) + } + } + + { + part, err := writer.CreateFormFile("myFile2", "hello.PHP") + if err == nil { + part.Write([]byte("Hello, World, PHP!")) + } + } + + { + part, err := writer.CreateFormFile("myFile3", "hello.asp") + if err == nil { + part.Write([]byte("Hello, World, ASP Pages!")) + } + } + + { + part, err := writer.CreateFormFile("myFile4", "hello.asp") + if err == nil { + part.Write([]byte("Hello, World, ASP Pages!")) + } + } + + writer.Close() + + rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn/", body) + if err != nil { + t.Fatal() + } + + req := requests.NewRequest(rawReq) + req.Header.Add("Content-Type", writer.FormDataContentType()) + + checkpoint := new(RequestUploadCheckpoint) + t.Log(checkpoint.RequestValue(req, "field", nil)) + t.Log(checkpoint.RequestValue(req, "minSize", nil)) + t.Log(checkpoint.RequestValue(req, "maxSize", nil)) + t.Log(checkpoint.RequestValue(req, "name", nil)) + t.Log(checkpoint.RequestValue(req, "ext", nil)) + + data, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatal(err) + } + t.Log(string(data)) +} diff --git a/internal/waf/checkpoints/request_uri.go b/internal/waf/checkpoints/request_uri.go new file mode 100644 index 0000000..b8193f9 --- /dev/null +++ b/internal/waf/checkpoints/request_uri.go @@ -0,0 +1,25 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestURICheckpoint struct { + Checkpoint +} + +func (this *RequestURICheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if len(req.RequestURI) > 0 { + value = req.RequestURI + } else if req.URL != nil { + value = req.URL.RequestURI() + } + return +} + +func (this *RequestURICheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/request_user_agent.go b/internal/waf/checkpoints/request_user_agent.go new file mode 100644 index 0000000..d75b621 --- /dev/null +++ b/internal/waf/checkpoints/request_user_agent.go @@ -0,0 +1,21 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +type RequestUserAgentCheckpoint struct { + Checkpoint +} + +func (this *RequestUserAgentCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = req.UserAgent() + return +} + +func (this *RequestUserAgentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/response_body.go b/internal/waf/checkpoints/response_body.go new file mode 100644 index 0000000..2f50a58 --- /dev/null +++ b/internal/waf/checkpoints/response_body.go @@ -0,0 +1,41 @@ +package checkpoints + +import ( + "bytes" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "io/ioutil" +) + +// ${responseBody} +type ResponseBodyCheckpoint struct { + Checkpoint +} + +func (this *ResponseBodyCheckpoint) IsRequest() bool { + return false +} + +func (this *ResponseBodyCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = "" + return +} + +func (this *ResponseBodyCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = "" + if resp != nil && resp.Body != nil { + if len(resp.BodyData) > 0 { + value = string(resp.BodyData) + return + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + sysErr = err + return + } + resp.BodyData = body + _ = resp.Body.Close() + value = body + resp.Body = ioutil.NopCloser(bytes.NewBuffer(body)) + } + return +} diff --git a/internal/waf/checkpoints/response_body_test.go b/internal/waf/checkpoints/response_body_test.go new file mode 100644 index 0000000..97959fe --- /dev/null +++ b/internal/waf/checkpoints/response_body_test.go @@ -0,0 +1,29 @@ +package checkpoints + +import ( + "bytes" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "io/ioutil" + "net/http" + "testing" +) + +func TestResponseBodyCheckpoint_ResponseValue(t *testing.T) { + resp := requests.NewResponse(new(http.Response)) + resp.StatusCode = 200 + resp.Header = http.Header{} + resp.Header.Set("Hello", "World") + resp.Body = ioutil.NopCloser(bytes.NewBuffer([]byte("Hello, World"))) + + checkpoint := new(ResponseBodyCheckpoint) + t.Log(checkpoint.ResponseValue(nil, resp, "", nil)) + t.Log(checkpoint.ResponseValue(nil, resp, "", nil)) + t.Log(checkpoint.ResponseValue(nil, resp, "", nil)) + t.Log(checkpoint.ResponseValue(nil, resp, "", nil)) + + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + t.Log("after read:", string(data)) +} diff --git a/internal/waf/checkpoints/response_bytes_sent.go b/internal/waf/checkpoints/response_bytes_sent.go new file mode 100644 index 0000000..4df2b3e --- /dev/null +++ b/internal/waf/checkpoints/response_bytes_sent.go @@ -0,0 +1,27 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +// ${bytesSent} +type ResponseBytesSentCheckpoint struct { + Checkpoint +} + +func (this *ResponseBytesSentCheckpoint) IsRequest() bool { + return false +} + +func (this *ResponseBytesSentCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = 0 + return +} + +func (this *ResponseBytesSentCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = 0 + if resp != nil { + value = resp.ContentLength + } + return +} diff --git a/internal/waf/checkpoints/response_header.go b/internal/waf/checkpoints/response_header.go new file mode 100644 index 0000000..4db9150 --- /dev/null +++ b/internal/waf/checkpoints/response_header.go @@ -0,0 +1,28 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +// ${responseHeader.arg} +type ResponseHeaderCheckpoint struct { + Checkpoint +} + +func (this *ResponseHeaderCheckpoint) IsRequest() bool { + return false +} + +func (this *ResponseHeaderCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = "" + return +} + +func (this *ResponseHeaderCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if resp != nil && resp.Header != nil { + value = resp.Header.Get(param) + } else { + value = "" + } + return +} diff --git a/internal/waf/checkpoints/response_header_test.go b/internal/waf/checkpoints/response_header_test.go new file mode 100644 index 0000000..60e733b --- /dev/null +++ b/internal/waf/checkpoints/response_header_test.go @@ -0,0 +1,17 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" + "testing" +) + +func TestResponseHeaderCheckpoint_ResponseValue(t *testing.T) { + resp := requests.NewResponse(new(http.Response)) + resp.StatusCode = 200 + resp.Header = http.Header{} + resp.Header.Set("Hello", "World") + + checkpoint := new(ResponseHeaderCheckpoint) + t.Log(checkpoint.ResponseValue(nil, resp, "Hello", nil)) +} diff --git a/internal/waf/checkpoints/response_status.go b/internal/waf/checkpoints/response_status.go new file mode 100644 index 0000000..f35870f --- /dev/null +++ b/internal/waf/checkpoints/response_status.go @@ -0,0 +1,26 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +// ${bytesSent} +type ResponseStatusCheckpoint struct { + Checkpoint +} + +func (this *ResponseStatusCheckpoint) IsRequest() bool { + return false +} + +func (this *ResponseStatusCheckpoint) RequestValue(req *requests.Request, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + value = 0 + return +} + +func (this *ResponseStatusCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]interface{}) (value interface{}, sysErr error, userErr error) { + if resp != nil { + value = resp.StatusCode + } + return +} diff --git a/internal/waf/checkpoints/response_status_test.go b/internal/waf/checkpoints/response_status_test.go new file mode 100644 index 0000000..252360d --- /dev/null +++ b/internal/waf/checkpoints/response_status_test.go @@ -0,0 +1,15 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "net/http" + "testing" +) + +func TestResponseStatusCheckpoint_ResponseValue(t *testing.T) { + resp := requests.NewResponse(new(http.Response)) + resp.StatusCode = 200 + + checkpoint := new(ResponseStatusCheckpoint) + t.Log(checkpoint.ResponseValue(nil, resp, "", nil)) +} diff --git a/internal/waf/checkpoints/sample_request.go b/internal/waf/checkpoints/sample_request.go new file mode 100644 index 0000000..be8b1cb --- /dev/null +++ b/internal/waf/checkpoints/sample_request.go @@ -0,0 +1,21 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +// just a sample checkpoint, copy and change it for your new checkpoint +type SampleRequestCheckpoint struct { + Checkpoint +} + +func (this *SampleRequestCheckpoint) RequestValue(req *requests.Request, param string, options map[string]string) (value interface{}, sysErr error, userErr error) { + return +} + +func (this *SampleRequestCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]string) (value interface{}, sysErr error, userErr error) { + if this.IsRequest() { + return this.RequestValue(req, param, options) + } + return +} diff --git a/internal/waf/checkpoints/sample_response.go b/internal/waf/checkpoints/sample_response.go new file mode 100644 index 0000000..27c0609 --- /dev/null +++ b/internal/waf/checkpoints/sample_response.go @@ -0,0 +1,22 @@ +package checkpoints + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +// just a sample checkpoint, copy and change it for your new checkpoint +type SampleResponseCheckpoint struct { + Checkpoint +} + +func (this *SampleResponseCheckpoint) IsRequest() bool { + return false +} + +func (this *SampleResponseCheckpoint) RequestValue(req *requests.Request, param string, options map[string]string) (value interface{}, sysErr error, userErr error) { + return +} + +func (this *SampleResponseCheckpoint) ResponseValue(req *requests.Request, resp *requests.Response, param string, options map[string]string) (value interface{}, sysErr error, userErr error) { + return +} diff --git a/internal/waf/checkpoints/utils.go b/internal/waf/checkpoints/utils.go new file mode 100644 index 0000000..b029402 --- /dev/null +++ b/internal/waf/checkpoints/utils.go @@ -0,0 +1,235 @@ +package checkpoints + +// all check points list +var AllCheckpoints = []*CheckpointDefinition{ + { + Name: "客户端地址(IP)", + Prefix: "remoteAddr", + Description: "试图通过分析X-Forwarded-For等Header获取的客户端地址,比如192.168.1.100", + HasParams: false, + Instance: new(RequestRemoteAddrCheckpoint), + }, + { + Name: "客户端源地址(IP)", + Prefix: "rawRemoteAddr", + Description: "直接连接的客户端地址,比如192.168.1.100", + HasParams: false, + Instance: new(RequestRawRemoteAddrCheckpoint), + }, + { + Name: "客户端端口", + Prefix: "remotePort", + Description: "直接连接的客户端地址端口", + HasParams: false, + Instance: new(RequestRemotePortCheckpoint), + }, + { + Name: "客户端用户名", + Prefix: "remoteUser", + Description: "通过BasicAuth登录的客户端用户名", + HasParams: false, + Instance: new(RequestRemoteUserCheckpoint), + }, + { + Name: "请求URI", + Prefix: "requestURI", + Description: "包含URL参数的请求URI,比如/hello/world?lang=go", + HasParams: false, + Instance: new(RequestURICheckpoint), + }, + { + Name: "请求路径", + Prefix: "requestPath", + Description: "不包含URL参数的请求路径,比如/hello/world", + HasParams: false, + Instance: new(RequestPathCheckpoint), + }, + { + Name: "请求内容长度", + Prefix: "requestLength", + Description: "请求Header中的Content-Length", + HasParams: false, + Instance: new(RequestLengthCheckpoint), + }, + { + Name: "请求体内容", + Prefix: "requestBody", + Description: "通常在POST或者PUT等操作时会附带请求体,最大限制32M", + HasParams: false, + Instance: new(RequestBodyCheckpoint), + }, + { + Name: "请求URI和请求体组合", + Prefix: "requestAll", + Description: "${requestURI}和${requestBody}组合", + HasParams: false, + Instance: new(RequestAllCheckpoint), + }, + { + Name: "请求表单参数", + Prefix: "requestForm", + Description: "获取POST或者其他方法发送的表单参数,最大请求体限制32M", + HasParams: true, + Instance: new(RequestFormArgCheckpoint), + }, + { + Name: "上传文件", + Prefix: "requestUpload", + Description: "获取POST上传的文件信息,最大请求体限制32M", + HasParams: true, + Instance: new(RequestUploadCheckpoint), + }, + { + Name: "请求JSON参数", + Prefix: "requestJSON", + Description: "获取POST或者其他方法发送的JSON,最大请求体限制32M,使用点(.)符号表示多级数据", + HasParams: true, + Instance: new(RequestJSONArgCheckpoint), + }, + { + Name: "请求方法", + Prefix: "requestMethod", + Description: "比如GET、POST", + HasParams: false, + Instance: new(RequestMethodCheckpoint), + }, + { + Name: "请求协议", + Prefix: "scheme", + Description: "比如http或https", + HasParams: false, + Instance: new(RequestSchemeCheckpoint), + }, + { + Name: "HTTP协议版本", + Prefix: "proto", + Description: "比如HTTP/1.1", + HasParams: false, + Instance: new(RequestProtoCheckpoint), + }, + { + Name: "主机名", + Prefix: "host", + Description: "比如teaos.cn", + HasParams: false, + Instance: new(RequestHostCheckpoint), + }, + { + Name: "请求来源URL", + Prefix: "referer", + Description: "请求Header中的Referer值", + HasParams: false, + Instance: new(RequestRefererCheckpoint), + }, + { + Name: "客户端信息", + Prefix: "userAgent", + Description: "比如Mozilla/5.0 AppleWebKit/537.36 (KHTML, like Gecko) Chrome/73.0.3683.103", + HasParams: false, + Instance: new(RequestUserAgentCheckpoint), + }, + { + Name: "内容类型", + Prefix: "contentType", + Description: "请求Header的Content-Type", + HasParams: false, + Instance: new(RequestContentTypeCheckpoint), + }, + { + Name: "所有cookie组合字符串", + Prefix: "cookies", + Description: "比如sid=IxZVPFhE&city=beijing&uid=18237", + HasParams: false, + Instance: new(RequestCookiesCheckpoint), + }, + { + Name: "单个cookie值", + Prefix: "cookie", + Description: "单个cookie值", + HasParams: true, + Instance: new(RequestCookieCheckpoint), + }, + { + Name: "所有URL参数组合", + Prefix: "args", + Description: "比如name=lu&age=20", + HasParams: false, + Instance: new(RequestArgsCheckpoint), + }, + { + Name: "单个URL参数值", + Prefix: "arg", + Description: "单个URL参数值", + HasParams: true, + Instance: new(RequestArgCheckpoint), + }, + { + Name: "所有Header信息", + Prefix: "headers", + Description: "使用\n隔开的Header信息字符串", + HasParams: false, + Instance: new(RequestHeadersCheckpoint), + }, + { + Name: "单个Header值", + Prefix: "header", + Description: "单个Header值", + HasParams: true, + Instance: new(RequestHeaderCheckpoint), + }, + { + Name: "CC统计", + Prefix: "cc", + Description: "统计某段时间段内的请求信息", + HasParams: true, + Instance: new(CCCheckpoint), + }, + { + Name: "响应状态码", + Prefix: "status", + Description: "响应状态码,比如200、404、500", + HasParams: false, + Instance: new(ResponseStatusCheckpoint), + }, + { + Name: "响应Header", + Prefix: "responseHeader", + Description: "响应Header值", + HasParams: true, + Instance: new(ResponseHeaderCheckpoint), + }, + { + Name: "响应内容", + Prefix: "responseBody", + Description: "响应内容字符串", + HasParams: false, + Instance: new(ResponseBodyCheckpoint), + }, + { + Name: "响应内容长度", + Prefix: "bytesSent", + Description: "响应内容长度,通过响应的Header Content-Length获取", + HasParams: false, + Instance: new(ResponseBytesSentCheckpoint), + }, +} + +// find a check point +func FindCheckpoint(prefix string) CheckpointInterface { + for _, def := range AllCheckpoints { + if def.Prefix == prefix { + return def.Instance + } + } + return nil +} + +// find a check point definition +func FindCheckpointDefinition(prefix string) *CheckpointDefinition { + for _, def := range AllCheckpoints { + if def.Prefix == prefix { + return def + } + } + return nil +} diff --git a/internal/waf/checkpoints/utils_test.go b/internal/waf/checkpoints/utils_test.go new file mode 100644 index 0000000..f761778 --- /dev/null +++ b/internal/waf/checkpoints/utils_test.go @@ -0,0 +1,31 @@ +package checkpoints + +import ( + "fmt" + "strings" + "testing" +) + +func TestFindCheckpointDefinition_Markdown(t *testing.T) { + result := []string{} + for _, def := range AllCheckpoints { + row := "## " + def.Name + "\n* 前缀:`${" + def.Prefix + "}`\n* 描述:" + def.Description + if def.HasParams { + row += "\n* 是否有子参数:YES" + + paramOptions := def.Instance.ParamOptions() + if paramOptions != nil && len(paramOptions.Options) > 0 { + row += "\n* 可选子参数" + for _, option := range paramOptions.Options { + row += "\n * `" + option.Name + "`:值为 `" + option.Value + "`" + } + } + } else { + row += "\n* 是否有子参数:NO" + } + row += "\n" + result = append(result, row) + } + + fmt.Print(strings.Join(result, "\n") + "\n") +} diff --git a/internal/waf/ip_table.go b/internal/waf/ip_table.go new file mode 100644 index 0000000..7367b84 --- /dev/null +++ b/internal/waf/ip_table.go @@ -0,0 +1,154 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" + "github.com/iwind/TeaGo/lists" + "github.com/iwind/TeaGo/types" + stringutil "github.com/iwind/TeaGo/utils/string" + "regexp" + "strings" + "time" +) + +type IPAction = string + +var RegexpDigitNumber = regexp.MustCompile("^\\d+$") + +const ( + IPActionAccept IPAction = "accept" + IPActionReject IPAction = "reject" +) + +// ip table +type IPTable struct { + Id string `yaml:"id" json:"id"` + On bool `yaml:"on" json:"on"` + IP string `yaml:"ip" json:"ip"` // single ip, cidr, ip range, TODO support * + Port string `yaml:"port" json:"port"` // single port, range, * + Action IPAction `yaml:"action" json:"action"` // accept, reject + TimeFrom int64 `yaml:"timeFrom" json:"timeFrom"` // from timestamp + TimeTo int64 `yaml:"timeTo" json:"timeTo"` // zero means forever + Remark string `yaml:"remark" json:"remark"` + + // port + minPort int + maxPort int + + minPortWildcard bool + maxPortWildcard bool + + ports []int + + // ip + ipRange *shared.IPRangeConfig +} + +func NewIPTable() *IPTable { + return &IPTable{ + On: true, + Id: stringutil.Rand(16), + } +} + +func (this *IPTable) Init() error { + // parse port + if RegexpDigitNumber.MatchString(this.Port) { + this.minPort = types.Int(this.Port) + this.maxPort = types.Int(this.Port) + } else if regexp.MustCompile(`[:-]`).MatchString(this.Port) { + pieces := regexp.MustCompile(`[:-]`).Split(this.Port, 2) + if pieces[0] == "*" { + this.minPortWildcard = true + } else { + this.minPort = types.Int(pieces[0]) + } + if pieces[1] == "*" { + this.maxPortWildcard = true + } else { + this.maxPort = types.Int(pieces[1]) + } + } else if strings.Contains(this.Port, ",") { + pieces := strings.Split(this.Port, ",") + for _, piece := range pieces { + piece = strings.TrimSpace(piece) + if len(piece) > 0 { + this.ports = append(this.ports, types.Int(piece)) + } + } + } else if this.Port == "*" { + this.minPortWildcard = true + this.maxPortWildcard = true + } + + // parse ip + if len(this.IP) > 0 { + ipRange, err := shared.ParseIPRange(this.IP) + if err != nil { + return err + } + this.ipRange = ipRange + } + + return nil +} + +// check ip +func (this *IPTable) Match(ip string, port int) (isMatched bool) { + if !this.On { + return + } + + now := time.Now().Unix() + if this.TimeFrom > 0 && now < this.TimeFrom { + return + } + if this.TimeTo > 0 && now > this.TimeTo { + return + } + + if !this.matchPort(port) { + return + } + + if !this.matchIP(ip) { + return + } + + return true +} + +func (this *IPTable) matchPort(port int) bool { + if port == 0 { + return false + } + if this.minPortWildcard { + if this.maxPortWildcard { + return true + } + if this.maxPort >= port { + return true + } + } + if this.maxPortWildcard { + if this.minPortWildcard { + return true + } + if this.minPort <= port { + return true + } + } + if (this.minPort > 0 || this.maxPort > 0) && this.minPort <= port && this.maxPort >= port { + return true + } + if len(this.ports) > 0 { + return lists.ContainsInt(this.ports, port) + } + return false +} + +func (this *IPTable) matchIP(ip string) bool { + if this.ipRange == nil { + return false + } + return this.ipRange.Contains(ip) +} diff --git a/internal/waf/ip_table_test.go b/internal/waf/ip_table_test.go new file mode 100644 index 0000000..5fbd3e8 --- /dev/null +++ b/internal/waf/ip_table_test.go @@ -0,0 +1,142 @@ +package waf + +import ( + "github.com/iwind/TeaGo/assert" + "testing" + "time" +) + +func TestIPTable_MatchIP(t *testing.T) { + a := assert.NewAssertion(t) + + { + table := NewIPTable() + err := table.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(table.Match("192.168.1.100", 8080)) + } + + { + table := NewIPTable() + table.IP = "*" + table.Port = "8080" + err := table.Init() + if err != nil { + t.Fatal(err) + } + t.Log("port:", table.minPort, table.maxPort) + a.IsTrue(table.Match("192.168.1.100", 8080)) + a.IsFalse(table.Match("192.168.1.100", 8081)) + } + + { + table := NewIPTable() + table.IP = "*" + table.Port = "8080-8082" + err := table.Init() + if err != nil { + t.Fatal(err) + } + t.Log("port:", table.minPort, table.maxPort) + a.IsTrue(table.Match("192.168.1.100", 8080)) + a.IsTrue(table.Match("192.168.1.100", 8081)) + a.IsFalse(table.Match("192.168.1.100", 8083)) + } + + { + table := NewIPTable() + table.IP = "*" + table.Port = "*-8082" + err := table.Init() + if err != nil { + t.Fatal(err) + } + t.Log("port:", table.minPort, table.maxPort) + a.IsTrue(table.Match("192.168.1.100", 8079)) + a.IsTrue(table.Match("192.168.1.100", 8080)) + a.IsTrue(table.Match("192.168.1.100", 8081)) + a.IsFalse(table.Match("192.168.1.100", 8083)) + } + + { + table := NewIPTable() + table.IP = "*" + table.Port = "8080-*" + err := table.Init() + if err != nil { + t.Fatal(err) + } + t.Log("port:", table.minPort, table.maxPort) + a.IsFalse(table.Match("192.168.1.100", 8079)) + a.IsTrue(table.Match("192.168.1.100", 8080)) + a.IsTrue(table.Match("192.168.1.100", 8081)) + a.IsTrue(table.Match("192.168.1.100", 8083)) + } + + { + table := NewIPTable() + table.IP = "*" + table.Port = "*" + err := table.Init() + if err != nil { + t.Fatal(err) + } + t.Log("port:", table.minPort, table.maxPort) + a.IsTrue(table.Match("192.168.1.100", 8079)) + a.IsTrue(table.Match("192.168.1.100", 8080)) + a.IsTrue(table.Match("192.168.1.100", 8081)) + a.IsTrue(table.Match("192.168.1.100", 8083)) + } + + { + table := NewIPTable() + table.IP = "192.168.1.100" + table.Port = "*" + err := table.Init() + if err != nil { + t.Fatal(err) + } + t.Log("port:", table.minPort, table.maxPort) + a.IsTrue(table.Match("192.168.1.100", 8080)) + } + + { + table := NewIPTable() + table.IP = "192.168.1.99-192.168.1.101" + table.Port = "*" + err := table.Init() + if err != nil { + t.Fatal(err) + } + t.Log("port:", table.minPort, table.maxPort) + a.IsTrue(table.Match("192.168.1.100", 8080)) + } + + { + table := NewIPTable() + table.IP = "192.168.1.99/24" + table.Port = "*" + err := table.Init() + if err != nil { + t.Fatal(err) + } + t.Log("ip:", table.ipRange) + a.IsTrue(table.Match("192.168.1.100", 8080)) + a.IsFalse(table.Match("192.168.2.100", 8080)) + } + + { + table := NewIPTable() + table.IP = "192.168.1.99/24" + table.TimeTo = time.Now().Unix() - 10 + table.Port = "*" + err := table.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(table.Match("192.168.1.100", 8080)) + a.IsFalse(table.Match("192.168.2.100", 8080)) + } +} diff --git a/internal/waf/requests/request.go b/internal/waf/requests/request.go new file mode 100644 index 0000000..a1b982b --- /dev/null +++ b/internal/waf/requests/request.go @@ -0,0 +1,35 @@ +package requests + +import ( + "bytes" + "io" + "io/ioutil" + "net/http" +) + +type Request struct { + *http.Request + BodyData []byte +} + +func NewRequest(raw *http.Request) *Request { + return &Request{ + Request: raw, + } +} + +func (this *Request) Raw() *http.Request { + return this.Request +} + +func (this *Request) ReadBody(max int64) (data []byte, err error) { + data, err = ioutil.ReadAll(io.LimitReader(this.Request.Body, max)) + return +} + +func (this *Request) RestoreBody(data []byte) { + rawReader := bytes.NewBuffer(data) + buf := make([]byte, 1024) + io.CopyBuffer(rawReader, this.Request.Body, buf) + this.Request.Body = ioutil.NopCloser(rawReader) +} diff --git a/internal/waf/requests/response.go b/internal/waf/requests/response.go new file mode 100644 index 0000000..2202135 --- /dev/null +++ b/internal/waf/requests/response.go @@ -0,0 +1,15 @@ +package requests + +import "net/http" + +type Response struct { + *http.Response + + BodyData []byte +} + +func NewResponse(resp *http.Response) *Response { + return &Response{ + Response: resp, + } +} diff --git a/internal/waf/rule.go b/internal/waf/rule.go new file mode 100644 index 0000000..8b3fb0b --- /dev/null +++ b/internal/waf/rule.go @@ -0,0 +1,584 @@ +package waf + +import ( + "bytes" + "encoding/binary" + "errors" + "github.com/TeaOSLab/EdgeCommon/pkg/configutils" + "github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/TeaOSLab/EdgeNode/internal/waf/utils" + "github.com/iwind/TeaGo/lists" + "github.com/iwind/TeaGo/maps" + "github.com/iwind/TeaGo/types" + "github.com/iwind/TeaGo/utils/string" + "net" + "reflect" + "regexp" + "strings" +) + +var singleParamRegexp = regexp.MustCompile("^\\${[\\w.-]+}$") + +// rule +type Rule struct { + Description string `yaml:"description" json:"description"` + Param string `yaml:"param" json:"param"` // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName} + Operator RuleOperator `yaml:"operator" json:"operator"` // such as contains, gt, ... + Value string `yaml:"value" json:"value"` // compared value + IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"` + CheckpointOptions map[string]interface{} `yaml:"checkpointOptions" json:"checkpointOptions"` + + checkpointFinder func(prefix string) checkpoints.CheckpointInterface + + singleParam string // real param after prefix + singleCheckpoint checkpoints.CheckpointInterface // if is single check point + + multipleCheckpoints map[string]checkpoints.CheckpointInterface + + isIP bool + ipValue net.IP + + floatValue float64 + reg *regexp.Regexp +} + +func NewRule() *Rule { + return &Rule{} +} + +func (this *Rule) Init() error { + // operator + switch this.Operator { + case RuleOperatorGt: + this.floatValue = types.Float64(this.Value) + case RuleOperatorGte: + this.floatValue = types.Float64(this.Value) + case RuleOperatorLt: + this.floatValue = types.Float64(this.Value) + case RuleOperatorLte: + this.floatValue = types.Float64(this.Value) + case RuleOperatorEq: + this.floatValue = types.Float64(this.Value) + case RuleOperatorNeq: + this.floatValue = types.Float64(this.Value) + case RuleOperatorMatch: + v := this.Value + if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") { + v = "(?i)" + v + } + + v = this.unescape(v) + + reg, err := regexp.Compile(v) + if err != nil { + return err + } + this.reg = reg + case RuleOperatorNotMatch: + v := this.Value + if this.IsCaseInsensitive && !strings.HasPrefix(v, "(?i)") { + v = "(?i)" + v + } + + v = this.unescape(v) + + reg, err := regexp.Compile(v) + if err != nil { + return err + } + this.reg = reg + case RuleOperatorEqIP, RuleOperatorGtIP, RuleOperatorGteIP, RuleOperatorLtIP, RuleOperatorLteIP: + this.ipValue = net.ParseIP(this.Value) + this.isIP = this.ipValue != nil + + if !this.isIP { + return errors.New("value should be a valid ip") + } + case RuleOperatorIPRange, RuleOperatorNotIPRange: + if strings.Contains(this.Value, ",") { + ipList := strings.SplitN(this.Value, ",", 2) + ipString1 := strings.TrimSpace(ipList[0]) + ipString2 := strings.TrimSpace(ipList[1]) + + if len(ipString1) > 0 { + ip1 := net.ParseIP(ipString1) + if ip1 == nil { + return errors.New("start ip is invalid") + } + } + + if len(ipString2) > 0 { + ip2 := net.ParseIP(ipString2) + if ip2 == nil { + return errors.New("end ip is invalid") + } + } + } else if strings.Contains(this.Value, "/") { + _, _, err := net.ParseCIDR(this.Value) + if err != nil { + return err + } + } else { + return errors.New("invalid ip range") + } + + } + + if singleParamRegexp.MatchString(this.Param) { + param := this.Param[2 : len(this.Param)-1] + pieces := strings.SplitN(param, ".", 2) + prefix := pieces[0] + if len(pieces) == 1 { + this.singleParam = "" + } else { + this.singleParam = pieces[1] + } + + if this.checkpointFinder != nil { + checkpoint := this.checkpointFinder(prefix) + if checkpoint == nil { + return errors.New("no check point '" + prefix + "' found") + } + this.singleCheckpoint = checkpoint + } else { + checkpoint := checkpoints.FindCheckpoint(prefix) + if checkpoint == nil { + return errors.New("no check point '" + prefix + "' found") + } + checkpoint.Init() + this.singleCheckpoint = checkpoint + } + + return nil + } + + this.multipleCheckpoints = map[string]checkpoints.CheckpointInterface{} + var err error = nil + configutils.ParseVariables(this.Param, func(varName string) (value string) { + pieces := strings.SplitN(varName, ".", 2) + prefix := pieces[0] + if this.checkpointFinder != nil { + checkpoint := this.checkpointFinder(prefix) + if checkpoint == nil { + err = errors.New("no check point '" + prefix + "' found") + } else { + this.multipleCheckpoints[prefix] = checkpoint + } + } else { + checkpoint := checkpoints.FindCheckpoint(prefix) + if checkpoint == nil { + err = errors.New("no check point '" + prefix + "' found") + } else { + checkpoint.Init() + this.multipleCheckpoints[prefix] = checkpoint + } + } + return "" + }) + + return err +} + +func (this *Rule) MatchRequest(req *requests.Request) (b bool, err error) { + if this.singleCheckpoint != nil { + value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions) + if err != nil { + return false, err + } + return this.Test(value), nil + } + + value := configutils.ParseVariables(this.Param, func(varName string) (value string) { + pieces := strings.SplitN(varName, ".", 2) + prefix := pieces[0] + point, ok := this.multipleCheckpoints[prefix] + if !ok { + return "" + } + + if len(pieces) == 1 { + value1, err1, _ := point.RequestValue(req, "", this.CheckpointOptions) + if err1 != nil { + err = err1 + } + return types.String(value1) + } + + value1, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions) + if err1 != nil { + err = err1 + } + return types.String(value1) + }) + + if err != nil { + return false, err + } + + return this.Test(value), nil +} + +func (this *Rule) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, err error) { + if this.singleCheckpoint != nil { + // if is request param + if this.singleCheckpoint.IsRequest() { + value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions) + if err != nil { + return false, err + } + return this.Test(value), nil + } + + // response param + value, err, _ := this.singleCheckpoint.ResponseValue(req, resp, this.singleParam, this.CheckpointOptions) + if err != nil { + return false, err + } + return this.Test(value), nil + } + + value := configutils.ParseVariables(this.Param, func(varName string) (value string) { + pieces := strings.SplitN(varName, ".", 2) + prefix := pieces[0] + point, ok := this.multipleCheckpoints[prefix] + if !ok { + return "" + } + + if len(pieces) == 1 { + if point.IsRequest() { + value1, err1, _ := point.RequestValue(req, "", this.CheckpointOptions) + if err1 != nil { + err = err1 + } + return types.String(value1) + } else { + value1, err1, _ := point.ResponseValue(req, resp, "", this.CheckpointOptions) + if err1 != nil { + err = err1 + } + return types.String(value1) + } + } + + if point.IsRequest() { + value1, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions) + if err1 != nil { + err = err1 + } + return types.String(value1) + } else { + value1, err1, _ := point.ResponseValue(req, resp, pieces[1], this.CheckpointOptions) + if err1 != nil { + err = err1 + } + return types.String(value1) + } + }) + + if err != nil { + return false, err + } + + return this.Test(value), nil +} + +func (this *Rule) Test(value interface{}) bool { + // operator + switch this.Operator { + case RuleOperatorGt: + return types.Float64(value) > this.floatValue + case RuleOperatorGte: + return types.Float64(value) >= this.floatValue + case RuleOperatorLt: + return types.Float64(value) < this.floatValue + case RuleOperatorLte: + return types.Float64(value) <= this.floatValue + case RuleOperatorEq: + return types.Float64(value) == this.floatValue + case RuleOperatorNeq: + return types.Float64(value) != this.floatValue + case RuleOperatorEqString: + if this.IsCaseInsensitive { + return strings.ToLower(types.String(value)) == strings.ToLower(this.Value) + } else { + return types.String(value) == this.Value + } + case RuleOperatorNeqString: + if this.IsCaseInsensitive { + return strings.ToLower(types.String(value)) != strings.ToLower(this.Value) + } else { + return types.String(value) != this.Value + } + case RuleOperatorMatch: + if value == nil { + return false + } + + // strings + stringList, ok := value.([]string) + if ok { + for _, s := range stringList { + if utils.MatchStringCache(this.reg, s) { + return true + } + } + return false + } + + // bytes + byteSlice, ok := value.([]byte) + if ok { + return utils.MatchBytesCache(this.reg, byteSlice) + } + + // string + return utils.MatchStringCache(this.reg, types.String(value)) + case RuleOperatorNotMatch: + if value == nil { + return true + } + stringList, ok := value.([]string) + if ok { + for _, s := range stringList { + if utils.MatchStringCache(this.reg, s) { + return false + } + } + return true + } + + // bytes + byteSlice, ok := value.([]byte) + if ok { + return !utils.MatchBytesCache(this.reg, byteSlice) + } + + return !utils.MatchStringCache(this.reg, types.String(value)) + case RuleOperatorContains: + if types.IsSlice(value) { + ok := false + lists.Each(value, func(k int, v interface{}) { + if types.String(v) == this.Value { + ok = true + } + }) + return ok + } + if types.IsMap(value) { + lowerValue := "" + if this.IsCaseInsensitive { + lowerValue = strings.ToLower(this.Value) + } + for _, v := range maps.NewMap(value) { + if this.IsCaseInsensitive { + if strings.ToLower(types.String(v)) == lowerValue { + return true + } + } else { + if types.String(v) == this.Value { + return true + } + } + } + return false + } + + if this.IsCaseInsensitive { + return strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value)) + } else { + return strings.Contains(types.String(value), this.Value) + } + case RuleOperatorNotContains: + if this.IsCaseInsensitive { + return !strings.Contains(strings.ToLower(types.String(value)), strings.ToLower(this.Value)) + } else { + return !strings.Contains(types.String(value), this.Value) + } + case RuleOperatorPrefix: + if this.IsCaseInsensitive { + return strings.HasPrefix(strings.ToLower(types.String(value)), strings.ToLower(this.Value)) + } else { + return strings.HasPrefix(types.String(value), this.Value) + } + case RuleOperatorSuffix: + if this.IsCaseInsensitive { + return strings.HasSuffix(strings.ToLower(types.String(value)), strings.ToLower(this.Value)) + } else { + return strings.HasSuffix(types.String(value), this.Value) + } + case RuleOperatorHasKey: + if types.IsSlice(value) { + index := types.Int(this.Value) + if index < 0 { + return false + } + return reflect.ValueOf(value).Len() > index + } else if types.IsMap(value) { + m := maps.NewMap(value) + if this.IsCaseInsensitive { + lowerValue := strings.ToLower(this.Value) + for k := range m { + if strings.ToLower(k) == lowerValue { + return true + } + } + } else { + return m.Has(this.Value) + } + } else { + return false + } + + case RuleOperatorVersionGt: + return stringutil.VersionCompare(this.Value, types.String(value)) > 0 + case RuleOperatorVersionLt: + return stringutil.VersionCompare(this.Value, types.String(value)) < 0 + case RuleOperatorVersionRange: + if strings.Contains(this.Value, ",") { + versions := strings.SplitN(this.Value, ",", 2) + version1 := strings.TrimSpace(versions[0]) + version2 := strings.TrimSpace(versions[1]) + if len(version1) > 0 && stringutil.VersionCompare(types.String(value), version1) < 0 { + return false + } + if len(version2) > 0 && stringutil.VersionCompare(types.String(value), version2) > 0 { + return false + } + return true + } else { + return stringutil.VersionCompare(types.String(value), this.Value) >= 0 + } + case RuleOperatorEqIP: + ip := net.ParseIP(types.String(value)) + if ip == nil { + return false + } + return this.isIP && bytes.Compare(this.ipValue, ip) == 0 + case RuleOperatorGtIP: + ip := net.ParseIP(types.String(value)) + if ip == nil { + return false + } + return this.isIP && bytes.Compare(ip, this.ipValue) > 0 + case RuleOperatorGteIP: + ip := net.ParseIP(types.String(value)) + if ip == nil { + return false + } + return this.isIP && bytes.Compare(ip, this.ipValue) >= 0 + case RuleOperatorLtIP: + ip := net.ParseIP(types.String(value)) + if ip == nil { + return false + } + return this.isIP && bytes.Compare(ip, this.ipValue) < 0 + case RuleOperatorLteIP: + ip := net.ParseIP(types.String(value)) + if ip == nil { + return false + } + return this.isIP && bytes.Compare(ip, this.ipValue) <= 0 + case RuleOperatorIPRange: + return this.containsIP(value) + case RuleOperatorNotIPRange: + return !this.containsIP(value) + case RuleOperatorIPMod: + pieces := strings.SplitN(this.Value, ",", 2) + if len(pieces) == 1 { + rem := types.Int64(pieces[0]) + return this.ipToInt64(net.ParseIP(types.String(value)))%10 == rem + } + div := types.Int64(pieces[0]) + if div == 0 { + return false + } + rem := types.Int64(pieces[1]) + return this.ipToInt64(net.ParseIP(types.String(value)))%div == rem + case RuleOperatorIPMod10: + return this.ipToInt64(net.ParseIP(types.String(value)))%10 == types.Int64(this.Value) + case RuleOperatorIPMod100: + return this.ipToInt64(net.ParseIP(types.String(value)))%100 == types.Int64(this.Value) + } + return false +} + +func (this *Rule) IsSingleCheckpoint() bool { + return this.singleCheckpoint != nil +} + +func (this *Rule) SetCheckpointFinder(finder func(prefix string) checkpoints.CheckpointInterface) { + this.checkpointFinder = finder +} + +func (this *Rule) unescape(v string) string { + //replace urlencoded characters + v = strings.Replace(v, `\s`, `(\s|%09|%0A|\+)`, -1) + v = strings.Replace(v, `\(`, `(\(|%28)`, -1) + v = strings.Replace(v, `=`, `(=|%3D)`, -1) + v = strings.Replace(v, `<`, `(<|%3C)`, -1) + v = strings.Replace(v, `\*`, `(\*|%2A)`, -1) + v = strings.Replace(v, `\\`, `(\\|%2F)`, -1) + v = strings.Replace(v, `!`, `(!|%21)`, -1) + v = strings.Replace(v, `/`, `(/|%2F)`, -1) + v = strings.Replace(v, `;`, `(;|%3B)`, -1) + v = strings.Replace(v, `\+`, `(\+|%20)`, -1) + return v +} + +func (this *Rule) containsIP(value interface{}) bool { + ip := net.ParseIP(types.String(value)) + if ip == nil { + return false + } + + // 检查IP范围格式 + if strings.Contains(this.Value, ",") { + ipList := strings.SplitN(this.Value, ",", 2) + ipString1 := strings.TrimSpace(ipList[0]) + ipString2 := strings.TrimSpace(ipList[1]) + + if len(ipString1) > 0 { + ip1 := net.ParseIP(ipString1) + if ip1 == nil { + return false + } + + if bytes.Compare(ip, ip1) < 0 { + return false + } + } + + if len(ipString2) > 0 { + ip2 := net.ParseIP(ipString2) + if ip2 == nil { + return false + } + + if bytes.Compare(ip, ip2) > 0 { + return false + } + } + + return true + } else if strings.Contains(this.Value, "/") { + _, ipNet, err := net.ParseCIDR(this.Value) + if err != nil { + return false + } + return ipNet.Contains(ip) + } else { + return false + } +} + +func (this *Rule) ipToInt64(ip net.IP) int64 { + if len(ip) == 0 { + return 0 + } + if len(ip) == 16 { + return int64(binary.BigEndian.Uint32(ip[12:16])) + } + return int64(binary.BigEndian.Uint32(ip)) +} diff --git a/internal/waf/rule_group.go b/internal/waf/rule_group.go new file mode 100644 index 0000000..1c7bea5 --- /dev/null +++ b/internal/waf/rule_group.go @@ -0,0 +1,147 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" +) + +// rule group +type RuleGroup struct { + Id string `yaml:"id" json:"id"` + IsOn bool `yaml:"isOn" json:"isOn"` + Name string `yaml:"name" json:"name"` // such as SQL Injection + Description string `yaml:"description" json:"description"` + Code string `yaml:"code" json:"code"` // identify the group + RuleSets []*RuleSet `yaml:"ruleSets" json:"ruleSets"` + IsInbound bool `yaml:"isInbound" json:"isInbound"` + + hasRuleSets bool +} + +func NewRuleGroup() *RuleGroup { + return &RuleGroup{ + IsOn: true, + } +} + +func (this *RuleGroup) Init() error { + this.hasRuleSets = len(this.RuleSets) > 0 + + if this.hasRuleSets { + for _, set := range this.RuleSets { + err := set.Init() + if err != nil { + return err + } + } + } + return nil +} + +func (this *RuleGroup) AddRuleSet(ruleSet *RuleSet) { + this.RuleSets = append(this.RuleSets, ruleSet) +} + +func (this *RuleGroup) FindRuleSet(id string) *RuleSet { + if len(id) == 0 { + return nil + } + for _, ruleSet := range this.RuleSets { + if ruleSet.Id == id { + return ruleSet + } + } + return nil +} + +func (this *RuleGroup) FindRuleSetWithCode(code string) *RuleSet { + if len(code) == 0 { + return nil + } + for _, ruleSet := range this.RuleSets { + if ruleSet.Code == code { + return ruleSet + } + } + return nil +} + +func (this *RuleGroup) RemoveRuleSet(id string) { + if len(id) == 0 { + return + } + result := []*RuleSet{} + for _, ruleSet := range this.RuleSets { + if ruleSet.Id == id { + continue + } + result = append(result, ruleSet) + } + this.RuleSets = result +} + +func (this *RuleGroup) MatchRequest(req *requests.Request) (b bool, set *RuleSet, err error) { + if !this.hasRuleSets { + return + } + for _, set := range this.RuleSets { + if !set.IsOn { + continue + } + b, err = set.MatchRequest(req) + if err != nil { + return false, nil, err + } + if b { + return true, set, nil + } + } + return +} + +func (this *RuleGroup) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, set *RuleSet, err error) { + if !this.hasRuleSets { + return + } + for _, set := range this.RuleSets { + if !set.IsOn { + continue + } + b, err = set.MatchResponse(req, resp) + if err != nil { + return false, nil, err + } + if b { + return true, set, nil + } + } + return +} + +func (this *RuleGroup) MoveRuleSet(fromIndex int, toIndex int) { + if fromIndex < 0 || fromIndex >= len(this.RuleSets) { + return + } + if toIndex < 0 || toIndex >= len(this.RuleSets) { + return + } + if fromIndex == toIndex { + return + } + + location := this.RuleSets[fromIndex] + result := []*RuleSet{} + for i := 0; i < len(this.RuleSets); i++ { + if i == fromIndex { + continue + } + if fromIndex > toIndex && i == toIndex { + result = append(result, location) + } + result = append(result, this.RuleSets[i]) + if fromIndex < toIndex && i == toIndex { + result = append(result, location) + } + } + + this.RuleSets = result +} diff --git a/internal/waf/rule_operator.go b/internal/waf/rule_operator.go new file mode 100644 index 0000000..d1fc0a6 --- /dev/null +++ b/internal/waf/rule_operator.go @@ -0,0 +1,219 @@ +package waf + +type RuleOperator = string +type RuleCaseInsensitive = string + +const ( + RuleOperatorGt RuleOperator = "gt" + RuleOperatorGte RuleOperator = "gte" + RuleOperatorLt RuleOperator = "lt" + RuleOperatorLte RuleOperator = "lte" + RuleOperatorEq RuleOperator = "eq" + RuleOperatorNeq RuleOperator = "neq" + RuleOperatorEqString RuleOperator = "eq string" + RuleOperatorNeqString RuleOperator = "neq string" + RuleOperatorMatch RuleOperator = "match" + RuleOperatorNotMatch RuleOperator = "not match" + RuleOperatorContains RuleOperator = "contains" + RuleOperatorNotContains RuleOperator = "not contains" + RuleOperatorPrefix RuleOperator = "prefix" + RuleOperatorSuffix RuleOperator = "suffix" + RuleOperatorHasKey RuleOperator = "has key" // has key in slice or map + RuleOperatorVersionGt RuleOperator = "version gt" + RuleOperatorVersionLt RuleOperator = "version lt" + RuleOperatorVersionRange RuleOperator = "version range" + + // ip + RuleOperatorEqIP RuleOperator = "eq ip" + RuleOperatorGtIP RuleOperator = "gt ip" + RuleOperatorGteIP RuleOperator = "gte ip" + RuleOperatorLtIP RuleOperator = "lt ip" + RuleOperatorLteIP RuleOperator = "lte ip" + RuleOperatorIPRange RuleOperator = "ip range" + RuleOperatorNotIPRange RuleOperator = "not ip range" + RuleOperatorIPMod10 RuleOperator = "ip mod 10" + RuleOperatorIPMod100 RuleOperator = "ip mod 100" + RuleOperatorIPMod RuleOperator = "ip mod" + + RuleCaseInsensitiveNone = "none" + RuleCaseInsensitiveYes = "yes" + RuleCaseInsensitiveNo = "no" +) + +type RuleOperatorDefinition struct { + Name string + Code string + Description string + CaseInsensitive RuleCaseInsensitive // default caseInsensitive setting +} + +var AllRuleOperators = []*RuleOperatorDefinition{ + { + Name: "数值大于", + Code: RuleOperatorGt, + Description: "使用数值对比大于", + CaseInsensitive: RuleCaseInsensitiveNone, + }, + { + Name: "数值大于等于", + Code: RuleOperatorGte, + Description: "使用数值对比大于等于", + CaseInsensitive: RuleCaseInsensitiveNone, + }, + { + Name: "数值小于", + Code: RuleOperatorLt, + Description: "使用数值对比小于", + CaseInsensitive: RuleCaseInsensitiveNone, + }, + { + Name: "数值小于等于", + Code: RuleOperatorLte, + Description: "使用数值对比小于等于", + CaseInsensitive: RuleCaseInsensitiveNone, + }, + { + Name: "数值等于", + Code: RuleOperatorEq, + Description: "使用数值对比等于", + CaseInsensitive: RuleCaseInsensitiveNone, + }, + { + Name: "数值不等于", + Code: RuleOperatorNeq, + Description: "使用数值对比不等于", + CaseInsensitive: RuleCaseInsensitiveNone, + }, + { + Name: "字符串等于", + Code: RuleOperatorEqString, + Description: "使用字符串对比等于", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "字符串不等于", + Code: RuleOperatorNeqString, + Description: "使用字符串对比不等于", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "正则匹配", + Code: RuleOperatorMatch, + Description: "使用正则表达式匹配,在头部使用(?i)表示不区分大小写,正则表达式语法 »", + CaseInsensitive: RuleCaseInsensitiveYes, + }, + { + Name: "正则不匹配", + Code: RuleOperatorNotMatch, + Description: "使用正则表达式不匹配,在头部使用(?i)表示不区分大小写,正则表达式语法 »", + CaseInsensitive: RuleCaseInsensitiveYes, + }, + { + Name: "包含字符串", + Code: RuleOperatorContains, + Description: "包含某个字符串", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "不包含字符串", + Code: RuleOperatorNotContains, + Description: "不包含某个字符串", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "包含前缀", + Code: RuleOperatorPrefix, + Description: "包含某个前缀", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "包含后缀", + Code: RuleOperatorSuffix, + Description: "包含某个后缀", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "包含索引", + Code: RuleOperatorHasKey, + Description: "对于一组数据拥有某个键值或者索引", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "版本号大于", + Code: RuleOperatorVersionGt, + Description: "对比版本号大于", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "版本号小于", + Code: RuleOperatorVersionLt, + Description: "对比版本号小于", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "版本号范围", + Code: RuleOperatorVersionRange, + Description: "判断版本号在某个范围内,格式为version1,version2", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "IP等于", + Code: RuleOperatorEqIP, + Description: "将参数转换为IP进行对比", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "IP大于", + Code: RuleOperatorGtIP, + Description: "将参数转换为IP进行对比", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "IP大于等于", + Code: RuleOperatorGteIP, + Description: "将参数转换为IP进行对比", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "IP小于", + Code: RuleOperatorLtIP, + Description: "将参数转换为IP进行对比", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "IP小于等于", + Code: RuleOperatorLteIP, + Description: "将参数转换为IP进行对比", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "IP范围", + Code: RuleOperatorIPRange, + Description: "IP在某个范围之内,范围格式可以是英文逗号分隔的ip1,ip2,或者CIDR格式的ip/bits", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "不在IP范围", + Code: RuleOperatorNotIPRange, + Description: "IP不在某个范围之内,范围格式可以是英文逗号分隔的ip1,ip2,或者CIDR格式的ip/bits", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "IP取模10", + Code: RuleOperatorIPMod10, + Description: "对IP参数值取模,除数为10,对比值为余数", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "IP取模100", + Code: RuleOperatorIPMod100, + Description: "对IP参数值取模,除数为100,对比值为余数", + CaseInsensitive: RuleCaseInsensitiveNo, + }, + { + Name: "IP取模", + Code: RuleOperatorIPMod, + Description: "对IP参数值取模,对比值格式为:除数,余数,比如10,1", + CaseInsensitive: RuleCaseInsensitiveNo, + }, +} diff --git a/internal/waf/rule_set.go b/internal/waf/rule_set.go new file mode 100644 index 0000000..431a504 --- /dev/null +++ b/internal/waf/rule_set.go @@ -0,0 +1,135 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/maps" + "github.com/iwind/TeaGo/utils/string" +) + +type RuleConnector = string + +const ( + RuleConnectorAnd = "and" + RuleConnectorOr = "or" +) + +type RuleSet struct { + Id string `yaml:"id" json:"id"` + Code string `yaml:"code" json:"code"` + IsOn bool `yaml:"isOn" json:"isOn"` + Name string `yaml:"name" json:"name"` + Description string `yaml:"description" json:"description"` + Rules []*Rule `yaml:"rules" json:"rules"` + Connector RuleConnector `yaml:"connector" json:"connector"` // rules connector + + Action ActionString `yaml:"action" json:"action"` + ActionOptions maps.Map `yaml:"actionOptions" json:"actionOptions"` // TODO TO BE IMPLEMENTED + + hasRules bool +} + +func NewRuleSet() *RuleSet { + return &RuleSet{ + Id: stringutil.Rand(16), + IsOn: true, + } +} + +func (this *RuleSet) Init() error { + this.hasRules = len(this.Rules) > 0 + if this.hasRules { + for _, rule := range this.Rules { + err := rule.Init() + if err != nil { + return err + } + } + } + return nil +} + +func (this *RuleSet) AddRule(rule ...*Rule) { + this.Rules = append(this.Rules, rule...) +} + +func (this *RuleSet) MatchRequest(req *requests.Request) (b bool, err error) { + if !this.hasRules { + return false, nil + } + switch this.Connector { + case RuleConnectorAnd: + for _, rule := range this.Rules { + b1, err1 := rule.MatchRequest(req) + if err1 != nil { + return false, err1 + } + if !b1 { + return false, nil + } + } + return true, nil + case RuleConnectorOr: + for _, rule := range this.Rules { + b1, err1 := rule.MatchRequest(req) + if err1 != nil { + return false, err1 + } + if b1 { + return true, nil + } + } + default: // same as And + for _, rule := range this.Rules { + b1, err1 := rule.MatchRequest(req) + if err1 != nil { + return false, err1 + } + if !b1 { + return false, nil + } + } + return true, nil + } + return +} + +func (this *RuleSet) MatchResponse(req *requests.Request, resp *requests.Response) (b bool, err error) { + if !this.hasRules { + return false, nil + } + switch this.Connector { + case RuleConnectorAnd: + for _, rule := range this.Rules { + b1, err1 := rule.MatchResponse(req, resp) + if err1 != nil { + return false, err1 + } + if !b1 { + return false, nil + } + } + return true, nil + case RuleConnectorOr: + for _, rule := range this.Rules { + b1, err1 := rule.MatchResponse(req, resp) + if err1 != nil { + return false, err1 + } + if b1 { + return true, nil + } + } + default: // same as And + for _, rule := range this.Rules { + b1, err1 := rule.MatchResponse(req, resp) + if err1 != nil { + return false, err1 + } + if !b1 { + return false, nil + } + } + return true, nil + } + return +} diff --git a/internal/waf/rule_set_test.go b/internal/waf/rule_set_test.go new file mode 100644 index 0000000..4a5fe22 --- /dev/null +++ b/internal/waf/rule_set_test.go @@ -0,0 +1,180 @@ +package waf + +import ( + "bytes" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/dchest/siphash" + "github.com/iwind/TeaGo/assert" + "net/http" + "regexp" + "runtime" + "testing" +) + +func TestRuleSet_MatchRequest(t *testing.T) { + set := NewRuleSet() + set.Connector = RuleConnectorAnd + + set.Rules = []*Rule{ + { + Param: "${arg.name}", + Operator: RuleOperatorEqString, + Value: "lu", + }, + { + Param: "${arg.age}", + Operator: RuleOperatorEq, + Value: "20", + }, + } + + err := set.Init() + if err != nil { + t.Fatal(err) + } + + rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil) + if err != nil { + t.Fatal(err) + } + req := requests.NewRequest(rawReq) + t.Log(set.MatchRequest(req)) +} + +func TestRuleSet_MatchRequest2(t *testing.T) { + a := assert.NewAssertion(t) + + set := NewRuleSet() + set.Connector = RuleConnectorOr + + set.Rules = []*Rule{ + { + Param: "${arg.name}", + Operator: RuleOperatorEqString, + Value: "lu", + }, + { + Param: "${arg.age}", + Operator: RuleOperatorEq, + Value: "21", + }, + } + + err := set.Init() + if err != nil { + t.Fatal(err) + } + + rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil) + if err != nil { + t.Fatal(err) + } + req := requests.NewRequest(rawReq) + a.IsTrue(set.MatchRequest(req)) +} + +func BenchmarkRuleSet_MatchRequest(b *testing.B) { + runtime.GOMAXPROCS(1) + + set := NewRuleSet() + set.Connector = RuleConnectorOr + + set.Rules = []*Rule{ + { + Param: "${requestAll}", + Operator: RuleOperatorMatch, + Value: `(onmouseover|onmousemove|onmousedown|onmouseup|onerror|onload|onclick|ondblclick|onkeydown|onkeyup|onkeypress)\s*=`, + }, + { + Param: "${requestAll}", + Operator: RuleOperatorMatch, + Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`, + }, + { + Param: "${arg.name}", + Operator: RuleOperatorEqString, + Value: "lu", + }, + { + Param: "${arg.age}", + Operator: RuleOperatorEq, + Value: "21", + }, + } + + err := set.Init() + if err != nil { + b.Fatal(err) + } + + rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn/hello?name=lu&age=20", bytes.NewBuffer(bytes.Repeat([]byte("HELLO"), 1024))) + if err != nil { + b.Fatal(err) + } + req := requests.NewRequest(rawReq) + for i := 0; i < b.N; i++ { + _, _ = set.MatchRequest(req) + } +} + +func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) { + runtime.GOMAXPROCS(1) + + set := NewRuleSet() + set.Connector = RuleConnectorOr + + set.Rules = []*Rule{ + { + Param: "${requestBody}", + Operator: RuleOperatorMatch, + Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`, + IsCaseInsensitive: false, + }, + } + + err := set.Init() + if err != nil { + b.Fatal(err) + } + + rawReq, err := http.NewRequest(http.MethodPost, "http://teaos.cn/hello?name=lu&age=20", bytes.NewBuffer(bytes.Repeat([]byte("HELLO"), 2048))) + if err != nil { + b.Fatal(err) + } + req := requests.NewRequest(rawReq) + for i := 0; i < b.N; i++ { + _, _ = set.MatchRequest(req) + } +} + +func BenchmarkRuleSet_MatchRequest_Regexp2(b *testing.B) { + reg, err := regexp.Compile(`(?iU)\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\b`) + if err != nil { + b.Fatal(err) + } + + buf := bytes.Repeat([]byte(" HELLO "), 10240) + + for i := 0; i < b.N; i++ { + _ = reg.Match(buf) + } +} + +func BenchmarkRuleSet_MatchRequest_Regexp3(b *testing.B) { + reg, err := regexp.Compile(`(?iU)^(eval|system|exec|execute|passthru|shell_exec|phpinfo)`) + if err != nil { + b.Fatal(err) + } + + buf := bytes.Repeat([]byte(" HELLO "), 1024) + + for i := 0; i < b.N; i++ { + _ = reg.Match(buf) + } +} + +func BenchmarkHash(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = siphash.Hash(0, 0, bytes.Repeat([]byte("HELLO"), 10240)) + } +} diff --git a/internal/waf/rule_test.go b/internal/waf/rule_test.go new file mode 100644 index 0000000..6a7731c --- /dev/null +++ b/internal/waf/rule_test.go @@ -0,0 +1,733 @@ +package waf + +import ( + "github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints" + "github.com/TeaOSLab/EdgeNode/internal/waf/requests" + "github.com/iwind/TeaGo/assert" + "github.com/iwind/TeaGo/maps" + "net/http" + "net/url" + "testing" +) + +func TestRule_Init_Single(t *testing.T) { + rule := NewRule() + rule.Param = "${arg.name}" + rule.Operator = RuleOperatorEqString + rule.Value = "lu" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + t.Log(rule.singleParam, rule.singleCheckpoint) + rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil) + if err != nil { + t.Fatal(err) + } + + req := requests.NewRequest(rawReq) + t.Log(rule.MatchRequest(req)) +} + +func TestRule_Init_Composite(t *testing.T) { + rule := NewRule() + rule.Param = "${arg.name} ${arg.age}" + rule.Operator = RuleOperatorContains + rule.Value = "lu" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + t.Log(rule.singleParam, rule.singleCheckpoint) + + rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil) + if err != nil { + t.Fatal(err) + } + req := requests.NewRequest(rawReq) + t.Log(rule.MatchRequest(req)) +} + +func TestRule_Test(t *testing.T) { + a := assert.NewAssertion(t) + + { + rule := NewRule() + rule.Operator = RuleOperatorGt + rule.Value = "123" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test("124")) + a.IsFalse(rule.Test("123")) + a.IsFalse(rule.Test("122")) + a.IsFalse(rule.Test("abcdef")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorGte + rule.Value = "123" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test("124")) + a.IsTrue(rule.Test("123")) + a.IsFalse(rule.Test("122")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorLt + rule.Value = "123" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("124")) + a.IsFalse(rule.Test("123")) + a.IsTrue(rule.Test("122")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorLte + rule.Value = "123" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("124")) + a.IsTrue(rule.Test("123")) + a.IsTrue(rule.Test("122")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorEq + rule.Value = "123" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("124")) + a.IsTrue(rule.Test("123")) + a.IsFalse(rule.Test("122")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorNeq + rule.Value = "123" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test("124")) + a.IsFalse(rule.Test("123")) + a.IsTrue(rule.Test("122")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorEqString + rule.Value = "123" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("124")) + a.IsTrue(rule.Test("123")) + a.IsFalse(rule.Test("122")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorEqString + rule.Value = "abc" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("ABC")) + a.IsTrue(rule.Test("abc")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorEqString + rule.IsCaseInsensitive = true + rule.Value = "abc" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test("ABC")) + a.IsTrue(rule.Test("abc")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorNeqString + rule.Value = "abc" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test("124")) + a.IsFalse(rule.Test("abc")) + a.IsTrue(rule.Test("122")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorNeqString + rule.IsCaseInsensitive = true + rule.Value = "abc" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("ABC")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorMatch + rule.Value = "^\\d+" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test("123")) + a.IsFalse(rule.Test("abc123")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorMatch + rule.Value = "abc" + rule.IsCaseInsensitive = true + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test("ABC")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorMatch + rule.Value = "^\\d+" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test([]string{"123", "456", "abc"})) + a.IsFalse(rule.Test([]string{"abc123"})) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorNotMatch + rule.Value = "\\d+" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("123")) + a.IsTrue(rule.Test("abc")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorNotMatch + rule.Value = "abc" + rule.IsCaseInsensitive = true + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("ABC")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorNotMatch + rule.Value = "^\\d+" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test([]string{"123", "456", "abc"})) + a.IsTrue(rule.Test([]string{"abc123"})) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorMatch + rule.Value = "^(?i)[a-z]+$" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test("ABC")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorContains + rule.Value = "Hello" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test("Hello, World")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorContains + rule.Value = "hello" + rule.IsCaseInsensitive = true + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test("Hello, World")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorContains + rule.Value = "Hello" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test([]string{"Hello", "World"})) + a.IsTrue(rule.Test(maps.Map{ + "a": "World", "b": "Hello", + })) + a.IsFalse(rule.Test(maps.Map{ + "a": "World", "b": "Hello2", + })) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorNotContains + rule.Value = "Hello" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("Hello, World")) + a.IsTrue(rule.Test("World")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorNotContains + rule.Value = "hello" + rule.IsCaseInsensitive = true + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("Hello, World")) + a.IsTrue(rule.Test("World")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorPrefix + rule.Value = "Hello" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test("Hello, World")) + a.IsFalse(rule.Test("World, Hello")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorPrefix + rule.Value = "hello" + rule.IsCaseInsensitive = true + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsTrue(rule.Test("Hello, World")) + a.IsFalse(rule.Test("World, Hello")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorSuffix + rule.Value = "Hello" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("Hello, World")) + a.IsTrue(rule.Test("World, Hello")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorSuffix + rule.Value = "hello" + rule.IsCaseInsensitive = true + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("Hello, World")) + a.IsTrue(rule.Test("World, Hello")) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorHasKey + rule.Value = "Hello" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("Hello, World")) + a.IsTrue(rule.Test(maps.Map{ + "Hello": "World", + })) + a.IsFalse(rule.Test(maps.Map{ + "Hello1": "World", + })) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorHasKey + rule.Value = "hello" + rule.IsCaseInsensitive = true + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("Hello, World")) + a.IsTrue(rule.Test(maps.Map{ + "Hello": "World", + })) + a.IsFalse(rule.Test(maps.Map{ + "Hello1": "World", + })) + } + + { + rule := NewRule() + rule.Operator = RuleOperatorHasKey + rule.Value = "3" + err := rule.Init() + if err != nil { + t.Fatal(err) + } + a.IsFalse(rule.Test("Hello, World")) + a.IsFalse(rule.Test(maps.Map{ + "Hello": "World", + })) + a.IsTrue(rule.Test([]int{1, 2, 3, 4})) + } +} + +func TestRule_MatchStar(t *testing.T) { + { + rule := NewRule() + rule.Operator = RuleOperatorMatch + rule.Value = `/\*(!|\x00)` + err := rule.Init() + if err != nil { + t.Fatal(err) + } + t.Log(rule.Test("/*!")) + t.Log(rule.Test(url.QueryEscape("/*!"))) + t.Log(url.QueryEscape("/*!")) + } +} + +func TestRule_SetCheckpointFinder(t *testing.T) { + { + rule := NewRule() + rule.Param = "${arg.abc}" + rule.Operator = RuleOperatorMatch + err := rule.Init() + if err != nil { + t.Fatal(err) + } + t.Logf("%#v", rule.singleCheckpoint) + } + + { + rule := NewRule() + rule.Param = "${arg.abc}" + rule.Operator = RuleOperatorMatch + rule.checkpointFinder = func(prefix string) checkpoints.CheckpointInterface { + return new(checkpoints.SampleRequestCheckpoint) + } + err := rule.Init() + if err != nil { + t.Fatal(err) + } + t.Logf("%#v", rule.singleCheckpoint) + } +} + +func TestRule_Version(t *testing.T) { + a := assert.NewAssertion(t) + + { + rule := Rule{ + Operator: RuleOperatorVersionRange, + Value: `1.0,1.1`, + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("1.0")) + } + + { + rule := Rule{ + Operator: RuleOperatorVersionRange, + Value: `1.0,`, + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("1.0")) + } + + { + rule := Rule{ + Operator: RuleOperatorVersionRange, + Value: `,1.1`, + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("1.0")) + } + + { + rule := Rule{ + Operator: RuleOperatorVersionRange, + Value: `1.0,1.1`, + } + a.IsNil(rule.Init()) + a.IsFalse(rule.Test("0.9")) + } + + { + rule := Rule{ + Operator: RuleOperatorVersionRange, + Value: `1.0`, + } + a.IsNil(rule.Init()) + a.IsFalse(rule.Test("0.9")) + } + + { + rule := Rule{ + Operator: RuleOperatorVersionRange, + Value: `1.0`, + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("1.1")) + } +} + +func TestRule_IP(t *testing.T) { + a := assert.NewAssertion(t) + + { + rule := Rule{ + Operator: RuleOperatorEqIP, + Value: "hello", + } + a.IsNotNil(rule.Init()) + a.IsFalse(rule.Test("hello")) + } + + { + rule := Rule{ + Operator: RuleOperatorEqIP, + Value: "hello", + } + a.IsNotNil(rule.Init()) + a.IsFalse(rule.Test("192.168.1.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorEqIP, + Value: "192.168.1.100", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.168.1.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorGtIP, + Value: "192.168.1.90", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.168.1.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorGteIP, + Value: "192.168.1.90", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.168.1.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorLtIP, + Value: "192.168.1.90", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.168.1.80")) + } + + { + rule := Rule{ + Operator: RuleOperatorLteIP, + Value: "192.168.1.90", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.168.0.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorIPRange, + Value: "192.168.0.90,", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.168.0.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorIPRange, + Value: "192.168.0.90,192.168.1.100", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.168.0.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorIPRange, + Value: ",192.168.1.100", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.168.0.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorIPRange, + Value: "192.168.0.90,192.168.1.99", + } + a.IsNil(rule.Init()) + a.IsFalse(rule.Test("192.168.1.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorIPRange, + Value: "192.168.0.90/24", + } + a.IsNil(rule.Init()) + a.IsFalse(rule.Test("192.168.1.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorIPRange, + Value: "192.168.0.90/18", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.168.1.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorIPRange, + Value: "a/18", + } + a.IsNotNil(rule.Init()) + a.IsFalse(rule.Test("192.168.1.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorIPMod10, + Value: "6", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.168.1.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorIPMod100, + Value: "76", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.168.1.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorIPMod, + Value: "10,6", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.168.1.100")) + } + + { + rule := Rule{ + Operator: RuleOperatorNotIPRange, + Value: "192.168.0.90,192.168.1.100", + } + a.IsNil(rule.Init()) + a.IsFalse(rule.Test("192.168.0.100")) + } + { + rule := Rule{ + Operator: RuleOperatorNotIPRange, + Value: "192.168.0.90,192.168.1.100", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.168.2.100")) + } + { + rule := Rule{ + Operator: RuleOperatorNotIPRange, + Value: "192.168.0.90/8", + } + a.IsNil(rule.Init()) + a.IsFalse(rule.Test("192.168.2.100")) + } + { + rule := Rule{ + Operator: RuleOperatorNotIPRange, + Value: "192.168.0.90/16", + } + a.IsNil(rule.Init()) + a.IsTrue(rule.Test("192.169.2.100")) + } +} diff --git a/internal/waf/template.go b/internal/waf/template.go new file mode 100644 index 0000000..244b523 --- /dev/null +++ b/internal/waf/template.go @@ -0,0 +1,492 @@ +package waf + +// 感谢以下规则来源: +// - Janusec: https://www.janusec.com/ +func Template() *WAF { + waf := NewWAF() + waf.Id = "template" + waf.IsOn = true + + // black list + { + group := NewRuleGroup() + group.IsOn = false + group.IsInbound = true + group.Name = "白名单" + group.Code = "whiteList" + group.Description = "在此名单中的IP地址可以直接跳过防火墙设置" + + { + + set := NewRuleSet() + set.IsOn = true + set.Name = "IP白名单" + set.Code = "9001" + set.Connector = RuleConnectorOr + set.Action = ActionAllow + set.AddRule(&Rule{ + Param: "${remoteAddr}", + Operator: RuleOperatorMatch, + Value: `127\.0\.0\.1|0\.0\.0\.0`, + IsCaseInsensitive: false, + }) + group.AddRuleSet(set) + } + + waf.AddRuleGroup(group) + } + + // black list + { + group := NewRuleGroup() + group.IsOn = false + group.IsInbound = true + group.Name = "黑名单" + group.Code = "blackList" + group.Description = "在此名单中的IP地址直接阻止" + + { + + set := NewRuleSet() + set.IsOn = true + set.Name = "IP黑名单" + set.Code = "10001" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + set.AddRule(&Rule{ + Param: "${remoteAddr}", + Operator: RuleOperatorMatch, + Value: `1\.1\.1\.1|2\.2\.2\.2`, + IsCaseInsensitive: false, + }) + group.AddRuleSet(set) + } + + waf.AddRuleGroup(group) + } + + // xss + { + group := NewRuleGroup() + group.IsOn = true + group.IsInbound = true + group.Name = "XSS" + group.Code = "xss" + group.Description = "防跨站脚本攻击(Cross Site Scripting)" + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "Javascript事件" + set.Code = "1001" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + set.AddRule(&Rule{ + Param: "${requestURI}", + Operator: RuleOperatorMatch, + Value: `(onmouseover|onmousemove|onmousedown|onmouseup|onerror|onload|onclick|ondblclick|onkeydown|onkeyup|onkeypress)\s*=`, // TODO more keywords here + IsCaseInsensitive: true, + }) + group.AddRuleSet(set) + } + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "Javascript函数" + set.Code = "1002" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + set.AddRule(&Rule{ + Param: "${requestURI}", + Operator: RuleOperatorMatch, + Value: `(alert|eval|prompt|confirm)\s*\(`, // TODO more keywords here + IsCaseInsensitive: true, + }) + group.AddRuleSet(set) + } + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "HTML标签" + set.Code = "1003" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + set.AddRule(&Rule{ + Param: "${requestURI}", + Operator: RuleOperatorMatch, + Value: `<(script|iframe|link)`, // TODO more keywords here + IsCaseInsensitive: true, + }) + group.AddRuleSet(set) + } + + waf.AddRuleGroup(group) + } + + // upload + { + group := NewRuleGroup() + group.IsOn = true + group.IsInbound = true + group.Name = "文件上传" + group.Code = "upload" + group.Description = "防止上传可执行脚本文件到服务器" + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "上传文件扩展名" + set.Code = "2001" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + set.AddRule(&Rule{ + Param: "${requestUpload.ext}", + Operator: RuleOperatorMatch, + Value: `\.(php|jsp|aspx|asp|exe|asa|rb|py)\b`, // TODO more keywords here + IsCaseInsensitive: true, + }) + group.AddRuleSet(set) + } + + waf.AddRuleGroup(group) + } + + // web shell + { + group := NewRuleGroup() + group.IsOn = true + group.IsInbound = true + group.Name = "Web Shell" + group.Code = "webShell" + group.Description = "防止远程执行服务器命令" + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "Web Shell" + set.Code = "3001" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + set.AddRule(&Rule{ + Param: "${requestAll}", + Operator: RuleOperatorMatch, + Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`, // TODO more keywords here + IsCaseInsensitive: true, + }) + group.AddRuleSet(set) + } + + waf.AddRuleGroup(group) + } + + // command injection + { + group := NewRuleGroup() + group.IsOn = true + group.IsInbound = true + group.Name = "命令注入" + group.Code = "commandInjection" + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "命令注入" + set.Code = "4001" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + set.AddRule(&Rule{ + Param: "${requestURI}", + Operator: RuleOperatorMatch, + Value: `\b(pwd|ls|ll|whoami|id|net\s+user)\s*$`, // TODO more keywords here + IsCaseInsensitive: false, + }) + set.AddRule(&Rule{ + Param: "${requestBody}", + Operator: RuleOperatorMatch, + Value: `\b(pwd|ls|ll|whoami|id|net\s+user)\s*$`, // TODO more keywords here + IsCaseInsensitive: false, + }) + group.AddRuleSet(set) + } + + waf.AddRuleGroup(group) + } + + // path traversal + { + group := NewRuleGroup() + group.IsOn = true + group.IsInbound = true + group.Name = "路径穿越" + group.Code = "pathTraversal" + group.Description = "防止读取网站目录之外的其他系统文件" + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "路径穿越" + set.Code = "5001" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + set.AddRule(&Rule{ + Param: "${requestURI}", + Operator: RuleOperatorMatch, + Value: `((\.+)(/+)){2,}`, // TODO more keywords here + IsCaseInsensitive: false, + }) + group.AddRuleSet(set) + } + + waf.AddRuleGroup(group) + } + + // special dirs + { + group := NewRuleGroup() + group.IsOn = true + group.IsInbound = true + group.Name = "特殊目录" + group.Code = "denyDirs" + group.Description = "防止通过Web访问到一些特殊目录" + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "特殊目录" + set.Code = "6001" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + set.AddRule(&Rule{ + Param: "${requestPath}", + Operator: RuleOperatorMatch, + Value: `/\.(git|svn|htaccess|idea)\b`, // TODO more keywords here + IsCaseInsensitive: true, + }) + group.AddRuleSet(set) + } + + waf.AddRuleGroup(group) + } + + // sql injection + { + group := NewRuleGroup() + group.IsOn = true + group.IsInbound = true + group.Name = "SQL注入" + group.Code = "sqlInjection" + group.Description = "防止SQL注入漏洞" + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "Union SQL Injection" + set.Code = "7001" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + + set.AddRule(&Rule{ + Param: "${requestAll}", + Operator: RuleOperatorMatch, + Value: `union[\s/\*]+select`, + IsCaseInsensitive: true, + }) + + group.AddRuleSet(set) + } + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "SQL注释" + set.Code = "7002" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + + set.AddRule(&Rule{ + Param: "${requestAll}", + Operator: RuleOperatorMatch, + Value: `/\*(!|\x00)`, + IsCaseInsensitive: true, + }) + + group.AddRuleSet(set) + } + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "SQL条件" + set.Code = "7003" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + + set.AddRule(&Rule{ + Param: "${requestAll}", + Operator: RuleOperatorMatch, + Value: `\s(and|or|rlike)\s+(if|updatexml)\s*\(`, + IsCaseInsensitive: true, + }) + set.AddRule(&Rule{ + Param: "${requestAll}", + Operator: RuleOperatorMatch, + Value: `\s+(and|or|rlike)\s+(select|case)\s+`, + IsCaseInsensitive: true, + }) + set.AddRule(&Rule{ + Param: "${requestAll}", + Operator: RuleOperatorMatch, + Value: `\s+(and|or|procedure)\s+[\w\p{L}]+\s*=\s*[\w\p{L}]+(\s|$|--|#)`, + IsCaseInsensitive: true, + }) + set.AddRule(&Rule{ + Param: "${requestAll}", + Operator: RuleOperatorMatch, + Value: `\(\s*case\s+when\s+[\w\p{L}]+\s*=\s*[\w\p{L}]+\s+then\s+`, + IsCaseInsensitive: true, + }) + + group.AddRuleSet(set) + } + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "SQL函数" + set.Code = "7004" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + + set.AddRule(&Rule{ + Param: "${requestAll}", + Operator: RuleOperatorMatch, + Value: `(updatexml|extractvalue|ascii|ord|char|chr|count|concat|rand|floor|substr|length|len|user|database|benchmark|analyse)\s*\(`, + IsCaseInsensitive: true, + }) + + group.AddRuleSet(set) + } + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "SQL附加语句" + set.Code = "7005" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + + set.AddRule(&Rule{ + Param: "${requestAll}", + Operator: RuleOperatorMatch, + Value: `;\s*(declare|use|drop|create|exec|delete|update|insert)\s`, + IsCaseInsensitive: true, + }) + + group.AddRuleSet(set) + } + + waf.AddRuleGroup(group) + } + + // bot + { + group := NewRuleGroup() + group.IsOn = false + group.IsInbound = true + group.Name = "网络爬虫" + group.Code = "bot" + group.Description = "禁止一些网络爬虫" + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "常见网络爬虫" + set.Code = "20001" + set.Connector = RuleConnectorOr + set.Action = ActionBlock + + set.AddRule(&Rule{ + Param: "${userAgent}", + Operator: RuleOperatorMatch, + Value: `Googlebot|AdsBot|bingbot|BingPreview|facebookexternalhit|Slurp|Sogou|proximic|Baiduspider|yandex|twitterbot|spider|python`, + IsCaseInsensitive: true, + }) + + group.AddRuleSet(set) + } + + waf.AddRuleGroup(group) + } + + // cc + { + group := NewRuleGroup() + group.IsOn = false + group.IsInbound = true + group.Name = "CC攻击" + group.Description = "Challenge Collapsar,防止短时间大量请求涌入,请谨慎开启和设置" + group.Code = "cc" + + { + set := NewRuleSet() + set.IsOn = true + set.Name = "CC请求数" + set.Description = "限制单IP在一定时间内的请求数" + set.Code = "8001" + set.Connector = RuleConnectorAnd + set.Action = ActionBlock + set.AddRule(&Rule{ + Param: "${cc.requests}", + Operator: RuleOperatorGt, + Value: "1000", + CheckpointOptions: map[string]interface{}{ + "period": "60", + }, + IsCaseInsensitive: false, + }) + set.AddRule(&Rule{ + Param: "${remoteAddr}", + Operator: RuleOperatorNotIPRange, + Value: `127.0.0.1/8`, + IsCaseInsensitive: false, + }) + set.AddRule(&Rule{ + Param: "${remoteAddr}", + Operator: RuleOperatorNotIPRange, + Value: `192.168.0.1/16`, + IsCaseInsensitive: false, + }) + set.AddRule(&Rule{ + Param: "${remoteAddr}", + Operator: RuleOperatorNotIPRange, + Value: `10.0.0.1/8`, + IsCaseInsensitive: false, + }) + set.AddRule(&Rule{ + Param: "${remoteAddr}", + Operator: RuleOperatorNotIPRange, + Value: `172.16.0.1/12`, + IsCaseInsensitive: false, + }) + + group.AddRuleSet(set) + } + + waf.AddRuleGroup(group) + } + + // custom + { + group := NewRuleGroup() + group.IsOn = true + group.IsInbound = true + group.Name = "自定义规则分组" + group.Description = "我的自定义规则分组,可以将自定义的规则放在这个分组下" + group.Code = "custom" + waf.AddRuleGroup(group) + } + + return waf +} diff --git a/internal/waf/template_test.go b/internal/waf/template_test.go new file mode 100644 index 0000000..f21a0cb --- /dev/null +++ b/internal/waf/template_test.go @@ -0,0 +1,352 @@ +package waf + +import ( + "bytes" + "github.com/iwind/TeaGo/assert" + "github.com/iwind/TeaGo/lists" + "github.com/iwind/TeaGo/logs" + "mime/multipart" + "net/http" + "net/url" + "strings" + "testing" + "time" +) + +func Test_Template(t *testing.T) { + a := assert.NewAssertion(t) + + template := Template() + err := template.Init() + if err != nil { + t.Fatal(err) + } + + template.OnAction(func(action ActionString) (goNext bool) { + return action != ActionBlock + }) + + testTemplate1001(a, t, template) + testTemplate1002(a, t, template) + testTemplate1003(a, t, template) + testTemplate2001(a, t, template) + testTemplate3001(a, t, template) + testTemplate4001(a, t, template) + testTemplate5001(a, t, template) + testTemplate6001(a, t, template) + testTemplate7001(a, t, template) + testTemplate20001(a, t, template) +} + +func Test_Template2(t *testing.T) { + reader := bytes.NewReader([]byte(strings.Repeat("HELLO", 1024))) + req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=123", reader) + if err != nil { + t.Fatal(err) + } + + waf := Template() + err = waf.Init() + if err != nil { + t.Fatal(err) + } + + now := time.Now() + goNext, _, set, err := waf.MatchRequest(req, nil) + if err != nil { + t.Fatal(err) + } + t.Log(time.Since(now).Seconds()*1000, "ms") + + if goNext { + t.Log("ok") + return + } + + logs.PrintAsJSON(set, t) +} + +func BenchmarkTemplate(b *testing.B) { + waf := Template() + err := waf.Init() + if err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + reader := bytes.NewReader([]byte(strings.Repeat("Hello", 1024))) + req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=123", reader) + if err != nil { + b.Fatal(err) + } + + _, _, _, _ = waf.MatchRequest(req, nil) + } +} + +func testTemplate1001(a *assert.Assertion, t *testing.T, template *WAF) { + req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=onmousedown%3D123", nil) + if err != nil { + t.Fatal(err) + } + _, _, result, err := template.MatchRequest(req, nil) + if err != nil { + t.Fatal(err) + } + a.IsNotNil(result) + if result != nil { + a.IsTrue(result.Code == "1001") + } +} + +func testTemplate1002(a *assert.Assertion, t *testing.T, template *WAF) { + req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=eval%28", nil) + if err != nil { + t.Fatal(err) + } + _, _, result, err := template.MatchRequest(req, nil) + if err != nil { + t.Fatal(err) + } + a.IsNotNil(result) + if result != nil { + a.IsTrue(result.Code == "1002") + } +} + +func testTemplate1003(a *assert.Assertion, t *testing.T, template *WAF) { + req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=