diff --git a/internal/events/events.go b/internal/events/events.go index 0257ce3..167b7a1 100644 --- a/internal/events/events.go +++ b/internal/events/events.go @@ -3,9 +3,10 @@ package events type Event = string const ( - EventStart Event = "start" // start loading - EventLoaded Event = "loaded" // first load - EventQuit Event = "quit" // quit node gracefully - EventReload Event = "reload" // reload config - EventTerminated Event = "terminated" // process terminated + EventStart Event = "start" // start loading + EventLoaded Event = "loaded" // first load + EventQuit Event = "quit" // quit node gracefully + EventReload Event = "reload" // reload config + EventTerminated Event = "terminated" // process terminated + EventNFTablesReady Event = "nftablesReady" // nftables ready ) diff --git a/internal/firewalls/.gitignore b/internal/firewalls/.gitignore deleted file mode 100644 index c1b01e3..0000000 --- a/internal/firewalls/.gitignore +++ /dev/null @@ -1 +0,0 @@ -firewall_nftables_test.go \ No newline at end of file diff --git a/internal/firewalls/ddos_protection.go b/internal/firewalls/ddos_protection.go new file mode 100644 index 0000000..dd1e5c1 --- /dev/null +++ b/internal/firewalls/ddos_protection.go @@ -0,0 +1,494 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux +// +build linux + +package firewalls + +import ( + "bytes" + "encoding/json" + "errors" + "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs" + "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables" + "github.com/TeaOSLab/EdgeNode/internal/remotelogs" + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/TeaOSLab/EdgeNode/internal/zero" + "github.com/iwind/TeaGo/lists" + "github.com/iwind/TeaGo/types" + "net" + "os/exec" + "strings" +) + +var SharedDDoSProtectionManager = NewDDoSProtectionManager() + +func init() { + events.On(events.EventReload, func() { + if nftablesInstance == nil { + return + } + + nodeConfig, _ := nodeconfigs.SharedNodeConfig() + if nodeConfig != nil { + err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection) + if err != nil { + remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error()) + } + } + }) + + events.On(events.EventNFTablesReady, func() { + nodeConfig, _ := nodeconfigs.SharedNodeConfig() + if nodeConfig != nil { + err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection) + if err != nil { + remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error()) + } + } + }) +} + +// DDoSProtectionManager DDoS防护 +type DDoSProtectionManager struct { + nftPath string + + lastAllowIPList []string + lastConfig []byte +} + +// NewDDoSProtectionManager 获取新对象 +func NewDDoSProtectionManager() *DDoSProtectionManager { + nftPath, _ := exec.LookPath("nft") + + return &DDoSProtectionManager{ + nftPath: nftPath, + } +} + +// Apply 应用配置 +func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error { + // 同集群节点IP白名单 + var allowIPListChanged = false + nodeConfig, _ := nodeconfigs.SharedNodeConfig() + if nodeConfig != nil { + var allowIPList = nodeConfig.AllowedIPs + if !utils.ContainsSameStrings(allowIPList, this.lastAllowIPList) { + allowIPListChanged = true + this.lastAllowIPList = allowIPList + } + } + + // 对比配置 + configJSON, err := json.Marshal(config) + if err != nil { + return errors.New("encode config to json failed: " + err.Error()) + } + if !allowIPListChanged && bytes.Equal(this.lastConfig, configJSON) { + return nil + } + remotelogs.Println("FIREWALL", "change DDoS protection config") + + if len(this.nftPath) == 0 { + return errors.New("can not find nft command") + } + + if nftablesInstance == nil { + return errors.New("nftables instance should not be nil") + } + + if config == nil { + // TCP + err := this.removeTCPRules() + if err != nil { + return err + } + + // TODO other protocols + + return nil + } + + // TCP + if config.TCP == nil { + err := this.removeTCPRules() + if err != nil { + return err + } + } else { + // allow ip list + var allowIPList = []string{} + for _, ipConfig := range config.TCP.AllowIPList { + allowIPList = append(allowIPList, ipConfig.IP) + } + for _, ip := range this.lastAllowIPList { + if !lists.ContainsString(allowIPList, ip) { + allowIPList = append(allowIPList, ip) + } + } + err = this.updateAllowIPList(allowIPList) + if err != nil { + return err + } + + // tcp + if config.TCP.IsOn { + err := this.addTCPRules(config.TCP) + if err != nil { + return err + } + } else { + err := this.removeTCPRules() + if err != nil { + return err + } + } + } + + this.lastConfig = configJSON + + return nil +} + +// 添加TCP规则 +func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig) error { + var ports = []int32{} + for _, portConfig := range tcpConfig.Ports { + if !lists.ContainsInt32(ports, portConfig.Port) { + ports = append(ports, portConfig.Port) + } + } + if len(ports) == 0 { + ports = []int32{80, 443} + } + + for _, filter := range nftablesFilters { + chain, oldRules, err := this.getRules(filter) + if err != nil { + return errors.New("get old rules failed: " + err.Error()) + } + + var protocol = filter.protocol() + + // max connections + var maxConnections = tcpConfig.MaxConnections + if maxConnections <= 0 { + maxConnections = nodeconfigs.DefaultTCPMaxConnections + if maxConnections <= 0 { + maxConnections = 100000 + } + } + + // max connections per ip + var maxConnectionsPerIP = tcpConfig.MaxConnectionsPerIP + if maxConnectionsPerIP <= 0 { + maxConnectionsPerIP = nodeconfigs.DefaultTCPMaxConnectionsPerIP + if maxConnectionsPerIP <= 0 { + maxConnectionsPerIP = 100000 + } + } + + // new connections rate + var newConnectionsRate = tcpConfig.NewConnectionsRate + if newConnectionsRate <= 0 { + newConnectionsRate = nodeconfigs.DefaultTCPNewConnectionsRate + if newConnectionsRate <= 0 { + newConnectionsRate = 100000 + } + } + + // 检查是否有变化 + var hasChanges = false + for _, port := range ports { + if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}) { + hasChanges = true + break + } + if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}) { + hasChanges = true + break + } + if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}) { + hasChanges = true + break + } + } + + if !hasChanges { + // 检查是否有多余的端口 + var oldPorts = this.getTCPPorts(oldRules) + if !this.eqPorts(ports, oldPorts) { + hasChanges = true + } + } + + if !hasChanges { + return nil + } + + // 先清空所有相关规则 + err = this.removeOldTCPRules(chain, oldRules) + if err != nil { + return errors.New("delete old rules failed: " + err.Error()) + } + + // 添加新规则 + for _, port := range ports { + if maxConnections > 0 { + var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)})) + err := cmd.Run() + if err != nil { + return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error()) + } + } + + if maxConnectionsPerIP > 0 { + var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)})) + var stderr = &bytes.Buffer{} + cmd.Stderr = stderr + err := cmd.Run() + if err != nil { + return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")") + } + } + + if newConnectionsRate > 0 { + // TODO 思考是否有惩罚机制 + var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsRate)+"/minute burst "+types.String(newConnectionsRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)})) + var stderr = &bytes.Buffer{} + cmd.Stderr = stderr + err := cmd.Run() + if err != nil { + return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")") + } + } + } + } + + return nil +} + +// 删除TCP规则 +func (this *DDoSProtectionManager) removeTCPRules() error { + for _, filter := range nftablesFilters { + chain, rules, err := this.getRules(filter) + + // TCP + err = this.removeOldTCPRules(chain, rules) + if err != nil { + return err + } + } + + return nil +} + +// 组合user data +// 数据中不能包含字母、数字、下划线以外的数据 +func (this *DDoSProtectionManager) encodeUserData(attrs []string) string { + if attrs == nil { + return "" + } + + return "ZZ" + strings.Join(attrs, "_") + "ZZ" +} + +// 解码user data +func (this *DDoSProtectionManager) decodeUserData(data []byte) []string { + if len(data) == 0 { + return nil + } + + var dataCopy = make([]byte, len(data)) + copy(dataCopy, data) + + var separatorLen = 2 + var index1 = bytes.Index(dataCopy, []byte{'Z', 'Z'}) + if index1 < 0 { + return nil + } + + dataCopy = dataCopy[index1+separatorLen:] + var index2 = bytes.LastIndex(dataCopy, []byte{'Z', 'Z'}) + if index2 < 0 { + return nil + } + + var s = string(dataCopy[:index2]) + var pieces = strings.Split(s, "_") + for index, piece := range pieces { + pieces[index] = strings.TrimSpace(piece) + } + return pieces +} + +// 清除规则 +func (this *DDoSProtectionManager) removeOldTCPRules(chain *nftables.Chain, rules []*nftables.Rule) error { + for _, rule := range rules { + var pieces = this.decodeUserData(rule.UserData()) + if len(pieces) != 4 { + continue + } + if pieces[0] != "tcp" { + continue + } + switch pieces[2] { + case "maxConnections", "maxConnectionsPerIP", "newConnectionsRate": + err := chain.DeleteRule(rule) + if err != nil { + return err + } + } + } + + return nil +} + +// 根据参数检查规则是否存在 +func (this *DDoSProtectionManager) existsRule(rules []*nftables.Rule, attrs []string) (exists bool) { + if len(attrs) == 0 { + return false + } + for _, oldRule := range rules { + var pieces = this.decodeUserData(oldRule.UserData()) + if len(attrs) != len(pieces) { + continue + } + var isSame = true + for index, piece := range pieces { + if strings.TrimSpace(piece) != attrs[index] { + isSame = false + break + } + } + if isSame { + return true + } + } + return false +} + +// 获取规则中的端口号 +func (this *DDoSProtectionManager) getTCPPorts(rules []*nftables.Rule) []int32 { + var ports = []int32{} + for _, rule := range rules { + var pieces = this.decodeUserData(rule.UserData()) + if len(pieces) != 4 { + continue + } + if pieces[0] != "tcp" { + continue + } + var port = types.Int32(pieces[1]) + if port > 0 && !lists.ContainsInt32(ports, port) { + ports = append(ports, port) + } + } + return ports +} + +// 检查端口是否一样 +func (this *DDoSProtectionManager) eqPorts(ports1 []int32, ports2 []int32) bool { + if len(ports1) != len(ports2) { + return false + } + + var portMap = map[int32]bool{} + for _, port := range ports2 { + portMap[port] = true + } + + for _, port := range ports1 { + _, ok := portMap[port] + if !ok { + return false + } + } + return true +} + +// 查找Table +func (this *DDoSProtectionManager) getTable(filter *nftablesTableDefinition) (*nftables.Table, error) { + var family nftables.TableFamily + if filter.IsIPv4 { + family = nftables.TableFamilyIPv4 + } else if filter.IsIPv6 { + family = nftables.TableFamilyIPv6 + } else { + return nil, errors.New("table '" + filter.Name + "' should be IPv4 or IPv6") + } + return nftablesInstance.conn.GetTable(filter.Name, family) +} + +// 查找所有规则 +func (this *DDoSProtectionManager) getRules(filter *nftablesTableDefinition) (*nftables.Chain, []*nftables.Rule, error) { + table, err := this.getTable(filter) + if err != nil { + return nil, nil, errors.New("get table failed: " + err.Error()) + } + chain, err := table.GetChain(nftablesChainName) + if err != nil { + return nil, nil, errors.New("get chain failed: " + err.Error()) + } + rules, err := chain.GetRules() + return chain, rules, err +} + +// 更新白名单 +func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error { + if nftablesInstance == nil { + return nil + } + + var allMap = map[string]zero.Zero{} + for _, ip := range allIPList { + allMap[ip] = zero.New() + } + + for _, set := range []*nftables.Set{nftablesInstance.allowIPv4Set, nftablesInstance.allowIPv6Set} { + var isIPv4 = set == nftablesInstance.allowIPv4Set + var isIPv6 = !isIPv4 + + // 现有的 + oldList, err := set.GetIPElements() + if err != nil { + return err + } + var oldMap = map[string]zero.Zero{} // ip=> zero + for _, ip := range oldList { + oldMap[ip] = zero.New() + + if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) { + _, ok := allMap[ip] + if !ok { + // 不存在则删除 + err = set.DeleteIPElement(ip) + if err != nil { + return errors.New("delete ip element '" + ip + "' failed: " + err.Error()) + } + } + } + } + + // 新增的 + for _, ip := range allIPList { + var ipObj = net.ParseIP(ip) + if ipObj == nil { + continue + } + if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) { + _, ok := oldMap[ip] + if !ok { + // 不存在则添加 + err = set.AddIPElement(ip, nil) + if err != nil { + return errors.New("add ip '" + ip + "' failed: " + err.Error()) + } + } + } + } + } + + return nil +} diff --git a/internal/firewalls/ddos_protection_others.go b/internal/firewalls/ddos_protection_others.go new file mode 100644 index 0000000..8f3afb5 --- /dev/null +++ b/internal/firewalls/ddos_protection_others.go @@ -0,0 +1,23 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build !linux +// +build !linux + +package firewalls + +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs" +) + +var SharedDDoSProtectionManager = NewDDoSProtectionManager() + +type DDoSProtectionManager struct { + nftPath string +} + +func NewDDoSProtectionManager() *DDoSProtectionManager { + return &DDoSProtectionManager{} +} + +func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error { + return nil +} diff --git a/internal/firewalls/firewall.go b/internal/firewalls/firewall.go index 96aa19a..ea56c9e 100644 --- a/internal/firewalls/firewall.go +++ b/internal/firewalls/firewall.go @@ -1,6 +1,4 @@ // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. -//go:build !plus -// +build !plus package firewalls @@ -8,9 +6,11 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/events" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "runtime" + "sync" ) var currentFirewall FirewallInterface +var firewallLocker = &sync.Mutex{} // 初始化 func init() { @@ -24,10 +24,28 @@ func init() { // Firewall 查找当前系统中最适合的防火墙 func Firewall() FirewallInterface { + firewallLocker.Lock() + defer firewallLocker.Unlock() if currentFirewall != nil { return currentFirewall } + // nftables + if runtime.GOOS == "linux" { + nftables, err := NewNFTablesFirewall() + if err != nil { + remotelogs.Warn("FIREWALL", "'nftables' should be installed on the system to enhance security (init failed: "+err.Error()+")") + } else { + if nftables.IsReady() { + currentFirewall = nftables + events.Notify(events.EventNFTablesReady) + return nftables + } else { + remotelogs.Warn("FIREWALL", "'nftables' should be enabled on the system to enhance security") + } + } + } + // firewalld if runtime.GOOS == "linux" { var firewalld = NewFirewalld() diff --git a/internal/firewalls/firewall_nftables.go b/internal/firewalls/firewall_nftables.go new file mode 100644 index 0000000..3a959eb --- /dev/null +++ b/internal/firewalls/firewall_nftables.go @@ -0,0 +1,373 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux +// +build linux + +package firewalls + +import ( + "errors" + "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables" + "github.com/TeaOSLab/EdgeNode/internal/remotelogs" + "github.com/iwind/TeaGo/types" + "net" + "os/exec" + "runtime" + "strings" + "time" +) + +// check nft status, if being enabled we load it automatically +func init() { + if runtime.GOOS == "linux" { + var ticker = time.NewTicker(3 * time.Minute) + go func() { + for range ticker.C { + // if already ready, we break + if nftablesIsReady { + ticker.Stop() + break + } + _, err := exec.LookPath("nft") + if err == nil { + nftablesFirewall, err := NewNFTablesFirewall() + if err != nil { + continue + } + currentFirewall = nftablesFirewall + remotelogs.Println("FIREWALL", "nftables is ready") + + // fire event + if nftablesFirewall.IsReady() { + events.Notify(events.EventNFTablesReady) + } + + ticker.Stop() + break + } + } + }() + } +} + +var nftablesInstance *NFTablesFirewall +var nftablesIsReady = false +var nftablesFilters = []*nftablesTableDefinition{ + // we shorten the name for table name length restriction + {Name: "edge_dft_v4", IsIPv4: true}, + {Name: "edge_dft_v6", IsIPv6: true}, +} +var nftablesChainName = "input" + +type nftablesTableDefinition struct { + Name string + IsIPv4 bool + IsIPv6 bool +} + +func (this *nftablesTableDefinition) protocol() string { + if this.IsIPv6 { + return "ip6" + } + return "ip" +} + +func NewNFTablesFirewall() (*NFTablesFirewall, error) { + var firewall = &NFTablesFirewall{ + conn: nftables.NewConn(), + } + err := firewall.init() + if err != nil { + return nil, err + } + + return firewall, nil +} + +type NFTablesFirewall struct { + conn *nftables.Conn + isReady bool + + allowIPv4Set *nftables.Set + allowIPv6Set *nftables.Set + + denyIPv4Set *nftables.Set + denyIPv6Set *nftables.Set + + firewalld *Firewalld +} + +func (this *NFTablesFirewall) init() error { + // check nft + _, err := exec.LookPath("nft") + if err != nil { + return errors.New("nft not found") + } + + // table + for _, tableDef := range nftablesFilters { + var family nftables.TableFamily + if tableDef.IsIPv4 { + family = nftables.TableFamilyIPv4 + } else if tableDef.IsIPv6 { + family = nftables.TableFamilyIPv6 + } else { + return errors.New("invalid table family: " + types.String(tableDef)) + } + table, err := this.conn.GetTable(tableDef.Name, family) + if err != nil { + if nftables.IsNotFound(err) { + if tableDef.IsIPv4 { + table, err = this.conn.AddIPv4Table(tableDef.Name) + } else if tableDef.IsIPv6 { + table, err = this.conn.AddIPv6Table(tableDef.Name) + } + if err != nil { + return errors.New("create table '" + tableDef.Name + "' failed: " + err.Error()) + } + } else { + return errors.New("get table '" + tableDef.Name + "' failed: " + err.Error()) + } + } + if table == nil { + return errors.New("can not create table '" + tableDef.Name + "'") + } + + // chain + var chainName = nftablesChainName + chain, err := table.GetChain(chainName) + if err != nil { + if nftables.IsNotFound(err) { + chain, err = table.AddAcceptChain(chainName) + if err != nil { + return errors.New("create chain '" + chainName + "' failed: " + err.Error()) + } + } else { + return errors.New("get chain '" + chainName + "' failed: " + err.Error()) + } + } + if chain == nil { + return errors.New("can not create chain '" + chainName + "'") + } + + // allow lo + var loRuleName = []byte("lo") + _, err = chain.GetRuleWithUserData(loRuleName) + if err != nil { + if nftables.IsNotFound(err) { + _, err = chain.AddAcceptInterfaceRule("lo", loRuleName) + } + if err != nil { + return errors.New("add 'lo' rule failed: " + err.Error()) + } + } + + // allow set + // "allow" should be always first + for _, setAction := range []string{"allow", "deny"} { + var setName = setAction + "_set" + + set, err := table.GetSet(setName) + if err != nil { + if nftables.IsNotFound(err) { + var keyType nftables.SetDataType + if tableDef.IsIPv4 { + keyType = nftables.TypeIPAddr + } else if tableDef.IsIPv6 { + keyType = nftables.TypeIP6Addr + } + set, err = table.AddSet(setName, &nftables.SetOptions{ + KeyType: keyType, + HasTimeout: true, + }) + if err != nil { + return errors.New("create set '" + setName + "' failed: " + err.Error()) + } + } else { + return errors.New("get set '" + setName + "' failed: " + err.Error()) + } + } + if set == nil { + return errors.New("can not create set '" + setName + "'") + } + if tableDef.IsIPv4 { + if setAction == "allow" { + this.allowIPv4Set = set + } else { + this.denyIPv4Set = set + } + } else if tableDef.IsIPv6 { + if setAction == "allow" { + this.allowIPv6Set = set + } else { + this.denyIPv6Set = set + } + } + + // rule + var ruleName = []byte(setAction) + rule, err := chain.GetRuleWithUserData(ruleName) + if err != nil { + if nftables.IsNotFound(err) { + if tableDef.IsIPv4 { + if setAction == "allow" { + rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName) + } else { + rule, err = chain.AddDropIPv4SetRule(setName, ruleName) + } + } else if tableDef.IsIPv6 { + if setAction == "allow" { + rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName) + } else { + rule, err = chain.AddDropIPv6SetRule(setName, ruleName) + } + } + if err != nil { + return errors.New("add rule failed: " + err.Error()) + } + } else { + return errors.New("get rule failed: " + err.Error()) + } + } + if rule == nil { + return errors.New("can not create rule '" + string(ruleName) + "'") + } + } + } + + this.isReady = true + nftablesIsReady = true + nftablesInstance = this + + // load firewalld + var firewalld = NewFirewalld() + if firewalld.IsReady() { + this.firewalld = firewalld + } + + return nil +} + +// Name 名称 +func (this *NFTablesFirewall) Name() string { + return "nftables" +} + +// IsReady 是否已准备被调用 +func (this *NFTablesFirewall) IsReady() bool { + return this.isReady +} + +// IsMock 是否为模拟 +func (this *NFTablesFirewall) IsMock() bool { + return false +} + +// AllowPort 允许端口 +func (this *NFTablesFirewall) AllowPort(port int, protocol string) error { + if this.firewalld != nil { + return this.firewalld.AllowPort(port, protocol) + } + return nil +} + +// RemovePort 删除端口 +func (this *NFTablesFirewall) RemovePort(port int, protocol string) error { + if this.firewalld != nil { + return this.firewalld.RemovePort(port, protocol) + } + return nil +} + +// AllowSourceIP Allow把IP加入白名单 +func (this *NFTablesFirewall) AllowSourceIP(ip string) error { + var data = net.ParseIP(ip) + if data == nil { + return errors.New("invalid ip '" + ip + "'") + } + + if strings.Contains(ip, ":") { // ipv6 + if this.allowIPv6Set == nil { + return errors.New("ipv6 ip set is nil") + } + return this.allowIPv6Set.AddElement(data.To16(), nil) + } + + // ipv4 + if this.allowIPv4Set == nil { + return errors.New("ipv4 ip set is nil") + } + return this.allowIPv4Set.AddElement(data.To4(), nil) +} + +// RejectSourceIP 拒绝某个源IP连接 +// we did not create set for drop ip, so we reuse DropSourceIP() method here +func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error { + return this.DropSourceIP(ip, timeoutSeconds) +} + +// DropSourceIP 丢弃某个源IP数据 +func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error { + var data = net.ParseIP(ip) + if data == nil { + return errors.New("invalid ip '" + ip + "'") + } + + if strings.Contains(ip, ":") { // ipv6 + if this.denyIPv6Set == nil { + return errors.New("ipv6 ip set is nil") + } + return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{ + Timeout: time.Duration(timeoutSeconds) * time.Second, + }) + } + + // ipv4 + if this.denyIPv4Set == nil { + return errors.New("ipv4 ip set is nil") + } + return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{ + Timeout: time.Duration(timeoutSeconds) * time.Second, + }) +} + +// RemoveSourceIP 删除某个源IP +func (this *NFTablesFirewall) RemoveSourceIP(ip string) error { + var data = net.ParseIP(ip) + if data == nil { + return errors.New("invalid ip '" + ip + "'") + } + + if strings.Contains(ip, ":") { // ipv6 + if this.denyIPv6Set != nil { + err := this.denyIPv6Set.DeleteElement(data.To16()) + if err != nil { + return err + } + } + + if this.allowIPv6Set != nil { + err := this.allowIPv6Set.DeleteElement(data.To16()) + if err != nil { + return err + } + } + + return nil + } + + // ipv4 + if this.allowIPv4Set != nil { + err := this.denyIPv4Set.DeleteElement(data.To4()) + if err != nil { + return err + } + + err = this.allowIPv4Set.DeleteElement(data.To4()) + if err != nil { + return err + } + } + + return nil +} diff --git a/internal/firewalls/firewall_nftables_others.go b/internal/firewalls/firewall_nftables_others.go new file mode 100644 index 0000000..b880be3 --- /dev/null +++ b/internal/firewalls/firewall_nftables_others.go @@ -0,0 +1,61 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build !linux +// +build !linux + +package firewalls + +import ( + "errors" +) + +func NewNFTablesFirewall() (*NFTablesFirewall, error) { + return nil, errors.New("not implemented") +} + +type NFTablesFirewall struct { +} + +// Name 名称 +func (this *NFTablesFirewall) Name() string { + return "nftables" +} + +// IsReady 是否已准备被调用 +func (this *NFTablesFirewall) IsReady() bool { + return false +} + +// IsMock 是否为模拟 +func (this *NFTablesFirewall) IsMock() bool { + return true +} + +// AllowPort 允许端口 +func (this *NFTablesFirewall) AllowPort(port int, protocol string) error { + return nil +} + +// RemovePort 删除端口 +func (this *NFTablesFirewall) RemovePort(port int, protocol string) error { + return nil +} + +// AllowSourceIP Allow把IP加入白名单 +func (this *NFTablesFirewall) AllowSourceIP(ip string) error { + return nil +} + +// RejectSourceIP 拒绝某个源IP连接 +func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error { + return nil +} + +// DropSourceIP 丢弃某个源IP数据 +func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error { + return nil +} + +// RemoveSourceIP 删除某个源IP +func (this *NFTablesFirewall) RemoveSourceIP(ip string) error { + return nil +} diff --git a/internal/firewalls/nftables/.gitignore b/internal/firewalls/nftables/.gitignore new file mode 100644 index 0000000..54071d2 --- /dev/null +++ b/internal/firewalls/nftables/.gitignore @@ -0,0 +1 @@ +build_remote.sh \ No newline at end of file diff --git a/internal/firewalls/nftables/chain.go b/internal/firewalls/nftables/chain.go new file mode 100644 index 0000000..005b585 --- /dev/null +++ b/internal/firewalls/nftables/chain.go @@ -0,0 +1,370 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux +// +build linux + +package nftables + +import ( + "bytes" + "errors" + nft "github.com/google/nftables" + "github.com/google/nftables/expr" +) + +const MaxChainNameLength = 31 + +type RuleOptions struct { + Exprs []expr.Any + UserData []byte +} + +// Chain chain object in table +type Chain struct { + conn *Conn + rawTable *nft.Table + rawChain *nft.Chain +} + +func NewChain(conn *Conn, rawTable *nft.Table, rawChain *nft.Chain) *Chain { + return &Chain{ + conn: conn, + rawTable: rawTable, + rawChain: rawChain, + } +} + +func (this *Chain) Raw() *nft.Chain { + return this.rawChain +} + +func (this *Chain) Name() string { + return this.rawChain.Name +} + +func (this *Chain) AddRule(options *RuleOptions) (*Rule, error) { + var rawRule = this.conn.Raw().AddRule(&nft.Rule{ + Table: this.rawTable, + Chain: this.rawChain, + Exprs: options.Exprs, + UserData: options.UserData, + }) + err := this.conn.Commit() + if err != nil { + return nil, err + } + return NewRule(rawRule), nil +} + +func (this *Chain) AddAcceptIPv4Rule(ip []byte, userData []byte) (*Rule, error) { + return this.AddRule(&RuleOptions{ + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ip, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + UserData: userData, + }) +} + +func (this *Chain) AddAcceptIPv6Rule(ip []byte, userData []byte) (*Rule, error) { + return this.AddRule(&RuleOptions{ + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 8, + Len: 16, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ip, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + UserData: userData, + }) +} + +func (this *Chain) AddDropIPv4Rule(ip []byte, userData []byte) (*Rule, error) { + return this.AddRule(&RuleOptions{ + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ip, + }, + &expr.Verdict{ + Kind: expr.VerdictDrop, + }, + }, + UserData: userData, + }) +} + +func (this *Chain) AddDropIPv6Rule(ip []byte, userData []byte) (*Rule, error) { + return this.AddRule(&RuleOptions{ + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 8, + Len: 16, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ip, + }, + &expr.Verdict{ + Kind: expr.VerdictDrop, + }, + }, + UserData: userData, + }) +} + +func (this *Chain) AddRejectIPv4Rule(ip []byte, userData []byte) (*Rule, error) { + return this.AddRule(&RuleOptions{ + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ip, + }, + &expr.Reject{}, + }, + UserData: userData, + }) +} + +func (this *Chain) AddRejectIPv6Rule(ip []byte, userData []byte) (*Rule, error) { + return this.AddRule(&RuleOptions{ + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 8, + Len: 16, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ip, + }, + &expr.Reject{}, + }, + UserData: userData, + }) +} + +func (this *Chain) AddAcceptIPv4SetRule(setName string, userData []byte) (*Rule, error) { + return this.AddRule(&RuleOptions{ + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: setName, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + UserData: userData, + }) +} + +func (this *Chain) AddAcceptIPv6SetRule(setName string, userData []byte) (*Rule, error) { + return this.AddRule(&RuleOptions{ + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 8, + Len: 16, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: setName, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + UserData: userData, + }) +} + +func (this *Chain) AddDropIPv4SetRule(setName string, userData []byte) (*Rule, error) { + return this.AddRule(&RuleOptions{ + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: setName, + }, + &expr.Verdict{ + Kind: expr.VerdictDrop, + }, + }, + UserData: userData, + }) +} + +func (this *Chain) AddDropIPv6SetRule(setName string, userData []byte) (*Rule, error) { + return this.AddRule(&RuleOptions{ + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 8, + Len: 16, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: setName, + }, + &expr.Verdict{ + Kind: expr.VerdictDrop, + }, + }, + UserData: userData, + }) +} + +func (this *Chain) AddRejectIPv4SetRule(setName string, userData []byte) (*Rule, error) { + return this.AddRule(&RuleOptions{ + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: setName, + }, + &expr.Reject{}, + }, + UserData: userData, + }) +} + +func (this *Chain) AddRejectIPv6SetRule(setName string, userData []byte) (*Rule, error) { + return this.AddRule(&RuleOptions{ + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 8, + Len: 16, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: setName, + }, + &expr.Reject{}, + }, + UserData: userData, + }) +} + +func (this *Chain) AddAcceptInterfaceRule(interfaceName string, userData []byte) (*Rule, error) { + if len(interfaceName) >= 16 { + return nil, errors.New("invalid interface name '" + interfaceName + "'") + } + var ifname = make([]byte, 16) + copy(ifname, interfaceName+"\x00") + + return this.AddRule(&RuleOptions{ + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + UserData: userData, + }) +} + +func (this *Chain) GetRuleWithUserData(userData []byte) (*Rule, error) { + rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain) + if err != nil { + return nil, err + } + for _, rawRule := range rawRules { + if bytes.Compare(rawRule.UserData, userData) == 0 { + return NewRule(rawRule), nil + } + } + return nil, ErrRuleNotFound +} + +func (this *Chain) GetRules() ([]*Rule, error) { + rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain) + if err != nil { + return nil, err + } + var result = []*Rule{} + for _, rawRule := range rawRules { + result = append(result, NewRule(rawRule)) + } + return result, nil +} + +func (this *Chain) DeleteRule(rule *Rule) error { + err := this.conn.Raw().DelRule(rule.Raw()) + if err != nil { + return err + } + return this.conn.Commit() +} + +func (this *Chain) Flush() error { + this.conn.Raw().FlushChain(this.rawChain) + return this.conn.Commit() +} diff --git a/internal/firewalls/nftables/chain_policy.go b/internal/firewalls/nftables/chain_policy.go new file mode 100644 index 0000000..677c573 --- /dev/null +++ b/internal/firewalls/nftables/chain_policy.go @@ -0,0 +1,13 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package nftables + +import nft "github.com/google/nftables" + +type ChainPolicy = nft.ChainPolicy + +// Possible ChainPolicy values. +const ( + ChainPolicyDrop = nft.ChainPolicyDrop + ChainPolicyAccept = nft.ChainPolicyAccept +) diff --git a/internal/firewalls/nftables/chain_test.go b/internal/firewalls/nftables/chain_test.go new file mode 100644 index 0000000..b75341f --- /dev/null +++ b/internal/firewalls/nftables/chain_test.go @@ -0,0 +1,130 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux +// +build linux + +package nftables_test + +import ( + "github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables" + "net" + "testing" +) + +func getIPv4Chain(t *testing.T) *nftables.Chain { + var conn = nftables.NewConn() + table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4) + if err != nil { + if err == nftables.ErrTableNotFound { + table, err = conn.AddIPv4Table("test_ipv4") + if err != nil { + t.Fatal(err) + } + } else { + t.Fatal(err) + } + } + + chain, err := table.GetChain("test_chain") + if err != nil { + if err == nftables.ErrChainNotFound { + chain, err = table.AddAcceptChain("test_chain") + } + } + + if err != nil { + t.Fatal(err) + } + + return chain +} + +func TestChain_AddAcceptIPRule(t *testing.T) { + var chain = getIPv4Chain(t) + _, err := chain.AddAcceptIPv4Rule(net.ParseIP("192.168.2.40").To4(), nil) + if err != nil { + t.Fatal(err) + } +} + +func TestChain_AddDropIPRule(t *testing.T) { + var chain = getIPv4Chain(t) + _, err := chain.AddDropIPv4Rule(net.ParseIP("192.168.2.31").To4(), nil) + if err != nil { + t.Fatal(err) + } +} + +func TestChain_AddAcceptSetRule(t *testing.T) { + var chain = getIPv4Chain(t) + _, err := chain.AddAcceptIPv4SetRule("ipv4_black_set", nil) + if err != nil { + t.Fatal(err) + } +} + +func TestChain_AddDropSetRule(t *testing.T) { + var chain = getIPv4Chain(t) + _, err := chain.AddDropIPv4SetRule("ipv4_black_set", nil) + if err != nil { + t.Fatal(err) + } +} + +func TestChain_AddRejectSetRule(t *testing.T) { + var chain = getIPv4Chain(t) + _, err := chain.AddRejectIPv4SetRule("ipv4_black_set", nil) + if err != nil { + t.Fatal(err) + } +} + +func TestChain_GetRuleWithUserData(t *testing.T) { + var chain = getIPv4Chain(t) + rule, err := chain.GetRuleWithUserData([]byte("test")) + if err != nil { + if err == nftables.ErrRuleNotFound { + t.Log("rule not found") + return + } else { + t.Fatal(err) + } + } + t.Log("rule:", rule) +} + +func TestChain_GetRules(t *testing.T) { + var chain = getIPv4Chain(t) + rules, err := chain.GetRules() + if err != nil { + t.Fatal(err) + } + for _, rule := range rules { + t.Log("handle:", rule.Handle(), "set name:", rule.LookupSetName(), + "verdict:", rule.VerDict(), "user data:", string(rule.UserData())) + } +} + +func TestChain_DeleteRule(t *testing.T) { + var chain = getIPv4Chain(t) + rule, err := chain.GetRuleWithUserData([]byte("test")) + if err != nil { + if err == nftables.ErrRuleNotFound { + t.Log("rule not found") + return + } + t.Fatal(err) + } + err = chain.DeleteRule(rule) + if err != nil { + t.Fatal(err) + } +} + +func TestChain_Flush(t *testing.T) { + var chain = getIPv4Chain(t) + err := chain.Flush() + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} diff --git a/internal/firewalls/nftables/conn.go b/internal/firewalls/nftables/conn.go new file mode 100644 index 0000000..d859ee4 --- /dev/null +++ b/internal/firewalls/nftables/conn.go @@ -0,0 +1,84 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux +// +build linux + +package nftables + +import ( + "errors" + nft "github.com/google/nftables" + "github.com/iwind/TeaGo/types" +) + +const MaxTableNameLength = 27 + +type Conn struct { + rawConn *nft.Conn +} + +func NewConn() *Conn { + return &Conn{ + rawConn: &nft.Conn{}, + } +} + +func (this *Conn) Raw() *nft.Conn { + return this.rawConn +} + +func (this *Conn) GetTable(name string, family TableFamily) (*Table, error) { + rawTables, err := this.rawConn.ListTables() + if err != nil { + return nil, err + } + + for _, rawTable := range rawTables { + if rawTable.Name == name && rawTable.Family == family { + return NewTable(this, rawTable), nil + } + } + + return nil, ErrTableNotFound +} + +func (this *Conn) AddTable(name string, family TableFamily) (*Table, error) { + if len(name) > MaxTableNameLength { + return nil, errors.New("table name too long (max " + types.String(MaxTableNameLength) + ")") + } + + var rawTable = this.rawConn.AddTable(&nft.Table{ + Family: family, + Name: name, + }) + + err := this.Commit() + if err != nil { + return nil, err + } + + return NewTable(this, rawTable), nil +} + +func (this *Conn) AddIPv4Table(name string) (*Table, error) { + return this.AddTable(name, TableFamilyIPv4) +} + +func (this *Conn) AddIPv6Table(name string) (*Table, error) { + return this.AddTable(name, TableFamilyIPv6) +} + +func (this *Conn) DeleteTable(name string, family TableFamily) error { + table, err := this.GetTable(name, family) + if err != nil { + if err == ErrTableNotFound { + return nil + } + return err + } + this.rawConn.DelTable(table.Raw()) + return this.Commit() +} + +func (this *Conn) Commit() error { + return this.rawConn.Flush() +} diff --git a/internal/firewalls/nftables/conn_test.go b/internal/firewalls/nftables/conn_test.go new file mode 100644 index 0000000..37fa999 --- /dev/null +++ b/internal/firewalls/nftables/conn_test.go @@ -0,0 +1,78 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux +// +build linux + +package nftables_test + +import ( + "github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables" + "os/exec" + "testing" +) + +func TestConn_Test(t *testing.T) { + _, err := exec.LookPath("nft") + if err == nil { + t.Log("ok") + return + } + t.Log(err) +} + +func TestConn_GetTable_NotFound(t *testing.T) { + var conn = nftables.NewConn() + + table, err := conn.GetTable("a", nftables.TableFamilyIPv4) + if err != nil { + if err == nftables.ErrTableNotFound { + t.Log("table not found") + } else { + t.Fatal(err) + } + } else { + t.Log("table:", table) + } +} + +func TestConn_GetTable(t *testing.T) { + var conn = nftables.NewConn() + + table, err := conn.GetTable("myFilter", nftables.TableFamilyIPv4) + if err != nil { + if err == nftables.ErrTableNotFound { + t.Log("table not found") + } else { + t.Fatal(err) + } + } else { + t.Log("table:", table) + } +} + +func TestConn_AddTable(t *testing.T) { + var conn = nftables.NewConn() + + { + table, err := conn.AddIPv4Table("test_ipv4") + if err != nil { + t.Fatal(err) + } + t.Log(table.Name()) + } + { + table, err := conn.AddIPv6Table("test_ipv6") + if err != nil { + t.Fatal(err) + } + t.Log(table.Name()) + } +} + +func TestConn_DeleteTable(t *testing.T) { + var conn = nftables.NewConn() + err := conn.DeleteTable("test_ipv4", nftables.TableFamilyIPv4) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} diff --git a/internal/firewalls/nftables/element.go b/internal/firewalls/nftables/element.go new file mode 100644 index 0000000..495b642 --- /dev/null +++ b/internal/firewalls/nftables/element.go @@ -0,0 +1,8 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux +// +build linux + +package nftables + +type Element struct { +} diff --git a/internal/firewalls/nftables/errors.go b/internal/firewalls/nftables/errors.go new file mode 100644 index 0000000..7eba54c --- /dev/null +++ b/internal/firewalls/nftables/errors.go @@ -0,0 +1,19 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux +// +build linux + +package nftables + +import "errors" + +var ErrTableNotFound = errors.New("table not found") +var ErrChainNotFound = errors.New("chain not found") +var ErrSetNotFound = errors.New("set not found") +var ErrRuleNotFound = errors.New("rule not found") + +func IsNotFound(err error) bool { + if err == nil { + return false + } + return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound +} diff --git a/internal/firewalls/nftables/family.go b/internal/firewalls/nftables/family.go new file mode 100644 index 0000000..42fa71a --- /dev/null +++ b/internal/firewalls/nftables/family.go @@ -0,0 +1,18 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package nftables + +import ( + nft "github.com/google/nftables" +) + +type TableFamily = nft.TableFamily + +const ( + TableFamilyINet TableFamily = nft.TableFamilyINet + TableFamilyIPv4 TableFamily = nft.TableFamilyIPv4 + TableFamilyIPv6 TableFamily = nft.TableFamilyIPv6 + TableFamilyARP TableFamily = nft.TableFamilyARP + TableFamilyNetdev TableFamily = nft.TableFamilyNetdev + TableFamilyBridge TableFamily = nft.TableFamilyBridge +) diff --git a/internal/firewalls/nftables/rule.go b/internal/firewalls/nftables/rule.go new file mode 100644 index 0000000..75b3874 --- /dev/null +++ b/internal/firewalls/nftables/rule.go @@ -0,0 +1,51 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package nftables + +import ( + nft "github.com/google/nftables" + "github.com/google/nftables/expr" +) + +type Rule struct { + rawRule *nft.Rule +} + +func NewRule(rawRule *nft.Rule) *Rule { + return &Rule{ + rawRule: rawRule, + } +} + +func (this *Rule) Raw() *nft.Rule { + return this.rawRule +} + +func (this *Rule) LookupSetName() string { + for _, e := range this.rawRule.Exprs { + exp, ok := e.(*expr.Lookup) + if ok { + return exp.SetName + } + } + return "" +} + +func (this *Rule) VerDict() expr.VerdictKind { + for _, e := range this.rawRule.Exprs { + exp, ok := e.(*expr.Verdict) + if ok { + return exp.Kind + } + } + + return -100 +} + +func (this *Rule) Handle() uint64 { + return this.rawRule.Handle +} + +func (this *Rule) UserData() []byte { + return this.rawRule.UserData +} diff --git a/internal/firewalls/nftables/set.go b/internal/firewalls/nftables/set.go new file mode 100644 index 0000000..3f5deee --- /dev/null +++ b/internal/firewalls/nftables/set.go @@ -0,0 +1,161 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux +// +build linux + +package nftables + +import ( + "errors" + "github.com/TeaOSLab/EdgeNode/internal/utils" + nft "github.com/google/nftables" + "net" + "strings" + "time" +) + +const MaxSetNameLength = 15 + +type SetOptions struct { + Id uint32 + HasTimeout bool + Timeout time.Duration + KeyType SetDataType + DataType SetDataType + Constant bool + Interval bool + Anonymous bool + IsMap bool +} + +type ElementOptions struct { + Timeout time.Duration +} + +type Set struct { + conn *Conn + rawSet *nft.Set + batch *SetBatch +} + +func NewSet(conn *Conn, rawSet *nft.Set) *Set { + return &Set{ + conn: conn, + rawSet: rawSet, + batch: &SetBatch{ + conn: conn, + rawSet: rawSet, + }, + } +} + +func (this *Set) Raw() *nft.Set { + return this.rawSet +} + +func (this *Set) Name() string { + return this.rawSet.Name +} + +func (this *Set) AddElement(key []byte, options *ElementOptions) error { + var rawElement = nft.SetElement{ + Key: key, + } + if options != nil { + rawElement.Timeout = options.Timeout + } + err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{ + rawElement, + }) + if err != nil { + return err + } + + err = this.conn.Commit() + if err != nil { + // retry if exists + if strings.Contains(err.Error(), "file exists") { + deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{ + { + Key: key, + }, + }) + if deleteErr == nil { + err = this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{ + rawElement, + }) + if err == nil { + err = this.conn.Commit() + } + } + } + } + + return err +} + +func (this *Set) AddIPElement(ip string, options *ElementOptions) error { + var ipObj = net.ParseIP(ip) + if ipObj == nil { + return errors.New("invalid ip '" + ip + "'") + } + + if utils.IsIPv4(ip) { + return this.AddElement(ipObj.To4(), options) + } else { + return this.AddElement(ipObj.To16(), options) + } +} + +func (this *Set) DeleteElement(key []byte) error { + err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{ + { + Key: key, + }, + }) + if err != nil { + return err + } + err = this.conn.Commit() + if err != nil { + if strings.Contains(err.Error(), "no such file or directory") { + err = nil + } + } + return err +} + +func (this *Set) DeleteIPElement(ip string) error { + var ipObj = net.ParseIP(ip) + if ipObj == nil { + return errors.New("invalid ip '" + ip + "'") + } + + if utils.IsIPv4(ip) { + return this.DeleteElement(ipObj.To4()) + } else { + return this.DeleteElement(ipObj.To16()) + } +} + +func (this *Set) Batch() *SetBatch { + return this.batch +} + +func (this *Set) GetIPElements() ([]string, error) { + elements, err := this.conn.Raw().GetSetElements(this.rawSet) + if err != nil { + return nil, err + } + + var result = []string{} + for _, element := range elements { + result = append(result, net.IP(element.Key).String()) + } + return result, nil +} + +// not work current time +/**func (this *Set) Flush() error { + this.conn.Raw().FlushSet(this.rawSet) + return this.conn.Commit() +}**/ diff --git a/internal/firewalls/nftables/set_batch.go b/internal/firewalls/nftables/set_batch.go new file mode 100644 index 0000000..c296561 --- /dev/null +++ b/internal/firewalls/nftables/set_batch.go @@ -0,0 +1,36 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package nftables + +import ( + nft "github.com/google/nftables" +) + +type SetBatch struct { + conn *Conn + rawSet *nft.Set +} + +func (this *SetBatch) AddElement(key []byte, options *ElementOptions) error { + var rawElement = nft.SetElement{ + Key: key, + } + if options != nil { + rawElement.Timeout = options.Timeout + } + return this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{ + rawElement, + }) +} + +func (this *SetBatch) DeleteElement(key []byte) error { + return this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{ + { + Key: key, + }, + }) +} + +func (this *SetBatch) Commit() error { + return this.conn.Commit() +} diff --git a/internal/firewalls/nftables/set_data_type.go b/internal/firewalls/nftables/set_data_type.go new file mode 100644 index 0000000..6fb80c8 --- /dev/null +++ b/internal/firewalls/nftables/set_data_type.go @@ -0,0 +1,57 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package nftables + +import nft "github.com/google/nftables" + +type SetDataType = nft.SetDatatype + +var ( + TypeInvalid = nft.TypeInvalid + TypeVerdict = nft.TypeVerdict + TypeNFProto = nft.TypeNFProto + TypeBitmask = nft.TypeBitmask + TypeInteger = nft.TypeInteger + TypeString = nft.TypeString + TypeLLAddr = nft.TypeLLAddr + TypeIPAddr = nft.TypeIPAddr + TypeIP6Addr = nft.TypeIP6Addr + TypeEtherAddr = nft.TypeEtherAddr + TypeEtherType = nft.TypeEtherType + TypeARPOp = nft.TypeARPOp + TypeInetProto = nft.TypeInetProto + TypeInetService = nft.TypeInetService + TypeICMPType = nft.TypeICMPType + TypeTCPFlag = nft.TypeTCPFlag + TypeDCCPPktType = nft.TypeDCCPPktType + TypeMHType = nft.TypeMHType + TypeTime = nft.TypeTime + TypeMark = nft.TypeMark + TypeIFIndex = nft.TypeIFIndex + TypeARPHRD = nft.TypeARPHRD + TypeRealm = nft.TypeRealm + TypeClassID = nft.TypeClassID + TypeUID = nft.TypeUID + TypeGID = nft.TypeGID + TypeCTState = nft.TypeCTState + TypeCTDir = nft.TypeCTDir + TypeCTStatus = nft.TypeCTStatus + TypeICMP6Type = nft.TypeICMP6Type + TypeCTLabel = nft.TypeCTLabel + TypePktType = nft.TypePktType + TypeICMPCode = nft.TypeICMPCode + TypeICMPV6Code = nft.TypeICMPV6Code + TypeICMPXCode = nft.TypeICMPXCode + TypeDevGroup = nft.TypeDevGroup + TypeDSCP = nft.TypeDSCP + TypeECN = nft.TypeECN + TypeFIBAddr = nft.TypeFIBAddr + TypeBoolean = nft.TypeBoolean + TypeCTEventBit = nft.TypeCTEventBit + TypeIFName = nft.TypeIFName + TypeIGMPType = nft.TypeIGMPType + TypeTimeDate = nft.TypeTimeDate + TypeTimeHour = nft.TypeTimeHour + TypeTimeDay = nft.TypeTimeDay + TypeCGroupV2 = nft.TypeCGroupV2 +) diff --git a/internal/firewalls/nftables/set_test.go b/internal/firewalls/nftables/set_test.go new file mode 100644 index 0000000..baa7c59 --- /dev/null +++ b/internal/firewalls/nftables/set_test.go @@ -0,0 +1,110 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package nftables_test + +import ( + "errors" + "github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables" + "github.com/iwind/TeaGo/types" + "github.com/mdlayher/netlink" + "net" + "testing" + "time" +) + +func getIPv4Set(t *testing.T) *nftables.Set { + var table = getIPv4Table(t) + set, err := table.GetSet("test_ipv4_set") + if err != nil { + if err == nftables.ErrSetNotFound { + set, err = table.AddSet("test_ipv4_set", &nftables.SetOptions{ + KeyType: nftables.TypeIPAddr, + HasTimeout: true, + }) + if err != nil { + t.Fatal(err) + } + } else { + t.Fatal(err) + } + } + return set +} + +func TestSet_AddElement(t *testing.T) { + var set = getIPv4Set(t) + err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second}) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} + +func TestSet_DeleteElement(t *testing.T) { + var set = getIPv4Set(t) + err := set.DeleteElement(net.ParseIP("192.168.2.31").To4()) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} + +func TestSet_Batch(t *testing.T) { + var batch = getIPv4Set(t).Batch() + + for _, ip := range []string{"192.168.2.30", "192.168.2.31", "192.168.2.32", "192.168.2.33", "192.168.2.34"} { + var ipData = net.ParseIP(ip).To4() + //err := batch.DeleteElement(ipData) + //if err != nil { + // t.Fatal(err) + //} + err := batch.AddElement(ipData, &nftables.ElementOptions{Timeout: 10 * time.Second}) + if err != nil { + t.Fatal(err) + } + } + + err := batch.Commit() + if err != nil { + t.Logf("%#v", errors.Unwrap(err).(*netlink.OpError)) + t.Fatal(err) + } + t.Log("ok") +} + +func TestSet_Add_Many(t *testing.T) { + var set = getIPv4Set(t) + + for i := 0; i < 255; i++ { + t.Log(i) + for j := 0; j < 255; j++ { + var ip = "192.167." + types.String(i) + "." + types.String(j) + var ipData = net.ParseIP(ip).To4() + err := set.Batch().AddElement(ipData, &nftables.ElementOptions{Timeout: 3600 * time.Second}) + if err != nil { + t.Fatal(err) + } + + if j%10 == 0 { + err = set.Batch().Commit() + if err != nil { + t.Fatal(err) + } + } + } + err := set.Batch().Commit() + if err != nil { + t.Fatal(err) + } + } + t.Log("ok") +} + +/**func TestSet_Flush(t *testing.T) { + var set = getIPv4Set(t) + err := set.Flush() + if err != nil { + t.Fatal(err) + } + t.Log("ok") +}**/ diff --git a/internal/firewalls/nftables/table.go b/internal/firewalls/nftables/table.go new file mode 100644 index 0000000..cfbe766 --- /dev/null +++ b/internal/firewalls/nftables/table.go @@ -0,0 +1,157 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux +// +build linux + +package nftables + +import ( + "errors" + nft "github.com/google/nftables" + "github.com/iwind/TeaGo/types" + "strings" +) + +type Table struct { + conn *Conn + rawTable *nft.Table +} + +func NewTable(conn *Conn, rawTable *nft.Table) *Table { + return &Table{ + conn: conn, + rawTable: rawTable, + } +} + +func (this *Table) Raw() *nft.Table { + return this.rawTable +} + +func (this *Table) Name() string { + return this.rawTable.Name +} + +func (this *Table) Family() TableFamily { + return this.rawTable.Family +} + +func (this *Table) GetChain(name string) (*Chain, error) { + rawChains, err := this.conn.Raw().ListChains() + if err != nil { + return nil, err + } + for _, rawChain := range rawChains { + // must compare table name + if rawChain.Name == name && rawChain.Table.Name == this.rawTable.Name { + return NewChain(this.conn, this.rawTable, rawChain), nil + } + } + return nil, ErrChainNotFound +} + +func (this *Table) AddChain(name string, chainPolicy *ChainPolicy) (*Chain, error) { + if len(name) > MaxChainNameLength { + return nil, errors.New("chain name too long (max " + types.String(MaxChainNameLength) + ")") + } + + var rawChain = this.conn.Raw().AddChain(&nft.Chain{ + Name: name, + Table: this.rawTable, + Hooknum: nft.ChainHookInput, + Priority: nft.ChainPriorityFilter, + Type: nft.ChainTypeFilter, + Policy: chainPolicy, + }) + + err := this.conn.Commit() + if err != nil { + return nil, err + } + return NewChain(this.conn, this.rawTable, rawChain), nil +} + +func (this *Table) AddAcceptChain(name string) (*Chain, error) { + var policy = ChainPolicyAccept + return this.AddChain(name, &policy) +} + +func (this *Table) AddDropChain(name string) (*Chain, error) { + var policy = ChainPolicyDrop + return this.AddChain(name, &policy) +} + +func (this *Table) DeleteChain(name string) error { + chain, err := this.GetChain(name) + if err != nil { + if err == ErrChainNotFound { + return nil + } + return err + } + this.conn.Raw().DelChain(chain.Raw()) + return this.conn.Commit() +} + +func (this *Table) GetSet(name string) (*Set, error) { + rawSet, err := this.conn.Raw().GetSetByName(this.rawTable, name) + if err != nil { + if strings.Contains(err.Error(), "no such file or directory") { + return nil, ErrSetNotFound + } + return nil, err + } + + return NewSet(this.conn, rawSet), nil +} + +func (this *Table) AddSet(name string, options *SetOptions) (*Set, error) { + if len(name) > MaxSetNameLength { + return nil, errors.New("set name too long (max " + types.String(MaxSetNameLength) + ")") + } + + if options == nil { + options = &SetOptions{} + } + var rawSet = &nft.Set{ + Table: this.rawTable, + ID: options.Id, + Name: name, + Anonymous: options.Anonymous, + Constant: options.Constant, + Interval: options.Interval, + IsMap: options.IsMap, + HasTimeout: options.HasTimeout, + Timeout: options.Timeout, + KeyType: options.KeyType, + DataType: options.DataType, + } + err := this.conn.Raw().AddSet(rawSet, nil) + if err != nil { + return nil, err + } + + err = this.conn.Commit() + if err != nil { + return nil, err + } + + return NewSet(this.conn, rawSet), nil +} + +func (this *Table) DeleteSet(name string) error { + set, err := this.GetSet(name) + if err != nil { + if err == ErrSetNotFound { + return nil + } + return err + } + + this.conn.Raw().DelSet(set.Raw()) + return this.conn.Commit() +} + +func (this *Table) Flush() error { + this.conn.Raw().FlushTable(this.rawTable) + return this.conn.Commit() +} diff --git a/internal/firewalls/nftables/table_test.go b/internal/firewalls/nftables/table_test.go new file mode 100644 index 0000000..e414a29 --- /dev/null +++ b/internal/firewalls/nftables/table_test.go @@ -0,0 +1,140 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. +//go:build linux +// +build linux + +package nftables_test + +import ( + "github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables" + "testing" +) + +func getIPv4Table(t *testing.T) *nftables.Table { + var conn = nftables.NewConn() + table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4) + if err != nil { + if err == nftables.ErrTableNotFound { + table, err = conn.AddIPv4Table("test_ipv4") + if err != nil { + t.Fatal(err) + } + } else { + t.Fatal(err) + } + } + return table +} + +func TestTable_AddChain(t *testing.T) { + var table = getIPv4Table(t) + + { + chain, err := table.AddChain("test_default_chain", nil) + if err != nil { + t.Fatal(err) + } + t.Log("created:", chain.Name()) + } + + { + chain, err := table.AddAcceptChain("test_accept_chain") + if err != nil { + t.Fatal(err) + } + t.Log("created:", chain.Name()) + } + + // Do not test drop chain before adding accept rule, you will drop yourself!!!!!!! + /**{ + chain, err := table.AddDropChain("test_drop_chain") + if err != nil { + t.Fatal(err) + } + t.Log("created:", chain.Name()) + }**/ +} + +func TestTable_GetChain(t *testing.T) { + var table = getIPv4Table(t) + for _, chainName := range []string{"not_found_chain", "test_default_chain"} { + chain, err := table.GetChain(chainName) + if err != nil { + if err == nftables.ErrChainNotFound { + t.Log(chainName, ":", "not found") + } else { + t.Fatal(err) + } + } else { + t.Log(chainName, ":", chain) + } + } +} + +func TestTable_DeleteChain(t *testing.T) { + var table = getIPv4Table(t) + err := table.DeleteChain("test_default_chain") + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} + +func TestTable_AddSet(t *testing.T) { + var table = getIPv4Table(t) + { + set, err := table.AddSet("ipv4_black_set", &nftables.SetOptions{ + HasTimeout: false, + KeyType: nftables.TypeIPAddr, + }) + if err != nil { + t.Fatal(err) + } + t.Log(set.Name()) + } + + { + set, err := table.AddSet("ipv6_black_set", &nftables.SetOptions{ + HasTimeout: true, + //Timeout: 3600 * time.Second, + KeyType: nftables.TypeIP6Addr, + }) + if err != nil { + t.Fatal(err) + } + t.Log(set.Name()) + } +} + +func TestTable_GetSet(t *testing.T) { + var table = getIPv4Table(t) + for _, setName := range []string{"not_found_set", "ipv4_black_set"} { + set, err := table.GetSet(setName) + if err != nil { + if err == nftables.ErrSetNotFound { + t.Log(setName, ": not found") + } else { + t.Fatal(err) + } + } else { + t.Log(setName, ":", set) + } + } +} + +func TestTable_DeleteSet(t *testing.T) { + var table = getIPv4Table(t) + err := table.DeleteSet("ipv4_black_set") + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} + +func TestTable_Flush(t *testing.T) { + var table = getIPv4Table(t) + err := table.Flush() + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} diff --git a/internal/nodes/api_stream.go b/internal/nodes/api_stream.go index cf174b8..5a9746b 100644 --- a/internal/nodes/api_stream.go +++ b/internal/nodes/api_stream.go @@ -1,6 +1,7 @@ package nodes import ( + "bytes" "context" "crypto/tls" "encoding/json" @@ -14,16 +15,20 @@ import ( teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/errors" "github.com/TeaOSLab/EdgeNode/internal/events" + "github.com/TeaOSLab/EdgeNode/internal/firewalls" "github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/rpc" "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/maps" "io" "net" "net/http" "net/url" "os/exec" + "regexp" + "runtime" "strconv" "strings" "sync" @@ -119,6 +124,8 @@ func (this *APIStream) loop() error { err = this.handleNewNodeTask(message) case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务 err = this.handleCheckSystemdService(message) + case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙 + err = this.handleCheckLocalFirewall(message) case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址 err = this.handleChangeAPINode(message) default: @@ -569,7 +576,7 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage) return nil } - cmd := utils.NewCommandExecutor() + var cmd = utils.NewCommandExecutor() shortName := teaconst.SystemdServiceName cmd.Add(systemctl, "is-enabled", shortName) output, err := cmd.Run() @@ -585,6 +592,63 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage) return nil } +// 检查本地防火墙 +func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) error { + var dataMessage = &messageconfigs.CheckLocalFirewallMessage{} + err := json.Unmarshal(message.DataJSON, dataMessage) + if err != nil { + this.replyFail(message.RequestId, "decode message data failed: "+err.Error()) + return nil + } + + // nft + if dataMessage.Name == "nftables" { + if runtime.GOOS != "linux" { + this.replyFail(message.RequestId, "not Linux system") + return nil + } + + nft, err := exec.LookPath("nft") + if err != nil { + this.replyFail(message.RequestId, "'nft' not found: "+err.Error()) + return nil + } + + var cmd = exec.Command(nft, "--version") + var output = &bytes.Buffer{} + cmd.Stdout = output + err = cmd.Run() + if err != nil { + this.replyFail(message.RequestId, "get version failed: "+err.Error()) + return nil + } + + var outputString = output.String() + var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString) + if len(versionMatches) <= 1 { + this.replyFail(message.RequestId, "can not get nft version") + return nil + } + var version = versionMatches[1] + + var result = maps.Map{ + "version": version, + } + + var protectionConfig = sharedNodeConfig.DDOSProtection + err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig) + if err != nil { + this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error()) + } else { + this.replyOk(message.RequestId, string(result.AsJSON())) + } + } else { + this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'") + } + + return nil +} + // 修改API地址 func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error { config, err := configs.LoadAPIConfig() @@ -660,6 +724,11 @@ func (this *APIStream) replyOk(requestId int64, message string) { _ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message}) } +// 回复成功并包含数据 +func (this *APIStream) replyOkData(requestId int64, message string, dataJSON []byte) { + _ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message, DataJSON: dataJSON}) +} + // 获取缓存存取对象 func (this *APIStream) cacheStorage(message *pb.NodeStreamMessage, cachePolicyJSON []byte) (storage caches.StorageInterface, shouldStop bool, err error) { cachePolicy := &serverconfigs.HTTPCachePolicy{} diff --git a/internal/nodes/client_conn.go b/internal/nodes/client_conn.go index ce2708c..224f0e5 100644 --- a/internal/nodes/client_conn.go +++ b/internal/nodes/client_conn.go @@ -7,7 +7,6 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/iplibrary" - "github.com/TeaOSLab/EdgeNode/internal/ratelimit" "github.com/TeaOSLab/EdgeNode/internal/ttlcache" "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/waf" @@ -21,8 +20,7 @@ import ( // ClientConn 客户端连接 type ClientConn struct { - once sync.Once - globalLimiter *ratelimit.Counter + once sync.Once isTLS bool hasDeadline bool @@ -33,7 +31,7 @@ type ClientConn struct { BaseClientConn } -func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn { +func NewClientConn(conn net.Conn, isTLS bool, quickClose bool) net.Conn { if quickClose { // TCP tcpConn, ok := conn.(*net.TCPConn) @@ -43,7 +41,7 @@ func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ra } } - return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS, globalLimiter: globalLimiter} + return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS} } func (this *ClientConn) Read(b []byte) (n int, err error) { @@ -96,13 +94,6 @@ func (this *ClientConn) Close() error { err := this.rawConn.Close() - // 全局并发数限制 - this.once.Do(func() { - if this.globalLimiter != nil { - this.globalLimiter.Release() - } - }) - // 单个服务并发数限制 sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String()) diff --git a/internal/nodes/client_listener.go b/internal/nodes/client_listener.go index e0f155a..58a2496 100644 --- a/internal/nodes/client_listener.go +++ b/internal/nodes/client_listener.go @@ -3,16 +3,12 @@ package nodes import ( - "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeNode/internal/iplibrary" - "github.com/TeaOSLab/EdgeNode/internal/ratelimit" "github.com/TeaOSLab/EdgeNode/internal/waf" "net" ) -var sharedConnectionsLimiter = ratelimit.NewCounter(nodeconfigs.DefaultTCPMaxConnections) - // ClientListener 客户端网络监听 type ClientListener struct { rawListener net.Listener @@ -36,13 +32,8 @@ func (this *ClientListener) IsTLS() bool { } func (this *ClientListener) Accept() (net.Conn, error) { - // 限制并发连接数 - var limiter = sharedConnectionsLimiter - limiter.Ack() - conn, err := this.rawListener.Accept() if err != nil { - limiter.Release() return nil, err } @@ -60,12 +51,11 @@ func (this *ClientListener) Accept() (net.Conn, error) { } _ = conn.Close() - limiter.Release() return this.Accept() } } - return NewClientConn(conn, this.isTLS, this.quickClose, limiter), nil + return NewClientConn(conn, this.isTLS, this.quickClose), nil } func (this *ClientListener) Close() error { diff --git a/internal/nodes/node.go b/internal/nodes/node.go index 6e04029..046a20a 100644 --- a/internal/nodes/node.go +++ b/internal/nodes/node.go @@ -7,6 +7,7 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs" "github.com/TeaOSLab/EdgeNode/internal/caches" "github.com/TeaOSLab/EdgeNode/internal/configs" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" @@ -15,7 +16,6 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/metrics" - "github.com/TeaOSLab/EdgeNode/internal/ratelimit" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/rpc" "github.com/TeaOSLab/EdgeNode/internal/stats" @@ -368,6 +368,38 @@ func (this *Node) loop() error { } sharedNodeConfig.ParentNodes = parentNodes + // 修改为已同步 + _, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{ + NodeTaskId: task.Id, + IsOk: true, + Error: "", + }) + if err != nil { + return err + } + case "ddosProtectionChanged": + resp, err := rpcClient.NodeRPC().FindNodeDDoSProtection(nodeCtx, &pb.FindNodeDDoSProtectionRequest{}) + if err != nil { + return err + } + if len(resp.DdosProtectionJSON) == 0 { + if sharedNodeConfig != nil { + sharedNodeConfig.DDOSProtection = nil + } + } else { + var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{} + err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig) + if err != nil { + return errors.New("decode DDoS protection config failed: " + err.Error()) + } + + err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig) + if err != nil { + // 不阻塞 + remotelogs.Error("NODE", "apply DDoS protection failed: "+err.Error()) + } + } + // 修改为已同步 _, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{ NodeTaskId: task.Id, @@ -730,7 +762,6 @@ func (this *Node) listenSock() error { "ipConns": ipConns, "serverConns": serverConns, "total": sharedListenerManager.TotalActiveConnections(), - "limiter": sharedConnectionsLimiter.Len(), }, }) case "dropIP": @@ -854,17 +885,6 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig) { this.maxThreads = config.MaxThreads } - // max tcp connections - if config.TCPMaxConnections <= 0 { - config.TCPMaxConnections = nodeconfigs.DefaultTCPMaxConnections - } - if config.TCPMaxConnections != sharedConnectionsLimiter.Count() { - remotelogs.Println("NODE", "[TCP]changed tcp max connections to '"+types.String(config.TCPMaxConnections)+"'") - - sharedConnectionsLimiter.Close() - sharedConnectionsLimiter = ratelimit.NewCounter(config.TCPMaxConnections) - } - // timezone var timeZone = config.TimeZone if len(timeZone) == 0 { diff --git a/internal/utils/ip.go b/internal/utils/ip.go index 6dae6a0..35aa56d 100644 --- a/internal/utils/ip.go +++ b/internal/utils/ip.go @@ -5,9 +5,12 @@ import ( "github.com/cespare/xxhash" "math" "net" + "regexp" "strings" ) +var ipv4Reg = regexp.MustCompile(`\d+\.`) + // IP2Long 将IP转换为整型 // 注意IPv6没有顺序 func IP2Long(ip string) uint64 { @@ -54,3 +57,24 @@ func IsLocalIP(ipString string) bool { return false } + +// IsIPv4 是否为IPv4 +func IsIPv4(ip string) bool { + var data = net.ParseIP(ip) + if data == nil { + return false + } + if strings.Contains(ip, ":") { + return false + } + return data.To4() != nil +} + +// IsIPv6 是否为IPv6 +func IsIPv6(ip string) bool { + var data = net.ParseIP(ip) + if data == nil { + return false + } + return !IsIPv4(ip) +} diff --git a/internal/utils/ip_test.go b/internal/utils/ip_test.go index ef2a412..5f9ac56 100644 --- a/internal/utils/ip_test.go +++ b/internal/utils/ip_test.go @@ -26,3 +26,26 @@ func TestIsLocalIP(t *testing.T) { a.IsFalse(IsLocalIP("::1:2:3")) a.IsFalse(IsLocalIP("8.8.8.8")) } + +func TestIsIPv4(t *testing.T) { + var a = assert.NewAssertion(t) + a.IsTrue(IsIPv4("192.168.1.1")) + a.IsTrue(IsIPv4("0.0.0.0")) + a.IsFalse(IsIPv4("192.168.1.256")) + a.IsFalse(IsIPv4("192.168.1")) + a.IsFalse(IsIPv4("::1")) + a.IsFalse(IsIPv4("2001:0db8:85a3:0000:0000:8a2e:0370:7334")) + a.IsFalse(IsIPv4("::ffff:192.168.0.1")) +} + +func TestIsIPv6(t *testing.T) { + var a = assert.NewAssertion(t) + a.IsFalse(IsIPv6("192.168.1.1")) + a.IsFloat32(IsIPv6("0.0.0.0")) + a.IsFalse(IsIPv6("192.168.1.256")) + a.IsFalse(IsIPv6("192.168.1")) + a.IsTrue(IsIPv6("::1")) + a.IsTrue(IsIPv6("2001:0db8:85a3:0000:0000:8a2e:0370:7334")) + a.IsTrue(IsIPv4("::ffff:192.168.0.1")) + a.IsTrue(IsIPv6("::ffff:192.168.0.1")) +} diff --git a/internal/utils/string.go b/internal/utils/string.go index 6ad9599..3a41437 100644 --- a/internal/utils/string.go +++ b/internal/utils/string.go @@ -1,6 +1,7 @@ package utils import ( + "sort" "strings" "unsafe" ) @@ -36,6 +37,22 @@ func FormatAddressList(addrList []string) []string { return result } +// ToValidUTF8string 去除字符串中的非UTF-8字符 func ToValidUTF8string(v string) string { return strings.ToValidUTF8(v, "") } + +// ContainsSameStrings 检查两个字符串slice内容是否一致 +func ContainsSameStrings(s1 []string, s2 []string) bool { + if len(s1) != len(s2) { + return false + } + sort.Strings(s1) + sort.Strings(s2) + for index, v1 := range s1 { + if v1 != s2[index] { + return false + } + } + return true +} diff --git a/internal/utils/string_test.go b/internal/utils/string_test.go index 14e0ba6..5c6084d 100644 --- a/internal/utils/string_test.go +++ b/internal/utils/string_test.go @@ -1,56 +1,67 @@ -package utils +package utils_test import ( + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/iwind/TeaGo/assert" "strings" "testing" ) func TestBytesToString(t *testing.T) { - t.Log(UnsafeBytesToString([]byte("Hello,World"))) + t.Log(utils.UnsafeBytesToString([]byte("Hello,World"))) } func TestStringToBytes(t *testing.T) { - t.Log(string(UnsafeStringToBytes("Hello,World"))) + t.Log(string(utils.UnsafeStringToBytes("Hello,World"))) } func BenchmarkBytesToString(b *testing.B) { - data := []byte("Hello,World") + var data = []byte("Hello,World") for i := 0; i < b.N; i++ { - _ = UnsafeBytesToString(data) + _ = utils.UnsafeBytesToString(data) } } func BenchmarkBytesToString2(b *testing.B) { - data := []byte("Hello,World") + var data = []byte("Hello,World") for i := 0; i < b.N; i++ { _ = string(data) } } func BenchmarkStringToBytes(b *testing.B) { - s := strings.Repeat("Hello,World", 1024) + var s = strings.Repeat("Hello,World", 1024) for i := 0; i < b.N; i++ { - _ = UnsafeStringToBytes(s) + _ = utils.UnsafeStringToBytes(s) } } func BenchmarkStringToBytes2(b *testing.B) { - s := strings.Repeat("Hello,World", 1024) + var s = strings.Repeat("Hello,World", 1024) for i := 0; i < b.N; i++ { _ = []byte(s) } } func TestFormatAddress(t *testing.T) { - t.Log(FormatAddress("127.0.0.1:1234")) - t.Log(FormatAddress("127.0.0.1 : 1234")) - t.Log(FormatAddress("127.0.0.1:1234")) + t.Log(utils.FormatAddress("127.0.0.1:1234")) + t.Log(utils.FormatAddress("127.0.0.1 : 1234")) + t.Log(utils.FormatAddress("127.0.0.1:1234")) } func TestFormatAddressList(t *testing.T) { - t.Log(FormatAddressList([]string{ + t.Log(utils.FormatAddressList([]string{ "127.0.0.1:1234", "127.0.0.1 : 1234", "127.0.0.1:1234", })) } + +func TestContainsSameStrings(t *testing.T) { + var a = assert.NewAssertion(t) + a.IsFalse(utils.ContainsSameStrings([]string{"a"}, []string{"b"})) + a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b"})) + a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b", "c"})) + a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b"})) + a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b", "a"})) +}