mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 07:40:56 +08:00 
			
		
		
		
	实现基础的DDoS防护
This commit is contained in:
		@@ -3,9 +3,10 @@ package events
 | 
			
		||||
type Event = string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	EventStart      Event = "start"      // start loading
 | 
			
		||||
	EventLoaded     Event = "loaded"     // first load
 | 
			
		||||
	EventQuit       Event = "quit"       // quit node gracefully
 | 
			
		||||
	EventReload     Event = "reload"     // reload config
 | 
			
		||||
	EventTerminated Event = "terminated" // process terminated
 | 
			
		||||
	EventStart         Event = "start"         // start loading
 | 
			
		||||
	EventLoaded        Event = "loaded"        // first load
 | 
			
		||||
	EventQuit          Event = "quit"          // quit node gracefully
 | 
			
		||||
	EventReload        Event = "reload"        // reload config
 | 
			
		||||
	EventTerminated    Event = "terminated"    // process terminated
 | 
			
		||||
	EventNFTablesReady Event = "nftablesReady" // nftables ready
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										1
									
								
								internal/firewalls/.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								internal/firewalls/.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -1 +0,0 @@
 | 
			
		||||
firewall_nftables_test.go
 | 
			
		||||
							
								
								
									
										494
									
								
								internal/firewalls/ddos_protection.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										494
									
								
								internal/firewalls/ddos_protection.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,494 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build linux
 | 
			
		||||
// +build linux
 | 
			
		||||
 | 
			
		||||
