实现基础的DDoS防护

This commit is contained in:
GoEdgeLab
2022-05-18 21:03:51 +08:00
parent 23192f6fec
commit 9d68710531
31 changed files with 2605 additions and 58 deletions

View File

@@ -3,9 +3,10 @@ package events
type Event = string type Event = string
const ( const (
EventStart Event = "start" // start loading EventStart Event = "start" // start loading
EventLoaded Event = "loaded" // first load EventLoaded Event = "loaded" // first load
EventQuit Event = "quit" // quit node gracefully EventQuit Event = "quit" // quit node gracefully
EventReload Event = "reload" // reload config EventReload Event = "reload" // reload config
EventTerminated Event = "terminated" // process terminated EventTerminated Event = "terminated" // process terminated
EventNFTablesReady Event = "nftablesReady" // nftables ready
) )

View File

@@ -1 +0,0 @@
firewall_nftables_test.go

View File

@@ -0,0 +1,494 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package firewalls
import (
"bytes"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/zero"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
"net"
"os/exec"
"strings"
)
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
func init() {
events.On(events.EventReload, func() {
if nftablesInstance == nil {
return
}
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil {
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection)
if err != nil {
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
}
}
})
events.On(events.EventNFTablesReady, func() {
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil {
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection)
if err != nil {
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
}
}
})
}
// DDoSProtectionManager DDoS防护
type DDoSProtectionManager struct {
nftPath string
lastAllowIPList []string
lastConfig []byte
}
// NewDDoSProtectionManager 获取新对象
func NewDDoSProtectionManager() *DDoSProtectionManager {
nftPath, _ := exec.LookPath("nft")
return &DDoSProtectionManager{
nftPath: nftPath,
}
}
// Apply 应用配置
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
// 同集群节点IP白名单
var allowIPListChanged = false
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
if nodeConfig != nil {
var allowIPList = nodeConfig.AllowedIPs
if !utils.ContainsSameStrings(allowIPList, this.lastAllowIPList) {
allowIPListChanged = true
this.lastAllowIPList = allowIPList
}
}
// 对比配置
configJSON, err := json.Marshal(config)
if err != nil {
return errors.New("encode config to json failed: " + err.Error())
}
if !allowIPListChanged && bytes.Equal(this.lastConfig, configJSON) {
return nil
}
remotelogs.Println("FIREWALL", "change DDoS protection config")
if len(this.nftPath) == 0 {
return errors.New("can not find nft command")
}
if nftablesInstance == nil {
return errors.New("nftables instance should not be nil")
}
if config == nil {
// TCP
err := this.removeTCPRules()
if err != nil {
return err
}
// TODO other protocols
return nil
}
// TCP
if config.TCP == nil {
err := this.removeTCPRules()
if err != nil {
return err
}
} else {
// allow ip list
var allowIPList = []string{}
for _, ipConfig := range config.TCP.AllowIPList {
allowIPList = append(allowIPList, ipConfig.IP)
}
for _, ip := range this.lastAllowIPList {
if !lists.ContainsString(allowIPList, ip) {
allowIPList = append(allowIPList, ip)
}
}
err = this.updateAllowIPList(allowIPList)
if err != nil {
return err
}
// tcp
if config.TCP.IsOn {
err := this.addTCPRules(config.TCP)
if err != nil {
return err
}
} else {
err := this.removeTCPRules()
if err != nil {
return err
}
}
}
this.lastConfig = configJSON
return nil
}
// 添加TCP规则
func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig) error {
var ports = []int32{}
for _, portConfig := range tcpConfig.Ports {
if !lists.ContainsInt32(ports, portConfig.Port) {
ports = append(ports, portConfig.Port)
}
}
if len(ports) == 0 {
ports = []int32{80, 443}
}
for _, filter := range nftablesFilters {
chain, oldRules, err := this.getRules(filter)
if err != nil {
return errors.New("get old rules failed: " + err.Error())
}
var protocol = filter.protocol()
// max connections
var maxConnections = tcpConfig.MaxConnections
if maxConnections <= 0 {
maxConnections = nodeconfigs.DefaultTCPMaxConnections
if maxConnections <= 0 {
maxConnections = 100000
}
}
// max connections per ip
var maxConnectionsPerIP = tcpConfig.MaxConnectionsPerIP
if maxConnectionsPerIP <= 0 {
maxConnectionsPerIP = nodeconfigs.DefaultTCPMaxConnectionsPerIP
if maxConnectionsPerIP <= 0 {
maxConnectionsPerIP = 100000
}
}
// new connections rate
var newConnectionsRate = tcpConfig.NewConnectionsRate
if newConnectionsRate <= 0 {
newConnectionsRate = nodeconfigs.DefaultTCPNewConnectionsRate
if newConnectionsRate <= 0 {
newConnectionsRate = 100000
}
}
// 检查是否有变化
var hasChanges = false
for _, port := range ports {
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}) {
hasChanges = true
break
}
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}) {
hasChanges = true
break
}
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}) {
hasChanges = true
break
}
}
if !hasChanges {
// 检查是否有多余的端口
var oldPorts = this.getTCPPorts(oldRules)
if !this.eqPorts(ports, oldPorts) {
hasChanges = true
}
}
if !hasChanges {
return nil
}
// 先清空所有相关规则
err = this.removeOldTCPRules(chain, oldRules)
if err != nil {
return errors.New("delete old rules failed: " + err.Error())
}
// 添加新规则
for _, port := range ports {
if maxConnections > 0 {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
err := cmd.Run()
if err != nil {
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error())
}
}
if maxConnectionsPerIP > 0 {
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
}
}
if newConnectionsRate > 0 {
// TODO 思考是否有惩罚机制
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsRate)+"/minute burst "+types.String(newConnectionsRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}))
var stderr = &bytes.Buffer{}
cmd.Stderr = stderr
err := cmd.Run()
if err != nil {
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
}
}
}
}
return nil
}
// 删除TCP规则
func (this *DDoSProtectionManager) removeTCPRules() error {
for _, filter := range nftablesFilters {
chain, rules, err := this.getRules(filter)
// TCP
err = this.removeOldTCPRules(chain, rules)
if err != nil {
return err
}
}
return nil
}
// 组合user data
// 数据中不能包含字母、数字、下划线以外的数据
func (this *DDoSProtectionManager) encodeUserData(attrs []string) string {
if attrs == nil {
return ""
}
return "ZZ" + strings.Join(attrs, "_") + "ZZ"
}
// 解码user data
func (this *DDoSProtectionManager) decodeUserData(data []byte) []string {
if len(data) == 0 {
return nil
}
var dataCopy = make([]byte, len(data))
copy(dataCopy, data)
var separatorLen = 2
var index1 = bytes.Index(dataCopy, []byte{'Z', 'Z'})
if index1 < 0 {
return nil
}
dataCopy = dataCopy[index1+separatorLen:]
var index2 = bytes.LastIndex(dataCopy, []byte{'Z', 'Z'})
if index2 < 0 {
return nil
}
var s = string(dataCopy[:index2])
var pieces = strings.Split(s, "_")
for index, piece := range pieces {
pieces[index] = strings.TrimSpace(piece)
}
return pieces
}
// 清除规则
func (this *DDoSProtectionManager) removeOldTCPRules(chain *nftables.Chain, rules []*nftables.Rule) error {
for _, rule := range rules {
var pieces = this.decodeUserData(rule.UserData())
if len(pieces) != 4 {
continue
}
if pieces[0] != "tcp" {
continue
}
switch pieces[2] {
case "maxConnections", "maxConnectionsPerIP", "newConnectionsRate":
err := chain.DeleteRule(rule)
if err != nil {
return err
}
}
}
return nil
}
// 根据参数检查规则是否存在
func (this *DDoSProtectionManager) existsRule(rules []*nftables.Rule, attrs []string) (exists bool) {
if len(attrs) == 0 {
return false
}
for _, oldRule := range rules {
var pieces = this.decodeUserData(oldRule.UserData())
if len(attrs) != len(pieces) {
continue
}
var isSame = true
for index, piece := range pieces {
if strings.TrimSpace(piece) != attrs[index] {
isSame = false
break
}
}
if isSame {
return true
}
}
return false
}
// 获取规则中的端口号
func (this *DDoSProtectionManager) getTCPPorts(rules []*nftables.Rule) []int32 {
var ports = []int32{}
for _, rule := range rules {
var pieces = this.decodeUserData(rule.UserData())
if len(pieces) != 4 {
continue
}
if pieces[0] != "tcp" {
continue
}
var port = types.Int32(pieces[1])
if port > 0 && !lists.ContainsInt32(ports, port) {
ports = append(ports, port)
}
}
return ports
}
// 检查端口是否一样
func (this *DDoSProtectionManager) eqPorts(ports1 []int32, ports2 []int32) bool {
if len(ports1) != len(ports2) {
return false
}
var portMap = map[int32]bool{}
for _, port := range ports2 {
portMap[port] = true
}
for _, port := range ports1 {
_, ok := portMap[port]
if !ok {
return false
}
}
return true
}
// 查找Table
func (this *DDoSProtectionManager) getTable(filter *nftablesTableDefinition) (*nftables.Table, error) {
var family nftables.TableFamily
if filter.IsIPv4 {
family = nftables.TableFamilyIPv4
} else if filter.IsIPv6 {
family = nftables.TableFamilyIPv6
} else {
return nil, errors.New("table '" + filter.Name + "' should be IPv4 or IPv6")
}
return nftablesInstance.conn.GetTable(filter.Name, family)
}
// 查找所有规则
func (this *DDoSProtectionManager) getRules(filter *nftablesTableDefinition) (*nftables.Chain, []*nftables.Rule, error) {
table, err := this.getTable(filter)
if err != nil {
return nil, nil, errors.New("get table failed: " + err.Error())
}
chain, err := table.GetChain(nftablesChainName)
if err != nil {
return nil, nil, errors.New("get chain failed: " + err.Error())
}
rules, err := chain.GetRules()
return chain, rules, err
}
// 更新白名单
func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
if nftablesInstance == nil {
return nil
}
var allMap = map[string]zero.Zero{}
for _, ip := range allIPList {
allMap[ip] = zero.New()
}
for _, set := range []*nftables.Set{nftablesInstance.allowIPv4Set, nftablesInstance.allowIPv6Set} {
var isIPv4 = set == nftablesInstance.allowIPv4Set
var isIPv6 = !isIPv4
// 现有的
oldList, err := set.GetIPElements()
if err != nil {
return err
}
var oldMap = map[string]zero.Zero{} // ip=> zero
for _, ip := range oldList {
oldMap[ip] = zero.New()
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
_, ok := allMap[ip]
if !ok {
// 不存在则删除
err = set.DeleteIPElement(ip)
if err != nil {
return errors.New("delete ip element '" + ip + "' failed: " + err.Error())
}
}
}
}
// 新增的
for _, ip := range allIPList {
var ipObj = net.ParseIP(ip)
if ipObj == nil {
continue
}
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
_, ok := oldMap[ip]
if !ok {
// 不存在则添加
err = set.AddIPElement(ip, nil)
if err != nil {
return errors.New("add ip '" + ip + "' failed: " + err.Error())
}
}
}
}
}
return nil
}

