mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-10 04:20:27 +08:00
实现基础的DDoS防护
This commit is contained in:
@@ -8,4 +8,5 @@ const (
|
||||
EventQuit Event = "quit" // quit node gracefully
|
||||
EventReload Event = "reload" // reload config
|
||||
EventTerminated Event = "terminated" // process terminated
|
||||
EventNFTablesReady Event = "nftablesReady" // nftables ready
|
||||
)
|
||||
|
||||
1
internal/firewalls/.gitignore
vendored
1
internal/firewalls/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
firewall_nftables_test.go
|
||||
494
internal/firewalls/ddos_protection.go
Normal file
494
internal/firewalls/ddos_protection.go
Normal file
@@ -0,0 +1,494 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
|
||||
|
||||
func init() {
|
||||
events.On(events.EventReload, func() {
|
||||
if nftablesInstance == nil {
|
||||
return
|
||||
}
|
||||
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil {
|
||||
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection)
|
||||
if err != nil {
|
||||
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
events.On(events.EventNFTablesReady, func() {
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil {
|
||||
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection)
|
||||
if err != nil {
|
||||
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// DDoSProtectionManager DDoS防护
|
||||
type DDoSProtectionManager struct {
|
||||
nftPath string
|
||||
|
||||
lastAllowIPList []string
|
||||
lastConfig []byte
|
||||
}
|
||||
|
||||
// NewDDoSProtectionManager 获取新对象
|
||||
func NewDDoSProtectionManager() *DDoSProtectionManager {
|
||||
nftPath, _ := exec.LookPath("nft")
|
||||
|
||||
return &DDoSProtectionManager{
|
||||
nftPath: nftPath,
|
||||
}
|
||||
}
|
||||
|
||||
// Apply 应用配置
|
||||
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
|
||||
// 同集群节点IP白名单
|
||||
var allowIPListChanged = false
|
||||
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||
if nodeConfig != nil {
|
||||
var allowIPList = nodeConfig.AllowedIPs
|
||||
if !utils.ContainsSameStrings(allowIPList, this.lastAllowIPList) {
|
||||
allowIPListChanged = true
|
||||
this.lastAllowIPList = allowIPList
|
||||
}
|
||||
}
|
||||
|
||||
// 对比配置
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return errors.New("encode config to json failed: " + err.Error())
|
||||
}
|
||||
if !allowIPListChanged && bytes.Equal(this.lastConfig, configJSON) {
|
||||
return nil
|
||||
}
|
||||
remotelogs.Println("FIREWALL", "change DDoS protection config")
|
||||
|
||||
if len(this.nftPath) == 0 {
|
||||
return errors.New("can not find nft command")
|
||||
}
|
||||
|
||||
if nftablesInstance == nil {
|
||||
return errors.New("nftables instance should not be nil")
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
// TCP
|
||||
err := this.removeTCPRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO other protocols
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TCP
|
||||
if config.TCP == nil {
|
||||
err := this.removeTCPRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// allow ip list
|
||||
var allowIPList = []string{}
|
||||
for _, ipConfig := range config.TCP.AllowIPList {
|
||||
allowIPList = append(allowIPList, ipConfig.IP)
|
||||
}
|
||||
for _, ip := range this.lastAllowIPList {
|
||||
if !lists.ContainsString(allowIPList, ip) {
|
||||
allowIPList = append(allowIPList, ip)
|
||||
}
|
||||
}
|
||||
err = this.updateAllowIPList(allowIPList)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// tcp
|
||||
if config.TCP.IsOn {
|
||||
err := this.addTCPRules(config.TCP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err := this.removeTCPRules()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.lastConfig = configJSON
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 添加TCP规则
|
||||
func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig) error {
|
||||
var ports = []int32{}
|
||||
for _, portConfig := range tcpConfig.Ports {
|
||||
if !lists.ContainsInt32(ports, portConfig.Port) {
|
||||
ports = append(ports, portConfig.Port)
|
||||
}
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
ports = []int32{80, 443}
|
||||
}
|
||||
|
||||
for _, filter := range nftablesFilters {
|
||||
chain, oldRules, err := this.getRules(filter)
|
||||
if err != nil {
|
||||
return errors.New("get old rules failed: " + err.Error())
|
||||
}
|
||||
|
||||
var protocol = filter.protocol()
|
||||
|
||||
// max connections
|
||||
var maxConnections = tcpConfig.MaxConnections
|
||||
if maxConnections <= 0 {
|
||||
maxConnections = nodeconfigs.DefaultTCPMaxConnections
|
||||
if maxConnections <= 0 {
|
||||
maxConnections = 100000
|
||||
}
|
||||
}
|
||||
|
||||
// max connections per ip
|
||||
var maxConnectionsPerIP = tcpConfig.MaxConnectionsPerIP
|
||||
if maxConnectionsPerIP <= 0 {
|
||||
maxConnectionsPerIP = nodeconfigs.DefaultTCPMaxConnectionsPerIP
|
||||
if maxConnectionsPerIP <= 0 {
|
||||
maxConnectionsPerIP = 100000
|
||||
}
|
||||
}
|
||||
|
||||
// new connections rate
|
||||
var newConnectionsRate = tcpConfig.NewConnectionsRate
|
||||
if newConnectionsRate <= 0 {
|
||||
newConnectionsRate = nodeconfigs.DefaultTCPNewConnectionsRate
|
||||
if newConnectionsRate <= 0 {
|
||||
newConnectionsRate = 100000
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否有变化
|
||||
var hasChanges = false
|
||||
for _, port := range ports {
|
||||
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}) {
|
||||
hasChanges = true
|
||||
break
|
||||
}
|
||||
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}) {
|
||||
hasChanges = true
|
||||
break
|
||||
}
|
||||
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}) {
|
||||
hasChanges = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasChanges {
|
||||
// 检查是否有多余的端口
|
||||
var oldPorts = this.getTCPPorts(oldRules)
|
||||
if !this.eqPorts(ports, oldPorts) {
|
||||
hasChanges = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasChanges {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 先清空所有相关规则
|
||||
err = this.removeOldTCPRules(chain, oldRules)
|
||||
if err != nil {
|
||||
return errors.New("delete old rules failed: " + err.Error())
|
||||
}
|
||||
|
||||
// 添加新规则
|
||||
for _, port := range ports {
|
||||
if maxConnections > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if maxConnectionsPerIP > 0 {
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
|
||||
}
|
||||
}
|
||||
|
||||
if newConnectionsRate > 0 {
|
||||
// TODO 思考是否有惩罚机制
|
||||
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsRate)+"/minute burst "+types.String(newConnectionsRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}))
|
||||
var stderr = &bytes.Buffer{}
|
||||
cmd.Stderr = stderr
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 删除TCP规则
|
||||
func (this *DDoSProtectionManager) removeTCPRules() error {
|
||||
for _, filter := range nftablesFilters {
|
||||
chain, rules, err := this.getRules(filter)
|
||||
|
||||
// TCP
|
||||
err = this.removeOldTCPRules(chain, rules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 组合user data
|
||||
// 数据中不能包含字母、数字、下划线以外的数据
|
||||
func (this *DDoSProtectionManager) encodeUserData(attrs []string) string {
|
||||
if attrs == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "ZZ" + strings.Join(attrs, "_") + "ZZ"
|
||||
}
|
||||
|
||||
// 解码user data
|
||||
func (this *DDoSProtectionManager) decodeUserData(data []byte) []string {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var dataCopy = make([]byte, len(data))
|
||||
copy(dataCopy, data)
|
||||
|
||||
var separatorLen = 2
|
||||
var index1 = bytes.Index(dataCopy, []byte{'Z', 'Z'})
|
||||
if index1 < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
dataCopy = dataCopy[index1+separatorLen:]
|
||||
var index2 = bytes.LastIndex(dataCopy, []byte{'Z', 'Z'})
|
||||
if index2 < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var s = string(dataCopy[:index2])
|
||||
var pieces = strings.Split(s, "_")
|
||||
for index, piece := range pieces {
|
||||
pieces[index] = strings.TrimSpace(piece)
|
||||
}
|
||||
return pieces
|
||||
}
|
||||
|
||||
// 清除规则
|
||||
func (this *DDoSProtectionManager) removeOldTCPRules(chain *nftables.Chain, rules []*nftables.Rule) error {
|
||||
for _, rule := range rules {
|
||||
var pieces = this.decodeUserData(rule.UserData())
|
||||
if len(pieces) != 4 {
|
||||
continue
|
||||
}
|
||||
if pieces[0] != "tcp" {
|
||||
continue
|
||||
}
|
||||
switch pieces[2] {
|
||||
case "maxConnections", "maxConnectionsPerIP", "newConnectionsRate":
|
||||
err := chain.DeleteRule(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 根据参数检查规则是否存在
|
||||
func (this *DDoSProtectionManager) existsRule(rules []*nftables.Rule, attrs []string) (exists bool) {
|
||||
if len(attrs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, oldRule := range rules {
|
||||
var pieces = this.decodeUserData(oldRule.UserData())
|
||||
if len(attrs) != len(pieces) {
|
||||
continue
|
||||
}
|
||||
var isSame = true
|
||||
for index, piece := range pieces {
|
||||
if strings.TrimSpace(piece) != attrs[index] {
|
||||
isSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if isSame {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取规则中的端口号
|
||||
func (this *DDoSProtectionManager) getTCPPorts(rules []*nftables.Rule) []int32 {
|
||||
var ports = []int32{}
|
||||
for _, rule := range rules {
|
||||
var pieces = this.decodeUserData(rule.UserData())
|
||||
if len(pieces) != 4 {
|
||||
continue
|
||||
}
|
||||
if pieces[0] != "tcp" {
|
||||
continue
|
||||
}
|
||||
var port = types.Int32(pieces[1])
|
||||
if port > 0 && !lists.ContainsInt32(ports, port) {
|
||||
ports = append(ports, port)
|
||||
}
|
||||
}
|
||||
return ports
|
||||
}
|
||||
|
||||
// 检查端口是否一样
|
||||
func (this *DDoSProtectionManager) eqPorts(ports1 []int32, ports2 []int32) bool {
|
||||
if len(ports1) != len(ports2) {
|
||||
return false
|
||||
}
|
||||
|
||||
var portMap = map[int32]bool{}
|
||||
for _, port := range ports2 {
|
||||
portMap[port] = true
|
||||
}
|
||||
|
||||
for _, port := range ports1 {
|
||||
_, ok := portMap[port]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 查找Table
|
||||
func (this *DDoSProtectionManager) getTable(filter *nftablesTableDefinition) (*nftables.Table, error) {
|
||||
var family nftables.TableFamily
|
||||
if filter.IsIPv4 {
|
||||
family = nftables.TableFamilyIPv4
|
||||
} else if filter.IsIPv6 {
|
||||
family = nftables.TableFamilyIPv6
|
||||
} else {
|
||||
return nil, errors.New("table '" + filter.Name + "' should be IPv4 or IPv6")
|
||||
}
|
||||
return nftablesInstance.conn.GetTable(filter.Name, family)
|
||||
}
|
||||
|
||||
// 查找所有规则
|
||||
func (this *DDoSProtectionManager) getRules(filter *nftablesTableDefinition) (*nftables.Chain, []*nftables.Rule, error) {
|
||||
table, err := this.getTable(filter)
|
||||
if err != nil {
|
||||
return nil, nil, errors.New("get table failed: " + err.Error())
|
||||
}
|
||||
chain, err := table.GetChain(nftablesChainName)
|
||||
if err != nil {
|
||||
return nil, nil, errors.New("get chain failed: " + err.Error())
|
||||
}
|
||||
rules, err := chain.GetRules()
|
||||
return chain, rules, err
|
||||
}
|
||||
|
||||
// 更新白名单
|
||||
func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
|
||||
if nftablesInstance == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var allMap = map[string]zero.Zero{}
|
||||
for _, ip := range allIPList {
|
||||
allMap[ip] = zero.New()
|
||||
}
|
||||
|
||||
for _, set := range []*nftables.Set{nftablesInstance.allowIPv4Set, nftablesInstance.allowIPv6Set} {
|
||||
var isIPv4 = set == nftablesInstance.allowIPv4Set
|
||||
var isIPv6 = !isIPv4
|
||||
|
||||
// 现有的
|
||||
oldList, err := set.GetIPElements()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var oldMap = map[string]zero.Zero{} // ip=> zero
|
||||
for _, ip := range oldList {
|
||||
oldMap[ip] = zero.New()
|
||||
|
||||
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
|
||||
_, ok := allMap[ip]
|
||||
if !ok {
|
||||
// 不存在则删除
|
||||
err = set.DeleteIPElement(ip)
|
||||
if err != nil {
|
||||
return errors.New("delete ip element '" + ip + "' failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 新增的
|
||||
for _, ip := range allIPList {
|
||||
var ipObj = net.ParseIP(ip)
|
||||
if ipObj == nil {
|
||||
continue
|
||||
}
|
||||
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
|
||||
_, ok := oldMap[ip]
|
||||
if !ok {
|
||||
// 不存在则添加
|
||||
err = set.AddIPElement(ip, nil)
|
||||
if err != nil {
|
||||
return errors.New("add ip '" + ip + "' failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
23
internal/firewalls/ddos_protection_others.go
Normal file
23
internal/firewalls/ddos_protection_others.go
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||
)
|
||||
|
||||
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
|
||||
|
||||
type DDoSProtectionManager struct {
|
||||
nftPath string
|
||||
}
|
||||
|
||||
func NewDDoSProtectionManager() *DDoSProtectionManager {
|
||||
return &DDoSProtectionManager{}
|
||||
}
|
||||
|
||||
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,6 +1,4 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !plus
|
||||
// +build !plus
|
||||
|
||||
package firewalls
|
||||
|
||||
@@ -8,9 +6,11 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var currentFirewall FirewallInterface
|
||||
var firewallLocker = &sync.Mutex{}
|
||||
|
||||
// 初始化
|
||||
func init() {
|
||||
@@ -24,10 +24,28 @@ func init() {
|
||||
|
||||
// Firewall 查找当前系统中最适合的防火墙
|
||||
func Firewall() FirewallInterface {
|
||||
firewallLocker.Lock()
|
||||
defer firewallLocker.Unlock()
|
||||
if currentFirewall != nil {
|
||||
return currentFirewall
|
||||
}
|
||||
|
||||
// nftables
|
||||
if runtime.GOOS == "linux" {
|
||||
nftables, err := NewNFTablesFirewall()
|
||||
if err != nil {
|
||||
remotelogs.Warn("FIREWALL", "'nftables' should be installed on the system to enhance security (init failed: "+err.Error()+")")
|
||||
} else {
|
||||
if nftables.IsReady() {
|
||||
currentFirewall = nftables
|
||||
events.Notify(events.EventNFTablesReady)
|
||||
return nftables
|
||||
} else {
|
||||
remotelogs.Warn("FIREWALL", "'nftables' should be enabled on the system to enhance security")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// firewalld
|
||||
if runtime.GOOS == "linux" {
|
||||
var firewalld = NewFirewalld()
|
||||
|
||||
373
internal/firewalls/firewall_nftables.go
Normal file
373
internal/firewalls/firewall_nftables.go
Normal file
@@ -0,0 +1,373 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"net"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// check nft status, if being enabled we load it automatically
|
||||
func init() {
|
||||
if runtime.GOOS == "linux" {
|
||||
var ticker = time.NewTicker(3 * time.Minute)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
// if already ready, we break
|
||||
if nftablesIsReady {
|
||||
ticker.Stop()
|
||||
break
|
||||
}
|
||||
_, err := exec.LookPath("nft")
|
||||
if err == nil {
|
||||
nftablesFirewall, err := NewNFTablesFirewall()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
currentFirewall = nftablesFirewall
|
||||
remotelogs.Println("FIREWALL", "nftables is ready")
|
||||
|
||||
// fire event
|
||||
if nftablesFirewall.IsReady() {
|
||||
events.Notify(events.EventNFTablesReady)
|
||||
}
|
||||
|
||||
ticker.Stop()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
var nftablesInstance *NFTablesFirewall
|
||||
var nftablesIsReady = false
|
||||
var nftablesFilters = []*nftablesTableDefinition{
|
||||
// we shorten the name for table name length restriction
|
||||
{Name: "edge_dft_v4", IsIPv4: true},
|
||||
{Name: "edge_dft_v6", IsIPv6: true},
|
||||
}
|
||||
var nftablesChainName = "input"
|
||||
|
||||
type nftablesTableDefinition struct {
|
||||
Name string
|
||||
IsIPv4 bool
|
||||
IsIPv6 bool
|
||||
}
|
||||
|
||||
func (this *nftablesTableDefinition) protocol() string {
|
||||
if this.IsIPv6 {
|
||||
return "ip6"
|
||||
}
|
||||
return "ip"
|
||||
}
|
||||
|
||||
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
|
||||
var firewall = &NFTablesFirewall{
|
||||
conn: nftables.NewConn(),
|
||||
}
|
||||
err := firewall.init()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return firewall, nil
|
||||
}
|
||||
|
||||
type NFTablesFirewall struct {
|
||||
conn *nftables.Conn
|
||||
isReady bool
|
||||
|
||||
allowIPv4Set *nftables.Set
|
||||
allowIPv6Set *nftables.Set
|
||||
|
||||
denyIPv4Set *nftables.Set
|
||||
denyIPv6Set *nftables.Set
|
||||
|
||||
firewalld *Firewalld
|
||||
}
|
||||
|
||||
func (this *NFTablesFirewall) init() error {
|
||||
// check nft
|
||||
_, err := exec.LookPath("nft")
|
||||
if err != nil {
|
||||
return errors.New("nft not found")
|
||||
}
|
||||
|
||||
// table
|
||||
for _, tableDef := range nftablesFilters {
|
||||
var family nftables.TableFamily
|
||||
if tableDef.IsIPv4 {
|
||||
family = nftables.TableFamilyIPv4
|
||||
} else if tableDef.IsIPv6 {
|
||||
family = nftables.TableFamilyIPv6
|
||||
} else {
|
||||
return errors.New("invalid table family: " + types.String(tableDef))
|
||||
}
|
||||
table, err := this.conn.GetTable(tableDef.Name, family)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
if tableDef.IsIPv4 {
|
||||
table, err = this.conn.AddIPv4Table(tableDef.Name)
|
||||
} else if tableDef.IsIPv6 {
|
||||
table, err = this.conn.AddIPv6Table(tableDef.Name)
|
||||
}
|
||||
if err != nil {
|
||||
return errors.New("create table '" + tableDef.Name + "' failed: " + err.Error())
|
||||
}
|
||||
} else {
|
||||
return errors.New("get table '" + tableDef.Name + "' failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
if table == nil {
|
||||
return errors.New("can not create table '" + tableDef.Name + "'")
|
||||
}
|
||||
|
||||
// chain
|
||||
var chainName = nftablesChainName
|
||||
chain, err := table.GetChain(chainName)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
chain, err = table.AddAcceptChain(chainName)
|
||||
if err != nil {
|
||||
return errors.New("create chain '" + chainName + "' failed: " + err.Error())
|
||||
}
|
||||
} else {
|
||||
return errors.New("get chain '" + chainName + "' failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
if chain == nil {
|
||||
return errors.New("can not create chain '" + chainName + "'")
|
||||
}
|
||||
|
||||
// allow lo
|
||||
var loRuleName = []byte("lo")
|
||||
_, err = chain.GetRuleWithUserData(loRuleName)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
_, err = chain.AddAcceptInterfaceRule("lo", loRuleName)
|
||||
}
|
||||
if err != nil {
|
||||
return errors.New("add 'lo' rule failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// allow set
|
||||
// "allow" should be always first
|
||||
for _, setAction := range []string{"allow", "deny"} {
|
||||
var setName = setAction + "_set"
|
||||
|
||||
set, err := table.GetSet(setName)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
var keyType nftables.SetDataType
|
||||
if tableDef.IsIPv4 {
|
||||
keyType = nftables.TypeIPAddr
|
||||
} else if tableDef.IsIPv6 {
|
||||
keyType = nftables.TypeIP6Addr
|
||||
}
|
||||
set, err = table.AddSet(setName, &nftables.SetOptions{
|
||||
KeyType: keyType,
|
||||
HasTimeout: true,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.New("create set '" + setName + "' failed: " + err.Error())
|
||||
}
|
||||
} else {
|
||||
return errors.New("get set '" + setName + "' failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
if set == nil {
|
||||
return errors.New("can not create set '" + setName + "'")
|
||||
}
|
||||
if tableDef.IsIPv4 {
|
||||
if setAction == "allow" {
|
||||
this.allowIPv4Set = set
|
||||
} else {
|
||||
this.denyIPv4Set = set
|
||||
}
|
||||
} else if tableDef.IsIPv6 {
|
||||
if setAction == "allow" {
|
||||
this.allowIPv6Set = set
|
||||
} else {
|
||||
this.denyIPv6Set = set
|
||||
}
|
||||
}
|
||||
|
||||
// rule
|
||||
var ruleName = []byte(setAction)
|
||||
rule, err := chain.GetRuleWithUserData(ruleName)
|
||||
if err != nil {
|
||||
if nftables.IsNotFound(err) {
|
||||
if tableDef.IsIPv4 {
|
||||
if setAction == "allow" {
|
||||
rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName)
|
||||
} else {
|
||||
rule, err = chain.AddDropIPv4SetRule(setName, ruleName)
|
||||
}
|
||||
} else if tableDef.IsIPv6 {
|
||||
if setAction == "allow" {
|
||||
rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName)
|
||||
} else {
|
||||
rule, err = chain.AddDropIPv6SetRule(setName, ruleName)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return errors.New("add rule failed: " + err.Error())
|
||||
}
|
||||
} else {
|
||||
return errors.New("get rule failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
if rule == nil {
|
||||
return errors.New("can not create rule '" + string(ruleName) + "'")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.isReady = true
|
||||
nftablesIsReady = true
|
||||
nftablesInstance = this
|
||||
|
||||
// load firewalld
|
||||
var firewalld = NewFirewalld()
|
||||
if firewalld.IsReady() {
|
||||
this.firewalld = firewalld
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Name 名称
|
||||
func (this *NFTablesFirewall) Name() string {
|
||||
return "nftables"
|
||||
}
|
||||
|
||||
// IsReady 是否已准备被调用
|
||||
func (this *NFTablesFirewall) IsReady() bool {
|
||||
return this.isReady
|
||||
}
|
||||
|
||||
// IsMock 是否为模拟
|
||||
func (this *NFTablesFirewall) IsMock() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// AllowPort 允许端口
|
||||
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
|
||||
if this.firewalld != nil {
|
||||
return this.firewalld.AllowPort(port, protocol)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePort 删除端口
|
||||
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
|
||||
if this.firewalld != nil {
|
||||
return this.firewalld.RemovePort(port, protocol)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowSourceIP Allow把IP加入白名单
|
||||
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if strings.Contains(ip, ":") { // ipv6
|
||||
if this.allowIPv6Set == nil {
|
||||
return errors.New("ipv6 ip set is nil")
|
||||
}
|
||||
return this.allowIPv6Set.AddElement(data.To16(), nil)
|
||||
}
|
||||
|
||||
// ipv4
|
||||
if this.allowIPv4Set == nil {
|
||||
return errors.New("ipv4 ip set is nil")
|
||||
}
|
||||
return this.allowIPv4Set.AddElement(data.To4(), nil)
|
||||
}
|
||||
|
||||
// RejectSourceIP 拒绝某个源IP连接
|
||||
// we did not create set for drop ip, so we reuse DropSourceIP() method here
|
||||
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
return this.DropSourceIP(ip, timeoutSeconds)
|
||||
}
|
||||
|
||||
// DropSourceIP 丢弃某个源IP数据
|
||||
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if strings.Contains(ip, ":") { // ipv6
|
||||
if this.denyIPv6Set == nil {
|
||||
return errors.New("ipv6 ip set is nil")
|
||||
}
|
||||
return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
|
||||
Timeout: time.Duration(timeoutSeconds) * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
// ipv4
|
||||
if this.denyIPv4Set == nil {
|
||||
return errors.New("ipv4 ip set is nil")
|
||||
}
|
||||
return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
|
||||
Timeout: time.Duration(timeoutSeconds) * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveSourceIP 删除某个源IP
|
||||
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if strings.Contains(ip, ":") { // ipv6
|
||||
if this.denyIPv6Set != nil {
|
||||
err := this.denyIPv6Set.DeleteElement(data.To16())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if this.allowIPv6Set != nil {
|
||||
err := this.allowIPv6Set.DeleteElement(data.To16())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ipv4
|
||||
if this.allowIPv4Set != nil {
|
||||
err := this.denyIPv4Set.DeleteElement(data.To4())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = this.allowIPv4Set.DeleteElement(data.To4())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
61
internal/firewalls/firewall_nftables_others.go
Normal file
61
internal/firewalls/firewall_nftables_others.go
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type NFTablesFirewall struct {
|
||||
}
|
||||
|
||||
// Name 名称
|
||||
func (this *NFTablesFirewall) Name() string {
|
||||
return "nftables"
|
||||
}
|
||||
|
||||
// IsReady 是否已准备被调用
|
||||
func (this *NFTablesFirewall) IsReady() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsMock 是否为模拟
|
||||
func (this *NFTablesFirewall) IsMock() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// AllowPort 允许端口
|
||||
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePort 删除端口
|
||||
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AllowSourceIP Allow把IP加入白名单
|
||||
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RejectSourceIP 拒绝某个源IP连接
|
||||
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropSourceIP 丢弃某个源IP数据
|
||||
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveSourceIP 删除某个源IP
|
||||
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
1
internal/firewalls/nftables/.gitignore
vendored
Normal file
1
internal/firewalls/nftables/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
build_remote.sh
|
||||
370
internal/firewalls/nftables/chain.go
Normal file
370
internal/firewalls/nftables/chain.go
Normal file
@@ -0,0 +1,370 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
nft "github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
)
|
||||
|
||||
const MaxChainNameLength = 31
|
||||
|
||||
type RuleOptions struct {
|
||||
Exprs []expr.Any
|
||||
UserData []byte
|
||||
}
|
||||
|
||||
// Chain chain object in table
|
||||
type Chain struct {
|
||||
conn *Conn
|
||||
rawTable *nft.Table
|
||||
rawChain *nft.Chain
|
||||
}
|
||||
|
||||
func NewChain(conn *Conn, rawTable *nft.Table, rawChain *nft.Chain) *Chain {
|
||||
return &Chain{
|
||||
conn: conn,
|
||||
rawTable: rawTable,
|
||||
rawChain: rawChain,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Chain) Raw() *nft.Chain {
|
||||
return this.rawChain
|
||||
}
|
||||
|
||||
func (this *Chain) Name() string {
|
||||
return this.rawChain.Name
|
||||
}
|
||||
|
||||
func (this *Chain) AddRule(options *RuleOptions) (*Rule, error) {
|
||||
var rawRule = this.conn.Raw().AddRule(&nft.Rule{
|
||||
Table: this.rawTable,
|
||||
Chain: this.rawChain,
|
||||
Exprs: options.Exprs,
|
||||
UserData: options.UserData,
|
||||
})
|
||||
err := this.conn.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewRule(rawRule), nil
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddDropIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddDropIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddRejectIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Reject{},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddRejectIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ip,
|
||||
},
|
||||
&expr.Reject{},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptIPv4SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptIPv6SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddDropIPv4SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddDropIPv6SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddRejectIPv4SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Reject{},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddRejectIPv6SetRule(setName string, userData []byte) (*Rule, error) {
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: setName,
|
||||
},
|
||||
&expr.Reject{},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) AddAcceptInterfaceRule(interfaceName string, userData []byte) (*Rule, error) {
|
||||
if len(interfaceName) >= 16 {
|
||||
return nil, errors.New("invalid interface name '" + interfaceName + "'")
|
||||
}
|
||||
var ifname = make([]byte, 16)
|
||||
copy(ifname, interfaceName+"\x00")
|
||||
|
||||
return this.AddRule(&RuleOptions{
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: userData,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *Chain) GetRuleWithUserData(userData []byte) (*Rule, error) {
|
||||
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, rawRule := range rawRules {
|
||||
if bytes.Compare(rawRule.UserData, userData) == 0 {
|
||||
return NewRule(rawRule), nil
|
||||
}
|
||||
}
|
||||
return nil, ErrRuleNotFound
|
||||
}
|
||||
|
||||
func (this *Chain) GetRules() ([]*Rule, error) {
|
||||
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result = []*Rule{}
|
||||
for _, rawRule := range rawRules {
|
||||
result = append(result, NewRule(rawRule))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (this *Chain) DeleteRule(rule *Rule) error {
|
||||
err := this.conn.Raw().DelRule(rule.Raw())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return this.conn.Commit()
|
||||
}
|
||||
|
||||
func (this *Chain) Flush() error {
|
||||
this.conn.Raw().FlushChain(this.rawChain)
|
||||
return this.conn.Commit()
|
||||
}
|
||||
13
internal/firewalls/nftables/chain_policy.go
Normal file
13
internal/firewalls/nftables/chain_policy.go
Normal file
@@ -0,0 +1,13 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nftables
|
||||
|
||||
import nft "github.com/google/nftables"
|
||||
|
||||
type ChainPolicy = nft.ChainPolicy
|
||||
|
||||
// Possible ChainPolicy values.
|
||||
const (
|
||||
ChainPolicyDrop = nft.ChainPolicyDrop
|
||||
ChainPolicyAccept = nft.ChainPolicyAccept
|
||||
)
|
||||
130
internal/firewalls/nftables/chain_test.go
Normal file
130
internal/firewalls/nftables/chain_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func getIPv4Chain(t *testing.T) *nftables.Chain {
|
||||
var conn = nftables.NewConn()
|
||||
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
table, err = conn.AddIPv4Table("test_ipv4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
chain, err := table.GetChain("test_chain")
|
||||
if err != nil {
|
||||
if err == nftables.ErrChainNotFound {
|
||||
chain, err = table.AddAcceptChain("test_chain")
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
func TestChain_AddAcceptIPRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddAcceptIPv4Rule(net.ParseIP("192.168.2.40").To4(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AddDropIPRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddDropIPv4Rule(net.ParseIP("192.168.2.31").To4(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AddAcceptSetRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddAcceptIPv4SetRule("ipv4_black_set", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AddDropSetRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddDropIPv4SetRule("ipv4_black_set", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_AddRejectSetRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
_, err := chain.AddRejectIPv4SetRule("ipv4_black_set", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_GetRuleWithUserData(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
rule, err := chain.GetRuleWithUserData([]byte("test"))
|
||||
if err != nil {
|
||||
if err == nftables.ErrRuleNotFound {
|
||||
t.Log("rule not found")
|
||||
return
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
t.Log("rule:", rule)
|
||||
}
|
||||
|
||||
func TestChain_GetRules(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
rules, err := chain.GetRules()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, rule := range rules {
|
||||
t.Log("handle:", rule.Handle(), "set name:", rule.LookupSetName(),
|
||||
"verdict:", rule.VerDict(), "user data:", string(rule.UserData()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_DeleteRule(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
rule, err := chain.GetRuleWithUserData([]byte("test"))
|
||||
if err != nil {
|
||||
if err == nftables.ErrRuleNotFound {
|
||||
t.Log("rule not found")
|
||||
return
|
||||
}
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = chain.DeleteRule(rule)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChain_Flush(t *testing.T) {
|
||||
var chain = getIPv4Chain(t)
|
||||
err := chain.Flush()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
84
internal/firewalls/nftables/conn.go
Normal file
84
internal/firewalls/nftables/conn.go
Normal file
@@ -0,0 +1,84 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
nft "github.com/google/nftables"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
)
|
||||
|
||||
const MaxTableNameLength = 27
|
||||
|
||||
type Conn struct {
|
||||
rawConn *nft.Conn
|
||||
}
|
||||
|
||||
func NewConn() *Conn {
|
||||
return &Conn{
|
||||
rawConn: &nft.Conn{},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Conn) Raw() *nft.Conn {
|
||||
return this.rawConn
|
||||
}
|
||||
|
||||
func (this *Conn) GetTable(name string, family TableFamily) (*Table, error) {
|
||||
rawTables, err := this.rawConn.ListTables()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, rawTable := range rawTables {
|
||||
if rawTable.Name == name && rawTable.Family == family {
|
||||
return NewTable(this, rawTable), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrTableNotFound
|
||||
}
|
||||
|
||||
func (this *Conn) AddTable(name string, family TableFamily) (*Table, error) {
|
||||
if len(name) > MaxTableNameLength {
|
||||
return nil, errors.New("table name too long (max " + types.String(MaxTableNameLength) + ")")
|
||||
}
|
||||
|
||||
var rawTable = this.rawConn.AddTable(&nft.Table{
|
||||
Family: family,
|
||||
Name: name,
|
||||
})
|
||||
|
||||
err := this.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewTable(this, rawTable), nil
|
||||
}
|
||||
|
||||
func (this *Conn) AddIPv4Table(name string) (*Table, error) {
|
||||
return this.AddTable(name, TableFamilyIPv4)
|
||||
}
|
||||
|
||||
func (this *Conn) AddIPv6Table(name string) (*Table, error) {
|
||||
return this.AddTable(name, TableFamilyIPv6)
|
||||
}
|
||||
|
||||
func (this *Conn) DeleteTable(name string, family TableFamily) error {
|
||||
table, err := this.GetTable(name, family)
|
||||
if err != nil {
|
||||
if err == ErrTableNotFound {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
this.rawConn.DelTable(table.Raw())
|
||||
return this.Commit()
|
||||
}
|
||||
|
||||
func (this *Conn) Commit() error {
|
||||
return this.rawConn.Flush()
|
||||
}
|
||||
78
internal/firewalls/nftables/conn_test.go
Normal file
78
internal/firewalls/nftables/conn_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"os/exec"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConn_Test(t *testing.T) {
|
||||
_, err := exec.LookPath("nft")
|
||||
if err == nil {
|
||||
t.Log("ok")
|
||||
return
|
||||
}
|
||||
t.Log(err)
|
||||
}
|
||||
|
||||
func TestConn_GetTable_NotFound(t *testing.T) {
|
||||
var conn = nftables.NewConn()
|
||||
|
||||
table, err := conn.GetTable("a", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
t.Log("table not found")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Log("table:", table)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_GetTable(t *testing.T) {
|
||||
var conn = nftables.NewConn()
|
||||
|
||||
table, err := conn.GetTable("myFilter", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
t.Log("table not found")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Log("table:", table)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_AddTable(t *testing.T) {
|
||||
var conn = nftables.NewConn()
|
||||
|
||||
{
|
||||
table, err := conn.AddIPv4Table("test_ipv4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(table.Name())
|
||||
}
|
||||
{
|
||||
table, err := conn.AddIPv6Table("test_ipv6")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(table.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_DeleteTable(t *testing.T) {
|
||||
var conn = nftables.NewConn()
|
||||
err := conn.DeleteTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
8
internal/firewalls/nftables/element.go
Normal file
8
internal/firewalls/nftables/element.go
Normal file
@@ -0,0 +1,8 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
type Element struct {
|
||||
}
|
||||
19
internal/firewalls/nftables/errors.go
Normal file
19
internal/firewalls/nftables/errors.go
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrTableNotFound = errors.New("table not found")
|
||||
var ErrChainNotFound = errors.New("chain not found")
|
||||
var ErrSetNotFound = errors.New("set not found")
|
||||
var ErrRuleNotFound = errors.New("rule not found")
|
||||
|
||||
func IsNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound
|
||||
}
|
||||
18
internal/firewalls/nftables/family.go
Normal file
18
internal/firewalls/nftables/family.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
nft "github.com/google/nftables"
|
||||
)
|
||||
|
||||
type TableFamily = nft.TableFamily
|
||||
|
||||
const (
|
||||
TableFamilyINet TableFamily = nft.TableFamilyINet
|
||||
TableFamilyIPv4 TableFamily = nft.TableFamilyIPv4
|
||||
TableFamilyIPv6 TableFamily = nft.TableFamilyIPv6
|
||||
TableFamilyARP TableFamily = nft.TableFamilyARP
|
||||
TableFamilyNetdev TableFamily = nft.TableFamilyNetdev
|
||||
TableFamilyBridge TableFamily = nft.TableFamilyBridge
|
||||
)
|
||||
51
internal/firewalls/nftables/rule.go
Normal file
51
internal/firewalls/nftables/rule.go
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
nft "github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
)
|
||||
|
||||
type Rule struct {
|
||||
rawRule *nft.Rule
|
||||
}
|
||||
|
||||
func NewRule(rawRule *nft.Rule) *Rule {
|
||||
return &Rule{
|
||||
rawRule: rawRule,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Rule) Raw() *nft.Rule {
|
||||
return this.rawRule
|
||||
}
|
||||
|
||||
func (this *Rule) LookupSetName() string {
|
||||
for _, e := range this.rawRule.Exprs {
|
||||
exp, ok := e.(*expr.Lookup)
|
||||
if ok {
|
||||
return exp.SetName
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (this *Rule) VerDict() expr.VerdictKind {
|
||||
for _, e := range this.rawRule.Exprs {
|
||||
exp, ok := e.(*expr.Verdict)
|
||||
if ok {
|
||||
return exp.Kind
|
||||
}
|
||||
}
|
||||
|
||||
return -100
|
||||
}
|
||||
|
||||
func (this *Rule) Handle() uint64 {
|
||||
return this.rawRule.Handle
|
||||
}
|
||||
|
||||
func (this *Rule) UserData() []byte {
|
||||
return this.rawRule.UserData
|
||||
}
|
||||
161
internal/firewalls/nftables/set.go
Normal file
161
internal/firewalls/nftables/set.go
Normal file
@@ -0,0 +1,161 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
nft "github.com/google/nftables"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const MaxSetNameLength = 15
|
||||
|
||||
type SetOptions struct {
|
||||
Id uint32
|
||||
HasTimeout bool
|
||||
Timeout time.Duration
|
||||
KeyType SetDataType
|
||||
DataType SetDataType
|
||||
Constant bool
|
||||
Interval bool
|
||||
Anonymous bool
|
||||
IsMap bool
|
||||
}
|
||||
|
||||
type ElementOptions struct {
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type Set struct {
|
||||
conn *Conn
|
||||
rawSet *nft.Set
|
||||
batch *SetBatch
|
||||
}
|
||||
|
||||
func NewSet(conn *Conn, rawSet *nft.Set) *Set {
|
||||
return &Set{
|
||||
conn: conn,
|
||||
rawSet: rawSet,
|
||||
batch: &SetBatch{
|
||||
conn: conn,
|
||||
rawSet: rawSet,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Set) Raw() *nft.Set {
|
||||
return this.rawSet
|
||||
}
|
||||
|
||||
func (this *Set) Name() string {
|
||||
return this.rawSet.Name
|
||||
}
|
||||
|
||||
func (this *Set) AddElement(key []byte, options *ElementOptions) error {
|
||||
var rawElement = nft.SetElement{
|
||||
Key: key,
|
||||
}
|
||||
if options != nil {
|
||||
rawElement.Timeout = options.Timeout
|
||||
}
|
||||
err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
|
||||
rawElement,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = this.conn.Commit()
|
||||
if err != nil {
|
||||
// retry if exists
|
||||
if strings.Contains(err.Error(), "file exists") {
|
||||
deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||
{
|
||||
Key: key,
|
||||
},
|
||||
})
|
||||
if deleteErr == nil {
|
||||
err = this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
|
||||
rawElement,
|
||||
})
|
||||
if err == nil {
|
||||
err = this.conn.Commit()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *Set) AddIPElement(ip string, options *ElementOptions) error {
|
||||
var ipObj = net.ParseIP(ip)
|
||||
if ipObj == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if utils.IsIPv4(ip) {
|
||||
return this.AddElement(ipObj.To4(), options)
|
||||
} else {
|
||||
return this.AddElement(ipObj.To16(), options)
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Set) DeleteElement(key []byte) error {
|
||||
err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||
{
|
||||
Key: key,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = this.conn.Commit()
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "no such file or directory") {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *Set) DeleteIPElement(ip string) error {
|
||||
var ipObj = net.ParseIP(ip)
|
||||
if ipObj == nil {
|
||||
return errors.New("invalid ip '" + ip + "'")
|
||||
}
|
||||
|
||||
if utils.IsIPv4(ip) {
|
||||
return this.DeleteElement(ipObj.To4())
|
||||
} else {
|
||||
return this.DeleteElement(ipObj.To16())
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Set) Batch() *SetBatch {
|
||||
return this.batch
|
||||
}
|
||||
|
||||
func (this *Set) GetIPElements() ([]string, error) {
|
||||
elements, err := this.conn.Raw().GetSetElements(this.rawSet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result = []string{}
|
||||
for _, element := range elements {
|
||||
result = append(result, net.IP(element.Key).String())
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// not work current time
|
||||
/**func (this *Set) Flush() error {
|
||||
this.conn.Raw().FlushSet(this.rawSet)
|
||||
return this.conn.Commit()
|
||||
}**/
|
||||
36
internal/firewalls/nftables/set_batch.go
Normal file
36
internal/firewalls/nftables/set_batch.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
nft "github.com/google/nftables"
|
||||
)
|
||||
|
||||
type SetBatch struct {
|
||||
conn *Conn
|
||||
rawSet *nft.Set
|
||||
}
|
||||
|
||||
func (this *SetBatch) AddElement(key []byte, options *ElementOptions) error {
|
||||
var rawElement = nft.SetElement{
|
||||
Key: key,
|
||||
}
|
||||
if options != nil {
|
||||
rawElement.Timeout = options.Timeout
|
||||
}
|
||||
return this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
|
||||
rawElement,
|
||||
})
|
||||
}
|
||||
|
||||
func (this *SetBatch) DeleteElement(key []byte) error {
|
||||
return this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||
{
|
||||
Key: key,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (this *SetBatch) Commit() error {
|
||||
return this.conn.Commit()
|
||||
}
|
||||
57
internal/firewalls/nftables/set_data_type.go
Normal file
57
internal/firewalls/nftables/set_data_type.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nftables
|
||||
|
||||
import nft "github.com/google/nftables"
|
||||
|
||||
type SetDataType = nft.SetDatatype
|
||||
|
||||
var (
|
||||
TypeInvalid = nft.TypeInvalid
|
||||
TypeVerdict = nft.TypeVerdict
|
||||
TypeNFProto = nft.TypeNFProto
|
||||
TypeBitmask = nft.TypeBitmask
|
||||
TypeInteger = nft.TypeInteger
|
||||
TypeString = nft.TypeString
|
||||
TypeLLAddr = nft.TypeLLAddr
|
||||
TypeIPAddr = nft.TypeIPAddr
|
||||
TypeIP6Addr = nft.TypeIP6Addr
|
||||
TypeEtherAddr = nft.TypeEtherAddr
|
||||
TypeEtherType = nft.TypeEtherType
|
||||
TypeARPOp = nft.TypeARPOp
|
||||
TypeInetProto = nft.TypeInetProto
|
||||
TypeInetService = nft.TypeInetService
|
||||
TypeICMPType = nft.TypeICMPType
|
||||
TypeTCPFlag = nft.TypeTCPFlag
|
||||
TypeDCCPPktType = nft.TypeDCCPPktType
|
||||
TypeMHType = nft.TypeMHType
|
||||
TypeTime = nft.TypeTime
|
||||
TypeMark = nft.TypeMark
|
||||
TypeIFIndex = nft.TypeIFIndex
|
||||
TypeARPHRD = nft.TypeARPHRD
|
||||
TypeRealm = nft.TypeRealm
|
||||
TypeClassID = nft.TypeClassID
|
||||
TypeUID = nft.TypeUID
|
||||
TypeGID = nft.TypeGID
|
||||
TypeCTState = nft.TypeCTState
|
||||
TypeCTDir = nft.TypeCTDir
|
||||
TypeCTStatus = nft.TypeCTStatus
|
||||
TypeICMP6Type = nft.TypeICMP6Type
|
||||
TypeCTLabel = nft.TypeCTLabel
|
||||
TypePktType = nft.TypePktType
|
||||
TypeICMPCode = nft.TypeICMPCode
|
||||
TypeICMPV6Code = nft.TypeICMPV6Code
|
||||
TypeICMPXCode = nft.TypeICMPXCode
|
||||
TypeDevGroup = nft.TypeDevGroup
|
||||
TypeDSCP = nft.TypeDSCP
|
||||
TypeECN = nft.TypeECN
|
||||
TypeFIBAddr = nft.TypeFIBAddr
|
||||
TypeBoolean = nft.TypeBoolean
|
||||
TypeCTEventBit = nft.TypeCTEventBit
|
||||
TypeIFName = nft.TypeIFName
|
||||
TypeIGMPType = nft.TypeIGMPType
|
||||
TypeTimeDate = nft.TypeTimeDate
|
||||
TypeTimeHour = nft.TypeTimeHour
|
||||
TypeTimeDay = nft.TypeTimeDay
|
||||
TypeCGroupV2 = nft.TypeCGroupV2
|
||||
)
|
||||
110
internal/firewalls/nftables/set_test.go
Normal file
110
internal/firewalls/nftables/set_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nftables_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"github.com/mdlayher/netlink"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func getIPv4Set(t *testing.T) *nftables.Set {
|
||||
var table = getIPv4Table(t)
|
||||
set, err := table.GetSet("test_ipv4_set")
|
||||
if err != nil {
|
||||
if err == nftables.ErrSetNotFound {
|
||||
set, err = table.AddSet("test_ipv4_set", &nftables.SetOptions{
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
HasTimeout: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
func TestSet_AddElement(t *testing.T) {
|
||||
var set = getIPv4Set(t)
|
||||
err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestSet_DeleteElement(t *testing.T) {
|
||||
var set = getIPv4Set(t)
|
||||
err := set.DeleteElement(net.ParseIP("192.168.2.31").To4())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestSet_Batch(t *testing.T) {
|
||||
var batch = getIPv4Set(t).Batch()
|
||||
|
||||
for _, ip := range []string{"192.168.2.30", "192.168.2.31", "192.168.2.32", "192.168.2.33", "192.168.2.34"} {
|
||||
var ipData = net.ParseIP(ip).To4()
|
||||
//err := batch.DeleteElement(ipData)
|
||||
//if err != nil {
|
||||
// t.Fatal(err)
|
||||
//}
|
||||
err := batch.AddElement(ipData, &nftables.ElementOptions{Timeout: 10 * time.Second})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := batch.Commit()
|
||||
if err != nil {
|
||||
t.Logf("%#v", errors.Unwrap(err).(*netlink.OpError))
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestSet_Add_Many(t *testing.T) {
|
||||
var set = getIPv4Set(t)
|
||||
|
||||
for i := 0; i < 255; i++ {
|
||||
t.Log(i)
|
||||
for j := 0; j < 255; j++ {
|
||||
var ip = "192.167." + types.String(i) + "." + types.String(j)
|
||||
var ipData = net.ParseIP(ip).To4()
|
||||
err := set.Batch().AddElement(ipData, &nftables.ElementOptions{Timeout: 3600 * time.Second})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if j%10 == 0 {
|
||||
err = set.Batch().Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
err := set.Batch().Commit()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
/**func TestSet_Flush(t *testing.T) {
|
||||
var set = getIPv4Set(t)
|
||||
err := set.Flush()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}**/
|
||||
157
internal/firewalls/nftables/table.go
Normal file
157
internal/firewalls/nftables/table.go
Normal file
@@ -0,0 +1,157 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
nft "github.com/google/nftables"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
conn *Conn
|
||||
rawTable *nft.Table
|
||||
}
|
||||
|
||||
func NewTable(conn *Conn, rawTable *nft.Table) *Table {
|
||||
return &Table{
|
||||
conn: conn,
|
||||
rawTable: rawTable,
|
||||
}
|
||||
}
|
||||
|
||||
func (this *Table) Raw() *nft.Table {
|
||||
return this.rawTable
|
||||
}
|
||||
|
||||
func (this *Table) Name() string {
|
||||
return this.rawTable.Name
|
||||
}
|
||||
|
||||
func (this *Table) Family() TableFamily {
|
||||
return this.rawTable.Family
|
||||
}
|
||||
|
||||
func (this *Table) GetChain(name string) (*Chain, error) {
|
||||
rawChains, err := this.conn.Raw().ListChains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, rawChain := range rawChains {
|
||||
// must compare table name
|
||||
if rawChain.Name == name && rawChain.Table.Name == this.rawTable.Name {
|
||||
return NewChain(this.conn, this.rawTable, rawChain), nil
|
||||
}
|
||||
}
|
||||
return nil, ErrChainNotFound
|
||||
}
|
||||
|
||||
func (this *Table) AddChain(name string, chainPolicy *ChainPolicy) (*Chain, error) {
|
||||
if len(name) > MaxChainNameLength {
|
||||
return nil, errors.New("chain name too long (max " + types.String(MaxChainNameLength) + ")")
|
||||
}
|
||||
|
||||
var rawChain = this.conn.Raw().AddChain(&nft.Chain{
|
||||
Name: name,
|
||||
Table: this.rawTable,
|
||||
Hooknum: nft.ChainHookInput,
|
||||
Priority: nft.ChainPriorityFilter,
|
||||
Type: nft.ChainTypeFilter,
|
||||
Policy: chainPolicy,
|
||||
})
|
||||
|
||||
err := this.conn.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewChain(this.conn, this.rawTable, rawChain), nil
|
||||
}
|
||||
|
||||
func (this *Table) AddAcceptChain(name string) (*Chain, error) {
|
||||
var policy = ChainPolicyAccept
|
||||
return this.AddChain(name, &policy)
|
||||
}
|
||||
|
||||
func (this *Table) AddDropChain(name string) (*Chain, error) {
|
||||
var policy = ChainPolicyDrop
|
||||
return this.AddChain(name, &policy)
|
||||
}
|
||||
|
||||
func (this *Table) DeleteChain(name string) error {
|
||||
chain, err := this.GetChain(name)
|
||||
if err != nil {
|
||||
if err == ErrChainNotFound {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
this.conn.Raw().DelChain(chain.Raw())
|
||||
return this.conn.Commit()
|
||||
}
|
||||
|
||||
func (this *Table) GetSet(name string) (*Set, error) {
|
||||
rawSet, err := this.conn.Raw().GetSetByName(this.rawTable, name)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "no such file or directory") {
|
||||
return nil, ErrSetNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSet(this.conn, rawSet), nil
|
||||
}
|
||||
|
||||
func (this *Table) AddSet(name string, options *SetOptions) (*Set, error) {
|
||||
if len(name) > MaxSetNameLength {
|
||||
return nil, errors.New("set name too long (max " + types.String(MaxSetNameLength) + ")")
|
||||
}
|
||||
|
||||
if options == nil {
|
||||
options = &SetOptions{}
|
||||
}
|
||||
var rawSet = &nft.Set{
|
||||
Table: this.rawTable,
|
||||
ID: options.Id,
|
||||
Name: name,
|
||||
Anonymous: options.Anonymous,
|
||||
Constant: options.Constant,
|
||||
Interval: options.Interval,
|
||||
IsMap: options.IsMap,
|
||||
HasTimeout: options.HasTimeout,
|
||||
Timeout: options.Timeout,
|
||||
KeyType: options.KeyType,
|
||||
DataType: options.DataType,
|
||||
}
|
||||
err := this.conn.Raw().AddSet(rawSet, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = this.conn.Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewSet(this.conn, rawSet), nil
|
||||
}
|
||||
|
||||
func (this *Table) DeleteSet(name string) error {
|
||||
set, err := this.GetSet(name)
|
||||
if err != nil {
|
||||
if err == ErrSetNotFound {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
this.conn.Raw().DelSet(set.Raw())
|
||||
return this.conn.Commit()
|
||||
}
|
||||
|
||||
func (this *Table) Flush() error {
|
||||
this.conn.Raw().FlushTable(this.rawTable)
|
||||
return this.conn.Commit()
|
||||
}
|
||||
140
internal/firewalls/nftables/table_test.go
Normal file
140
internal/firewalls/nftables/table_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package nftables_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func getIPv4Table(t *testing.T) *nftables.Table {
|
||||
var conn = nftables.NewConn()
|
||||
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
if err == nftables.ErrTableNotFound {
|
||||
table, err = conn.AddIPv4Table("test_ipv4")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func TestTable_AddChain(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
|
||||
{
|
||||
chain, err := table.AddChain("test_default_chain", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("created:", chain.Name())
|
||||
}
|
||||
|
||||
{
|
||||
chain, err := table.AddAcceptChain("test_accept_chain")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("created:", chain.Name())
|
||||
}
|
||||
|
||||
// Do not test drop chain before adding accept rule, you will drop yourself!!!!!!!
|
||||
/**{
|
||||
chain, err := table.AddDropChain("test_drop_chain")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("created:", chain.Name())
|
||||
}**/
|
||||
}
|
||||
|
||||
func TestTable_GetChain(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
for _, chainName := range []string{"not_found_chain", "test_default_chain"} {
|
||||
chain, err := table.GetChain(chainName)
|
||||
if err != nil {
|
||||
if err == nftables.ErrChainNotFound {
|
||||
t.Log(chainName, ":", "not found")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Log(chainName, ":", chain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_DeleteChain(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
err := table.DeleteChain("test_default_chain")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestTable_AddSet(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
{
|
||||
set, err := table.AddSet("ipv4_black_set", &nftables.SetOptions{
|
||||
HasTimeout: false,
|
||||
KeyType: nftables.TypeIPAddr,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(set.Name())
|
||||
}
|
||||
|
||||
{
|
||||
set, err := table.AddSet("ipv6_black_set", &nftables.SetOptions{
|
||||
HasTimeout: true,
|
||||
//Timeout: 3600 * time.Second,
|
||||
KeyType: nftables.TypeIP6Addr,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(set.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_GetSet(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
for _, setName := range []string{"not_found_set", "ipv4_black_set"} {
|
||||
set, err := table.GetSet(setName)
|
||||
if err != nil {
|
||||
if err == nftables.ErrSetNotFound {
|
||||
t.Log(setName, ": not found")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
t.Log(setName, ":", set)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTable_DeleteSet(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
err := table.DeleteSet("ipv4_black_set")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
|
||||
func TestTable_Flush(t *testing.T) {
|
||||
var table = getIPv4Table(t)
|
||||
err := table.Flush()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("ok")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
@@ -14,16 +15,20 @@ import (
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/errors"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -119,6 +124,8 @@ func (this *APIStream) loop() error {
|
||||
err = this.handleNewNodeTask(message)
|
||||
case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务
|
||||
err = this.handleCheckSystemdService(message)
|
||||
case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
|
||||
err = this.handleCheckLocalFirewall(message)
|
||||
case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址
|
||||
err = this.handleChangeAPINode(message)
|
||||
default:
|
||||
@@ -569,7 +576,7 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd := utils.NewCommandExecutor()
|
||||
var cmd = utils.NewCommandExecutor()
|
||||
shortName := teaconst.SystemdServiceName
|
||||
cmd.Add(systemctl, "is-enabled", shortName)
|
||||
output, err := cmd.Run()
|
||||
@@ -585,6 +592,63 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查本地防火墙
|
||||
func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) error {
|
||||
var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
|
||||
err := json.Unmarshal(message.DataJSON, dataMessage)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
// nft
|
||||
if dataMessage.Name == "nftables" {
|
||||
if runtime.GOOS != "linux" {
|
||||
this.replyFail(message.RequestId, "not Linux system")
|
||||
return nil
|
||||
}
|
||||
|
||||
nft, err := exec.LookPath("nft")
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd = exec.Command(nft, "--version")
|
||||
var output = &bytes.Buffer{}
|
||||
cmd.Stdout = output
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, "get version failed: "+err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
var outputString = output.String()
|
||||
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
|
||||
if len(versionMatches) <= 1 {
|
||||
this.replyFail(message.RequestId, "can not get nft version")
|
||||
return nil
|
||||
}
|
||||
var version = versionMatches[1]
|
||||
|
||||
var result = maps.Map{
|
||||
"version": version,
|
||||
}
|
||||
|
||||
var protectionConfig = sharedNodeConfig.DDOSProtection
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
|
||||
if err != nil {
|
||||
this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error())
|
||||
} else {
|
||||
this.replyOk(message.RequestId, string(result.AsJSON()))
|
||||
}
|
||||
} else {
|
||||
this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 修改API地址
|
||||
func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error {
|
||||
config, err := configs.LoadAPIConfig()
|
||||
@@ -660,6 +724,11 @@ func (this *APIStream) replyOk(requestId int64, message string) {
|
||||
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
|
||||
}
|
||||
|
||||
// 回复成功并包含数据
|
||||
func (this *APIStream) replyOkData(requestId int64, message string, dataJSON []byte) {
|
||||
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message, DataJSON: dataJSON})
|
||||
}
|
||||
|
||||
// 获取缓存存取对象
|
||||
func (this *APIStream) cacheStorage(message *pb.NodeStreamMessage, cachePolicyJSON []byte) (storage caches.StorageInterface, shouldStop bool, err error) {
|
||||
cachePolicy := &serverconfigs.HTTPCachePolicy{}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
@@ -22,7 +21,6 @@ import (
|
||||
// ClientConn 客户端连接
|
||||
type ClientConn struct {
|
||||
once sync.Once
|
||||
globalLimiter *ratelimit.Counter
|
||||
|
||||
isTLS bool
|
||||
hasDeadline bool
|
||||
@@ -33,7 +31,7 @@ type ClientConn struct {
|
||||
BaseClientConn
|
||||
}
|
||||
|
||||
func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn {
|
||||
func NewClientConn(conn net.Conn, isTLS bool, quickClose bool) net.Conn {
|
||||
if quickClose {
|
||||
// TCP
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
@@ -43,7 +41,7 @@ func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ra
|
||||
}
|
||||
}
|
||||
|
||||
return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS, globalLimiter: globalLimiter}
|
||||
return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS}
|
||||
}
|
||||
|
||||
func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||
@@ -96,13 +94,6 @@ func (this *ClientConn) Close() error {
|
||||
|
||||
err := this.rawConn.Close()
|
||||
|
||||
// 全局并发数限制
|
||||
this.once.Do(func() {
|
||||
if this.globalLimiter != nil {
|
||||
this.globalLimiter.Release()
|
||||
}
|
||||
})
|
||||
|
||||
// 单个服务并发数限制
|
||||
sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
|
||||
|
||||
|
||||
@@ -3,16 +3,12 @@
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"net"
|
||||
)
|
||||
|
||||
var sharedConnectionsLimiter = ratelimit.NewCounter(nodeconfigs.DefaultTCPMaxConnections)
|
||||
|
||||
// ClientListener 客户端网络监听
|
||||
type ClientListener struct {
|
||||
rawListener net.Listener
|
||||
@@ -36,13 +32,8 @@ func (this *ClientListener) IsTLS() bool {
|
||||
}
|
||||
|
||||
func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
// 限制并发连接数
|
||||
var limiter = sharedConnectionsLimiter
|
||||
limiter.Ack()
|
||||
|
||||
conn, err := this.rawListener.Accept()
|
||||
if err != nil {
|
||||
limiter.Release()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -60,12 +51,11 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
}
|
||||
|
||||
_ = conn.Close()
|
||||
limiter.Release()
|
||||
return this.Accept()
|
||||
}
|
||||
}
|
||||
|
||||
return NewClientConn(conn, this.isTLS, this.quickClose, limiter), nil
|
||||
return NewClientConn(conn, this.isTLS, this.quickClose), nil
|
||||
}
|
||||
|
||||
func (this *ClientListener) Close() error {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||
@@ -15,7 +16,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/metrics"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||
@@ -368,6 +368,38 @@ func (this *Node) loop() error {
|
||||
}
|
||||
sharedNodeConfig.ParentNodes = parentNodes
|
||||
|
||||
// 修改为已同步
|
||||
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: task.Id,
|
||||
IsOk: true,
|
||||
Error: "",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case "ddosProtectionChanged":
|
||||
resp, err := rpcClient.NodeRPC().FindNodeDDoSProtection(nodeCtx, &pb.FindNodeDDoSProtectionRequest{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(resp.DdosProtectionJSON) == 0 {
|
||||
if sharedNodeConfig != nil {
|
||||
sharedNodeConfig.DDOSProtection = nil
|
||||
}
|
||||
} else {
|
||||
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
|
||||
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
|
||||
if err != nil {
|
||||
return errors.New("decode DDoS protection config failed: " + err.Error())
|
||||
}
|
||||
|
||||
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
|
||||
if err != nil {
|
||||
// 不阻塞
|
||||
remotelogs.Error("NODE", "apply DDoS protection failed: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// 修改为已同步
|
||||
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||
NodeTaskId: task.Id,
|
||||
@@ -730,7 +762,6 @@ func (this *Node) listenSock() error {
|
||||
"ipConns": ipConns,
|
||||
"serverConns": serverConns,
|
||||
"total": sharedListenerManager.TotalActiveConnections(),
|
||||
"limiter": sharedConnectionsLimiter.Len(),
|
||||
},
|
||||
})
|
||||
case "dropIP":
|
||||
@@ -854,17 +885,6 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig) {
|
||||
this.maxThreads = config.MaxThreads
|
||||
}
|
||||
|
||||
// max tcp connections
|
||||
if config.TCPMaxConnections <= 0 {
|
||||
config.TCPMaxConnections = nodeconfigs.DefaultTCPMaxConnections
|
||||
}
|
||||
if config.TCPMaxConnections != sharedConnectionsLimiter.Count() {
|
||||
remotelogs.Println("NODE", "[TCP]changed tcp max connections to '"+types.String(config.TCPMaxConnections)+"'")
|
||||
|
||||
sharedConnectionsLimiter.Close()
|
||||
sharedConnectionsLimiter = ratelimit.NewCounter(config.TCPMaxConnections)
|
||||
}
|
||||
|
||||
// timezone
|
||||
var timeZone = config.TimeZone
|
||||
if len(timeZone) == 0 {
|
||||
|
||||
@@ -5,9 +5,12 @@ import (
|
||||
"github.com/cespare/xxhash"
|
||||
"math"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ipv4Reg = regexp.MustCompile(`\d+\.`)
|
||||
|
||||
// IP2Long 将IP转换为整型
|
||||
// 注意IPv6没有顺序
|
||||
func IP2Long(ip string) uint64 {
|
||||
@@ -54,3 +57,24 @@ func IsLocalIP(ipString string) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsIPv4 是否为IPv4
|
||||
func IsIPv4(ip string) bool {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(ip, ":") {
|
||||
return false
|
||||
}
|
||||
return data.To4() != nil
|
||||
}
|
||||
|
||||
// IsIPv6 是否为IPv6
|
||||
func IsIPv6(ip string) bool {
|
||||
var data = net.ParseIP(ip)
|
||||
if data == nil {
|
||||
return false
|
||||
}
|
||||
return !IsIPv4(ip)
|
||||
}
|
||||
|
||||
@@ -26,3 +26,26 @@ func TestIsLocalIP(t *testing.T) {
|
||||
a.IsFalse(IsLocalIP("::1:2:3"))
|
||||
a.IsFalse(IsLocalIP("8.8.8.8"))
|
||||
}
|
||||
|
||||
func TestIsIPv4(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
a.IsTrue(IsIPv4("192.168.1.1"))
|
||||
a.IsTrue(IsIPv4("0.0.0.0"))
|
||||
a.IsFalse(IsIPv4("192.168.1.256"))
|
||||
a.IsFalse(IsIPv4("192.168.1"))
|
||||
a.IsFalse(IsIPv4("::1"))
|
||||
a.IsFalse(IsIPv4("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
|
||||
a.IsFalse(IsIPv4("::ffff:192.168.0.1"))
|
||||
}
|
||||
|
||||
func TestIsIPv6(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
a.IsFalse(IsIPv6("192.168.1.1"))
|
||||
a.IsFloat32(IsIPv6("0.0.0.0"))
|
||||
a.IsFalse(IsIPv6("192.168.1.256"))
|
||||
a.IsFalse(IsIPv6("192.168.1"))
|
||||
a.IsTrue(IsIPv6("::1"))
|
||||
a.IsTrue(IsIPv6("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
|
||||
a.IsTrue(IsIPv4("::ffff:192.168.0.1"))
|
||||
a.IsTrue(IsIPv6("::ffff:192.168.0.1"))
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
@@ -36,6 +37,22 @@ func FormatAddressList(addrList []string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// ToValidUTF8string 去除字符串中的非UTF-8字符
|
||||
func ToValidUTF8string(v string) string {
|
||||
return strings.ToValidUTF8(v, "")
|
||||
}
|
||||
|
||||
// ContainsSameStrings 检查两个字符串slice内容是否一致
|
||||
func ContainsSameStrings(s1 []string, s2 []string) bool {
|
||||
if len(s1) != len(s2) {
|
||||
return false
|
||||
}
|
||||
sort.Strings(s1)
|
||||
sort.Strings(s2)
|
||||
for index, v1 := range s1 {
|
||||
if v1 != s2[index] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1,56 +1,67 @@
|
||||
package utils
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBytesToString(t *testing.T) {
|
||||
t.Log(UnsafeBytesToString([]byte("Hello,World")))
|
||||
t.Log(utils.UnsafeBytesToString([]byte("Hello,World")))
|
||||
}
|
||||
|
||||
func TestStringToBytes(t *testing.T) {
|
||||
t.Log(string(UnsafeStringToBytes("Hello,World")))
|
||||
t.Log(string(utils.UnsafeStringToBytes("Hello,World")))
|
||||
}
|
||||
|
||||
func BenchmarkBytesToString(b *testing.B) {
|
||||
data := []byte("Hello,World")
|
||||
var data = []byte("Hello,World")
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = UnsafeBytesToString(data)
|
||||
_ = utils.UnsafeBytesToString(data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBytesToString2(b *testing.B) {
|
||||
data := []byte("Hello,World")
|
||||
var data = []byte("Hello,World")
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = string(data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStringToBytes(b *testing.B) {
|
||||
s := strings.Repeat("Hello,World", 1024)
|
||||
var s = strings.Repeat("Hello,World", 1024)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = UnsafeStringToBytes(s)
|
||||
_ = utils.UnsafeStringToBytes(s)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStringToBytes2(b *testing.B) {
|
||||
s := strings.Repeat("Hello,World", 1024)
|
||||
var s = strings.Repeat("Hello,World", 1024)
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = []byte(s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatAddress(t *testing.T) {
|
||||
t.Log(FormatAddress("127.0.0.1:1234"))
|
||||
t.Log(FormatAddress("127.0.0.1 : 1234"))
|
||||
t.Log(FormatAddress("127.0.0.1:1234"))
|
||||
t.Log(utils.FormatAddress("127.0.0.1:1234"))
|
||||
t.Log(utils.FormatAddress("127.0.0.1 : 1234"))
|
||||
t.Log(utils.FormatAddress("127.0.0.1:1234"))
|
||||
}
|
||||
|
||||
func TestFormatAddressList(t *testing.T) {
|
||||
t.Log(FormatAddressList([]string{
|
||||
t.Log(utils.FormatAddressList([]string{
|
||||
"127.0.0.1:1234",
|
||||
"127.0.0.1 : 1234",
|
||||
"127.0.0.1:1234",
|
||||
}))
|
||||
}
|
||||
|
||||
func TestContainsSameStrings(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
a.IsFalse(utils.ContainsSameStrings([]string{"a"}, []string{"b"}))
|
||||
a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b"}))
|
||||
a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b", "c"}))
|
||||
a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b"}))
|
||||
a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b", "a"}))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user