package firewalls
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/zero"
 | 
			
		||||
	"github.com/iwind/TeaGo/lists"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	events.On(events.EventReload, func() {
 | 
			
		||||
		if nftablesInstance == nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		nodeConfig, _ := nodeconfigs.SharedNodeConfig()
 | 
			
		||||
		if nodeConfig != nil {
 | 
			
		||||
			err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	events.On(events.EventNFTablesReady, func() {
 | 
			
		||||
		nodeConfig, _ := nodeconfigs.SharedNodeConfig()
 | 
			
		||||
		if nodeConfig != nil {
 | 
			
		||||
			err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DDoSProtectionManager DDoS防护
 | 
			
		||||
type DDoSProtectionManager struct {
 | 
			
		||||
	nftPath string
 | 
			
		||||
 | 
			
		||||
	lastAllowIPList []string
 | 
			
		||||
	lastConfig      []byte
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewDDoSProtectionManager 获取新对象
 | 
			
		||||
func NewDDoSProtectionManager() *DDoSProtectionManager {
 | 
			
		||||
	nftPath, _ := exec.LookPath("nft")
 | 
			
		||||
 | 
			
		||||
	return &DDoSProtectionManager{
 | 
			
		||||
		nftPath: nftPath,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Apply 应用配置
 | 
			
		||||
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
 | 
			
		||||
	// 同集群节点IP白名单
 | 
			
		||||
	var allowIPListChanged = false
 | 
			
		||||
	nodeConfig, _ := nodeconfigs.SharedNodeConfig()
 | 
			
		||||
	if nodeConfig != nil {
 | 
			
		||||
		var allowIPList = nodeConfig.AllowedIPs
 | 
			
		||||
		if !utils.ContainsSameStrings(allowIPList, this.lastAllowIPList) {
 | 
			
		||||
			allowIPListChanged = true
 | 
			
		||||
			this.lastAllowIPList = allowIPList
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 对比配置
 | 
			
		||||
	configJSON, err := json.Marshal(config)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("encode config to json failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	if !allowIPListChanged && bytes.Equal(this.lastConfig, configJSON) {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	remotelogs.Println("FIREWALL", "change DDoS protection config")
 | 
			
		||||
 | 
			
		||||
	if len(this.nftPath) == 0 {
 | 
			
		||||
		return errors.New("can not find nft command")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if nftablesInstance == nil {
 | 
			
		||||
		return errors.New("nftables instance should not be nil")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if config == nil {
 | 
			
		||||
		// TCP
 | 
			
		||||
		err := this.removeTCPRules()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// TODO other protocols
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TCP
 | 
			
		||||
	if config.TCP == nil {
 | 
			
		||||
		err := this.removeTCPRules()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		// allow ip list
 | 
			
		||||
		var allowIPList = []string{}
 | 
			
		||||
		for _, ipConfig := range config.TCP.AllowIPList {
 | 
			
		||||
			allowIPList = append(allowIPList, ipConfig.IP)
 | 
			
		||||
		}
 | 
			
		||||
		for _, ip := range this.lastAllowIPList {
 | 
			
		||||
			if !lists.ContainsString(allowIPList, ip) {
 | 
			
		||||
				allowIPList = append(allowIPList, ip)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		err = this.updateAllowIPList(allowIPList)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// tcp
 | 
			
		||||
		if config.TCP.IsOn {
 | 
			
		||||
			err := this.addTCPRules(config.TCP)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			err := this.removeTCPRules()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.lastConfig = configJSON
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 添加TCP规则
 | 
			
		||||
func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig) error {
 | 
			
		||||
	var ports = []int32{}
 | 
			
		||||
	for _, portConfig := range tcpConfig.Ports {
 | 
			
		||||
		if !lists.ContainsInt32(ports, portConfig.Port) {
 | 
			
		||||
			ports = append(ports, portConfig.Port)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if len(ports) == 0 {
 | 
			
		||||
		ports = []int32{80, 443}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, filter := range nftablesFilters {
 | 
			
		||||
		chain, oldRules, err := this.getRules(filter)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errors.New("get old rules failed: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var protocol = filter.protocol()
 | 
			
		||||
 | 
			
		||||
		// max connections
 | 
			
		||||
		var maxConnections = tcpConfig.MaxConnections
 | 
			
		||||
		if maxConnections <= 0 {
 | 
			
		||||
			maxConnections = nodeconfigs.DefaultTCPMaxConnections
 | 
			
		||||
			if maxConnections <= 0 {
 | 
			
		||||
				maxConnections = 100000
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// max connections per ip
 | 
			
		||||
		var maxConnectionsPerIP = tcpConfig.MaxConnectionsPerIP
 | 
			
		||||
		if maxConnectionsPerIP <= 0 {
 | 
			
		||||
			maxConnectionsPerIP = nodeconfigs.DefaultTCPMaxConnectionsPerIP
 | 
			
		||||
			if maxConnectionsPerIP <= 0 {
 | 
			
		||||
				maxConnectionsPerIP = 100000
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// new connections rate
 | 
			
		||||
		var newConnectionsRate = tcpConfig.NewConnectionsRate
 | 
			
		||||
		if newConnectionsRate <= 0 {
 | 
			
		||||
			newConnectionsRate = nodeconfigs.DefaultTCPNewConnectionsRate
 | 
			
		||||
			if newConnectionsRate <= 0 {
 | 
			
		||||
				newConnectionsRate = 100000
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 检查是否有变化
 | 
			
		||||
		var hasChanges = false
 | 
			
		||||
		for _, port := range ports {
 | 
			
		||||
			if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}) {
 | 
			
		||||
				hasChanges = true
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
			if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}) {
 | 
			
		||||
				hasChanges = true
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
			if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}) {
 | 
			
		||||
				hasChanges = true
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !hasChanges {
 | 
			
		||||
			// 检查是否有多余的端口
 | 
			
		||||
			var oldPorts = this.getTCPPorts(oldRules)
 | 
			
		||||
			if !this.eqPorts(ports, oldPorts) {
 | 
			
		||||
				hasChanges = true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !hasChanges {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 先清空所有相关规则
 | 
			
		||||
		err = this.removeOldTCPRules(chain, oldRules)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errors.New("delete old rules failed: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 添加新规则
 | 
			
		||||
		for _, port := range ports {
 | 
			
		||||
			if maxConnections > 0 {
 | 
			
		||||
				var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
 | 
			
		||||
				err := cmd.Run()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if maxConnectionsPerIP > 0 {
 | 
			
		||||
				var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
 | 
			
		||||
				var stderr = &bytes.Buffer{}
 | 
			
		||||
				cmd.Stderr = stderr
 | 
			
		||||
				err := cmd.Run()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if newConnectionsRate > 0 {
 | 
			
		||||
				// TODO 思考是否有惩罚机制
 | 
			
		||||
				var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsRate)+"/minute burst "+types.String(newConnectionsRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}))
 | 
			
		||||
				var stderr = &bytes.Buffer{}
 | 
			
		||||
				cmd.Stderr = stderr
 | 
			
		||||
				err := cmd.Run()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 删除TCP规则
 | 
			
		||||
func (this *DDoSProtectionManager) removeTCPRules() error {
 | 
			
		||||
	for _, filter := range nftablesFilters {
 | 
			
		||||
		chain, rules, err := this.getRules(filter)
 | 
			
		||||
 | 
			
		||||
		// TCP
 | 
			
		||||
		err = this.removeOldTCPRules(chain, rules)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 组合user data
 | 
			
		||||
// 数据中不能包含字母、数字、下划线以外的数据
 | 
			
		||||
func (this *DDoSProtectionManager) encodeUserData(attrs []string) string {
 | 
			
		||||
	if attrs == nil {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return "ZZ" + strings.Join(attrs, "_") + "ZZ"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 解码user data
 | 
			
		||||
func (this *DDoSProtectionManager) decodeUserData(data []byte) []string {
 | 
			
		||||
	if len(data) == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var dataCopy = make([]byte, len(data))
 | 
			
		||||
	copy(dataCopy, data)
 | 
			
		||||
 | 
			
		||||
	var separatorLen = 2
 | 
			
		||||
	var index1 = bytes.Index(dataCopy, []byte{'Z', 'Z'})
 | 
			
		||||
	if index1 < 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dataCopy = dataCopy[index1+separatorLen:]
 | 
			
		||||
	var index2 = bytes.LastIndex(dataCopy, []byte{'Z', 'Z'})
 | 
			
		||||
	if index2 < 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var s = string(dataCopy[:index2])
 | 
			
		||||
	var pieces = strings.Split(s, "_")
 | 
			
		||||
	for index, piece := range pieces {
 | 
			
		||||
		pieces[index] = strings.TrimSpace(piece)
 | 
			
		||||
	}
 | 
			
		||||
	return pieces
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 清除规则
 | 
			
		||||
func (this *DDoSProtectionManager) removeOldTCPRules(chain *nftables.Chain, rules []*nftables.Rule) error {
 | 
			
		||||
	for _, rule := range rules {
 | 
			
		||||
		var pieces = this.decodeUserData(rule.UserData())
 | 
			
		||||
		if len(pieces) != 4 {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if pieces[0] != "tcp" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		switch pieces[2] {
 | 
			
		||||
		case "maxConnections", "maxConnectionsPerIP", "newConnectionsRate":
 | 
			
		||||
			err := chain.DeleteRule(rule)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 根据参数检查规则是否存在
 | 
			
		||||
func (this *DDoSProtectionManager) existsRule(rules []*nftables.Rule, attrs []string) (exists bool) {
 | 
			
		||||
	if len(attrs) == 0 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	for _, oldRule := range rules {
 | 
			
		||||
		var pieces = this.decodeUserData(oldRule.UserData())
 | 
			
		||||
		if len(attrs) != len(pieces) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		var isSame = true
 | 
			
		||||
		for index, piece := range pieces {
 | 
			
		||||
			if strings.TrimSpace(piece) != attrs[index] {
 | 
			
		||||
				isSame = false
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if isSame {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取规则中的端口号
 | 
			
		||||
func (this *DDoSProtectionManager) getTCPPorts(rules []*nftables.Rule) []int32 {
 | 
			
		||||
	var ports = []int32{}
 | 
			
		||||
	for _, rule := range rules {
 | 
			
		||||
		var pieces = this.decodeUserData(rule.UserData())
 | 
			
		||||
		if len(pieces) != 4 {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if pieces[0] != "tcp" {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		var port = types.Int32(pieces[1])
 | 
			
		||||
		if port > 0 && !lists.ContainsInt32(ports, port) {
 | 
			
		||||
			ports = append(ports, port)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return ports
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 检查端口是否一样
 | 
			
		||||
func (this *DDoSProtectionManager) eqPorts(ports1 []int32, ports2 []int32) bool {
 | 
			
		||||
	if len(ports1) != len(ports2) {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var portMap = map[int32]bool{}
 | 
			
		||||
	for _, port := range ports2 {
 | 
			
		||||
		portMap[port] = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, port := range ports1 {
 | 
			
		||||
		_, ok := portMap[port]
 | 
			
		||||
		if !ok {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 查找Table
 | 
			
		||||
func (this *DDoSProtectionManager) getTable(filter *nftablesTableDefinition) (*nftables.Table, error) {
 | 
			
		||||
	var family nftables.TableFamily
 | 
			
		||||
	if filter.IsIPv4 {
 | 
			
		||||
		family = nftables.TableFamilyIPv4
 | 
			
		||||
	} else if filter.IsIPv6 {
 | 
			
		||||
		family = nftables.TableFamilyIPv6
 | 
			
		||||
	} else {
 | 
			
		||||
		return nil, errors.New("table '" + filter.Name + "' should be IPv4 or IPv6")
 | 
			
		||||
	}
 | 
			
		||||
	return nftablesInstance.conn.GetTable(filter.Name, family)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 查找所有规则
 | 
			
		||||
func (this *DDoSProtectionManager) getRules(filter *nftablesTableDefinition) (*nftables.Chain, []*nftables.Rule, error) {
 | 
			
		||||
	table, err := this.getTable(filter)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, errors.New("get table failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	chain, err := table.GetChain(nftablesChainName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, errors.New("get chain failed: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	rules, err := chain.GetRules()
 | 
			
		||||
	return chain, rules, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 更新白名单
 | 
			
		||||
func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
 | 
			
		||||
	if nftablesInstance == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var allMap = map[string]zero.Zero{}
 | 
			
		||||
	for _, ip := range allIPList {
 | 
			
		||||
		allMap[ip] = zero.New()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, set := range []*nftables.Set{nftablesInstance.allowIPv4Set, nftablesInstance.allowIPv6Set} {
 | 
			
		||||
		var isIPv4 = set == nftablesInstance.allowIPv4Set
 | 
			
		||||
		var isIPv6 = !isIPv4
 | 
			
		||||
 | 
			
		||||
		// 现有的
 | 
			
		||||
		oldList, err := set.GetIPElements()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		var oldMap = map[string]zero.Zero{} // ip=> zero
 | 
			
		||||
		for _, ip := range oldList {
 | 
			
		||||
			oldMap[ip] = zero.New()
 | 
			
		||||
 | 
			
		||||
			if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
 | 
			
		||||
				_, ok := allMap[ip]
 | 
			
		||||
				if !ok {
 | 
			
		||||
					// 不存在则删除
 | 
			
		||||
					err = set.DeleteIPElement(ip)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.New("delete ip element '" + ip + "' failed: " + err.Error())
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 新增的
 | 
			
		||||
		for _, ip := range allIPList {
 | 
			
		||||
			var ipObj = net.ParseIP(ip)
 | 
			
		||||
			if ipObj == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
 | 
			
		||||
				_, ok := oldMap[ip]
 | 
			
		||||
				if !ok {
 | 
			
		||||
					// 不存在则添加
 | 
			
		||||
					err = set.AddIPElement(ip, nil)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.New("add ip '" + ip + "' failed: " + err.Error())
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										23
									
								
								internal/firewalls/ddos_protection_others.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								internal/firewalls/ddos_protection_others.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,23 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build !linux
 | 
			
		||||
// +build !linux
 | 
			
		||||
 | 
			
		||||
package firewalls
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
 | 
			
		||||
 | 
			
		||||
type DDoSProtectionManager struct {
 | 
			
		||||
	nftPath string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewDDoSProtectionManager() *DDoSProtectionManager {
 | 
			
		||||
	return &DDoSProtectionManager{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -1,6 +1,4 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build !plus
 | 
			
		||||
// +build !plus
 | 
			
		||||
 | 
			
		||||
package firewalls
 | 
			
		||||
 | 
			
		||||
@@ -8,9 +6,11 @@ import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var currentFirewall FirewallInterface
 | 
			
		||||
var firewallLocker = &sync.Mutex{}
 | 
			
		||||
 | 
			
		||||
// 初始化
 | 
			
		||||
func init() {
 | 
			
		||||
@@ -24,10 +24,28 @@ func init() {
 | 
			
		||||
 | 
			
		||||
// Firewall 查找当前系统中最适合的防火墙
 | 
			
		||||
func Firewall() FirewallInterface {
 | 
			
		||||
	firewallLocker.Lock()
 | 
			
		||||
	defer firewallLocker.Unlock()
 | 
			
		||||
	if currentFirewall != nil {
 | 
			
		||||
		return currentFirewall
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// nftables
 | 
			
		||||
	if runtime.GOOS == "linux" {
 | 
			
		||||
		nftables, err := NewNFTablesFirewall()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			remotelogs.Warn("FIREWALL", "'nftables' should be installed on the system to enhance security (init failed: "+err.Error()+")")
 | 
			
		||||
		} else {
 | 
			
		||||
			if nftables.IsReady() {
 | 
			
		||||
				currentFirewall = nftables
 | 
			
		||||
				events.Notify(events.EventNFTablesReady)
 | 
			
		||||
				return nftables
 | 
			
		||||
			} else {
 | 
			
		||||
				remotelogs.Warn("FIREWALL", "'nftables' should be enabled on the system to enhance security")
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// firewalld
 | 
			
		||||
	if runtime.GOOS == "linux" {
 | 
			
		||||
		var firewalld = NewFirewalld()
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										373
									
								
								internal/firewalls/firewall_nftables.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										373
									
								
								internal/firewalls/firewall_nftables.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,373 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build linux
 | 
			
		||||
// +build linux
 | 
			
		||||
 | 
			
		||||
package firewalls
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// check nft status, if being enabled we load it automatically
 | 
			
		||||
func init() {
 | 
			
		||||
	if runtime.GOOS == "linux" {
 | 
			
		||||
		var ticker = time.NewTicker(3 * time.Minute)
 | 
			
		||||
		go func() {
 | 
			
		||||
			for range ticker.C {
 | 
			
		||||
				// if already ready, we break
 | 
			
		||||
				if nftablesIsReady {
 | 
			
		||||
					ticker.Stop()
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
				_, err := exec.LookPath("nft")
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					nftablesFirewall, err := NewNFTablesFirewall()
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
					currentFirewall = nftablesFirewall
 | 
			
		||||
					remotelogs.Println("FIREWALL", "nftables is ready")
 | 
			
		||||
 | 
			
		||||
					// fire event
 | 
			
		||||
					if nftablesFirewall.IsReady() {
 | 
			
		||||
						events.Notify(events.EventNFTablesReady)
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					ticker.Stop()
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var nftablesInstance *NFTablesFirewall
 | 
			
		||||
var nftablesIsReady = false
 | 
			
		||||
var nftablesFilters = []*nftablesTableDefinition{
 | 
			
		||||
	// we shorten the name for table name length restriction
 | 
			
		||||
	{Name: "edge_dft_v4", IsIPv4: true},
 | 
			
		||||
	{Name: "edge_dft_v6", IsIPv6: true},
 | 
			
		||||
}
 | 
			
		||||
var nftablesChainName = "input"
 | 
			
		||||
 | 
			
		||||
type nftablesTableDefinition struct {
 | 
			
		||||
	Name   string
 | 
			
		||||
	IsIPv4 bool
 | 
			
		||||
	IsIPv6 bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *nftablesTableDefinition) protocol() string {
 | 
			
		||||
	if this.IsIPv6 {
 | 
			
		||||
		return "ip6"
 | 
			
		||||
	}
 | 
			
		||||
	return "ip"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
 | 
			
		||||
	var firewall = &NFTablesFirewall{
 | 
			
		||||
		conn: nftables.NewConn(),
 | 
			
		||||
	}
 | 
			
		||||
	err := firewall.init()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return firewall, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NFTablesFirewall struct {
 | 
			
		||||
	conn    *nftables.Conn
 | 
			
		||||
	isReady bool
 | 
			
		||||
 | 
			
		||||
	allowIPv4Set *nftables.Set
 | 
			
		||||
	allowIPv6Set *nftables.Set
 | 
			
		||||
 | 
			
		||||
	denyIPv4Set *nftables.Set
 | 
			
		||||
	denyIPv6Set *nftables.Set
 | 
			
		||||
 | 
			
		||||
	firewalld *Firewalld
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *NFTablesFirewall) init() error {
 | 
			
		||||
	// check nft
 | 
			
		||||
	_, err := exec.LookPath("nft")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.New("nft not found")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// table
 | 
			
		||||
	for _, tableDef := range nftablesFilters {
 | 
			
		||||
		var family nftables.TableFamily
 | 
			
		||||
		if tableDef.IsIPv4 {
 | 
			
		||||
			family = nftables.TableFamilyIPv4
 | 
			
		||||
		} else if tableDef.IsIPv6 {
 | 
			
		||||
			family = nftables.TableFamilyIPv6
 | 
			
		||||
		} else {
 | 
			
		||||
			return errors.New("invalid table family: " + types.String(tableDef))
 | 
			
		||||
		}
 | 
			
		||||
		table, err := this.conn.GetTable(tableDef.Name, family)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if nftables.IsNotFound(err) {
 | 
			
		||||
				if tableDef.IsIPv4 {
 | 
			
		||||
					table, err = this.conn.AddIPv4Table(tableDef.Name)
 | 
			
		||||
				} else if tableDef.IsIPv6 {
 | 
			
		||||
					table, err = this.conn.AddIPv6Table(tableDef.Name)
 | 
			
		||||
				}
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return errors.New("create table '" + tableDef.Name + "' failed: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				return errors.New("get table '" + tableDef.Name + "' failed: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if table == nil {
 | 
			
		||||
			return errors.New("can not create table '" + tableDef.Name + "'")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// chain
 | 
			
		||||
		var chainName = nftablesChainName
 | 
			
		||||
		chain, err := table.GetChain(chainName)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if nftables.IsNotFound(err) {
 | 
			
		||||
				chain, err = table.AddAcceptChain(chainName)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return errors.New("create chain '" + chainName + "' failed: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				return errors.New("get chain '" + chainName + "' failed: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if chain == nil {
 | 
			
		||||
			return errors.New("can not create chain '" + chainName + "'")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// allow lo
 | 
			
		||||
		var loRuleName = []byte("lo")
 | 
			
		||||
		_, err = chain.GetRuleWithUserData(loRuleName)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if nftables.IsNotFound(err) {
 | 
			
		||||
				_, err = chain.AddAcceptInterfaceRule("lo", loRuleName)
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return errors.New("add 'lo' rule failed: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// allow set
 | 
			
		||||
		// "allow" should be always first
 | 
			
		||||
		for _, setAction := range []string{"allow", "deny"} {
 | 
			
		||||
			var setName = setAction + "_set"
 | 
			
		||||
 | 
			
		||||
			set, err := table.GetSet(setName)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				if nftables.IsNotFound(err) {
 | 
			
		||||
					var keyType nftables.SetDataType
 | 
			
		||||
					if tableDef.IsIPv4 {
 | 
			
		||||
						keyType = nftables.TypeIPAddr
 | 
			
		||||
					} else if tableDef.IsIPv6 {
 | 
			
		||||
						keyType = nftables.TypeIP6Addr
 | 
			
		||||
					}
 | 
			
		||||
					set, err = table.AddSet(setName, &nftables.SetOptions{
 | 
			
		||||
						KeyType:    keyType,
 | 
			
		||||
						HasTimeout: true,
 | 
			
		||||
					})
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.New("create set '" + setName + "' failed: " + err.Error())
 | 
			
		||||
					}
 | 
			
		||||
				} else {
 | 
			
		||||
					return errors.New("get set '" + setName + "' failed: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if set == nil {
 | 
			
		||||
				return errors.New("can not create set '" + setName + "'")
 | 
			
		||||
			}
 | 
			
		||||
			if tableDef.IsIPv4 {
 | 
			
		||||
				if setAction == "allow" {
 | 
			
		||||
					this.allowIPv4Set = set
 | 
			
		||||
				} else {
 | 
			
		||||
					this.denyIPv4Set = set
 | 
			
		||||
				}
 | 
			
		||||
			} else if tableDef.IsIPv6 {
 | 
			
		||||
				if setAction == "allow" {
 | 
			
		||||
					this.allowIPv6Set = set
 | 
			
		||||
				} else {
 | 
			
		||||
					this.denyIPv6Set = set
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// rule
 | 
			
		||||
			var ruleName = []byte(setAction)
 | 
			
		||||
			rule, err := chain.GetRuleWithUserData(ruleName)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				if nftables.IsNotFound(err) {
 | 
			
		||||
					if tableDef.IsIPv4 {
 | 
			
		||||
						if setAction == "allow" {
 | 
			
		||||
							rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName)
 | 
			
		||||
						} else {
 | 
			
		||||
							rule, err = chain.AddDropIPv4SetRule(setName, ruleName)
 | 
			
		||||
						}
 | 
			
		||||
					} else if tableDef.IsIPv6 {
 | 
			
		||||
						if setAction == "allow" {
 | 
			
		||||
							rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName)
 | 
			
		||||
						} else {
 | 
			
		||||
							rule, err = chain.AddDropIPv6SetRule(setName, ruleName)
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.New("add rule failed: " + err.Error())
 | 
			
		||||
					}
 | 
			
		||||
				} else {
 | 
			
		||||
					return errors.New("get rule failed: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if rule == nil {
 | 
			
		||||
				return errors.New("can not create rule '" + string(ruleName) + "'")
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.isReady = true
 | 
			
		||||
	nftablesIsReady = true
 | 
			
		||||
	nftablesInstance = this
 | 
			
		||||
 | 
			
		||||
	// load firewalld
 | 
			
		||||
	var firewalld = NewFirewalld()
 | 
			
		||||
	if firewalld.IsReady() {
 | 
			
		||||
		this.firewalld = firewalld
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Name 名称
 | 
			
		||||
func (this *NFTablesFirewall) Name() string {
 | 
			
		||||
	return "nftables"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsReady 是否已准备被调用
 | 
			
		||||
func (this *NFTablesFirewall) IsReady() bool {
 | 
			
		||||
	return this.isReady
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsMock 是否为模拟
 | 
			
		||||
func (this *NFTablesFirewall) IsMock() bool {
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AllowPort 允许端口
 | 
			
		||||
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
 | 
			
		||||
	if this.firewalld != nil {
 | 
			
		||||
		return this.firewalld.AllowPort(port, protocol)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemovePort 删除端口
 | 
			
		||||
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
 | 
			
		||||
	if this.firewalld != nil {
 | 
			
		||||
		return this.firewalld.RemovePort(port, protocol)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AllowSourceIP Allow把IP加入白名单
 | 
			
		||||
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
 | 
			
		||||
	var data = net.ParseIP(ip)
 | 
			
		||||
	if data == nil {
 | 
			
		||||
		return errors.New("invalid ip '" + ip + "'")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.Contains(ip, ":") { // ipv6
 | 
			
		||||
		if this.allowIPv6Set == nil {
 | 
			
		||||
			return errors.New("ipv6 ip set is nil")
 | 
			
		||||
		}
 | 
			
		||||
		return this.allowIPv6Set.AddElement(data.To16(), nil)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ipv4
 | 
			
		||||
	if this.allowIPv4Set == nil {
 | 
			
		||||
		return errors.New("ipv4 ip set is nil")
 | 
			
		||||
	}
 | 
			
		||||
	return this.allowIPv4Set.AddElement(data.To4(), nil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RejectSourceIP 拒绝某个源IP连接
 | 
			
		||||
// we did not create set for drop ip, so we reuse DropSourceIP() method here
 | 
			
		||||
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
 | 
			
		||||
	return this.DropSourceIP(ip, timeoutSeconds)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DropSourceIP 丢弃某个源IP数据
 | 
			
		||||
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
 | 
			
		||||
	var data = net.ParseIP(ip)
 | 
			
		||||
	if data == nil {
 | 
			
		||||
		return errors.New("invalid ip '" + ip + "'")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.Contains(ip, ":") { // ipv6
 | 
			
		||||
		if this.denyIPv6Set == nil {
 | 
			
		||||
			return errors.New("ipv6 ip set is nil")
 | 
			
		||||
		}
 | 
			
		||||
		return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
 | 
			
		||||
			Timeout: time.Duration(timeoutSeconds) * time.Second,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ipv4
 | 
			
		||||
	if this.denyIPv4Set == nil {
 | 
			
		||||
		return errors.New("ipv4 ip set is nil")
 | 
			
		||||
	}
 | 
			
		||||
	return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
 | 
			
		||||
		Timeout: time.Duration(timeoutSeconds) * time.Second,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveSourceIP 删除某个源IP
 | 
			
		||||
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
 | 
			
		||||
	var data = net.ParseIP(ip)
 | 
			
		||||
	if data == nil {
 | 
			
		||||
		return errors.New("invalid ip '" + ip + "'")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if strings.Contains(ip, ":") { // ipv6
 | 
			
		||||
		if this.denyIPv6Set != nil {
 | 
			
		||||
			err := this.denyIPv6Set.DeleteElement(data.To16())
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if this.allowIPv6Set != nil {
 | 
			
		||||
			err := this.allowIPv6Set.DeleteElement(data.To16())
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ipv4
 | 
			
		||||
	if this.allowIPv4Set != nil {
 | 
			
		||||
		err := this.denyIPv4Set.DeleteElement(data.To4())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = this.allowIPv4Set.DeleteElement(data.To4())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										61
									
								
								internal/firewalls/firewall_nftables_others.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								internal/firewalls/firewall_nftables_others.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,61 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build !linux
 | 
			
		||||
// +build !linux
 | 
			
		||||
 | 
			
		||||
package firewalls
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
 | 
			
		||||
	return nil, errors.New("not implemented")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NFTablesFirewall struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Name 名称
 | 
			
		||||
func (this *NFTablesFirewall) Name() string {
 | 
			
		||||
	return "nftables"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsReady 是否已准备被调用
 | 
			
		||||
func (this *NFTablesFirewall) IsReady() bool {
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsMock 是否为模拟
 | 
			
		||||
func (this *NFTablesFirewall) IsMock() bool {
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AllowPort 允许端口
 | 
			
		||||
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemovePort 删除端口
 | 
			
		||||
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AllowSourceIP Allow把IP加入白名单
 | 
			
		||||
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RejectSourceIP 拒绝某个源IP连接
 | 
			
		||||
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DropSourceIP 丢弃某个源IP数据
 | 
			
		||||
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveSourceIP 删除某个源IP
 | 
			
		||||
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										1
									
								
								internal/firewalls/nftables/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								internal/firewalls/nftables/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
build_remote.sh
 | 
			
		||||
							
								
								
									
										370
									
								
								internal/firewalls/nftables/chain.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										370
									
								
								internal/firewalls/nftables/chain.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,370 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build linux
 | 
			
		||||
// +build linux
 | 
			
		||||
 | 
			
		||||
package nftables
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"errors"
 | 
			
		||||
	nft "github.com/google/nftables"
 | 
			
		||||
	"github.com/google/nftables/expr"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const MaxChainNameLength = 31
 | 
			
		||||
 | 
			
		||||
type RuleOptions struct {
 | 
			
		||||
	Exprs    []expr.Any
 | 
			
		||||
	UserData []byte
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Chain chain object in table
 | 
			
		||||
type Chain struct {
 | 
			
		||||
	conn     *Conn
 | 
			
		||||
	rawTable *nft.Table
 | 
			
		||||
	rawChain *nft.Chain
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewChain(conn *Conn, rawTable *nft.Table, rawChain *nft.Chain) *Chain {
 | 
			
		||||
	return &Chain{
 | 
			
		||||
		conn:     conn,
 | 
			
		||||
		rawTable: rawTable,
 | 
			
		||||
		rawChain: rawChain,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) Raw() *nft.Chain {
 | 
			
		||||
	return this.rawChain
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) Name() string {
 | 
			
		||||
	return this.rawChain.Name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddRule(options *RuleOptions) (*Rule, error) {
 | 
			
		||||
	var rawRule = this.conn.Raw().AddRule(&nft.Rule{
 | 
			
		||||
		Table:    this.rawTable,
 | 
			
		||||
		Chain:    this.rawChain,
 | 
			
		||||
		Exprs:    options.Exprs,
 | 
			
		||||
		UserData: options.UserData,
 | 
			
		||||
	})
 | 
			
		||||
	err := this.conn.Commit()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return NewRule(rawRule), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddAcceptIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
 | 
			
		||||
	return this.AddRule(&RuleOptions{
 | 
			
		||||
		Exprs: []expr.Any{
 | 
			
		||||
			&expr.Payload{
 | 
			
		||||
				DestRegister: 1,
 | 
			
		||||
				Base:         expr.PayloadBaseNetworkHeader,
 | 
			
		||||
				Offset:       12,
 | 
			
		||||
				Len:          4,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Cmp{
 | 
			
		||||
				Op:       expr.CmpOpEq,
 | 
			
		||||
				Register: 1,
 | 
			
		||||
				Data:     ip,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Verdict{
 | 
			
		||||
				Kind: expr.VerdictAccept,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		UserData: userData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddAcceptIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
 | 
			
		||||
	return this.AddRule(&RuleOptions{
 | 
			
		||||
		Exprs: []expr.Any{
 | 
			
		||||
			&expr.Payload{
 | 
			
		||||
				DestRegister: 1,
 | 
			
		||||
				Base:         expr.PayloadBaseNetworkHeader,
 | 
			
		||||
				Offset:       8,
 | 
			
		||||
				Len:          16,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Cmp{
 | 
			
		||||
				Op:       expr.CmpOpEq,
 | 
			
		||||
				Register: 1,
 | 
			
		||||
				Data:     ip,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Verdict{
 | 
			
		||||
				Kind: expr.VerdictAccept,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		UserData: userData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddDropIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
 | 
			
		||||
	return this.AddRule(&RuleOptions{
 | 
			
		||||
		Exprs: []expr.Any{
 | 
			
		||||
			&expr.Payload{
 | 
			
		||||
				DestRegister: 1,
 | 
			
		||||
				Base:         expr.PayloadBaseNetworkHeader,
 | 
			
		||||
				Offset:       12,
 | 
			
		||||
				Len:          4,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Cmp{
 | 
			
		||||
				Op:       expr.CmpOpEq,
 | 
			
		||||
				Register: 1,
 | 
			
		||||
				Data:     ip,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Verdict{
 | 
			
		||||
				Kind: expr.VerdictDrop,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		UserData: userData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddDropIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
 | 
			
		||||
	return this.AddRule(&RuleOptions{
 | 
			
		||||
		Exprs: []expr.Any{
 | 
			
		||||
			&expr.Payload{
 | 
			
		||||
				DestRegister: 1,
 | 
			
		||||
				Base:         expr.PayloadBaseNetworkHeader,
 | 
			
		||||
				Offset:       8,
 | 
			
		||||
				Len:          16,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Cmp{
 | 
			
		||||
				Op:       expr.CmpOpEq,
 | 
			
		||||
				Register: 1,
 | 
			
		||||
				Data:     ip,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Verdict{
 | 
			
		||||
				Kind: expr.VerdictDrop,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		UserData: userData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddRejectIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
 | 
			
		||||
	return this.AddRule(&RuleOptions{
 | 
			
		||||
		Exprs: []expr.Any{
 | 
			
		||||
			&expr.Payload{
 | 
			
		||||
				DestRegister: 1,
 | 
			
		||||
				Base:         expr.PayloadBaseNetworkHeader,
 | 
			
		||||
				Offset:       12,
 | 
			
		||||
				Len:          4,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Cmp{
 | 
			
		||||
				Op:       expr.CmpOpEq,
 | 
			
		||||
				Register: 1,
 | 
			
		||||
				Data:     ip,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Reject{},
 | 
			
		||||
		},
 | 
			
		||||
		UserData: userData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddRejectIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
 | 
			
		||||
	return this.AddRule(&RuleOptions{
 | 
			
		||||
		Exprs: []expr.Any{
 | 
			
		||||
			&expr.Payload{
 | 
			
		||||
				DestRegister: 1,
 | 
			
		||||
				Base:         expr.PayloadBaseNetworkHeader,
 | 
			
		||||
				Offset:       8,
 | 
			
		||||
				Len:          16,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Cmp{
 | 
			
		||||
				Op:       expr.CmpOpEq,
 | 
			
		||||
				Register: 1,
 | 
			
		||||
				Data:     ip,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Reject{},
 | 
			
		||||
		},
 | 
			
		||||
		UserData: userData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddAcceptIPv4SetRule(setName string, userData []byte) (*Rule, error) {
 | 
			
		||||
	return this.AddRule(&RuleOptions{
 | 
			
		||||
		Exprs: []expr.Any{
 | 
			
		||||
			&expr.Payload{
 | 
			
		||||
				DestRegister: 1,
 | 
			
		||||
				Base:         expr.PayloadBaseNetworkHeader,
 | 
			
		||||
				Offset:       12,
 | 
			
		||||
				Len:          4,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Lookup{
 | 
			
		||||
				SourceRegister: 1,
 | 
			
		||||
				SetName:        setName,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Verdict{
 | 
			
		||||
				Kind: expr.VerdictAccept,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		UserData: userData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddAcceptIPv6SetRule(setName string, userData []byte) (*Rule, error) {
 | 
			
		||||
	return this.AddRule(&RuleOptions{
 | 
			
		||||
		Exprs: []expr.Any{
 | 
			
		||||
			&expr.Payload{
 | 
			
		||||
				DestRegister: 1,
 | 
			
		||||
				Base:         expr.PayloadBaseNetworkHeader,
 | 
			
		||||
				Offset:       8,
 | 
			
		||||
				Len:          16,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Lookup{
 | 
			
		||||
				SourceRegister: 1,
 | 
			
		||||
				SetName:        setName,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Verdict{
 | 
			
		||||
				Kind: expr.VerdictAccept,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		UserData: userData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddDropIPv4SetRule(setName string, userData []byte) (*Rule, error) {
 | 
			
		||||
	return this.AddRule(&RuleOptions{
 | 
			
		||||
		Exprs: []expr.Any{
 | 
			
		||||
			&expr.Payload{
 | 
			
		||||
				DestRegister: 1,
 | 
			
		||||
				Base:         expr.PayloadBaseNetworkHeader,
 | 
			
		||||
				Offset:       12,
 | 
			
		||||
				Len:          4,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Lookup{
 | 
			
		||||
				SourceRegister: 1,
 | 
			
		||||
				SetName:        setName,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Verdict{
 | 
			
		||||
				Kind: expr.VerdictDrop,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		UserData: userData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddDropIPv6SetRule(setName string, userData []byte) (*Rule, error) {
 | 
			
		||||
	return this.AddRule(&RuleOptions{
 | 
			
		||||
		Exprs: []expr.Any{
 | 
			
		||||
			&expr.Payload{
 | 
			
		||||
				DestRegister: 1,
 | 
			
		||||
				Base:         expr.PayloadBaseNetworkHeader,
 | 
			
		||||
				Offset:       8,
 | 
			
		||||
				Len:          16,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Lookup{
 | 
			
		||||
				SourceRegister: 1,
 | 
			
		||||
				SetName:        setName,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Verdict{
 | 
			
		||||
				Kind: expr.VerdictDrop,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		UserData: userData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddRejectIPv4SetRule(setName string, userData []byte) (*Rule, error) {
 | 
			
		||||
	return this.AddRule(&RuleOptions{
 | 
			
		||||
		Exprs: []expr.Any{
 | 
			
		||||
			&expr.Payload{
 | 
			
		||||
				DestRegister: 1,
 | 
			
		||||
				Base:         expr.PayloadBaseNetworkHeader,
 | 
			
		||||
				Offset:       12,
 | 
			
		||||
				Len:          4,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Lookup{
 | 
			
		||||
				SourceRegister: 1,
 | 
			
		||||
				SetName:        setName,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Reject{},
 | 
			
		||||
		},
 | 
			
		||||
		UserData: userData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddRejectIPv6SetRule(setName string, userData []byte) (*Rule, error) {
 | 
			
		||||
	return this.AddRule(&RuleOptions{
 | 
			
		||||
		Exprs: []expr.Any{
 | 
			
		||||
			&expr.Payload{
 | 
			
		||||
				DestRegister: 1,
 | 
			
		||||
				Base:         expr.PayloadBaseNetworkHeader,
 | 
			
		||||
				Offset:       8,
 | 
			
		||||
				Len:          16,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Lookup{
 | 
			
		||||
				SourceRegister: 1,
 | 
			
		||||
				SetName:        setName,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Reject{},
 | 
			
		||||
		},
 | 
			
		||||
		UserData: userData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) AddAcceptInterfaceRule(interfaceName string, userData []byte) (*Rule, error) {
 | 
			
		||||
	if len(interfaceName) >= 16 {
 | 
			
		||||
		return nil, errors.New("invalid interface name '" + interfaceName + "'")
 | 
			
		||||
	}
 | 
			
		||||
	var ifname = make([]byte, 16)
 | 
			
		||||
	copy(ifname, interfaceName+"\x00")
 | 
			
		||||
 | 
			
		||||
	return this.AddRule(&RuleOptions{
 | 
			
		||||
		Exprs: []expr.Any{
 | 
			
		||||
			&expr.Meta{
 | 
			
		||||
				Key:      expr.MetaKeyIIFNAME,
 | 
			
		||||
				Register: 1,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Cmp{
 | 
			
		||||
				Op:       expr.CmpOpEq,
 | 
			
		||||
				Register: 1,
 | 
			
		||||
				Data:     ifname,
 | 
			
		||||
			},
 | 
			
		||||
			&expr.Verdict{
 | 
			
		||||
				Kind: expr.VerdictAccept,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		UserData: userData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) GetRuleWithUserData(userData []byte) (*Rule, error) {
 | 
			
		||||
	rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	for _, rawRule := range rawRules {
 | 
			
		||||
		if bytes.Compare(rawRule.UserData, userData) == 0 {
 | 
			
		||||
			return NewRule(rawRule), nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil, ErrRuleNotFound
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) GetRules() ([]*Rule, error) {
 | 
			
		||||
	rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	var result = []*Rule{}
 | 
			
		||||
	for _, rawRule := range rawRules {
 | 
			
		||||
		result = append(result, NewRule(rawRule))
 | 
			
		||||
	}
 | 
			
		||||
	return result, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) DeleteRule(rule *Rule) error {
 | 
			
		||||
	err := this.conn.Raw().DelRule(rule.Raw())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return this.conn.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Chain) Flush() error {
 | 
			
		||||
	this.conn.Raw().FlushChain(this.rawChain)
 | 
			
		||||
	return this.conn.Commit()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										13
									
								
								internal/firewalls/nftables/chain_policy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								internal/firewalls/nftables/chain_policy.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,13 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package nftables
 | 
			
		||||
 | 
			
		||||
import nft "github.com/google/nftables"
 | 
			
		||||
 | 
			
		||||
type ChainPolicy = nft.ChainPolicy
 | 
			
		||||
 | 
			
		||||
// Possible ChainPolicy values.
 | 
			
		||||
const (
 | 
			
		||||
	ChainPolicyDrop   = nft.ChainPolicyDrop
 | 
			
		||||
	ChainPolicyAccept = nft.ChainPolicyAccept
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										130
									
								
								internal/firewalls/nftables/chain_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								internal/firewalls/nftables/chain_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,130 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build linux
 | 
			
		||||
// +build linux
 | 
			
		||||
 | 
			
		||||
package nftables_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
 | 
			
		||||
	"net"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getIPv4Chain(t *testing.T) *nftables.Chain {
 | 
			
		||||
	var conn = nftables.NewConn()
 | 
			
		||||
	table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == nftables.ErrTableNotFound {
 | 
			
		||||
			table, err = conn.AddIPv4Table("test_ipv4")
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatal(err)
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	chain, err := table.GetChain("test_chain")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == nftables.ErrChainNotFound {
 | 
			
		||||
			chain, err = table.AddAcceptChain("test_chain")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return chain
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChain_AddAcceptIPRule(t *testing.T) {
 | 
			
		||||
	var chain = getIPv4Chain(t)
 | 
			
		||||
	_, err := chain.AddAcceptIPv4Rule(net.ParseIP("192.168.2.40").To4(), nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChain_AddDropIPRule(t *testing.T) {
 | 
			
		||||
	var chain = getIPv4Chain(t)
 | 
			
		||||
	_, err := chain.AddDropIPv4Rule(net.ParseIP("192.168.2.31").To4(), nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChain_AddAcceptSetRule(t *testing.T) {
 | 
			
		||||
	var chain = getIPv4Chain(t)
 | 
			
		||||
	_, err := chain.AddAcceptIPv4SetRule("ipv4_black_set", nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChain_AddDropSetRule(t *testing.T) {
 | 
			
		||||
	var chain = getIPv4Chain(t)
 | 
			
		||||
	_, err := chain.AddDropIPv4SetRule("ipv4_black_set", nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChain_AddRejectSetRule(t *testing.T) {
 | 
			
		||||
	var chain = getIPv4Chain(t)
 | 
			
		||||
	_, err := chain.AddRejectIPv4SetRule("ipv4_black_set", nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChain_GetRuleWithUserData(t *testing.T) {
 | 
			
		||||
	var chain = getIPv4Chain(t)
 | 
			
		||||
	rule, err := chain.GetRuleWithUserData([]byte("test"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == nftables.ErrRuleNotFound {
 | 
			
		||||
			t.Log("rule not found")
 | 
			
		||||
			return
 | 
			
		||||
		} else {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("rule:", rule)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChain_GetRules(t *testing.T) {
 | 
			
		||||
	var chain = getIPv4Chain(t)
 | 
			
		||||
	rules, err := chain.GetRules()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	for _, rule := range rules {
 | 
			
		||||
		t.Log("handle:", rule.Handle(), "set name:", rule.LookupSetName(),
 | 
			
		||||
			"verdict:", rule.VerDict(), "user data:", string(rule.UserData()))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChain_DeleteRule(t *testing.T) {
 | 
			
		||||
	var chain = getIPv4Chain(t)
 | 
			
		||||
	rule, err := chain.GetRuleWithUserData([]byte("test"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == nftables.ErrRuleNotFound {
 | 
			
		||||
			t.Log("rule not found")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	err = chain.DeleteRule(rule)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChain_Flush(t *testing.T) {
 | 
			
		||||
	var chain = getIPv4Chain(t)
 | 
			
		||||
	err := chain.Flush()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("ok")
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										84
									
								
								internal/firewalls/nftables/conn.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								internal/firewalls/nftables/conn.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,84 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build linux
 | 
			
		||||
// +build linux
 | 
			
		||||
 | 
			
		||||
package nftables
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	nft "github.com/google/nftables"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const MaxTableNameLength = 27
 | 
			
		||||
 | 
			
		||||
type Conn struct {
 | 
			
		||||
	rawConn *nft.Conn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewConn() *Conn {
 | 
			
		||||
	return &Conn{
 | 
			
		||||
		rawConn: &nft.Conn{},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Conn) Raw() *nft.Conn {
 | 
			
		||||
	return this.rawConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Conn) GetTable(name string, family TableFamily) (*Table, error) {
 | 
			
		||||
	rawTables, err := this.rawConn.ListTables()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, rawTable := range rawTables {
 | 
			
		||||
		if rawTable.Name == name && rawTable.Family == family {
 | 
			
		||||
			return NewTable(this, rawTable), nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, ErrTableNotFound
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Conn) AddTable(name string, family TableFamily) (*Table, error) {
 | 
			
		||||
	if len(name) > MaxTableNameLength {
 | 
			
		||||
		return nil, errors.New("table name too long (max " + types.String(MaxTableNameLength) + ")")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var rawTable = this.rawConn.AddTable(&nft.Table{
 | 
			
		||||
		Family: family,
 | 
			
		||||
		Name:   name,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	err := this.Commit()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return NewTable(this, rawTable), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Conn) AddIPv4Table(name string) (*Table, error) {
 | 
			
		||||
	return this.AddTable(name, TableFamilyIPv4)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Conn) AddIPv6Table(name string) (*Table, error) {
 | 
			
		||||
	return this.AddTable(name, TableFamilyIPv6)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Conn) DeleteTable(name string, family TableFamily) error {
 | 
			
		||||
	table, err := this.GetTable(name, family)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == ErrTableNotFound {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	this.rawConn.DelTable(table.Raw())
 | 
			
		||||
	return this.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Conn) Commit() error {
 | 
			
		||||
	return this.rawConn.Flush()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										78
									
								
								internal/firewalls/nftables/conn_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								internal/firewalls/nftables/conn_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,78 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build linux
 | 
			
		||||
// +build linux
 | 
			
		||||
 | 
			
		||||
package nftables_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestConn_Test(t *testing.T) {
 | 
			
		||||
	_, err := exec.LookPath("nft")
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		t.Log("ok")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	t.Log(err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestConn_GetTable_NotFound(t *testing.T) {
 | 
			
		||||
	var conn = nftables.NewConn()
 | 
			
		||||
 | 
			
		||||
	table, err := conn.GetTable("a", nftables.TableFamilyIPv4)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == nftables.ErrTableNotFound {
 | 
			
		||||
			t.Log("table not found")
 | 
			
		||||
		} else {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		t.Log("table:", table)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestConn_GetTable(t *testing.T) {
 | 
			
		||||
	var conn = nftables.NewConn()
 | 
			
		||||
 | 
			
		||||
	table, err := conn.GetTable("myFilter", nftables.TableFamilyIPv4)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == nftables.ErrTableNotFound {
 | 
			
		||||
			t.Log("table not found")
 | 
			
		||||
		} else {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		t.Log("table:", table)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestConn_AddTable(t *testing.T) {
 | 
			
		||||
	var conn = nftables.NewConn()
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		table, err := conn.AddIPv4Table("test_ipv4")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
		t.Log(table.Name())
 | 
			
		||||
	}
 | 
			
		||||
	{
 | 
			
		||||
		table, err := conn.AddIPv6Table("test_ipv6")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
		t.Log(table.Name())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestConn_DeleteTable(t *testing.T) {
 | 
			
		||||
	var conn = nftables.NewConn()
 | 
			
		||||
	err := conn.DeleteTable("test_ipv4", nftables.TableFamilyIPv4)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("ok")
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										8
									
								
								internal/firewalls/nftables/element.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								internal/firewalls/nftables/element.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,8 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build linux
 | 
			
		||||
// +build linux
 | 
			
		||||
 | 
			
		||||
package nftables
 | 
			
		||||
 | 
			
		||||
type Element struct {
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										19
									
								
								internal/firewalls/nftables/errors.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								internal/firewalls/nftables/errors.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build linux
 | 
			
		||||
// +build linux
 | 
			
		||||
 | 
			
		||||
package nftables
 | 
			
		||||
 | 
			
		||||
import "errors"
 | 
			
		||||
 | 
			
		||||
var ErrTableNotFound = errors.New("table not found")
 | 
			
		||||
var ErrChainNotFound = errors.New("chain not found")
 | 
			
		||||
var ErrSetNotFound = errors.New("set not found")
 | 
			
		||||
var ErrRuleNotFound = errors.New("rule not found")
 | 
			
		||||
 | 
			
		||||
func IsNotFound(err error) bool {
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										18
									
								
								internal/firewalls/nftables/family.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								internal/firewalls/nftables/family.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,18 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package nftables
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	nft "github.com/google/nftables"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TableFamily = nft.TableFamily
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	TableFamilyINet   TableFamily = nft.TableFamilyINet
 | 
			
		||||
	TableFamilyIPv4   TableFamily = nft.TableFamilyIPv4
 | 
			
		||||
	TableFamilyIPv6   TableFamily = nft.TableFamilyIPv6
 | 
			
		||||
	TableFamilyARP    TableFamily = nft.TableFamilyARP
 | 
			
		||||
	TableFamilyNetdev TableFamily = nft.TableFamilyNetdev
 | 
			
		||||
	TableFamilyBridge TableFamily = nft.TableFamilyBridge
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										51
									
								
								internal/firewalls/nftables/rule.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								internal/firewalls/nftables/rule.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,51 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package nftables
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	nft "github.com/google/nftables"
 | 
			
		||||
	"github.com/google/nftables/expr"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Rule struct {
 | 
			
		||||
	rawRule *nft.Rule
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRule(rawRule *nft.Rule) *Rule {
 | 
			
		||||
	return &Rule{
 | 
			
		||||
		rawRule: rawRule,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Rule) Raw() *nft.Rule {
 | 
			
		||||
	return this.rawRule
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Rule) LookupSetName() string {
 | 
			
		||||
	for _, e := range this.rawRule.Exprs {
 | 
			
		||||
		exp, ok := e.(*expr.Lookup)
 | 
			
		||||
		if ok {
 | 
			
		||||
			return exp.SetName
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Rule) VerDict() expr.VerdictKind {
 | 
			
		||||
	for _, e := range this.rawRule.Exprs {
 | 
			
		||||
		exp, ok := e.(*expr.Verdict)
 | 
			
		||||
		if ok {
 | 
			
		||||
			return exp.Kind
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return -100
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Rule) Handle() uint64 {
 | 
			
		||||
	return this.rawRule.Handle
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Rule) UserData() []byte {
 | 
			
		||||
	return this.rawRule.UserData
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										161
									
								
								internal/firewalls/nftables/set.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										161
									
								
								internal/firewalls/nftables/set.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,161 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build linux
 | 
			
		||||
// +build linux
 | 
			
		||||
 | 
			
		||||
package nftables
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
			
		||||
	nft "github.com/google/nftables"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const MaxSetNameLength = 15
 | 
			
		||||
 | 
			
		||||
type SetOptions struct {
 | 
			
		||||
	Id         uint32
 | 
			
		||||
	HasTimeout bool
 | 
			
		||||
	Timeout    time.Duration
 | 
			
		||||
	KeyType    SetDataType
 | 
			
		||||
	DataType   SetDataType
 | 
			
		||||
	Constant   bool
 | 
			
		||||
	Interval   bool
 | 
			
		||||
	Anonymous  bool
 | 
			
		||||
	IsMap      bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ElementOptions struct {
 | 
			
		||||
	Timeout time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Set struct {
 | 
			
		||||
	conn   *Conn
 | 
			
		||||
	rawSet *nft.Set
 | 
			
		||||
	batch  *SetBatch
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSet(conn *Conn, rawSet *nft.Set) *Set {
 | 
			
		||||
	return &Set{
 | 
			
		||||
		conn:   conn,
 | 
			
		||||
		rawSet: rawSet,
 | 
			
		||||
		batch: &SetBatch{
 | 
			
		||||
			conn:   conn,
 | 
			
		||||
			rawSet: rawSet,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Set) Raw() *nft.Set {
 | 
			
		||||
	return this.rawSet
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Set) Name() string {
 | 
			
		||||
	return this.rawSet.Name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Set) AddElement(key []byte, options *ElementOptions) error {
 | 
			
		||||
	var rawElement = nft.SetElement{
 | 
			
		||||
		Key: key,
 | 
			
		||||
	}
 | 
			
		||||
	if options != nil {
 | 
			
		||||
		rawElement.Timeout = options.Timeout
 | 
			
		||||
	}
 | 
			
		||||
	err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
 | 
			
		||||
		rawElement,
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = this.conn.Commit()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		// retry if exists
 | 
			
		||||
		if strings.Contains(err.Error(), "file exists") {
 | 
			
		||||
			deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
 | 
			
		||||
				{
 | 
			
		||||
					Key: key,
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
			if deleteErr == nil {
 | 
			
		||||
				err = this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
 | 
			
		||||
					rawElement,
 | 
			
		||||
				})
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					err = this.conn.Commit()
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Set) AddIPElement(ip string, options *ElementOptions) error {
 | 
			
		||||
	var ipObj = net.ParseIP(ip)
 | 
			
		||||
	if ipObj == nil {
 | 
			
		||||
		return errors.New("invalid ip '" + ip + "'")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if utils.IsIPv4(ip) {
 | 
			
		||||
		return this.AddElement(ipObj.To4(), options)
 | 
			
		||||
	} else {
 | 
			
		||||
		return this.AddElement(ipObj.To16(), options)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Set) DeleteElement(key []byte) error {
 | 
			
		||||
	err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
 | 
			
		||||
		{
 | 
			
		||||
			Key: key,
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	err = this.conn.Commit()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if strings.Contains(err.Error(), "no such file or directory") {
 | 
			
		||||
			err = nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Set) DeleteIPElement(ip string) error {
 | 
			
		||||
	var ipObj = net.ParseIP(ip)
 | 
			
		||||
	if ipObj == nil {
 | 
			
		||||
		return errors.New("invalid ip '" + ip + "'")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if utils.IsIPv4(ip) {
 | 
			
		||||
		return this.DeleteElement(ipObj.To4())
 | 
			
		||||
	} else {
 | 
			
		||||
		return this.DeleteElement(ipObj.To16())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Set) Batch() *SetBatch {
 | 
			
		||||
	return this.batch
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Set) GetIPElements() ([]string, error) {
 | 
			
		||||
	elements, err := this.conn.Raw().GetSetElements(this.rawSet)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var result = []string{}
 | 
			
		||||
	for _, element := range elements {
 | 
			
		||||
		result = append(result, net.IP(element.Key).String())
 | 
			
		||||
	}
 | 
			
		||||
	return result, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// not work current time
 | 
			
		||||
/**func (this *Set) Flush() error {
 | 
			
		||||
	this.conn.Raw().FlushSet(this.rawSet)
 | 
			
		||||
	return this.conn.Commit()
 | 
			
		||||
}**/
 | 
			
		||||
							
								
								
									
										36
									
								
								internal/firewalls/nftables/set_batch.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								internal/firewalls/nftables/set_batch.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,36 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package nftables
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	nft "github.com/google/nftables"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type SetBatch struct {
 | 
			
		||||
	conn   *Conn
 | 
			
		||||
	rawSet *nft.Set
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *SetBatch) AddElement(key []byte, options *ElementOptions) error {
 | 
			
		||||
	var rawElement = nft.SetElement{
 | 
			
		||||
		Key: key,
 | 
			
		||||
	}
 | 
			
		||||
	if options != nil {
 | 
			
		||||
		rawElement.Timeout = options.Timeout
 | 
			
		||||
	}
 | 
			
		||||
	return this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
 | 
			
		||||
		rawElement,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *SetBatch) DeleteElement(key []byte) error {
 | 
			
		||||
	return this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
 | 
			
		||||
		{
 | 
			
		||||
			Key: key,
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *SetBatch) Commit() error {
 | 
			
		||||
	return this.conn.Commit()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										57
									
								
								internal/firewalls/nftables/set_data_type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								internal/firewalls/nftables/set_data_type.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,57 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package nftables
 | 
			
		||||
 | 
			
		||||
import nft "github.com/google/nftables"
 | 
			
		||||
 | 
			
		||||
type SetDataType = nft.SetDatatype
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	TypeInvalid     = nft.TypeInvalid
 | 
			
		||||
	TypeVerdict     = nft.TypeVerdict
 | 
			
		||||
	TypeNFProto     = nft.TypeNFProto
 | 
			
		||||
	TypeBitmask     = nft.TypeBitmask
 | 
			
		||||
	TypeInteger     = nft.TypeInteger
 | 
			
		||||
	TypeString      = nft.TypeString
 | 
			
		||||
	TypeLLAddr      = nft.TypeLLAddr
 | 
			
		||||
	TypeIPAddr      = nft.TypeIPAddr
 | 
			
		||||
	TypeIP6Addr     = nft.TypeIP6Addr
 | 
			
		||||
	TypeEtherAddr   = nft.TypeEtherAddr
 | 
			
		||||
	TypeEtherType   = nft.TypeEtherType
 | 
			
		||||
	TypeARPOp       = nft.TypeARPOp
 | 
			
		||||
	TypeInetProto   = nft.TypeInetProto
 | 
			
		||||
	TypeInetService = nft.TypeInetService
 | 
			
		||||
	TypeICMPType    = nft.TypeICMPType
 | 
			
		||||
	TypeTCPFlag     = nft.TypeTCPFlag
 | 
			
		||||
	TypeDCCPPktType = nft.TypeDCCPPktType
 | 
			
		||||
	TypeMHType      = nft.TypeMHType
 | 
			
		||||
	TypeTime        = nft.TypeTime
 | 
			
		||||
	TypeMark        = nft.TypeMark
 | 
			
		||||
	TypeIFIndex     = nft.TypeIFIndex
 | 
			
		||||
	TypeARPHRD      = nft.TypeARPHRD
 | 
			
		||||
	TypeRealm       = nft.TypeRealm
 | 
			
		||||
	TypeClassID     = nft.TypeClassID
 | 
			
		||||
	TypeUID         = nft.TypeUID
 | 
			
		||||
	TypeGID         = nft.TypeGID
 | 
			
		||||
	TypeCTState     = nft.TypeCTState
 | 
			
		||||
	TypeCTDir       = nft.TypeCTDir
 | 
			
		||||
	TypeCTStatus    = nft.TypeCTStatus
 | 
			
		||||
	TypeICMP6Type   = nft.TypeICMP6Type
 | 
			
		||||
	TypeCTLabel     = nft.TypeCTLabel
 | 
			
		||||
	TypePktType     = nft.TypePktType
 | 
			
		||||
	TypeICMPCode    = nft.TypeICMPCode
 | 
			
		||||
	TypeICMPV6Code  = nft.TypeICMPV6Code
 | 
			
		||||
	TypeICMPXCode   = nft.TypeICMPXCode
 | 
			
		||||
	TypeDevGroup    = nft.TypeDevGroup
 | 
			
		||||
	TypeDSCP        = nft.TypeDSCP
 | 
			
		||||
	TypeECN         = nft.TypeECN
 | 
			
		||||
	TypeFIBAddr     = nft.TypeFIBAddr
 | 
			
		||||
	TypeBoolean     = nft.TypeBoolean
 | 
			
		||||
	TypeCTEventBit  = nft.TypeCTEventBit
 | 
			
		||||
	TypeIFName      = nft.TypeIFName
 | 
			
		||||
	TypeIGMPType    = nft.TypeIGMPType
 | 
			
		||||
	TypeTimeDate    = nft.TypeTimeDate
 | 
			
		||||
	TypeTimeHour    = nft.TypeTimeHour
 | 
			
		||||
	TypeTimeDay     = nft.TypeTimeDay
 | 
			
		||||
	TypeCGroupV2    = nft.TypeCGroupV2
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										110
									
								
								internal/firewalls/nftables/set_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								internal/firewalls/nftables/set_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,110 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
 | 
			
		||||
package nftables_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"github.com/mdlayher/netlink"
 | 
			
		||||
	"net"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getIPv4Set(t *testing.T) *nftables.Set {
 | 
			
		||||
	var table = getIPv4Table(t)
 | 
			
		||||
	set, err := table.GetSet("test_ipv4_set")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == nftables.ErrSetNotFound {
 | 
			
		||||
			set, err = table.AddSet("test_ipv4_set", &nftables.SetOptions{
 | 
			
		||||
				KeyType:    nftables.TypeIPAddr,
 | 
			
		||||
				HasTimeout: true,
 | 
			
		||||
			})
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatal(err)
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return set
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSet_AddElement(t *testing.T) {
 | 
			
		||||
	var set = getIPv4Set(t)
 | 
			
		||||
	err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("ok")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSet_DeleteElement(t *testing.T) {
 | 
			
		||||
	var set = getIPv4Set(t)
 | 
			
		||||
	err := set.DeleteElement(net.ParseIP("192.168.2.31").To4())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("ok")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSet_Batch(t *testing.T) {
 | 
			
		||||
	var batch = getIPv4Set(t).Batch()
 | 
			
		||||
 | 
			
		||||
	for _, ip := range []string{"192.168.2.30", "192.168.2.31", "192.168.2.32", "192.168.2.33", "192.168.2.34"} {
 | 
			
		||||
		var ipData = net.ParseIP(ip).To4()
 | 
			
		||||
		//err := batch.DeleteElement(ipData)
 | 
			
		||||
		//if err != nil {
 | 
			
		||||
		//	t.Fatal(err)
 | 
			
		||||
		//}
 | 
			
		||||
		err := batch.AddElement(ipData, &nftables.ElementOptions{Timeout: 10 * time.Second})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := batch.Commit()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Logf("%#v", errors.Unwrap(err).(*netlink.OpError))
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("ok")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSet_Add_Many(t *testing.T) {
 | 
			
		||||
	var set = getIPv4Set(t)
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < 255; i++ {
 | 
			
		||||
		t.Log(i)
 | 
			
		||||
		for j := 0; j < 255; j++ {
 | 
			
		||||
			var ip = "192.167." + types.String(i) + "." + types.String(j)
 | 
			
		||||
			var ipData = net.ParseIP(ip).To4()
 | 
			
		||||
			err := set.Batch().AddElement(ipData, &nftables.ElementOptions{Timeout: 3600 * time.Second})
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatal(err)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if j%10 == 0 {
 | 
			
		||||
				err = set.Batch().Commit()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Fatal(err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		err := set.Batch().Commit()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("ok")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**func TestSet_Flush(t *testing.T) {
 | 
			
		||||
	var set = getIPv4Set(t)
 | 
			
		||||
	err := set.Flush()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("ok")
 | 
			
		||||
}**/
 | 
			
		||||
							
								
								
									
										157
									
								
								internal/firewalls/nftables/table.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										157
									
								
								internal/firewalls/nftables/table.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,157 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build linux
 | 
			
		||||
// +build linux
 | 
			
		||||
 | 
			
		||||
package nftables
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	nft "github.com/google/nftables"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Table struct {
 | 
			
		||||
	conn     *Conn
 | 
			
		||||
	rawTable *nft.Table
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewTable(conn *Conn, rawTable *nft.Table) *Table {
 | 
			
		||||
	return &Table{
 | 
			
		||||
		conn:     conn,
 | 
			
		||||
		rawTable: rawTable,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table) Raw() *nft.Table {
 | 
			
		||||
	return this.rawTable
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table) Name() string {
 | 
			
		||||
	return this.rawTable.Name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table) Family() TableFamily {
 | 
			
		||||
	return this.rawTable.Family
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table) GetChain(name string) (*Chain, error) {
 | 
			
		||||
	rawChains, err := this.conn.Raw().ListChains()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	for _, rawChain := range rawChains {
 | 
			
		||||
		// must compare table name
 | 
			
		||||
		if rawChain.Name == name && rawChain.Table.Name == this.rawTable.Name {
 | 
			
		||||
			return NewChain(this.conn, this.rawTable, rawChain), nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil, ErrChainNotFound
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table) AddChain(name string, chainPolicy *ChainPolicy) (*Chain, error) {
 | 
			
		||||
	if len(name) > MaxChainNameLength {
 | 
			
		||||
		return nil, errors.New("chain name too long (max " + types.String(MaxChainNameLength) + ")")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var rawChain = this.conn.Raw().AddChain(&nft.Chain{
 | 
			
		||||
		Name:     name,
 | 
			
		||||
		Table:    this.rawTable,
 | 
			
		||||
		Hooknum:  nft.ChainHookInput,
 | 
			
		||||
		Priority: nft.ChainPriorityFilter,
 | 
			
		||||
		Type:     nft.ChainTypeFilter,
 | 
			
		||||
		Policy:   chainPolicy,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	err := this.conn.Commit()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return NewChain(this.conn, this.rawTable, rawChain), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table) AddAcceptChain(name string) (*Chain, error) {
 | 
			
		||||
	var policy = ChainPolicyAccept
 | 
			
		||||
	return this.AddChain(name, &policy)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table) AddDropChain(name string) (*Chain, error) {
 | 
			
		||||
	var policy = ChainPolicyDrop
 | 
			
		||||
	return this.AddChain(name, &policy)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table) DeleteChain(name string) error {
 | 
			
		||||
	chain, err := this.GetChain(name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == ErrChainNotFound {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	this.conn.Raw().DelChain(chain.Raw())
 | 
			
		||||
	return this.conn.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table) GetSet(name string) (*Set, error) {
 | 
			
		||||
	rawSet, err := this.conn.Raw().GetSetByName(this.rawTable, name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if strings.Contains(err.Error(), "no such file or directory") {
 | 
			
		||||
			return nil, ErrSetNotFound
 | 
			
		||||
		}
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return NewSet(this.conn, rawSet), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table) AddSet(name string, options *SetOptions) (*Set, error) {
 | 
			
		||||
	if len(name) > MaxSetNameLength {
 | 
			
		||||
		return nil, errors.New("set name too long (max " + types.String(MaxSetNameLength) + ")")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if options == nil {
 | 
			
		||||
		options = &SetOptions{}
 | 
			
		||||
	}
 | 
			
		||||
	var rawSet = &nft.Set{
 | 
			
		||||
		Table:      this.rawTable,
 | 
			
		||||
		ID:         options.Id,
 | 
			
		||||
		Name:       name,
 | 
			
		||||
		Anonymous:  options.Anonymous,
 | 
			
		||||
		Constant:   options.Constant,
 | 
			
		||||
		Interval:   options.Interval,
 | 
			
		||||
		IsMap:      options.IsMap,
 | 
			
		||||
		HasTimeout: options.HasTimeout,
 | 
			
		||||
		Timeout:    options.Timeout,
 | 
			
		||||
		KeyType:    options.KeyType,
 | 
			
		||||
		DataType:   options.DataType,
 | 
			
		||||
	}
 | 
			
		||||
	err := this.conn.Raw().AddSet(rawSet, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = this.conn.Commit()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return NewSet(this.conn, rawSet), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table) DeleteSet(name string) error {
 | 
			
		||||
	set, err := this.GetSet(name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == ErrSetNotFound {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	this.conn.Raw().DelSet(set.Raw())
 | 
			
		||||
	return this.conn.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Table) Flush() error {
 | 
			
		||||
	this.conn.Raw().FlushTable(this.rawTable)
 | 
			
		||||
	return this.conn.Commit()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										140
									
								
								internal/firewalls/nftables/table_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										140
									
								
								internal/firewalls/nftables/table_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,140 @@
 | 
			
		||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
 | 
			
		||||
//go:build linux
 | 
			
		||||
// +build linux
 | 
			
		||||
 | 
			
		||||
package nftables_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getIPv4Table(t *testing.T) *nftables.Table {
 | 
			
		||||
	var conn = nftables.NewConn()
 | 
			
		||||
	table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == nftables.ErrTableNotFound {
 | 
			
		||||
			table, err = conn.AddIPv4Table("test_ipv4")
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatal(err)
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return table
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTable_AddChain(t *testing.T) {
 | 
			
		||||
	var table = getIPv4Table(t)
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		chain, err := table.AddChain("test_default_chain", nil)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
		t.Log("created:", chain.Name())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		chain, err := table.AddAcceptChain("test_accept_chain")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
		t.Log("created:", chain.Name())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Do not test drop chain before adding accept rule, you will drop yourself!!!!!!!
 | 
			
		||||
	/**{
 | 
			
		||||
		chain, err := table.AddDropChain("test_drop_chain")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
		t.Log("created:", chain.Name())
 | 
			
		||||
	}**/
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTable_GetChain(t *testing.T) {
 | 
			
		||||
	var table = getIPv4Table(t)
 | 
			
		||||
	for _, chainName := range []string{"not_found_chain", "test_default_chain"} {
 | 
			
		||||
		chain, err := table.GetChain(chainName)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if err == nftables.ErrChainNotFound {
 | 
			
		||||
				t.Log(chainName, ":", "not found")
 | 
			
		||||
			} else {
 | 
			
		||||
				t.Fatal(err)
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			t.Log(chainName, ":", chain)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTable_DeleteChain(t *testing.T) {
 | 
			
		||||
	var table = getIPv4Table(t)
 | 
			
		||||
	err := table.DeleteChain("test_default_chain")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("ok")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTable_AddSet(t *testing.T) {
 | 
			
		||||
	var table = getIPv4Table(t)
 | 
			
		||||
	{
 | 
			
		||||
		set, err := table.AddSet("ipv4_black_set", &nftables.SetOptions{
 | 
			
		||||
			HasTimeout: false,
 | 
			
		||||
			KeyType:    nftables.TypeIPAddr,
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
		t.Log(set.Name())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		set, err := table.AddSet("ipv6_black_set", &nftables.SetOptions{
 | 
			
		||||
			HasTimeout: true,
 | 
			
		||||
			//Timeout:    3600 * time.Second,
 | 
			
		||||
			KeyType: nftables.TypeIP6Addr,
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
		t.Log(set.Name())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTable_GetSet(t *testing.T) {
 | 
			
		||||
	var table = getIPv4Table(t)
 | 
			
		||||
	for _, setName := range []string{"not_found_set", "ipv4_black_set"} {
 | 
			
		||||
		set, err := table.GetSet(setName)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if err == nftables.ErrSetNotFound {
 | 
			
		||||
				t.Log(setName, ": not found")
 | 
			
		||||
			} else {
 | 
			
		||||
				t.Fatal(err)
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			t.Log(setName, ":", set)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTable_DeleteSet(t *testing.T) {
 | 
			
		||||
	var table = getIPv4Table(t)
 | 
			
		||||
	err := table.DeleteSet("ipv4_black_set")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("ok")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTable_Flush(t *testing.T) {
 | 
			
		||||
	var table = getIPv4Table(t)
 | 
			
		||||
	err := table.Flush()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	t.Log("ok")
 | 
			
		||||
}
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package nodes
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
@@ -14,16 +15,20 @@ import (
 | 
			
		||||
	teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/errors"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/events"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/firewalls"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/goman"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/rpc"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
			
		||||
	"github.com/iwind/TeaGo/Tea"
 | 
			
		||||
	"github.com/iwind/TeaGo/maps"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
@@ -119,6 +124,8 @@ func (this *APIStream) loop() error {
 | 
			
		||||
			err = this.handleNewNodeTask(message)
 | 
			
		||||
		case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务
 | 
			
		||||
			err = this.handleCheckSystemdService(message)
 | 
			
		||||
		case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
 | 
			
		||||
			err = this.handleCheckLocalFirewall(message)
 | 
			
		||||
		case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址
 | 
			
		||||
			err = this.handleChangeAPINode(message)
 | 
			
		||||
		default:
 | 
			
		||||
@@ -569,7 +576,7 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cmd := utils.NewCommandExecutor()
 | 
			
		||||
	var cmd = utils.NewCommandExecutor()
 | 
			
		||||
	shortName := teaconst.SystemdServiceName
 | 
			
		||||
	cmd.Add(systemctl, "is-enabled", shortName)
 | 
			
		||||
	output, err := cmd.Run()
 | 
			
		||||
@@ -585,6 +592,63 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 检查本地防火墙
 | 
			
		||||
func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) error {
 | 
			
		||||
	var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
 | 
			
		||||
	err := json.Unmarshal(message.DataJSON, dataMessage)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// nft
 | 
			
		||||
	if dataMessage.Name == "nftables" {
 | 
			
		||||
		if runtime.GOOS != "linux" {
 | 
			
		||||
			this.replyFail(message.RequestId, "not Linux system")
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		nft, err := exec.LookPath("nft")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var cmd = exec.Command(nft, "--version")
 | 
			
		||||
		var output = &bytes.Buffer{}
 | 
			
		||||
		cmd.Stdout = output
 | 
			
		||||
		err = cmd.Run()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			this.replyFail(message.RequestId, "get version failed: "+err.Error())
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var outputString = output.String()
 | 
			
		||||
		var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
 | 
			
		||||
		if len(versionMatches) <= 1 {
 | 
			
		||||
			this.replyFail(message.RequestId, "can not get nft version")
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		var version = versionMatches[1]
 | 
			
		||||
 | 
			
		||||
		var result = maps.Map{
 | 
			
		||||
			"version": version,
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var protectionConfig = sharedNodeConfig.DDOSProtection
 | 
			
		||||
		err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error())
 | 
			
		||||
		} else {
 | 
			
		||||
			this.replyOk(message.RequestId, string(result.AsJSON()))
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 修改API地址
 | 
			
		||||
func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error {
 | 
			
		||||
	config, err := configs.LoadAPIConfig()
 | 
			
		||||
@@ -660,6 +724,11 @@ func (this *APIStream) replyOk(requestId int64, message string) {
 | 
			
		||||
	_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 回复成功并包含数据
 | 
			
		||||
func (this *APIStream) replyOkData(requestId int64, message string, dataJSON []byte) {
 | 
			
		||||
	_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message, DataJSON: dataJSON})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取缓存存取对象
 | 
			
		||||
func (this *APIStream) cacheStorage(message *pb.NodeStreamMessage, cachePolicyJSON []byte) (storage caches.StorageInterface, shouldStop bool, err error) {
 | 
			
		||||
	cachePolicy := &serverconfigs.HTTPCachePolicy{}
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,6 @@ import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
			
		||||
	teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/waf"
 | 
			
		||||
@@ -21,8 +20,7 @@ import (
 | 
			
		||||
 | 
			
		||||
// ClientConn 客户端连接
 | 
			
		||||
type ClientConn struct {
 | 
			
		||||
	once          sync.Once
 | 
			
		||||
	globalLimiter *ratelimit.Counter
 | 
			
		||||
	once sync.Once
 | 
			
		||||
 | 
			
		||||
	isTLS       bool
 | 
			
		||||
	hasDeadline bool
 | 
			
		||||
@@ -33,7 +31,7 @@ type ClientConn struct {
 | 
			
		||||
	BaseClientConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn {
 | 
			
		||||
func NewClientConn(conn net.Conn, isTLS bool, quickClose bool) net.Conn {
 | 
			
		||||
	if quickClose {
 | 
			
		||||
		// TCP
 | 
			
		||||
		tcpConn, ok := conn.(*net.TCPConn)
 | 
			
		||||
@@ -43,7 +41,7 @@ func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ra
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS, globalLimiter: globalLimiter}
 | 
			
		||||
	return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ClientConn) Read(b []byte) (n int, err error) {
 | 
			
		||||
@@ -96,13 +94,6 @@ func (this *ClientConn) Close() error {
 | 
			
		||||
 | 
			
		||||
	err := this.rawConn.Close()
 | 
			
		||||
 | 
			
		||||
	// 全局并发数限制
 | 
			
		||||
	this.once.Do(func() {
 | 
			
		||||
		if this.globalLimiter != nil {
 | 
			
		||||
			this.globalLimiter.Release()
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	// 单个服务并发数限制
 | 
			
		||||
	sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,16 +3,12 @@
 | 
			
		||||
package nodes
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/waf"
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var sharedConnectionsLimiter = ratelimit.NewCounter(nodeconfigs.DefaultTCPMaxConnections)
 | 
			
		||||
 | 
			
		||||
// ClientListener 客户端网络监听
 | 
			
		||||
type ClientListener struct {
 | 
			
		||||
	rawListener net.Listener
 | 
			
		||||
@@ -36,13 +32,8 @@ func (this *ClientListener) IsTLS() bool {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ClientListener) Accept() (net.Conn, error) {
 | 
			
		||||
	// 限制并发连接数
 | 
			
		||||
	var limiter = sharedConnectionsLimiter
 | 
			
		||||
	limiter.Ack()
 | 
			
		||||
 | 
			
		||||
	conn, err := this.rawListener.Accept()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		limiter.Release()
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -60,12 +51,11 @@ func (this *ClientListener) Accept() (net.Conn, error) {
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			_ = conn.Close()
 | 
			
		||||
			limiter.Release()
 | 
			
		||||
			return this.Accept()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return NewClientConn(conn, this.isTLS, this.quickClose, limiter), nil
 | 
			
		||||
	return NewClientConn(conn, this.isTLS, this.quickClose), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *ClientListener) Close() error {
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,7 @@ import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/caches"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/configs"
 | 
			
		||||
	teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
 | 
			
		||||
@@ -15,7 +16,6 @@ import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/goman"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/metrics"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/rpc"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/stats"
 | 
			
		||||
@@ -368,6 +368,38 @@ func (this *Node) loop() error {
 | 
			
		||||
			}
 | 
			
		||||
			sharedNodeConfig.ParentNodes = parentNodes
 | 
			
		||||
 | 
			
		||||
			// 修改为已同步
 | 
			
		||||
			_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
 | 
			
		||||
				NodeTaskId: task.Id,
 | 
			
		||||
				IsOk:       true,
 | 
			
		||||
				Error:      "",
 | 
			
		||||
			})
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		case "ddosProtectionChanged":
 | 
			
		||||
			resp, err := rpcClient.NodeRPC().FindNodeDDoSProtection(nodeCtx, &pb.FindNodeDDoSProtectionRequest{})
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if len(resp.DdosProtectionJSON) == 0 {
 | 
			
		||||
				if sharedNodeConfig != nil {
 | 
			
		||||
					sharedNodeConfig.DDOSProtection = nil
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
 | 
			
		||||
				err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return errors.New("decode DDoS protection config failed: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					// 不阻塞
 | 
			
		||||
					remotelogs.Error("NODE", "apply DDoS protection failed: "+err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 修改为已同步
 | 
			
		||||
			_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
 | 
			
		||||
				NodeTaskId: task.Id,
 | 
			
		||||
@@ -730,7 +762,6 @@ func (this *Node) listenSock() error {
 | 
			
		||||
						"ipConns":     ipConns,
 | 
			
		||||
						"serverConns": serverConns,
 | 
			
		||||
						"total":       sharedListenerManager.TotalActiveConnections(),
 | 
			
		||||
						"limiter":     sharedConnectionsLimiter.Len(),
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
			case "dropIP":
 | 
			
		||||
@@ -854,17 +885,6 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig) {
 | 
			
		||||
		this.maxThreads = config.MaxThreads
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// max tcp connections
 | 
			
		||||
	if config.TCPMaxConnections <= 0 {
 | 
			
		||||
		config.TCPMaxConnections = nodeconfigs.DefaultTCPMaxConnections
 | 
			
		||||
	}
 | 
			
		||||
	if config.TCPMaxConnections != sharedConnectionsLimiter.Count() {
 | 
			
		||||
		remotelogs.Println("NODE", "[TCP]changed tcp max connections to '"+types.String(config.TCPMaxConnections)+"'")
 | 
			
		||||
 | 
			
		||||
		sharedConnectionsLimiter.Close()
 | 
			
		||||
		sharedConnectionsLimiter = ratelimit.NewCounter(config.TCPMaxConnections)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// timezone
 | 
			
		||||
	var timeZone = config.TimeZone
 | 
			
		||||
	if len(timeZone) == 0 {
 | 
			
		||||
 
 | 
			
		||||
@@ -5,9 +5,12 @@ import (
 | 
			
		||||
	"github.com/cespare/xxhash"
 | 
			
		||||
	"math"
 | 
			
		||||
	"net"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ipv4Reg = regexp.MustCompile(`\d+\.`)
 | 
			
		||||
 | 
			
		||||
// IP2Long 将IP转换为整型
 | 
			
		||||
// 注意IPv6没有顺序
 | 
			
		||||
func IP2Long(ip string) uint64 {
 | 
			
		||||
@@ -54,3 +57,24 @@ func IsLocalIP(ipString string) bool {
 | 
			
		||||
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsIPv4 是否为IPv4
 | 
			
		||||
func IsIPv4(ip string) bool {
 | 
			
		||||
	var data = net.ParseIP(ip)
 | 
			
		||||
	if data == nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(ip, ":") {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return data.To4() != nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsIPv6 是否为IPv6
 | 
			
		||||
func IsIPv6(ip string) bool {
 | 
			
		||||
	var data = net.ParseIP(ip)
 | 
			
		||||
	if data == nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return !IsIPv4(ip)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -26,3 +26,26 @@ func TestIsLocalIP(t *testing.T) {
 | 
			
		||||
	a.IsFalse(IsLocalIP("::1:2:3"))
 | 
			
		||||
	a.IsFalse(IsLocalIP("8.8.8.8"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIsIPv4(t *testing.T) {
 | 
			
		||||
	var a = assert.NewAssertion(t)
 | 
			
		||||
	a.IsTrue(IsIPv4("192.168.1.1"))
 | 
			
		||||
	a.IsTrue(IsIPv4("0.0.0.0"))
 | 
			
		||||
	a.IsFalse(IsIPv4("192.168.1.256"))
 | 
			
		||||
	a.IsFalse(IsIPv4("192.168.1"))
 | 
			
		||||
	a.IsFalse(IsIPv4("::1"))
 | 
			
		||||
	a.IsFalse(IsIPv4("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
 | 
			
		||||
	a.IsFalse(IsIPv4("::ffff:192.168.0.1"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIsIPv6(t *testing.T) {
 | 
			
		||||
	var a = assert.NewAssertion(t)
 | 
			
		||||
	a.IsFalse(IsIPv6("192.168.1.1"))
 | 
			
		||||
	a.IsFloat32(IsIPv6("0.0.0.0"))
 | 
			
		||||
	a.IsFalse(IsIPv6("192.168.1.256"))
 | 
			
		||||
	a.IsFalse(IsIPv6("192.168.1"))
 | 
			
		||||
	a.IsTrue(IsIPv6("::1"))
 | 
			
		||||
	a.IsTrue(IsIPv6("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
 | 
			
		||||
	a.IsTrue(IsIPv4("::ffff:192.168.0.1"))
 | 
			
		||||
	a.IsTrue(IsIPv6("::ffff:192.168.0.1"))
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package utils
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"unsafe"
 | 
			
		||||
)
 | 
			
		||||
@@ -36,6 +37,22 @@ func FormatAddressList(addrList []string) []string {
 | 
			
		||||
	return result
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ToValidUTF8string 去除字符串中的非UTF-8字符
 | 
			
		||||
func ToValidUTF8string(v string) string {
 | 
			
		||||
	return strings.ToValidUTF8(v, "")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ContainsSameStrings 检查两个字符串slice内容是否一致
 | 
			
		||||
func ContainsSameStrings(s1 []string, s2 []string) bool {
 | 
			
		||||
	if len(s1) != len(s2) {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	sort.Strings(s1)
 | 
			
		||||
	sort.Strings(s2)
 | 
			
		||||
	for index, v1 := range s1 {
 | 
			
		||||
		if v1 != s2[index] {
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,56 +1,67 @@
 | 
			
		||||
package utils
 | 
			
		||||
package utils_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
			
		||||
	"github.com/iwind/TeaGo/assert"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestBytesToString(t *testing.T) {
 | 
			
		||||
	t.Log(UnsafeBytesToString([]byte("Hello,World")))
 | 
			
		||||
	t.Log(utils.UnsafeBytesToString([]byte("Hello,World")))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStringToBytes(t *testing.T) {
 | 
			
		||||
	t.Log(string(UnsafeStringToBytes("Hello,World")))
 | 
			
		||||
	t.Log(string(utils.UnsafeStringToBytes("Hello,World")))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkBytesToString(b *testing.B) {
 | 
			
		||||
	data := []byte("Hello,World")
 | 
			
		||||
	var data = []byte("Hello,World")
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		_ = UnsafeBytesToString(data)
 | 
			
		||||
		_ = utils.UnsafeBytesToString(data)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkBytesToString2(b *testing.B) {
 | 
			
		||||
	data := []byte("Hello,World")
 | 
			
		||||
	var data = []byte("Hello,World")
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		_ = string(data)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkStringToBytes(b *testing.B) {
 | 
			
		||||
	s := strings.Repeat("Hello,World", 1024)
 | 
			
		||||
	var s = strings.Repeat("Hello,World", 1024)
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		_ = UnsafeStringToBytes(s)
 | 
			
		||||
		_ = utils.UnsafeStringToBytes(s)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkStringToBytes2(b *testing.B) {
 | 
			
		||||
	s := strings.Repeat("Hello,World", 1024)
 | 
			
		||||
	var s = strings.Repeat("Hello,World", 1024)
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		_ = []byte(s)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFormatAddress(t *testing.T) {
 | 
			
		||||
	t.Log(FormatAddress("127.0.0.1:1234"))
 | 
			
		||||
	t.Log(FormatAddress("127.0.0.1 : 1234"))
 | 
			
		||||
	t.Log(FormatAddress("127.0.0.1:1234"))
 | 
			
		||||
	t.Log(utils.FormatAddress("127.0.0.1:1234"))
 | 
			
		||||
	t.Log(utils.FormatAddress("127.0.0.1 : 1234"))
 | 
			
		||||
	t.Log(utils.FormatAddress("127.0.0.1:1234"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFormatAddressList(t *testing.T) {
 | 
			
		||||
	t.Log(FormatAddressList([]string{
 | 
			
		||||
	t.Log(utils.FormatAddressList([]string{
 | 
			
		||||
		"127.0.0.1:1234",
 | 
			
		||||
		"127.0.0.1 : 1234",
 | 
			
		||||
		"127.0.0.1:1234",
 | 
			
		||||
	}))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestContainsSameStrings(t *testing.T) {
 | 
			
		||||
	var a = assert.NewAssertion(t)
 | 
			
		||||
	a.IsFalse(utils.ContainsSameStrings([]string{"a"}, []string{"b"}))
 | 
			
		||||
	a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b"}))
 | 
			
		||||
	a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b", "c"}))
 | 
			
		||||
	a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b"}))
 | 
			
		||||
	a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b", "a"}))
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user