View File

@@ -0,0 +1,23 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !linux
// +build !linux
package firewalls
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
)
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
type DDoSProtectionManager struct {
nftPath string
}
func NewDDoSProtectionManager() *DDoSProtectionManager {
return &DDoSProtectionManager{}
}
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
return nil
}

View File

@@ -1,6 +1,4 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !plus
// +build !plus
package firewalls package firewalls
@@ -8,9 +6,11 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/events" "github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"runtime" "runtime"
"sync"
) )
var currentFirewall FirewallInterface var currentFirewall FirewallInterface
var firewallLocker = &sync.Mutex{}
// 初始化 // 初始化
func init() { func init() {
@@ -24,10 +24,28 @@ func init() {
// Firewall 查找当前系统中最适合的防火墙 // Firewall 查找当前系统中最适合的防火墙
func Firewall() FirewallInterface { func Firewall() FirewallInterface {
firewallLocker.Lock()
defer firewallLocker.Unlock()
if currentFirewall != nil { if currentFirewall != nil {
return currentFirewall return currentFirewall
} }
// nftables
if runtime.GOOS == "linux" {
nftables, err := NewNFTablesFirewall()
if err != nil {
remotelogs.Warn("FIREWALL", "'nftables' should be installed on the system to enhance security (init failed: "+err.Error()+")")
} else {
if nftables.IsReady() {
currentFirewall = nftables
events.Notify(events.EventNFTablesReady)
return nftables
} else {
remotelogs.Warn("FIREWALL", "'nftables' should be enabled on the system to enhance security")
}
}
}
// firewalld // firewalld
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
var firewalld = NewFirewalld() var firewalld = NewFirewalld()

View File

@@ -0,0 +1,373 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package firewalls
import (
"errors"
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/types"
"net"
"os/exec"
"runtime"
"strings"
"time"
)
// check nft status, if being enabled we load it automatically
func init() {
if runtime.GOOS == "linux" {
var ticker = time.NewTicker(3 * time.Minute)
go func() {
for range ticker.C {
// if already ready, we break
if nftablesIsReady {
ticker.Stop()
break
}
_, err := exec.LookPath("nft")
if err == nil {
nftablesFirewall, err := NewNFTablesFirewall()
if err != nil {
continue
}
currentFirewall = nftablesFirewall
remotelogs.Println("FIREWALL", "nftables is ready")
// fire event
if nftablesFirewall.IsReady() {
events.Notify(events.EventNFTablesReady)
}
ticker.Stop()
break
}
}
}()
}
}
var nftablesInstance *NFTablesFirewall
var nftablesIsReady = false
var nftablesFilters = []*nftablesTableDefinition{
// we shorten the name for table name length restriction
{Name: "edge_dft_v4", IsIPv4: true},
{Name: "edge_dft_v6", IsIPv6: true},
}
var nftablesChainName = "input"
type nftablesTableDefinition struct {
Name string
IsIPv4 bool
IsIPv6 bool
}
func (this *nftablesTableDefinition) protocol() string {
if this.IsIPv6 {
return "ip6"
}
return "ip"
}
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
var firewall = &NFTablesFirewall{
conn: nftables.NewConn(),
}
err := firewall.init()
if err != nil {
return nil, err
}
return firewall, nil
}
type NFTablesFirewall struct {
conn *nftables.Conn
isReady bool
allowIPv4Set *nftables.Set
allowIPv6Set *nftables.Set
denyIPv4Set *nftables.Set
denyIPv6Set *nftables.Set
firewalld *Firewalld
}
func (this *NFTablesFirewall) init() error {
// check nft
_, err := exec.LookPath("nft")
if err != nil {
return errors.New("nft not found")
}
// table
for _, tableDef := range nftablesFilters {
var family nftables.TableFamily
if tableDef.IsIPv4 {
family = nftables.TableFamilyIPv4
} else if tableDef.IsIPv6 {
family = nftables.TableFamilyIPv6
} else {
return errors.New("invalid table family: " + types.String(tableDef))
}
table, err := this.conn.GetTable(tableDef.Name, family)
if err != nil {
if nftables.IsNotFound(err) {
if tableDef.IsIPv4 {
table, err = this.conn.AddIPv4Table(tableDef.Name)
} else if tableDef.IsIPv6 {
table, err = this.conn.AddIPv6Table(tableDef.Name)
}
if err != nil {
return errors.New("create table '" + tableDef.Name + "' failed: " + err.Error())
}
} else {
return errors.New("get table '" + tableDef.Name + "' failed: " + err.Error())
}
}
if table == nil {
return errors.New("can not create table '" + tableDef.Name + "'")
}
// chain
var chainName = nftablesChainName
chain, err := table.GetChain(chainName)
if err != nil {
if nftables.IsNotFound(err) {
chain, err = table.AddAcceptChain(chainName)
if err != nil {
return errors.New("create chain '" + chainName + "' failed: " + err.Error())
}
} else {
return errors.New("get chain '" + chainName + "' failed: " + err.Error())
}
}
if chain == nil {
return errors.New("can not create chain '" + chainName + "'")
}
// allow lo
var loRuleName = []byte("lo")
_, err = chain.GetRuleWithUserData(loRuleName)
if err != nil {
if nftables.IsNotFound(err) {
_, err = chain.AddAcceptInterfaceRule("lo", loRuleName)
}
if err != nil {
return errors.New("add 'lo' rule failed: " + err.Error())
}
}
// allow set
// "allow" should be always first
for _, setAction := range []string{"allow", "deny"} {
var setName = setAction + "_set"
set, err := table.GetSet(setName)
if err != nil {
if nftables.IsNotFound(err) {
var keyType nftables.SetDataType
if tableDef.IsIPv4 {
keyType = nftables.TypeIPAddr
} else if tableDef.IsIPv6 {
keyType = nftables.TypeIP6Addr
}
set, err = table.AddSet(setName, &nftables.SetOptions{
KeyType: keyType,
HasTimeout: true,
})
if err != nil {
return errors.New("create set '" + setName + "' failed: " + err.Error())
}
} else {
return errors.New("get set '" + setName + "' failed: " + err.Error())
}
}
if set == nil {
return errors.New("can not create set '" + setName + "'")
}
if tableDef.IsIPv4 {
if setAction == "allow" {
this.allowIPv4Set = set
} else {
this.denyIPv4Set = set
}
} else if tableDef.IsIPv6 {
if setAction == "allow" {
this.allowIPv6Set = set
} else {
this.denyIPv6Set = set
}
}
// rule
var ruleName = []byte(setAction)
rule, err := chain.GetRuleWithUserData(ruleName)
if err != nil {
if nftables.IsNotFound(err) {
if tableDef.IsIPv4 {
if setAction == "allow" {
rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName)
} else {
rule, err = chain.AddDropIPv4SetRule(setName, ruleName)
}
} else if tableDef.IsIPv6 {
if setAction == "allow" {
rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName)
} else {
rule, err = chain.AddDropIPv6SetRule(setName, ruleName)
}
}
if err != nil {
return errors.New("add rule failed: " + err.Error())
}
} else {
return errors.New("get rule failed: " + err.Error())
}
}
if rule == nil {
return errors.New("can not create rule '" + string(ruleName) + "'")
}
}
}
this.isReady = true
nftablesIsReady = true
nftablesInstance = this
// load firewalld
var firewalld = NewFirewalld()
if firewalld.IsReady() {
this.firewalld = firewalld
}
return nil
}
// Name 名称
func (this *NFTablesFirewall) Name() string {
return "nftables"
}
// IsReady 是否已准备被调用
func (this *NFTablesFirewall) IsReady() bool {
return this.isReady
}
// IsMock 是否为模拟
func (this *NFTablesFirewall) IsMock() bool {
return false
}
// AllowPort 允许端口
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
if this.firewalld != nil {
return this.firewalld.AllowPort(port, protocol)
}
return nil
}
// RemovePort 删除端口
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
if this.firewalld != nil {
return this.firewalld.RemovePort(port, protocol)
}
return nil
}
// AllowSourceIP Allow把IP加入白名单
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
var data = net.ParseIP(ip)
if data == nil {
return errors.New("invalid ip '" + ip + "'")
}
if strings.Contains(ip, ":") { // ipv6
if this.allowIPv6Set == nil {
return errors.New("ipv6 ip set is nil")
}
return this.allowIPv6Set.AddElement(data.To16(), nil)
}
// ipv4
if this.allowIPv4Set == nil {
return errors.New("ipv4 ip set is nil")
}
return this.allowIPv4Set.AddElement(data.To4(), nil)
}
// RejectSourceIP 拒绝某个源IP连接
// we did not create set for drop ip, so we reuse DropSourceIP() method here
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
return this.DropSourceIP(ip, timeoutSeconds)
}
// DropSourceIP 丢弃某个源IP数据
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
var data = net.ParseIP(ip)
if data == nil {
return errors.New("invalid ip '" + ip + "'")
}
if strings.Contains(ip, ":") { // ipv6
if this.denyIPv6Set == nil {
return errors.New("ipv6 ip set is nil")
}
return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
Timeout: time.Duration(timeoutSeconds) * time.Second,
})
}
// ipv4
if this.denyIPv4Set == nil {
return errors.New("ipv4 ip set is nil")
}
return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
Timeout: time.Duration(timeoutSeconds) * time.Second,
})
}
// RemoveSourceIP 删除某个源IP
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
var data = net.ParseIP(ip)
if data == nil {
return errors.New("invalid ip '" + ip + "'")
}
if strings.Contains(ip, ":") { // ipv6
if this.denyIPv6Set != nil {
err := this.denyIPv6Set.DeleteElement(data.To16())
if err != nil {
return err
}
}
if this.allowIPv6Set != nil {
err := this.allowIPv6Set.DeleteElement(data.To16())
if err != nil {
return err
}
}
return nil
}
// ipv4
if this.allowIPv4Set != nil {
err := this.denyIPv4Set.DeleteElement(data.To4())
if err != nil {
return err
}
err = this.allowIPv4Set.DeleteElement(data.To4())
if err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,61 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build !linux
// +build !linux
package firewalls
import (
"errors"
)
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
return nil, errors.New("not implemented")
}
type NFTablesFirewall struct {
}
// Name 名称
func (this *NFTablesFirewall) Name() string {
return "nftables"
}
// IsReady 是否已准备被调用
func (this *NFTablesFirewall) IsReady() bool {
return false
}
// IsMock 是否为模拟
func (this *NFTablesFirewall) IsMock() bool {
return true
}
// AllowPort 允许端口
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
return nil
}
// RemovePort 删除端口
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
return nil
}
// AllowSourceIP Allow把IP加入白名单
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
return nil
}
// RejectSourceIP 拒绝某个源IP连接
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
return nil
}
// DropSourceIP 丢弃某个源IP数据
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
return nil
}
// RemoveSourceIP 删除某个源IP
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
return nil
}

