diff --git a/internal/conns/map.go b/internal/conns/map.go new file mode 100644 index 0000000..afbaf0a --- /dev/null +++ b/internal/conns/map.go @@ -0,0 +1,117 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package conns + +import ( + "net" + "sync" +) + +var SharedMap = NewMap() + +type Map struct { + m map[string]map[int]net.Conn // ip => { port => Conn } + + locker sync.RWMutex +} + +func NewMap() *Map { + return &Map{ + m: map[string]map[int]net.Conn{}, + } +} + +func (this *Map) Add(conn net.Conn) { + if conn == nil { + return + } + tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr) + if !ok { + return + } + + var ip = tcpAddr.IP.String() + var port = tcpAddr.Port + + this.locker.Lock() + defer this.locker.Unlock() + connMap, ok := this.m[ip] + if !ok { + this.m[ip] = map[int]net.Conn{ + port: conn, + } + } else { + connMap[port] = conn + } +} + +func (this *Map) Remove(conn net.Conn) { + if conn == nil { + return + } + tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr) + if !ok { + return + } + + var ip = tcpAddr.IP.String() + var port = tcpAddr.Port + + this.locker.Lock() + defer this.locker.Unlock() + + connMap, ok := this.m[ip] + if !ok { + return + } + delete(connMap, port) + + if len(connMap) == 0 { + delete(this.m, ip) + } +} + +func (this *Map) CountIPConns(ip string) int { + this.locker.RLock() + var l = len(this.m[ip]) + this.locker.RUnlock() + return l +} + +func (this *Map) CloseIPConns(ip string) { + var conns = []net.Conn{} + + this.locker.RLock() + connMap, ok := this.m[ip] + + // 复制,防止在Close时产生并发冲突 + if ok { + for _, conn := range connMap { + conns = append(conns, conn) + } + } + + // 需要在Close之前结束,防止死循环 + this.locker.RUnlock() + + if ok { + for _, conn := range conns { + _ = conn.Close() + } + + // 这里不需要从 m 中删除,因为关闭时会自然触发回调 + } +} + +func (this *Map) AllConns() []net.Conn { + this.locker.RLock() + defer this.locker.RUnlock() + + var result = []net.Conn{} + for _, m := range this.m { + for _, conn := range m { + result = append(result, conn) + } + } + return result +} diff --git a/internal/firewalls/firewall_base.go b/internal/firewalls/firewall_base.go new file mode 100644 index 0000000..aef93a0 --- /dev/null +++ b/internal/firewalls/firewall_base.go @@ -0,0 +1,47 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package firewalls + +import ( + "github.com/iwind/TeaGo/types" + "strings" + "sync" + "time" +) + +type BaseFirewall struct { + locker sync.Mutex + latestIPTimes []string // [ip@time, ....] +} + +// 检查是否在最近添加过 +func (this *BaseFirewall) checkLatestIP(ip string) bool { + this.locker.Lock() + defer this.locker.Unlock() + + var expiredIndex = -1 + for index, ipTime := range this.latestIPTimes { + var pieces = strings.Split(ipTime, "@") + var oldIP = pieces[0] + var oldTimestamp = pieces[1] + if types.Int64(oldTimestamp) < time.Now().Unix()-3 /** 3秒外表示过期 **/ { + expiredIndex = index + continue + } + if oldIP == ip { + return true + } + } + + if expiredIndex > -1 { + this.latestIPTimes = this.latestIPTimes[expiredIndex+1:] + } + + this.latestIPTimes = append(this.latestIPTimes, ip+"@"+types.String(time.Now().Unix())) + const maxLen = 128 + if len(this.latestIPTimes) > maxLen { + this.latestIPTimes = this.latestIPTimes[1:] + } + + return false +} diff --git a/internal/firewalls/firewall_firewalld.go b/internal/firewalls/firewall_firewalld.go index 3cb2880..6ec1a15 100644 --- a/internal/firewalls/firewall_firewalld.go +++ b/internal/firewalls/firewall_firewalld.go @@ -4,6 +4,7 @@ package firewalls import ( "errors" + "github.com/TeaOSLab/EdgeNode/internal/conns" "github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/iwind/TeaGo/types" @@ -11,15 +12,22 @@ import ( "strings" ) +type firewalldCmd struct { + cmd *exec.Cmd + denyIP string +} + type Firewalld struct { + BaseFirewall + isReady bool exe string - cmdQueue chan *exec.Cmd + cmdQueue chan *firewalldCmd } func NewFirewalld() *Firewalld { var firewalld = &Firewalld{ - cmdQueue: make(chan *exec.Cmd, 4096), + cmdQueue: make(chan *firewalldCmd, 4096), } path, err := exec.LookPath("firewall-cmd") @@ -41,13 +49,19 @@ func NewFirewalld() *Firewalld { func (this *Firewalld) init() { goman.New(func() { - for cmd := range this.cmdQueue { + for c := range this.cmdQueue { + var cmd = c.cmd err := cmd.Run() if err != nil { if strings.HasPrefix(err.Error(), "Warning:") { continue } remotelogs.Warn("FIREWALL", "run command failed '"+cmd.String()+"': "+err.Error()) + } else { + // 关闭连接 + if len(c.denyIP) > 0 { + conns.SharedMap.CloseIPConns(c.denyIP) + } } } }) @@ -72,7 +86,7 @@ func (this *Firewalld) AllowPort(port int, protocol string) error { return nil } var cmd = exec.Command(this.exe, "--add-port="+types.String(port)+"/"+protocol) - this.pushCmd(cmd) + this.pushCmd(cmd, "") return nil } @@ -82,12 +96,12 @@ func (this *Firewalld) AllowPortRangesPermanently(portRanges [][2]int, protocol { var cmd = exec.Command(this.exe, "--add-port="+port, "--permanent") - this.pushCmd(cmd) + this.pushCmd(cmd, "") } { var cmd = exec.Command(this.exe, "--add-port="+port) - this.pushCmd(cmd) + this.pushCmd(cmd, "") } } @@ -99,7 +113,7 @@ func (this *Firewalld) RemovePort(port int, protocol string) error { return nil } var cmd = exec.Command(this.exe, "--remove-port="+types.String(port)+"/"+protocol) - this.pushCmd(cmd) + this.pushCmd(cmd, "") return nil } @@ -108,12 +122,12 @@ func (this *Firewalld) RemovePortRangePermanently(portRange [2]int, protocol str { var cmd = exec.Command(this.exe, "--remove-port="+port, "--permanent") - this.pushCmd(cmd) + this.pushCmd(cmd, "") } { var cmd = exec.Command(this.exe, "--remove-port="+port) - this.pushCmd(cmd) + this.pushCmd(cmd, "") } return nil @@ -131,6 +145,12 @@ func (this *Firewalld) RejectSourceIP(ip string, timeoutSeconds int) error { if !this.isReady { return nil } + + // 避免短时间内重复添加 + if this.checkLatestIP(ip) { + return nil + } + var family = "ipv4" if strings.Contains(ip, ":") { family = "ipv6" @@ -140,7 +160,7 @@ func (this *Firewalld) RejectSourceIP(ip string, timeoutSeconds int) error { args = append(args, "--timeout="+types.String(timeoutSeconds)+"s") } var cmd = exec.Command(this.exe, args...) - this.pushCmd(cmd) + this.pushCmd(cmd, ip) return nil } @@ -148,6 +168,12 @@ func (this *Firewalld) DropSourceIP(ip string, timeoutSeconds int, async bool) e if !this.isReady { return nil } + + // 避免短时间内重复添加 + if this.checkLatestIP(ip) { + return nil + } + var family = "ipv4" if strings.Contains(ip, ":") { family = "ipv6" @@ -158,10 +184,13 @@ func (this *Firewalld) DropSourceIP(ip string, timeoutSeconds int, async bool) e } var cmd = exec.Command(this.exe, args...) if async { - this.pushCmd(cmd) + this.pushCmd(cmd, ip) return nil } + // 关闭连接 + defer conns.SharedMap.CloseIPConns(ip) + err := cmd.Run() if err != nil { return errors.New("run command failed '" + cmd.String() + "': " + err.Error()) @@ -173,6 +202,7 @@ func (this *Firewalld) RemoveSourceIP(ip string) error { if !this.isReady { return nil } + var family = "ipv4" if strings.Contains(ip, ":") { family = "ipv6" @@ -180,14 +210,14 @@ func (this *Firewalld) RemoveSourceIP(ip string) error { for _, action := range []string{"reject", "drop"} { var args = []string{"--remove-rich-rule=rule family='" + family + "' source address='" + ip + "' " + action} var cmd = exec.Command(this.exe, args...) - this.pushCmd(cmd) + this.pushCmd(cmd, "") } return nil } -func (this *Firewalld) pushCmd(cmd *exec.Cmd) { +func (this *Firewalld) pushCmd(cmd *exec.Cmd, denyIP string) { select { - case this.cmdQueue <- cmd: + case this.cmdQueue <- &firewalldCmd{cmd: cmd, denyIP: denyIP}: default: // we discard the command } diff --git a/internal/firewalls/firewall_nftables.go b/internal/firewalls/firewall_nftables.go index f56a566..33517a0 100644 --- a/internal/firewalls/firewall_nftables.go +++ b/internal/firewalls/firewall_nftables.go @@ -7,6 +7,7 @@ package firewalls import ( "bytes" "errors" + "github.com/TeaOSLab/EdgeNode/internal/conns" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/events" "github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables" @@ -100,6 +101,8 @@ func NewNFTablesFirewall() (*NFTablesFirewall, error) { } type NFTablesFirewall struct { + BaseFirewall + conn *nftables.Conn isReady bool version string @@ -344,6 +347,14 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async return errors.New("invalid ip '" + ip + "'") } + // 避免短时间内重复添加 + if this.checkLatestIP(ip) { + return nil + } + + // 尝试关闭连接 + conns.SharedMap.CloseIPConns(ip) + if async { select { case this.dropIPQueue <- &blockIPItem{ @@ -357,6 +368,9 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async return nil } + // 再次尝试关闭连接 + defer conns.SharedMap.CloseIPConns(ip) + if strings.Contains(ip, ":") { // ipv6 if this.denyIPv6Set == nil { return errors.New("ipv6 ip set is nil") @@ -433,3 +447,35 @@ func (this *NFTablesFirewall) readVersion(nftPath string) string { } return versionMatches[1] } + +// 检查是否在最近添加过 +func (this *NFTablesFirewall) existLatestIP(ip string) bool { + this.locker.Lock() + defer this.locker.Unlock() + + var expiredIndex = -1 + for index, ipTime := range this.latestIPTimes { + var pieces = strings.Split(ipTime, "@") + var oldIP = pieces[0] + var oldTimestamp = pieces[1] + if types.Int64(oldTimestamp) < time.Now().Unix()-3 /** 3秒外表示过期 **/ { + expiredIndex = index + continue + } + if oldIP == ip { + return true + } + } + + if expiredIndex > -1 { + this.latestIPTimes = this.latestIPTimes[expiredIndex+1:] + } + + this.latestIPTimes = append(this.latestIPTimes, ip+"@"+types.String(time.Now().Unix())) + const maxLen = 128 + if len(this.latestIPTimes) > maxLen { + this.latestIPTimes = this.latestIPTimes[1:] + } + + return false +} diff --git a/internal/nodes/client_conn.go b/internal/nodes/client_conn.go index 95a5463..d1f97da 100644 --- a/internal/nodes/client_conn.go +++ b/internal/nodes/client_conn.go @@ -5,6 +5,7 @@ package nodes import ( "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" + "github.com/TeaOSLab/EdgeNode/internal/conns" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/stats" @@ -52,6 +53,9 @@ func NewClientConn(rawConn net.Conn, isTLS bool, quickClose bool) net.Conn { _ = conn.SetLinger(nodeconfigs.DefaultTCPLinger) } + // 加入到Map + conns.SharedMap.Add(conn) + return conn } @@ -130,6 +134,9 @@ func (this *ClientConn) Close() error { // 不能加条件限制,因为服务配置随时有变化 sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String()) + // 从conn map中移除 + conns.SharedMap.Remove(this) + return err } diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index 6af7547..b4d2b10 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -677,7 +677,7 @@ func (this *HTTPRequest) Format(source string) string { case "remoteAddrValue": return this.requestRemoteAddr(false) case "rawRemoteAddr": - addr := this.RawReq.RemoteAddr + var addr = this.RawReq.RemoteAddr host, _, err := net.SplitHostPort(addr) if err == nil { addr = host @@ -1103,7 +1103,7 @@ func (this *HTTPRequest) requestRemoteAddr(supportVar bool) string { } // Remote-Addr - remoteAddr := this.RawReq.RemoteAddr + var remoteAddr = this.RawReq.RemoteAddr host, _, err := net.SplitHostPort(remoteAddr) if err == nil { if supportVar { @@ -1320,7 +1320,7 @@ func (this *HTTPRequest) RemoteAddr() string { } func (this *HTTPRequest) RawRemoteAddr() string { - addr := this.RawReq.RemoteAddr + var addr = this.RawReq.RemoteAddr host, _, err := net.SplitHostPort(addr) if err == nil { addr = host diff --git a/internal/nodes/http_request_metrics.go b/internal/nodes/http_request_metrics.go index 2d6254e..6df4878 100644 --- a/internal/nodes/http_request_metrics.go +++ b/internal/nodes/http_request_metrics.go @@ -34,17 +34,7 @@ func (this *HTTPRequest) MetricValue(value string) (result int64, ok bool) { } return this.RawReq.ContentLength + hl, true case "${countConnection}": - metricNewConnMapLocker.Lock() - _, ok := metricNewConnMap[this.RawReq.RemoteAddr] - if ok { - delete(metricNewConnMap, this.RawReq.RemoteAddr) - } - metricNewConnMapLocker.Unlock() - if ok { - return 1, true - } else { - return 0, false - } + return 1, true } return 0, false } diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index 4115ff0..a4d2aae 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -372,6 +372,8 @@ func (this *HTTPRequest) WAFServerId() int64 { // WAFClose 关闭连接 func (this *HTTPRequest) WAFClose() { this.Close() + + // 这里不要强关IP所有连接,避免因为单个服务而影响所有 } func (this *HTTPRequest) WAFOnAction(action interface{}) (goNext bool) { diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index d092acd..2fbfc4f 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" - "github.com/TeaOSLab/EdgeNode/internal/zero" "github.com/iwind/TeaGo/Tea" "golang.org/x/net/http2" "io" @@ -13,14 +12,11 @@ import ( "net" "net/http" "strings" - "sync" "sync/atomic" "time" ) var httpErrorLogger = log.New(io.Discard, "", 0) -var metricNewConnMap = map[string]zero.Zero{} // remoteAddr => bool -var metricNewConnMapLocker = &sync.Mutex{} type contextKey struct { key string @@ -55,23 +51,10 @@ func (this *HTTPListener) Serve() error { switch state { case http.StateNew: atomic.AddInt64(&this.countActiveConnections, 1) - - // 为指标存储连接信息 - if sharedNodeConfig.HasHTTPConnectionMetrics() { - metricNewConnMapLocker.Lock() - metricNewConnMap[conn.RemoteAddr().String()] = zero.New() - metricNewConnMapLocker.Unlock() - } case http.StateActive, http.StateIdle, http.StateHijacked: // Nothing to do case http.StateClosed: atomic.AddInt64(&this.countActiveConnections, -1) - - // 移除指标存储连接信息 - // 因为中途配置可能有改变,所以暂时不添加条件 - metricNewConnMapLocker.Lock() - delete(metricNewConnMap, conn.RemoteAddr().String()) - metricNewConnMapLocker.Unlock() } }, ConnContext: func(ctx context.Context, conn net.Conn) context.Context { diff --git a/internal/nodes/node.go b/internal/nodes/node.go index 65d109f..f93a636 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -11,6 +11,7 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs" "github.com/TeaOSLab/EdgeNode/internal/caches" "github.com/TeaOSLab/EdgeNode/internal/configs" + "github.com/TeaOSLab/EdgeNode/internal/conns" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/events" "github.com/TeaOSLab/EdgeNode/internal/firewalls" @@ -797,13 +798,16 @@ func (this *Node) listenSock() error { }, }) case "conns": - ipConns, serverConns := sharedClientConnLimiter.Conns() + var addrs = []string{} + var connMap = conns.SharedMap.AllConns() + for _, conn := range connMap { + addrs = append(addrs, conn.RemoteAddr().String()) + } _ = cmd.Reply(&gosock.Command{ Params: map[string]interface{}{ - "ipConns": ipConns, - "serverConns": serverConns, - "total": sharedListenerManager.TotalActiveConnections(), + "addrs": addrs, + "total": len(addrs), }, }) case "dropIP": diff --git a/internal/waf/ip_list.go b/internal/waf/ip_list.go index 52a08ed..77a0f6a 100644 --- a/internal/waf/ip_list.go +++ b/internal/waf/ip_list.go @@ -4,6 +4,7 @@ package waf import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" + "github.com/TeaOSLab/EdgeNode/internal/conns" "github.com/TeaOSLab/EdgeNode/internal/firewalls" "github.com/TeaOSLab/EdgeNode/internal/utils/expires" "github.com/iwind/TeaGo/types" @@ -47,7 +48,7 @@ func NewIPList(listType IPListType) *IPList { list.expireList = e e.OnGC(func(itemId uint64) { - list.remove(itemId) + list.remove(itemId) // TODO 使用异步,防止阻塞GC }) return list @@ -115,6 +116,9 @@ func (this *IPList) RecordIP(ipType string, _ = firewalls.Firewall().DropSourceIP(ip, int(seconds), true) } } + + // 关闭所有连接 + conns.SharedMap.CloseIPConns(ip) } }