mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 16:00:25 +08:00 
			
		
		
		
	将IP加入黑名单时,同时也会关闭此IP相关的连接
This commit is contained in:
		
							
								
								
									
										117
									
								
								internal/conns/map.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								internal/conns/map.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,117 @@
 | 
				
			|||||||
 | 
					// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package conns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var SharedMap = NewMap()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Map struct {
 | 
				
			||||||
 | 
						m map[string]map[int]net.Conn // ip => { port => Conn }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						locker sync.RWMutex
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewMap() *Map {
 | 
				
			||||||
 | 
						return &Map{
 | 
				
			||||||
 | 
							m: map[string]map[int]net.Conn{},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Map) Add(conn net.Conn) {
 | 
				
			||||||
 | 
						if conn == nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var ip = tcpAddr.IP.String()
 | 
				
			||||||
 | 
						var port = tcpAddr.Port
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						defer this.locker.Unlock()
 | 
				
			||||||
 | 
						connMap, ok := this.m[ip]
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							this.m[ip] = map[int]net.Conn{
 | 
				
			||||||
 | 
								port: conn,
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							connMap[port] = conn
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Map) Remove(conn net.Conn) {
 | 
				
			||||||
 | 
						if conn == nil {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var ip = tcpAddr.IP.String()
 | 
				
			||||||
 | 
						var port = tcpAddr.Port
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						defer this.locker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						connMap, ok := this.m[ip]
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						delete(connMap, port)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(connMap) == 0 {
 | 
				
			||||||
 | 
							delete(this.m, ip)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Map) CountIPConns(ip string) int {
 | 
				
			||||||
 | 
						this.locker.RLock()
 | 
				
			||||||
 | 
						var l = len(this.m[ip])
 | 
				
			||||||
 | 
						this.locker.RUnlock()
 | 
				
			||||||
 | 
						return l
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Map) CloseIPConns(ip string) {
 | 
				
			||||||
 | 
						var conns = []net.Conn{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.locker.RLock()
 | 
				
			||||||
 | 
						connMap, ok := this.m[ip]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 复制,防止在Close时产生并发冲突
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							for _, conn := range connMap {
 | 
				
			||||||
 | 
								conns = append(conns, conn)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 需要在Close之前结束,防止死循环
 | 
				
			||||||
 | 
						this.locker.RUnlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							for _, conn := range conns {
 | 
				
			||||||
 | 
								_ = conn.Close()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// 这里不需要从 m 中删除,因为关闭时会自然触发回调
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *Map) AllConns() []net.Conn {
 | 
				
			||||||
 | 
						this.locker.RLock()
 | 
				
			||||||
 | 
						defer this.locker.RUnlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var result = []net.Conn{}
 | 
				
			||||||
 | 
						for _, m := range this.m {
 | 
				
			||||||
 | 
							for _, conn := range m {
 | 
				
			||||||
 | 
								result = append(result, conn)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return result
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										47
									
								
								internal/firewalls/firewall_base.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								internal/firewalls/firewall_base.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,47 @@
 | 
				
			|||||||
 | 
					// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package firewalls
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type BaseFirewall struct {
 | 
				
			||||||
 | 
						locker        sync.Mutex
 | 
				
			||||||
 | 
						latestIPTimes []string // [ip@time, ....]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 检查是否在最近添加过
 | 
				
			||||||
 | 
					func (this *BaseFirewall) checkLatestIP(ip string) bool {
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						defer this.locker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var expiredIndex = -1
 | 
				
			||||||
 | 
						for index, ipTime := range this.latestIPTimes {
 | 
				
			||||||
 | 
							var pieces = strings.Split(ipTime, "@")
 | 
				
			||||||
 | 
							var oldIP = pieces[0]
 | 
				
			||||||
 | 
							var oldTimestamp = pieces[1]
 | 
				
			||||||
 | 
							if types.Int64(oldTimestamp) < time.Now().Unix()-3 /** 3秒外表示过期 **/ {
 | 
				
			||||||
 | 
								expiredIndex = index
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if oldIP == ip {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if expiredIndex > -1 {
 | 
				
			||||||
 | 
							this.latestIPTimes = this.latestIPTimes[expiredIndex+1:]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.latestIPTimes = append(this.latestIPTimes, ip+"@"+types.String(time.Now().Unix()))
 | 
				
			||||||
 | 
						const maxLen = 128
 | 
				
			||||||
 | 
						if len(this.latestIPTimes) > maxLen {
 | 
				
			||||||
 | 
							this.latestIPTimes = this.latestIPTimes[1:]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -4,6 +4,7 @@ package firewalls
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/conns"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/goman"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/goman"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
				
			||||||
	"github.com/iwind/TeaGo/types"
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
@@ -11,15 +12,22 @@ import (
 | 
				
			|||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type firewalldCmd struct {
 | 
				
			||||||
 | 
						cmd    *exec.Cmd
 | 
				
			||||||
 | 
						denyIP string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Firewalld struct {
 | 
					type Firewalld struct {
 | 
				
			||||||
 | 
						BaseFirewall
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	isReady  bool
 | 
						isReady  bool
 | 
				
			||||||
	exe      string
 | 
						exe      string
 | 
				
			||||||
	cmdQueue chan *exec.Cmd
 | 
						cmdQueue chan *firewalldCmd
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewFirewalld() *Firewalld {
 | 
					func NewFirewalld() *Firewalld {
 | 
				
			||||||
	var firewalld = &Firewalld{
 | 
						var firewalld = &Firewalld{
 | 
				
			||||||
		cmdQueue: make(chan *exec.Cmd, 4096),
 | 
							cmdQueue: make(chan *firewalldCmd, 4096),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	path, err := exec.LookPath("firewall-cmd")
 | 
						path, err := exec.LookPath("firewall-cmd")
 | 
				
			||||||
@@ -41,13 +49,19 @@ func NewFirewalld() *Firewalld {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func (this *Firewalld) init() {
 | 
					func (this *Firewalld) init() {
 | 
				
			||||||
	goman.New(func() {
 | 
						goman.New(func() {
 | 
				
			||||||
		for cmd := range this.cmdQueue {
 | 
							for c := range this.cmdQueue {
 | 
				
			||||||
 | 
								var cmd = c.cmd
 | 
				
			||||||
			err := cmd.Run()
 | 
								err := cmd.Run()
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				if strings.HasPrefix(err.Error(), "Warning:") {
 | 
									if strings.HasPrefix(err.Error(), "Warning:") {
 | 
				
			||||||
					continue
 | 
										continue
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				remotelogs.Warn("FIREWALL", "run command failed '"+cmd.String()+"': "+err.Error())
 | 
									remotelogs.Warn("FIREWALL", "run command failed '"+cmd.String()+"': "+err.Error())
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									// 关闭连接
 | 
				
			||||||
 | 
									if len(c.denyIP) > 0 {
 | 
				
			||||||
 | 
										conns.SharedMap.CloseIPConns(c.denyIP)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
@@ -72,7 +86,7 @@ func (this *Firewalld) AllowPort(port int, protocol string) error {
 | 
				
			|||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var cmd = exec.Command(this.exe, "--add-port="+types.String(port)+"/"+protocol)
 | 
						var cmd = exec.Command(this.exe, "--add-port="+types.String(port)+"/"+protocol)
 | 
				
			||||||
	this.pushCmd(cmd)
 | 
						this.pushCmd(cmd, "")
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -82,12 +96,12 @@ func (this *Firewalld) AllowPortRangesPermanently(portRanges [][2]int, protocol
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			var cmd = exec.Command(this.exe, "--add-port="+port, "--permanent")
 | 
								var cmd = exec.Command(this.exe, "--add-port="+port, "--permanent")
 | 
				
			||||||
			this.pushCmd(cmd)
 | 
								this.pushCmd(cmd, "")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			var cmd = exec.Command(this.exe, "--add-port="+port)
 | 
								var cmd = exec.Command(this.exe, "--add-port="+port)
 | 
				
			||||||
			this.pushCmd(cmd)
 | 
								this.pushCmd(cmd, "")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -99,7 +113,7 @@ func (this *Firewalld) RemovePort(port int, protocol string) error {
 | 
				
			|||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var cmd = exec.Command(this.exe, "--remove-port="+types.String(port)+"/"+protocol)
 | 
						var cmd = exec.Command(this.exe, "--remove-port="+types.String(port)+"/"+protocol)
 | 
				
			||||||
	this.pushCmd(cmd)
 | 
						this.pushCmd(cmd, "")
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -108,12 +122,12 @@ func (this *Firewalld) RemovePortRangePermanently(portRange [2]int, protocol str
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		var cmd = exec.Command(this.exe, "--remove-port="+port, "--permanent")
 | 
							var cmd = exec.Command(this.exe, "--remove-port="+port, "--permanent")
 | 
				
			||||||
		this.pushCmd(cmd)
 | 
							this.pushCmd(cmd, "")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		var cmd = exec.Command(this.exe, "--remove-port="+port)
 | 
							var cmd = exec.Command(this.exe, "--remove-port="+port)
 | 
				
			||||||
		this.pushCmd(cmd)
 | 
							this.pushCmd(cmd, "")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
@@ -131,6 +145,12 @@ func (this *Firewalld) RejectSourceIP(ip string, timeoutSeconds int) error {
 | 
				
			|||||||
	if !this.isReady {
 | 
						if !this.isReady {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 避免短时间内重复添加
 | 
				
			||||||
 | 
						if this.checkLatestIP(ip) {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var family = "ipv4"
 | 
						var family = "ipv4"
 | 
				
			||||||
	if strings.Contains(ip, ":") {
 | 
						if strings.Contains(ip, ":") {
 | 
				
			||||||
		family = "ipv6"
 | 
							family = "ipv6"
 | 
				
			||||||
@@ -140,7 +160,7 @@ func (this *Firewalld) RejectSourceIP(ip string, timeoutSeconds int) error {
 | 
				
			|||||||
		args = append(args, "--timeout="+types.String(timeoutSeconds)+"s")
 | 
							args = append(args, "--timeout="+types.String(timeoutSeconds)+"s")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	var cmd = exec.Command(this.exe, args...)
 | 
						var cmd = exec.Command(this.exe, args...)
 | 
				
			||||||
	this.pushCmd(cmd)
 | 
						this.pushCmd(cmd, ip)
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -148,6 +168,12 @@ func (this *Firewalld) DropSourceIP(ip string, timeoutSeconds int, async bool) e
 | 
				
			|||||||
	if !this.isReady {
 | 
						if !this.isReady {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 避免短时间内重复添加
 | 
				
			||||||
 | 
						if this.checkLatestIP(ip) {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var family = "ipv4"
 | 
						var family = "ipv4"
 | 
				
			||||||
	if strings.Contains(ip, ":") {
 | 
						if strings.Contains(ip, ":") {
 | 
				
			||||||
		family = "ipv6"
 | 
							family = "ipv6"
 | 
				
			||||||
@@ -158,10 +184,13 @@ func (this *Firewalld) DropSourceIP(ip string, timeoutSeconds int, async bool) e
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	var cmd = exec.Command(this.exe, args...)
 | 
						var cmd = exec.Command(this.exe, args...)
 | 
				
			||||||
	if async {
 | 
						if async {
 | 
				
			||||||
		this.pushCmd(cmd)
 | 
							this.pushCmd(cmd, ip)
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 关闭连接
 | 
				
			||||||
 | 
						defer conns.SharedMap.CloseIPConns(ip)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := cmd.Run()
 | 
						err := cmd.Run()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return errors.New("run command failed '" + cmd.String() + "': " + err.Error())
 | 
							return errors.New("run command failed '" + cmd.String() + "': " + err.Error())
 | 
				
			||||||
@@ -173,6 +202,7 @@ func (this *Firewalld) RemoveSourceIP(ip string) error {
 | 
				
			|||||||
	if !this.isReady {
 | 
						if !this.isReady {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var family = "ipv4"
 | 
						var family = "ipv4"
 | 
				
			||||||
	if strings.Contains(ip, ":") {
 | 
						if strings.Contains(ip, ":") {
 | 
				
			||||||
		family = "ipv6"
 | 
							family = "ipv6"
 | 
				
			||||||
@@ -180,14 +210,14 @@ func (this *Firewalld) RemoveSourceIP(ip string) error {
 | 
				
			|||||||
	for _, action := range []string{"reject", "drop"} {
 | 
						for _, action := range []string{"reject", "drop"} {
 | 
				
			||||||
		var args = []string{"--remove-rich-rule=rule family='" + family + "' source address='" + ip + "' " + action}
 | 
							var args = []string{"--remove-rich-rule=rule family='" + family + "' source address='" + ip + "' " + action}
 | 
				
			||||||
		var cmd = exec.Command(this.exe, args...)
 | 
							var cmd = exec.Command(this.exe, args...)
 | 
				
			||||||
		this.pushCmd(cmd)
 | 
							this.pushCmd(cmd, "")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *Firewalld) pushCmd(cmd *exec.Cmd) {
 | 
					func (this *Firewalld) pushCmd(cmd *exec.Cmd, denyIP string) {
 | 
				
			||||||
	select {
 | 
						select {
 | 
				
			||||||
	case this.cmdQueue <- cmd:
 | 
						case this.cmdQueue <- &firewalldCmd{cmd: cmd, denyIP: denyIP}:
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		// we discard the command
 | 
							// we discard the command
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,6 +7,7 @@ package firewalls
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/conns"
 | 
				
			||||||
	teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
						teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
 | 
				
			||||||
@@ -100,6 +101,8 @@ func NewNFTablesFirewall() (*NFTablesFirewall, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type NFTablesFirewall struct {
 | 
					type NFTablesFirewall struct {
 | 
				
			||||||
 | 
						BaseFirewall
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	conn    *nftables.Conn
 | 
						conn    *nftables.Conn
 | 
				
			||||||
	isReady bool
 | 
						isReady bool
 | 
				
			||||||
	version string
 | 
						version string
 | 
				
			||||||
@@ -344,6 +347,14 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async
 | 
				
			|||||||
		return errors.New("invalid ip '" + ip + "'")
 | 
							return errors.New("invalid ip '" + ip + "'")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 避免短时间内重复添加
 | 
				
			||||||
 | 
						if this.checkLatestIP(ip) {
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 尝试关闭连接
 | 
				
			||||||
 | 
						conns.SharedMap.CloseIPConns(ip)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if async {
 | 
						if async {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case this.dropIPQueue <- &blockIPItem{
 | 
							case this.dropIPQueue <- &blockIPItem{
 | 
				
			||||||
@@ -357,6 +368,9 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async
 | 
				
			|||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 再次尝试关闭连接
 | 
				
			||||||
 | 
						defer conns.SharedMap.CloseIPConns(ip)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if strings.Contains(ip, ":") { // ipv6
 | 
						if strings.Contains(ip, ":") { // ipv6
 | 
				
			||||||
		if this.denyIPv6Set == nil {
 | 
							if this.denyIPv6Set == nil {
 | 
				
			||||||
			return errors.New("ipv6 ip set is nil")
 | 
								return errors.New("ipv6 ip set is nil")
 | 
				
			||||||
@@ -433,3 +447,35 @@ func (this *NFTablesFirewall) readVersion(nftPath string) string {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	return versionMatches[1]
 | 
						return versionMatches[1]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 检查是否在最近添加过
 | 
				
			||||||
 | 
					func (this *NFTablesFirewall) existLatestIP(ip string) bool {
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						defer this.locker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var expiredIndex = -1
 | 
				
			||||||
 | 
						for index, ipTime := range this.latestIPTimes {
 | 
				
			||||||
 | 
							var pieces = strings.Split(ipTime, "@")
 | 
				
			||||||
 | 
							var oldIP = pieces[0]
 | 
				
			||||||
 | 
							var oldTimestamp = pieces[1]
 | 
				
			||||||
 | 
							if types.Int64(oldTimestamp) < time.Now().Unix()-3 /** 3秒外表示过期 **/ {
 | 
				
			||||||
 | 
								expiredIndex = index
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if oldIP == ip {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if expiredIndex > -1 {
 | 
				
			||||||
 | 
							this.latestIPTimes = this.latestIPTimes[expiredIndex+1:]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.latestIPTimes = append(this.latestIPTimes, ip+"@"+types.String(time.Now().Unix()))
 | 
				
			||||||
 | 
						const maxLen = 128
 | 
				
			||||||
 | 
						if len(this.latestIPTimes) > maxLen {
 | 
				
			||||||
 | 
							this.latestIPTimes = this.latestIPTimes[1:]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,6 +5,7 @@ package nodes
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/conns"
 | 
				
			||||||
	teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
						teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/stats"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/stats"
 | 
				
			||||||
@@ -52,6 +53,9 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool) net.Conn {
 | 
				
			|||||||
		_ = conn.SetLinger(nodeconfigs.DefaultTCPLinger)
 | 
							_ = conn.SetLinger(nodeconfigs.DefaultTCPLinger)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 加入到Map
 | 
				
			||||||
 | 
						conns.SharedMap.Add(conn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return conn
 | 
						return conn
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -130,6 +134,9 @@ func (this *ClientConn) Close() error {
 | 
				
			|||||||
	// 不能加条件限制,因为服务配置随时有变化
 | 
						// 不能加条件限制,因为服务配置随时有变化
 | 
				
			||||||
	sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
 | 
						sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 从conn map中移除
 | 
				
			||||||
 | 
						conns.SharedMap.Remove(this)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return err
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -677,7 +677,7 @@ func (this *HTTPRequest) Format(source string) string {
 | 
				
			|||||||
		case "remoteAddrValue":
 | 
							case "remoteAddrValue":
 | 
				
			||||||
			return this.requestRemoteAddr(false)
 | 
								return this.requestRemoteAddr(false)
 | 
				
			||||||
		case "rawRemoteAddr":
 | 
							case "rawRemoteAddr":
 | 
				
			||||||
			addr := this.RawReq.RemoteAddr
 | 
								var addr = this.RawReq.RemoteAddr
 | 
				
			||||||
			host, _, err := net.SplitHostPort(addr)
 | 
								host, _, err := net.SplitHostPort(addr)
 | 
				
			||||||
			if err == nil {
 | 
								if err == nil {
 | 
				
			||||||
				addr = host
 | 
									addr = host
 | 
				
			||||||
@@ -1103,7 +1103,7 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Remote-Addr
 | 
						// Remote-Addr
 | 
				
			||||||
	remoteAddr := this.RawReq.RemoteAddr
 | 
						var remoteAddr = this.RawReq.RemoteAddr
 | 
				
			||||||
	host, _, err := net.SplitHostPort(remoteAddr)
 | 
						host, _, err := net.SplitHostPort(remoteAddr)
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		if supportVar {
 | 
							if supportVar {
 | 
				
			||||||
@@ -1320,7 +1320,7 @@ func (this *HTTPRequest) RemoteAddr() string {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *HTTPRequest) RawRemoteAddr() string {
 | 
					func (this *HTTPRequest) RawRemoteAddr() string {
 | 
				
			||||||
	addr := this.RawReq.RemoteAddr
 | 
						var addr = this.RawReq.RemoteAddr
 | 
				
			||||||
	host, _, err := net.SplitHostPort(addr)
 | 
						host, _, err := net.SplitHostPort(addr)
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		addr = host
 | 
							addr = host
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -34,17 +34,7 @@ func (this *HTTPRequest) MetricValue(value string) (result int64, ok bool) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		return this.RawReq.ContentLength + hl, true
 | 
							return this.RawReq.ContentLength + hl, true
 | 
				
			||||||
	case "${countConnection}":
 | 
						case "${countConnection}":
 | 
				
			||||||
		metricNewConnMapLocker.Lock()
 | 
							return 1, true
 | 
				
			||||||
		_, ok := metricNewConnMap[this.RawReq.RemoteAddr]
 | 
					 | 
				
			||||||
		if ok {
 | 
					 | 
				
			||||||
			delete(metricNewConnMap, this.RawReq.RemoteAddr)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		metricNewConnMapLocker.Unlock()
 | 
					 | 
				
			||||||
		if ok {
 | 
					 | 
				
			||||||
			return 1, true
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			return 0, false
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return 0, false
 | 
						return 0, false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -372,6 +372,8 @@ func (this *HTTPRequest) WAFServerId() int64 {
 | 
				
			|||||||
// WAFClose 关闭连接
 | 
					// WAFClose 关闭连接
 | 
				
			||||||
func (this *HTTPRequest) WAFClose() {
 | 
					func (this *HTTPRequest) WAFClose() {
 | 
				
			||||||
	this.Close()
 | 
						this.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 这里不要强关IP所有连接,避免因为单个服务而影响所有
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *HTTPRequest) WAFOnAction(action interface{}) (goNext bool) {
 | 
					func (this *HTTPRequest) WAFOnAction(action interface{}) (goNext bool) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,7 +5,6 @@ import (
 | 
				
			|||||||
	"crypto/tls"
 | 
						"crypto/tls"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/zero"
 | 
					 | 
				
			||||||
	"github.com/iwind/TeaGo/Tea"
 | 
						"github.com/iwind/TeaGo/Tea"
 | 
				
			||||||
	"golang.org/x/net/http2"
 | 
						"golang.org/x/net/http2"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
@@ -13,14 +12,11 @@ import (
 | 
				
			|||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
					 | 
				
			||||||
	"sync/atomic"
 | 
						"sync/atomic"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var httpErrorLogger = log.New(io.Discard, "", 0)
 | 
					var httpErrorLogger = log.New(io.Discard, "", 0)
 | 
				
			||||||
var metricNewConnMap = map[string]zero.Zero{} // remoteAddr => bool
 | 
					 | 
				
			||||||
var metricNewConnMapLocker = &sync.Mutex{}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
type contextKey struct {
 | 
					type contextKey struct {
 | 
				
			||||||
	key string
 | 
						key string
 | 
				
			||||||
@@ -55,23 +51,10 @@ func (this *HTTPListener) Serve() error {
 | 
				
			|||||||
			switch state {
 | 
								switch state {
 | 
				
			||||||
			case http.StateNew:
 | 
								case http.StateNew:
 | 
				
			||||||
				atomic.AddInt64(&this.countActiveConnections, 1)
 | 
									atomic.AddInt64(&this.countActiveConnections, 1)
 | 
				
			||||||
 | 
					 | 
				
			||||||
				// 为指标存储连接信息
 | 
					 | 
				
			||||||
				if sharedNodeConfig.HasHTTPConnectionMetrics() {
 | 
					 | 
				
			||||||
					metricNewConnMapLocker.Lock()
 | 
					 | 
				
			||||||
					metricNewConnMap[conn.RemoteAddr().String()] = zero.New()
 | 
					 | 
				
			||||||
					metricNewConnMapLocker.Unlock()
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			case http.StateActive, http.StateIdle, http.StateHijacked:
 | 
								case http.StateActive, http.StateIdle, http.StateHijacked:
 | 
				
			||||||
				// Nothing to do
 | 
									// Nothing to do
 | 
				
			||||||
			case http.StateClosed:
 | 
								case http.StateClosed:
 | 
				
			||||||
				atomic.AddInt64(&this.countActiveConnections, -1)
 | 
									atomic.AddInt64(&this.countActiveConnections, -1)
 | 
				
			||||||
 | 
					 | 
				
			||||||
				// 移除指标存储连接信息
 | 
					 | 
				
			||||||
				// 因为中途配置可能有改变,所以暂时不添加条件
 | 
					 | 
				
			||||||
				metricNewConnMapLocker.Lock()
 | 
					 | 
				
			||||||
				delete(metricNewConnMap, conn.RemoteAddr().String())
 | 
					 | 
				
			||||||
				metricNewConnMapLocker.Unlock()
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
 | 
							ConnContext: func(ctx context.Context, conn net.Conn) context.Context {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,6 +11,7 @@ import (
 | 
				
			|||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/caches"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/caches"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/configs"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/configs"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/conns"
 | 
				
			||||||
	teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
						teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/firewalls"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/firewalls"
 | 
				
			||||||
@@ -797,13 +798,16 @@ func (this *Node) listenSock() error {
 | 
				
			|||||||
					},
 | 
										},
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
			case "conns":
 | 
								case "conns":
 | 
				
			||||||
				ipConns, serverConns := sharedClientConnLimiter.Conns()
 | 
									var addrs = []string{}
 | 
				
			||||||
 | 
									var connMap = conns.SharedMap.AllConns()
 | 
				
			||||||
 | 
									for _, conn := range connMap {
 | 
				
			||||||
 | 
										addrs = append(addrs, conn.RemoteAddr().String())
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				_ = cmd.Reply(&gosock.Command{
 | 
									_ = cmd.Reply(&gosock.Command{
 | 
				
			||||||
					Params: map[string]interface{}{
 | 
										Params: map[string]interface{}{
 | 
				
			||||||
						"ipConns":     ipConns,
 | 
											"addrs": addrs,
 | 
				
			||||||
						"serverConns": serverConns,
 | 
											"total": len(addrs),
 | 
				
			||||||
						"total":       sharedListenerManager.TotalActiveConnections(),
 | 
					 | 
				
			||||||
					},
 | 
										},
 | 
				
			||||||
				})
 | 
									})
 | 
				
			||||||
			case "dropIP":
 | 
								case "dropIP":
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,6 +4,7 @@ package waf
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/conns"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/firewalls"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/firewalls"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
 | 
				
			||||||
	"github.com/iwind/TeaGo/types"
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
@@ -47,7 +48,7 @@ func NewIPList(listType IPListType) *IPList {
 | 
				
			|||||||
	list.expireList = e
 | 
						list.expireList = e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	e.OnGC(func(itemId uint64) {
 | 
						e.OnGC(func(itemId uint64) {
 | 
				
			||||||
		list.remove(itemId)
 | 
							list.remove(itemId) // TODO 使用异步,防止阻塞GC
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return list
 | 
						return list
 | 
				
			||||||
@@ -115,6 +116,9 @@ func (this *IPList) RecordIP(ipType string,
 | 
				
			|||||||
				_ = firewalls.Firewall().DropSourceIP(ip, int(seconds), true)
 | 
									_ = firewalls.Firewall().DropSourceIP(ip, int(seconds), true)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// 关闭所有连接
 | 
				
			||||||
 | 
							conns.SharedMap.CloseIPConns(ip)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user