View File

@@ -0,0 +1 @@
build_remote.sh

View File

@@ -0,0 +1,370 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables
import (
"bytes"
"errors"
nft "github.com/google/nftables"
"github.com/google/nftables/expr"
)
const MaxChainNameLength = 31
type RuleOptions struct {
Exprs []expr.Any
UserData []byte
}
// Chain chain object in table
type Chain struct {
conn *Conn
rawTable *nft.Table
rawChain *nft.Chain
}
func NewChain(conn *Conn, rawTable *nft.Table, rawChain *nft.Chain) *Chain {
return &Chain{
conn: conn,
rawTable: rawTable,
rawChain: rawChain,
}
}
func (this *Chain) Raw() *nft.Chain {
return this.rawChain
}
func (this *Chain) Name() string {
return this.rawChain.Name
}
func (this *Chain) AddRule(options *RuleOptions) (*Rule, error) {
var rawRule = this.conn.Raw().AddRule(&nft.Rule{
Table: this.rawTable,
Chain: this.rawChain,
Exprs: options.Exprs,
UserData: options.UserData,
})
err := this.conn.Commit()
if err != nil {
return nil, err
}
return NewRule(rawRule), nil
}
func (this *Chain) AddAcceptIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) AddAcceptIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) AddDropIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
UserData: userData,
})
}
func (this *Chain) AddDropIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
UserData: userData,
})
}
func (this *Chain) AddRejectIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Reject{},
},
UserData: userData,
})
}
func (this *Chain) AddRejectIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ip,
},
&expr.Reject{},
},
UserData: userData,
})
}
func (this *Chain) AddAcceptIPv4SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) AddAcceptIPv6SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) AddDropIPv4SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
UserData: userData,
})
}
func (this *Chain) AddDropIPv6SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
UserData: userData,
})
}
func (this *Chain) AddRejectIPv4SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Reject{},
},
UserData: userData,
})
}
func (this *Chain) AddRejectIPv6SetRule(setName string, userData []byte) (*Rule, error) {
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
&expr.Lookup{
SourceRegister: 1,
SetName: setName,
},
&expr.Reject{},
},
UserData: userData,
})
}
func (this *Chain) AddAcceptInterfaceRule(interfaceName string, userData []byte) (*Rule, error) {
if len(interfaceName) >= 16 {
return nil, errors.New("invalid interface name '" + interfaceName + "'")
}
var ifname = make([]byte, 16)
copy(ifname, interfaceName+"\x00")
return this.AddRule(&RuleOptions{
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: userData,
})
}
func (this *Chain) GetRuleWithUserData(userData []byte) (*Rule, error) {
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
if err != nil {
return nil, err
}
for _, rawRule := range rawRules {
if bytes.Compare(rawRule.UserData, userData) == 0 {
return NewRule(rawRule), nil
}
}
return nil, ErrRuleNotFound
}
func (this *Chain) GetRules() ([]*Rule, error) {
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
if err != nil {
return nil, err
}
var result = []*Rule{}
for _, rawRule := range rawRules {
result = append(result, NewRule(rawRule))
}
return result, nil
}
func (this *Chain) DeleteRule(rule *Rule) error {
err := this.conn.Raw().DelRule(rule.Raw())
if err != nil {
return err
}
return this.conn.Commit()
}
func (this *Chain) Flush() error {
this.conn.Raw().FlushChain(this.rawChain)
return this.conn.Commit()
}

View File

@@ -0,0 +1,13 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nftables
import nft "github.com/google/nftables"
type ChainPolicy = nft.ChainPolicy
// Possible ChainPolicy values.
const (
ChainPolicyDrop = nft.ChainPolicyDrop
ChainPolicyAccept = nft.ChainPolicyAccept
)

View File

@@ -0,0 +1,130 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables_test
import (
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"net"
"testing"
)
func getIPv4Chain(t *testing.T) *nftables.Chain {
var conn = nftables.NewConn()
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
if err != nil {
if err == nftables.ErrTableNotFound {
table, err = conn.AddIPv4Table("test_ipv4")
if err != nil {
t.Fatal(err)
}
} else {
t.Fatal(err)
}
}
chain, err := table.GetChain("test_chain")
if err != nil {
if err == nftables.ErrChainNotFound {
chain, err = table.AddAcceptChain("test_chain")
}
}
if err != nil {
t.Fatal(err)
}
return chain
}
func TestChain_AddAcceptIPRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddAcceptIPv4Rule(net.ParseIP("192.168.2.40").To4(), nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_AddDropIPRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddDropIPv4Rule(net.ParseIP("192.168.2.31").To4(), nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_AddAcceptSetRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddAcceptIPv4SetRule("ipv4_black_set", nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_AddDropSetRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddDropIPv4SetRule("ipv4_black_set", nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_AddRejectSetRule(t *testing.T) {
var chain = getIPv4Chain(t)
_, err := chain.AddRejectIPv4SetRule("ipv4_black_set", nil)
if err != nil {
t.Fatal(err)
}
}
func TestChain_GetRuleWithUserData(t *testing.T) {
var chain = getIPv4Chain(t)
rule, err := chain.GetRuleWithUserData([]byte("test"))
if err != nil {
if err == nftables.ErrRuleNotFound {
t.Log("rule not found")
return
} else {
t.Fatal(err)
}
}
t.Log("rule:", rule)
}
func TestChain_GetRules(t *testing.T) {
var chain = getIPv4Chain(t)
rules, err := chain.GetRules()
if err != nil {
t.Fatal(err)
}
for _, rule := range rules {
t.Log("handle:", rule.Handle(), "set name:", rule.LookupSetName(),
"verdict:", rule.VerDict(), "user data:", string(rule.UserData()))
}
}
func TestChain_DeleteRule(t *testing.T) {
var chain = getIPv4Chain(t)
rule, err := chain.GetRuleWithUserData([]byte("test"))
if err != nil {
if err == nftables.ErrRuleNotFound {
t.Log("rule not found")
return
}
t.Fatal(err)
}
err = chain.DeleteRule(rule)
if err != nil {
t.Fatal(err)
}
}
func TestChain_Flush(t *testing.T) {
var chain = getIPv4Chain(t)
err := chain.Flush()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,84 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables
import (
"errors"
nft "github.com/google/nftables"
"github.com/iwind/TeaGo/types"
)
const MaxTableNameLength = 27
type Conn struct {
rawConn *nft.Conn
}
func NewConn() *Conn {
return &Conn{
rawConn: &nft.Conn{},
}
}
func (this *Conn) Raw() *nft.Conn {
return this.rawConn
}
func (this *Conn) GetTable(name string, family TableFamily) (*Table, error) {
rawTables, err := this.rawConn.ListTables()
if err != nil {
return nil, err
}
for _, rawTable := range rawTables {
if rawTable.Name == name && rawTable.Family == family {
return NewTable(this, rawTable), nil
}
}
return nil, ErrTableNotFound
}
func (this *Conn) AddTable(name string, family TableFamily) (*Table, error) {
if len(name) > MaxTableNameLength {
return nil, errors.New("table name too long (max " + types.String(MaxTableNameLength) + ")")
}
var rawTable = this.rawConn.AddTable(&nft.Table{
Family: family,
Name: name,
})
err := this.Commit()
if err != nil {
return nil, err
}
return NewTable(this, rawTable), nil
}
func (this *Conn) AddIPv4Table(name string) (*Table, error) {
return this.AddTable(name, TableFamilyIPv4)
}
func (this *Conn) AddIPv6Table(name string) (*Table, error) {
return this.AddTable(name, TableFamilyIPv6)
}
func (this *Conn) DeleteTable(name string, family TableFamily) error {
table, err := this.GetTable(name, family)
if err != nil {
if err == ErrTableNotFound {
return nil
}
return err
}
this.rawConn.DelTable(table.Raw())
return this.Commit()
}
func (this *Conn) Commit() error {
return this.rawConn.Flush()
}

View File

@@ -0,0 +1,78 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables_test
import (
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"os/exec"
"testing"
)
func TestConn_Test(t *testing.T) {
_, err := exec.LookPath("nft")
if err == nil {
t.Log("ok")
return
}
t.Log(err)
}
func TestConn_GetTable_NotFound(t *testing.T) {
var conn = nftables.NewConn()
table, err := conn.GetTable("a", nftables.TableFamilyIPv4)
if err != nil {
if err == nftables.ErrTableNotFound {
t.Log("table not found")
} else {
t.Fatal(err)
}
} else {
t.Log("table:", table)
}
}
func TestConn_GetTable(t *testing.T) {
var conn = nftables.NewConn()
table, err := conn.GetTable("myFilter", nftables.TableFamilyIPv4)
if err != nil {
if err == nftables.ErrTableNotFound {
t.Log("table not found")
} else {
t.Fatal(err)
}
} else {
t.Log("table:", table)
}
}
func TestConn_AddTable(t *testing.T) {
var conn = nftables.NewConn()
{
table, err := conn.AddIPv4Table("test_ipv4")
if err != nil {
t.Fatal(err)
}
t.Log(table.Name())
}
{
table, err := conn.AddIPv6Table("test_ipv6")
if err != nil {
t.Fatal(err)
}
t.Log(table.Name())
}
}
func TestConn_DeleteTable(t *testing.T) {
var conn = nftables.NewConn()
err := conn.DeleteTable("test_ipv4", nftables.TableFamilyIPv4)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -0,0 +1,8 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables
type Element struct {
}

View File

@@ -0,0 +1,19 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables
import "errors"
var ErrTableNotFound = errors.New("table not found")
var ErrChainNotFound = errors.New("chain not found")
var ErrSetNotFound = errors.New("set not found")
var ErrRuleNotFound = errors.New("rule not found")
func IsNotFound(err error) bool {
if err == nil {
return false
}
return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound
}

View File

@@ -0,0 +1,18 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nftables
import (
nft "github.com/google/nftables"
)
type TableFamily = nft.TableFamily
const (
TableFamilyINet TableFamily = nft.TableFamilyINet
TableFamilyIPv4 TableFamily = nft.TableFamilyIPv4
TableFamilyIPv6 TableFamily = nft.TableFamilyIPv6
TableFamilyARP TableFamily = nft.TableFamilyARP
TableFamilyNetdev TableFamily = nft.TableFamilyNetdev
TableFamilyBridge TableFamily = nft.TableFamilyBridge
)

View File

@@ -0,0 +1,51 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nftables
import (
nft "github.com/google/nftables"
"github.com/google/nftables/expr"
)
type Rule struct {
rawRule *nft.Rule
}
func NewRule(rawRule *nft.Rule) *Rule {
return &Rule{
rawRule: rawRule,
}
}
func (this *Rule) Raw() *nft.Rule {
return this.rawRule
}
func (this *Rule) LookupSetName() string {
for _, e := range this.rawRule.Exprs {
exp, ok := e.(*expr.Lookup)
if ok {
return exp.SetName
}
}
return ""
}
func (this *Rule) VerDict() expr.VerdictKind {
for _, e := range this.rawRule.Exprs {
exp, ok := e.(*expr.Verdict)
if ok {
return exp.Kind
}
}
return -100
}
func (this *Rule) Handle() uint64 {
return this.rawRule.Handle
}
func (this *Rule) UserData() []byte {
return this.rawRule.UserData
}

View File

@@ -0,0 +1,161 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables
import (
"errors"
"github.com/TeaOSLab/EdgeNode/internal/utils"
nft "github.com/google/nftables"
"net"
"strings"
"time"
)
const MaxSetNameLength = 15
type SetOptions struct {
Id uint32
HasTimeout bool
Timeout time.Duration
KeyType SetDataType
DataType SetDataType
Constant bool
Interval bool
Anonymous bool
IsMap bool
}
type ElementOptions struct {
Timeout time.Duration
}
type Set struct {
conn *Conn
rawSet *nft.Set
batch *SetBatch
}
func NewSet(conn *Conn, rawSet *nft.Set) *Set {
return &Set{
conn: conn,
rawSet: rawSet,
batch: &SetBatch{
conn: conn,
rawSet: rawSet,
},
}
}
func (this *Set) Raw() *nft.Set {
return this.rawSet
}
func (this *Set) Name() string {
return this.rawSet.Name
}
func (this *Set) AddElement(key []byte, options *ElementOptions) error {
var rawElement = nft.SetElement{
Key: key,
}
if options != nil {
rawElement.Timeout = options.Timeout
}
err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
rawElement,
})
if err != nil {
return err
}
err = this.conn.Commit()
if err != nil {
// retry if exists
if strings.Contains(err.Error(), "file exists") {
deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
{
Key: key,
},
})
if deleteErr == nil {
err = this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
rawElement,
})
if err == nil {
err = this.conn.Commit()
}
}
}
}
return err
}
func (this *Set) AddIPElement(ip string, options *ElementOptions) error {
var ipObj = net.ParseIP(ip)
if ipObj == nil {
return errors.New("invalid ip '" + ip + "'")
}
if utils.IsIPv4(ip) {
return this.AddElement(ipObj.To4(), options)
} else {
return this.AddElement(ipObj.To16(), options)
}
}
func (this *Set) DeleteElement(key []byte) error {
err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
{
Key: key,
},
})
if err != nil {
return err
}
err = this.conn.Commit()
if err != nil {
if strings.Contains(err.Error(), "no such file or directory") {
err = nil
}
}
return err
}
func (this *Set) DeleteIPElement(ip string) error {
var ipObj = net.ParseIP(ip)
if ipObj == nil {
return errors.New("invalid ip '" + ip + "'")
}
if utils.IsIPv4(ip) {
return this.DeleteElement(ipObj.To4())
} else {
return this.DeleteElement(ipObj.To16())
}
}
func (this *Set) Batch() *SetBatch {
return this.batch
}
func (this *Set) GetIPElements() ([]string, error) {
elements, err := this.conn.Raw().GetSetElements(this.rawSet)
if err != nil {
return nil, err
}
var result = []string{}
for _, element := range elements {
result = append(result, net.IP(element.Key).String())
}
return result, nil
}
// not work current time
/**func (this *Set) Flush() error {
this.conn.Raw().FlushSet(this.rawSet)
return this.conn.Commit()
}**/

View File

@@ -0,0 +1,36 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nftables
import (
nft "github.com/google/nftables"
)
type SetBatch struct {
conn *Conn
rawSet *nft.Set
}
func (this *SetBatch) AddElement(key []byte, options *ElementOptions) error {
var rawElement = nft.SetElement{
Key: key,
}
if options != nil {
rawElement.Timeout = options.Timeout
}
return this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
rawElement,
})
}
func (this *SetBatch) DeleteElement(key []byte) error {
return this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
{
Key: key,
},
})
}
func (this *SetBatch) Commit() error {
return this.conn.Commit()
}

View File

@@ -0,0 +1,57 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nftables
import nft "github.com/google/nftables"
type SetDataType = nft.SetDatatype
var (
TypeInvalid = nft.TypeInvalid
TypeVerdict = nft.TypeVerdict
TypeNFProto = nft.TypeNFProto
TypeBitmask = nft.TypeBitmask
TypeInteger = nft.TypeInteger
TypeString = nft.TypeString
TypeLLAddr = nft.TypeLLAddr
TypeIPAddr = nft.TypeIPAddr
TypeIP6Addr = nft.TypeIP6Addr
TypeEtherAddr = nft.TypeEtherAddr
TypeEtherType = nft.TypeEtherType
TypeARPOp = nft.TypeARPOp
TypeInetProto = nft.TypeInetProto
TypeInetService = nft.TypeInetService
TypeICMPType = nft.TypeICMPType
TypeTCPFlag = nft.TypeTCPFlag
TypeDCCPPktType = nft.TypeDCCPPktType
TypeMHType = nft.TypeMHType
TypeTime = nft.TypeTime
TypeMark = nft.TypeMark
TypeIFIndex = nft.TypeIFIndex
TypeARPHRD = nft.TypeARPHRD
TypeRealm = nft.TypeRealm
TypeClassID = nft.TypeClassID
TypeUID = nft.TypeUID
TypeGID = nft.TypeGID
TypeCTState = nft.TypeCTState
TypeCTDir = nft.TypeCTDir
TypeCTStatus = nft.TypeCTStatus
TypeICMP6Type = nft.TypeICMP6Type
TypeCTLabel = nft.TypeCTLabel
TypePktType = nft.TypePktType
TypeICMPCode = nft.TypeICMPCode
TypeICMPV6Code = nft.TypeICMPV6Code
TypeICMPXCode = nft.TypeICMPXCode
TypeDevGroup = nft.TypeDevGroup
TypeDSCP = nft.TypeDSCP
TypeECN = nft.TypeECN
TypeFIBAddr = nft.TypeFIBAddr
TypeBoolean = nft.TypeBoolean
TypeCTEventBit = nft.TypeCTEventBit
TypeIFName = nft.TypeIFName
TypeIGMPType = nft.TypeIGMPType
TypeTimeDate = nft.TypeTimeDate
TypeTimeHour = nft.TypeTimeHour
TypeTimeDay = nft.TypeTimeDay
TypeCGroupV2 = nft.TypeCGroupV2
)

View File

@@ -0,0 +1,110 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nftables_test
import (
"errors"
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"github.com/iwind/TeaGo/types"
"github.com/mdlayher/netlink"
"net"
"testing"
"time"
)
func getIPv4Set(t *testing.T) *nftables.Set {
var table = getIPv4Table(t)
set, err := table.GetSet("test_ipv4_set")
if err != nil {
if err == nftables.ErrSetNotFound {
set, err = table.AddSet("test_ipv4_set", &nftables.SetOptions{
KeyType: nftables.TypeIPAddr,
HasTimeout: true,
})
if err != nil {
t.Fatal(err)
}
} else {
t.Fatal(err)
}
}
return set
}
func TestSet_AddElement(t *testing.T) {
var set = getIPv4Set(t)
err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second})
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestSet_DeleteElement(t *testing.T) {
var set = getIPv4Set(t)
err := set.DeleteElement(net.ParseIP("192.168.2.31").To4())
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestSet_Batch(t *testing.T) {
var batch = getIPv4Set(t).Batch()
for _, ip := range []string{"192.168.2.30", "192.168.2.31", "192.168.2.32", "192.168.2.33", "192.168.2.34"} {
var ipData = net.ParseIP(ip).To4()
//err := batch.DeleteElement(ipData)
//if err != nil {
// t.Fatal(err)
//}
err := batch.AddElement(ipData, &nftables.ElementOptions{Timeout: 10 * time.Second})
if err != nil {
t.Fatal(err)
}
}
err := batch.Commit()
if err != nil {
t.Logf("%#v", errors.Unwrap(err).(*netlink.OpError))
t.Fatal(err)
}
t.Log("ok")
}
func TestSet_Add_Many(t *testing.T) {
var set = getIPv4Set(t)
for i := 0; i < 255; i++ {
t.Log(i)
for j := 0; j < 255; j++ {
var ip = "192.167." + types.String(i) + "." + types.String(j)
var ipData = net.ParseIP(ip).To4()
err := set.Batch().AddElement(ipData, &nftables.ElementOptions{Timeout: 3600 * time.Second})
if err != nil {
t.Fatal(err)
}
if j%10 == 0 {
err = set.Batch().Commit()
if err != nil {
t.Fatal(err)
}
}
}
err := set.Batch().Commit()
if err != nil {
t.Fatal(err)
}
}
t.Log("ok")
}
/**func TestSet_Flush(t *testing.T) {
var set = getIPv4Set(t)
err := set.Flush()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}**/

View File

@@ -0,0 +1,157 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables
import (
"errors"
nft "github.com/google/nftables"
"github.com/iwind/TeaGo/types"
"strings"
)
type Table struct {
conn *Conn
rawTable *nft.Table
}
func NewTable(conn *Conn, rawTable *nft.Table) *Table {
return &Table{
conn: conn,
rawTable: rawTable,
}
}
func (this *Table) Raw() *nft.Table {
return this.rawTable
}
func (this *Table) Name() string {
return this.rawTable.Name
}
func (this *Table) Family() TableFamily {
return this.rawTable.Family
}
func (this *Table) GetChain(name string) (*Chain, error) {
rawChains, err := this.conn.Raw().ListChains()
if err != nil {
return nil, err
}
for _, rawChain := range rawChains {
// must compare table name
if rawChain.Name == name && rawChain.Table.Name == this.rawTable.Name {
return NewChain(this.conn, this.rawTable, rawChain), nil
}
}
return nil, ErrChainNotFound
}
func (this *Table) AddChain(name string, chainPolicy *ChainPolicy) (*Chain, error) {
if len(name) > MaxChainNameLength {
return nil, errors.New("chain name too long (max " + types.String(MaxChainNameLength) + ")")
}
var rawChain = this.conn.Raw().AddChain(&nft.Chain{
Name: name,
Table: this.rawTable,
Hooknum: nft.ChainHookInput,
Priority: nft.ChainPriorityFilter,
Type: nft.ChainTypeFilter,
Policy: chainPolicy,
})
err := this.conn.Commit()
if err != nil {
return nil, err
}
return NewChain(this.conn, this.rawTable, rawChain), nil
}
func (this *Table) AddAcceptChain(name string) (*Chain, error) {
var policy = ChainPolicyAccept
return this.AddChain(name, &policy)
}
func (this *Table) AddDropChain(name string) (*Chain, error) {
var policy = ChainPolicyDrop
return this.AddChain(name, &policy)
}
func (this *Table) DeleteChain(name string) error {
chain, err := this.GetChain(name)
if err != nil {
if err == ErrChainNotFound {
return nil
}
return err
}
this.conn.Raw().DelChain(chain.Raw())
return this.conn.Commit()
}
func (this *Table) GetSet(name string) (*Set, error) {
rawSet, err := this.conn.Raw().GetSetByName(this.rawTable, name)
if err != nil {
if strings.Contains(err.Error(), "no such file or directory") {
return nil, ErrSetNotFound
}
return nil, err
}
return NewSet(this.conn, rawSet), nil
}
func (this *Table) AddSet(name string, options *SetOptions) (*Set, error) {
if len(name) > MaxSetNameLength {
return nil, errors.New("set name too long (max " + types.String(MaxSetNameLength) + ")")
}
if options == nil {
options = &SetOptions{}
}
var rawSet = &nft.Set{
Table: this.rawTable,
ID: options.Id,
Name: name,
Anonymous: options.Anonymous,
Constant: options.Constant,
Interval: options.Interval,
IsMap: options.IsMap,
HasTimeout: options.HasTimeout,
Timeout: options.Timeout,
KeyType: options.KeyType,
DataType: options.DataType,
}
err := this.conn.Raw().AddSet(rawSet, nil)
if err != nil {
return nil, err
}
err = this.conn.Commit()
if err != nil {
return nil, err
}
return NewSet(this.conn, rawSet), nil
}
func (this *Table) DeleteSet(name string) error {
set, err := this.GetSet(name)
if err != nil {
if err == ErrSetNotFound {
return nil
}
return err
}
this.conn.Raw().DelSet(set.Raw())
return this.conn.Commit()
}
func (this *Table) Flush() error {
this.conn.Raw().FlushTable(this.rawTable)
return this.conn.Commit()
}

View File

@@ -0,0 +1,140 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
//go:build linux
// +build linux
package nftables_test
import (
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"testing"
)
func getIPv4Table(t *testing.T) *nftables.Table {
var conn = nftables.NewConn()
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
if err != nil {
if err == nftables.ErrTableNotFound {
table, err = conn.AddIPv4Table("test_ipv4")
if err != nil {
t.Fatal(err)
}
} else {
t.Fatal(err)
}
}
return table
}
func TestTable_AddChain(t *testing.T) {
var table = getIPv4Table(t)
{
chain, err := table.AddChain("test_default_chain", nil)
if err != nil {
t.Fatal(err)
}
t.Log("created:", chain.Name())
}
{
chain, err := table.AddAcceptChain("test_accept_chain")
if err != nil {
t.Fatal(err)
}
t.Log("created:", chain.Name())
}
// Do not test drop chain before adding accept rule, you will drop yourself!!!!!!!
/**{
chain, err := table.AddDropChain("test_drop_chain")
if err != nil {
t.Fatal(err)
}
t.Log("created:", chain.Name())
}**/
}
func TestTable_GetChain(t *testing.T) {
var table = getIPv4Table(t)
for _, chainName := range []string{"not_found_chain", "test_default_chain"} {
chain, err := table.GetChain(chainName)
if err != nil {
if err == nftables.ErrChainNotFound {
t.Log(chainName, ":", "not found")
} else {
t.Fatal(err)
}
} else {
t.Log(chainName, ":", chain)
}
}
}
func TestTable_DeleteChain(t *testing.T) {
var table = getIPv4Table(t)
err := table.DeleteChain("test_default_chain")
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestTable_AddSet(t *testing.T) {
var table = getIPv4Table(t)
{
set, err := table.AddSet("ipv4_black_set", &nftables.SetOptions{
HasTimeout: false,
KeyType: nftables.TypeIPAddr,
})
if err != nil {
t.Fatal(err)
}
t.Log(set.Name())
}
{
set, err := table.AddSet("ipv6_black_set", &nftables.SetOptions{
HasTimeout: true,
//Timeout: 3600 * time.Second,
KeyType: nftables.TypeIP6Addr,
})
if err != nil {
t.Fatal(err)
}
t.Log(set.Name())
}
}
func TestTable_GetSet(t *testing.T) {
var table = getIPv4Table(t)
for _, setName := range []string{"not_found_set", "ipv4_black_set"} {
set, err := table.GetSet(setName)
if err != nil {
if err == nftables.ErrSetNotFound {
t.Log(setName, ": not found")
} else {
t.Fatal(err)
}
} else {
t.Log(setName, ":", set)
}
}
}
func TestTable_DeleteSet(t *testing.T) {
var table = getIPv4Table(t)
err := table.DeleteSet("ipv4_black_set")
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}
func TestTable_Flush(t *testing.T) {
var table = getIPv4Table(t)
err := table.Flush()
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -1,6 +1,7 @@
package nodes package nodes
import ( import (
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
@@ -14,16 +15,20 @@ import (
teaconst "github.com/TeaOSLab/EdgeNode/internal/const" teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/errors" "github.com/TeaOSLab/EdgeNode/internal/errors"
"github.com/TeaOSLab/EdgeNode/internal/events" "github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc" "github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/maps"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os/exec" "os/exec"
"regexp"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -119,6 +124,8 @@ func (this *APIStream) loop() error {
err = this.handleNewNodeTask(message) err = this.handleNewNodeTask(message)
case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务 case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务
err = this.handleCheckSystemdService(message) err = this.handleCheckSystemdService(message)
case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
err = this.handleCheckLocalFirewall(message)
case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址 case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址
err = this.handleChangeAPINode(message) err = this.handleChangeAPINode(message)
default: default:
@@ -569,7 +576,7 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
return nil return nil
} }
cmd := utils.NewCommandExecutor() var cmd = utils.NewCommandExecutor()
shortName := teaconst.SystemdServiceName shortName := teaconst.SystemdServiceName
cmd.Add(systemctl, "is-enabled", shortName) cmd.Add(systemctl, "is-enabled", shortName)
output, err := cmd.Run() output, err := cmd.Run()
@@ -585,6 +592,63 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
return nil return nil
} }
// 检查本地防火墙
func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) error {
var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
err := json.Unmarshal(message.DataJSON, dataMessage)
if err != nil {
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
return nil
}
// nft
if dataMessage.Name == "nftables" {
if runtime.GOOS != "linux" {
this.replyFail(message.RequestId, "not Linux system")
return nil
}
nft, err := exec.LookPath("nft")
if err != nil {
this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
return nil
}
var cmd = exec.Command(nft, "--version")
var output = &bytes.Buffer{}
cmd.Stdout = output
err = cmd.Run()
if err != nil {
this.replyFail(message.RequestId, "get version failed: "+err.Error())
return nil
}
var outputString = output.String()
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
if len(versionMatches) <= 1 {
this.replyFail(message.RequestId, "can not get nft version")
return nil
}
var version = versionMatches[1]
var result = maps.Map{
"version": version,
}
var protectionConfig = sharedNodeConfig.DDOSProtection
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
if err != nil {
this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error())
} else {
this.replyOk(message.RequestId, string(result.AsJSON()))
}
} else {
this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
}
return nil
}
// 修改API地址 // 修改API地址
func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error { func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error {
config, err := configs.LoadAPIConfig() config, err := configs.LoadAPIConfig()
@@ -660,6 +724,11 @@ func (this *APIStream) replyOk(requestId int64, message string) {
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message}) _ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
} }
// 回复成功并包含数据
func (this *APIStream) replyOkData(requestId int64, message string, dataJSON []byte) {
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message, DataJSON: dataJSON})
}
// 获取缓存存取对象 // 获取缓存存取对象
func (this *APIStream) cacheStorage(message *pb.NodeStreamMessage, cachePolicyJSON []byte) (storage caches.StorageInterface, shouldStop bool, err error) { func (this *APIStream) cacheStorage(message *pb.NodeStreamMessage, cachePolicyJSON []byte) (storage caches.StorageInterface, shouldStop bool, err error) {
cachePolicy := &serverconfigs.HTTPCachePolicy{} cachePolicy := &serverconfigs.HTTPCachePolicy{}

View File

@@ -7,7 +7,6 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const" teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
"github.com/TeaOSLab/EdgeNode/internal/ttlcache" "github.com/TeaOSLab/EdgeNode/internal/ttlcache"
"github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/TeaOSLab/EdgeNode/internal/waf"
@@ -21,8 +20,7 @@ import (
// ClientConn 客户端连接 // ClientConn 客户端连接
type ClientConn struct { type ClientConn struct {
once sync.Once once sync.Once
globalLimiter *ratelimit.Counter
isTLS bool isTLS bool
hasDeadline bool hasDeadline bool
@@ -33,7 +31,7 @@ type ClientConn struct {
BaseClientConn BaseClientConn
} }
func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn { func NewClientConn(conn net.Conn, isTLS bool, quickClose bool) net.Conn {
if quickClose { if quickClose {
// TCP // TCP
tcpConn, ok := conn.(*net.TCPConn) tcpConn, ok := conn.(*net.TCPConn)
@@ -43,7 +41,7 @@ func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ra
} }
} }
return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS, globalLimiter: globalLimiter} return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS}
} }
func (this *ClientConn) Read(b []byte) (n int, err error) { func (this *ClientConn) Read(b []byte) (n int, err error) {
@@ -96,13 +94,6 @@ func (this *ClientConn) Close() error {
err := this.rawConn.Close() err := this.rawConn.Close()
// 全局并发数限制
this.once.Do(func() {
if this.globalLimiter != nil {
this.globalLimiter.Release()
}
})
// 单个服务并发数限制 // 单个服务并发数限制
sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String()) sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())

View File

@@ -3,16 +3,12 @@
package nodes package nodes
import ( import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
"github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/TeaOSLab/EdgeNode/internal/waf"
"net" "net"
) )
var sharedConnectionsLimiter = ratelimit.NewCounter(nodeconfigs.DefaultTCPMaxConnections)
// ClientListener 客户端网络监听 // ClientListener 客户端网络监听
type ClientListener struct { type ClientListener struct {
rawListener net.Listener rawListener net.Listener
@@ -36,13 +32,8 @@ func (this *ClientListener) IsTLS() bool {
} }
func (this *ClientListener) Accept() (net.Conn, error) { func (this *ClientListener) Accept() (net.Conn, error) {
// 限制并发连接数
var limiter = sharedConnectionsLimiter
limiter.Ack()
conn, err := this.rawListener.Accept() conn, err := this.rawListener.Accept()
if err != nil { if err != nil {
limiter.Release()
return nil, err return nil, err
} }
@@ -60,12 +51,11 @@ func (this *ClientListener) Accept() (net.Conn, error) {
} }
_ = conn.Close() _ = conn.Close()
limiter.Release()
return this.Accept() return this.Accept()
} }
} }
return NewClientConn(conn, this.isTLS, this.quickClose, limiter), nil return NewClientConn(conn, this.isTLS, this.quickClose), nil
} }
func (this *ClientListener) Close() error { func (this *ClientListener) Close() error {

View File

@@ -7,6 +7,7 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
"github.com/TeaOSLab/EdgeNode/internal/caches" "github.com/TeaOSLab/EdgeNode/internal/caches"
"github.com/TeaOSLab/EdgeNode/internal/configs" "github.com/TeaOSLab/EdgeNode/internal/configs"
teaconst "github.com/TeaOSLab/EdgeNode/internal/const" teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
@@ -15,7 +16,6 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/metrics" "github.com/TeaOSLab/EdgeNode/internal/metrics"
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc" "github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/stats" "github.com/TeaOSLab/EdgeNode/internal/stats"
@@ -368,6 +368,38 @@ func (this *Node) loop() error {
} }
sharedNodeConfig.ParentNodes = parentNodes sharedNodeConfig.ParentNodes = parentNodes
// 修改为已同步
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
NodeTaskId: task.Id,
IsOk: true,
Error: "",
})
if err != nil {
return err
}
case "ddosProtectionChanged":
resp, err := rpcClient.NodeRPC().FindNodeDDoSProtection(nodeCtx, &pb.FindNodeDDoSProtectionRequest{})
if err != nil {
return err
}
if len(resp.DdosProtectionJSON) == 0 {
if sharedNodeConfig != nil {
sharedNodeConfig.DDOSProtection = nil
}
} else {
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
if err != nil {
return errors.New("decode DDoS protection config failed: " + err.Error())
}
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
if err != nil {
// 不阻塞
remotelogs.Error("NODE", "apply DDoS protection failed: "+err.Error())
}
}
// 修改为已同步 // 修改为已同步
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{ _, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
NodeTaskId: task.Id, NodeTaskId: task.Id,
@@ -730,7 +762,6 @@ func (this *Node) listenSock() error {
"ipConns": ipConns, "ipConns": ipConns,
"serverConns": serverConns, "serverConns": serverConns,
"total": sharedListenerManager.TotalActiveConnections(), "total": sharedListenerManager.TotalActiveConnections(),
"limiter": sharedConnectionsLimiter.Len(),
}, },
}) })
case "dropIP": case "dropIP":
@@ -854,17 +885,6 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig) {
this.maxThreads = config.MaxThreads this.maxThreads = config.MaxThreads
} }
// max tcp connections
if config.TCPMaxConnections <= 0 {
config.TCPMaxConnections = nodeconfigs.DefaultTCPMaxConnections
}
if config.TCPMaxConnections != sharedConnectionsLimiter.Count() {
remotelogs.Println("NODE", "[TCP]changed tcp max connections to '"+types.String(config.TCPMaxConnections)+"'")
sharedConnectionsLimiter.Close()
sharedConnectionsLimiter = ratelimit.NewCounter(config.TCPMaxConnections)
}
// timezone // timezone
var timeZone = config.TimeZone var timeZone = config.TimeZone
if len(timeZone) == 0 { if len(timeZone) == 0 {

View File

@@ -5,9 +5,12 @@ import (
"github.com/cespare/xxhash" "github.com/cespare/xxhash"
"math" "math"
"net" "net"
"regexp"
"strings" "strings"
) )
var ipv4Reg = regexp.MustCompile(`\d+\.`)
// IP2Long 将IP转换为整型 // IP2Long 将IP转换为整型
// 注意IPv6没有顺序 // 注意IPv6没有顺序
func IP2Long(ip string) uint64 { func IP2Long(ip string) uint64 {
@@ -54,3 +57,24 @@ func IsLocalIP(ipString string) bool {
return false return false
} }
// IsIPv4 是否为IPv4
func IsIPv4(ip string) bool {
var data = net.ParseIP(ip)
if data == nil {
return false
}
if strings.Contains(ip, ":") {
return false
}
return data.To4() != nil
}
// IsIPv6 是否为IPv6
func IsIPv6(ip string) bool {
var data = net.ParseIP(ip)
if data == nil {
return false
}
return !IsIPv4(ip)
}

View File

@@ -26,3 +26,26 @@ func TestIsLocalIP(t *testing.T) {
a.IsFalse(IsLocalIP("::1:2:3")) a.IsFalse(IsLocalIP("::1:2:3"))
a.IsFalse(IsLocalIP("8.8.8.8")) a.IsFalse(IsLocalIP("8.8.8.8"))
} }
func TestIsIPv4(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsTrue(IsIPv4("192.168.1.1"))
a.IsTrue(IsIPv4("0.0.0.0"))
a.IsFalse(IsIPv4("192.168.1.256"))
a.IsFalse(IsIPv4("192.168.1"))
a.IsFalse(IsIPv4("::1"))
a.IsFalse(IsIPv4("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
a.IsFalse(IsIPv4("::ffff:192.168.0.1"))
}
func TestIsIPv6(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(IsIPv6("192.168.1.1"))
a.IsFloat32(IsIPv6("0.0.0.0"))
a.IsFalse(IsIPv6("192.168.1.256"))
a.IsFalse(IsIPv6("192.168.1"))
a.IsTrue(IsIPv6("::1"))
a.IsTrue(IsIPv6("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
a.IsTrue(IsIPv4("::ffff:192.168.0.1"))
a.IsTrue(IsIPv6("::ffff:192.168.0.1"))
}

View File

@@ -1,6 +1,7 @@
package utils package utils
import ( import (
"sort"
"strings" "strings"
"unsafe" "unsafe"
) )
@@ -36,6 +37,22 @@ func FormatAddressList(addrList []string) []string {
return result return result
} }
// ToValidUTF8string 去除字符串中的非UTF-8字符
func ToValidUTF8string(v string) string { func ToValidUTF8string(v string) string {
return strings.ToValidUTF8(v, "") return strings.ToValidUTF8(v, "")
} }
// ContainsSameStrings 检查两个字符串slice内容是否一致
func ContainsSameStrings(s1 []string, s2 []string) bool {
if len(s1) != len(s2) {
return false
}
sort.Strings(s1)
sort.Strings(s2)
for index, v1 := range s1 {
if v1 != s2[index] {
return false
}
}
return true
}

View File

@@ -1,56 +1,67 @@
package utils package utils_test
import ( import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/assert"
"strings" "strings"
"testing" "testing"
) )
func TestBytesToString(t *testing.T) { func TestBytesToString(t *testing.T) {
t.Log(UnsafeBytesToString([]byte("Hello,World"))) t.Log(utils.UnsafeBytesToString([]byte("Hello,World")))
} }
func TestStringToBytes(t *testing.T) { func TestStringToBytes(t *testing.T) {
t.Log(string(UnsafeStringToBytes("Hello,World"))) t.Log(string(utils.UnsafeStringToBytes("Hello,World")))
} }
func BenchmarkBytesToString(b *testing.B) { func BenchmarkBytesToString(b *testing.B) {
data := []byte("Hello,World") var data = []byte("Hello,World")
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_ = UnsafeBytesToString(data) _ = utils.UnsafeBytesToString(data)
} }
} }
func BenchmarkBytesToString2(b *testing.B) { func BenchmarkBytesToString2(b *testing.B) {
data := []byte("Hello,World") var data = []byte("Hello,World")
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_ = string(data) _ = string(data)
} }
} }
func BenchmarkStringToBytes(b *testing.B) { func BenchmarkStringToBytes(b *testing.B) {
s := strings.Repeat("Hello,World", 1024) var s = strings.Repeat("Hello,World", 1024)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_ = UnsafeStringToBytes(s) _ = utils.UnsafeStringToBytes(s)
} }
} }
func BenchmarkStringToBytes2(b *testing.B) { func BenchmarkStringToBytes2(b *testing.B) {
s := strings.Repeat("Hello,World", 1024) var s = strings.Repeat("Hello,World", 1024)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_ = []byte(s) _ = []byte(s)
} }
} }
func TestFormatAddress(t *testing.T) { func TestFormatAddress(t *testing.T) {
t.Log(FormatAddress("127.0.0.1:1234")) t.Log(utils.FormatAddress("127.0.0.1:1234"))
t.Log(FormatAddress("127.0.0.1 : 1234")) t.Log(utils.FormatAddress("127.0.0.1 : 1234"))
t.Log(FormatAddress("127.0.0.11234")) t.Log(utils.FormatAddress("127.0.0.11234"))
} }
func TestFormatAddressList(t *testing.T) { func TestFormatAddressList(t *testing.T) {
t.Log(FormatAddressList([]string{ t.Log(utils.FormatAddressList([]string{
"127.0.0.1:1234", "127.0.0.1:1234",
"127.0.0.1 : 1234", "127.0.0.1 : 1234",
"127.0.0.11234", "127.0.0.11234",
})) }))
} }
func TestContainsSameStrings(t *testing.T) {
var a = assert.NewAssertion(t)
a.IsFalse(utils.ContainsSameStrings([]string{"a"}, []string{"b"}))
a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b"}))
a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b", "c"}))
a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b"}))
a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b", "a"}))
}