mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-03 23:20:25 +08:00
实现基础的DDoS防护
This commit is contained in:
@@ -3,9 +3,10 @@ package events
|
|||||||
type Event = string
|
type Event = string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
EventStart Event = "start" // start loading
|
EventStart Event = "start" // start loading
|
||||||
EventLoaded Event = "loaded" // first load
|
EventLoaded Event = "loaded" // first load
|
||||||
EventQuit Event = "quit" // quit node gracefully
|
EventQuit Event = "quit" // quit node gracefully
|
||||||
EventReload Event = "reload" // reload config
|
EventReload Event = "reload" // reload config
|
||||||
EventTerminated Event = "terminated" // process terminated
|
EventTerminated Event = "terminated" // process terminated
|
||||||
|
EventNFTablesReady Event = "nftablesReady" // nftables ready
|
||||||
)
|
)
|
||||||
|
|||||||
1
internal/firewalls/.gitignore
vendored
1
internal/firewalls/.gitignore
vendored
@@ -1 +0,0 @@
|
|||||||
firewall_nftables_test.go
|
|
||||||
494
internal/firewalls/ddos_protection.go
Normal file
494
internal/firewalls/ddos_protection.go
Normal file
@@ -0,0 +1,494 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package firewalls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/zero"
|
||||||
|
"github.com/iwind/TeaGo/lists"
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
"net"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
events.On(events.EventReload, func() {
|
||||||
|
if nftablesInstance == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||||
|
if nodeConfig != nil {
|
||||||
|
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection)
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
events.On(events.EventNFTablesReady, func() {
|
||||||
|
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||||
|
if nodeConfig != nil {
|
||||||
|
err := SharedDDoSProtectionManager.Apply(nodeConfig.DDOSProtection)
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Error("FIREWALL", "apply DDoS protection failed: "+err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// DDoSProtectionManager DDoS防护
|
||||||
|
type DDoSProtectionManager struct {
|
||||||
|
nftPath string
|
||||||
|
|
||||||
|
lastAllowIPList []string
|
||||||
|
lastConfig []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDDoSProtectionManager 获取新对象
|
||||||
|
func NewDDoSProtectionManager() *DDoSProtectionManager {
|
||||||
|
nftPath, _ := exec.LookPath("nft")
|
||||||
|
|
||||||
|
return &DDoSProtectionManager{
|
||||||
|
nftPath: nftPath,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply 应用配置
|
||||||
|
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
|
||||||
|
// 同集群节点IP白名单
|
||||||
|
var allowIPListChanged = false
|
||||||
|
nodeConfig, _ := nodeconfigs.SharedNodeConfig()
|
||||||
|
if nodeConfig != nil {
|
||||||
|
var allowIPList = nodeConfig.AllowedIPs
|
||||||
|
if !utils.ContainsSameStrings(allowIPList, this.lastAllowIPList) {
|
||||||
|
allowIPListChanged = true
|
||||||
|
this.lastAllowIPList = allowIPList
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对比配置
|
||||||
|
configJSON, err := json.Marshal(config)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("encode config to json failed: " + err.Error())
|
||||||
|
}
|
||||||
|
if !allowIPListChanged && bytes.Equal(this.lastConfig, configJSON) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
remotelogs.Println("FIREWALL", "change DDoS protection config")
|
||||||
|
|
||||||
|
if len(this.nftPath) == 0 {
|
||||||
|
return errors.New("can not find nft command")
|
||||||
|
}
|
||||||
|
|
||||||
|
if nftablesInstance == nil {
|
||||||
|
return errors.New("nftables instance should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config == nil {
|
||||||
|
// TCP
|
||||||
|
err := this.removeTCPRules()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO other protocols
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCP
|
||||||
|
if config.TCP == nil {
|
||||||
|
err := this.removeTCPRules()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// allow ip list
|
||||||
|
var allowIPList = []string{}
|
||||||
|
for _, ipConfig := range config.TCP.AllowIPList {
|
||||||
|
allowIPList = append(allowIPList, ipConfig.IP)
|
||||||
|
}
|
||||||
|
for _, ip := range this.lastAllowIPList {
|
||||||
|
if !lists.ContainsString(allowIPList, ip) {
|
||||||
|
allowIPList = append(allowIPList, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = this.updateAllowIPList(allowIPList)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// tcp
|
||||||
|
if config.TCP.IsOn {
|
||||||
|
err := this.addTCPRules(config.TCP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err := this.removeTCPRules()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this.lastConfig = configJSON
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加TCP规则
|
||||||
|
func (this *DDoSProtectionManager) addTCPRules(tcpConfig *ddosconfigs.TCPConfig) error {
|
||||||
|
var ports = []int32{}
|
||||||
|
for _, portConfig := range tcpConfig.Ports {
|
||||||
|
if !lists.ContainsInt32(ports, portConfig.Port) {
|
||||||
|
ports = append(ports, portConfig.Port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(ports) == 0 {
|
||||||
|
ports = []int32{80, 443}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, filter := range nftablesFilters {
|
||||||
|
chain, oldRules, err := this.getRules(filter)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("get old rules failed: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
var protocol = filter.protocol()
|
||||||
|
|
||||||
|
// max connections
|
||||||
|
var maxConnections = tcpConfig.MaxConnections
|
||||||
|
if maxConnections <= 0 {
|
||||||
|
maxConnections = nodeconfigs.DefaultTCPMaxConnections
|
||||||
|
if maxConnections <= 0 {
|
||||||
|
maxConnections = 100000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// max connections per ip
|
||||||
|
var maxConnectionsPerIP = tcpConfig.MaxConnectionsPerIP
|
||||||
|
if maxConnectionsPerIP <= 0 {
|
||||||
|
maxConnectionsPerIP = nodeconfigs.DefaultTCPMaxConnectionsPerIP
|
||||||
|
if maxConnectionsPerIP <= 0 {
|
||||||
|
maxConnectionsPerIP = 100000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// new connections rate
|
||||||
|
var newConnectionsRate = tcpConfig.NewConnectionsRate
|
||||||
|
if newConnectionsRate <= 0 {
|
||||||
|
newConnectionsRate = nodeconfigs.DefaultTCPNewConnectionsRate
|
||||||
|
if newConnectionsRate <= 0 {
|
||||||
|
newConnectionsRate = 100000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否有变化
|
||||||
|
var hasChanges = false
|
||||||
|
for _, port := range ports {
|
||||||
|
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}) {
|
||||||
|
hasChanges = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}) {
|
||||||
|
hasChanges = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !this.existsRule(oldRules, []string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}) {
|
||||||
|
hasChanges = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasChanges {
|
||||||
|
// 检查是否有多余的端口
|
||||||
|
var oldPorts = this.getTCPPorts(oldRules)
|
||||||
|
if !this.eqPorts(ports, oldPorts) {
|
||||||
|
hasChanges = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasChanges {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先清空所有相关规则
|
||||||
|
err = this.removeOldTCPRules(chain, oldRules)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("delete old rules failed: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加新规则
|
||||||
|
for _, port := range ports {
|
||||||
|
if maxConnections > 0 {
|
||||||
|
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "count", "over", types.String(maxConnections), "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnections", types.String(maxConnections)}))
|
||||||
|
err := cmd.Run()
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxConnectionsPerIP > 0 {
|
||||||
|
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "meter", "meter-"+protocol+"-"+types.String(port)+"-max-connections", "{ "+protocol+" saddr ct count over "+types.String(maxConnectionsPerIP)+" }", "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "maxConnectionsPerIP", types.String(maxConnectionsPerIP)}))
|
||||||
|
var stderr = &bytes.Buffer{}
|
||||||
|
cmd.Stderr = stderr
|
||||||
|
err := cmd.Run()
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if newConnectionsRate > 0 {
|
||||||
|
// TODO 思考是否有惩罚机制
|
||||||
|
var cmd = exec.Command(this.nftPath, "add", "rule", protocol, filter.Name, nftablesChainName, "tcp", "dport", types.String(port), "ct", "state", "new", "meter", "meter-"+protocol+"-"+types.String(port)+"-new-connections-rate", "{ "+protocol+" saddr limit rate over "+types.String(newConnectionsRate)+"/minute burst "+types.String(newConnectionsRate+3)+" packets }" /**"add", "@deny_set", "{"+protocol+" saddr}",**/, "counter", "drop", "comment", this.encodeUserData([]string{"tcp", types.String(port), "newConnectionsRate", types.String(newConnectionsRate)}))
|
||||||
|
var stderr = &bytes.Buffer{}
|
||||||
|
cmd.Stderr = stderr
|
||||||
|
err := cmd.Run()
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("add nftables rule '" + cmd.String() + "' failed: " + err.Error() + " (" + stderr.String() + ")")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除TCP规则
|
||||||
|
func (this *DDoSProtectionManager) removeTCPRules() error {
|
||||||
|
for _, filter := range nftablesFilters {
|
||||||
|
chain, rules, err := this.getRules(filter)
|
||||||
|
|
||||||
|
// TCP
|
||||||
|
err = this.removeOldTCPRules(chain, rules)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 组合user data
|
||||||
|
// 数据中不能包含字母、数字、下划线以外的数据
|
||||||
|
func (this *DDoSProtectionManager) encodeUserData(attrs []string) string {
|
||||||
|
if attrs == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return "ZZ" + strings.Join(attrs, "_") + "ZZ"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解码user data
|
||||||
|
func (this *DDoSProtectionManager) decodeUserData(data []byte) []string {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var dataCopy = make([]byte, len(data))
|
||||||
|
copy(dataCopy, data)
|
||||||
|
|
||||||
|
var separatorLen = 2
|
||||||
|
var index1 = bytes.Index(dataCopy, []byte{'Z', 'Z'})
|
||||||
|
if index1 < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dataCopy = dataCopy[index1+separatorLen:]
|
||||||
|
var index2 = bytes.LastIndex(dataCopy, []byte{'Z', 'Z'})
|
||||||
|
if index2 < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var s = string(dataCopy[:index2])
|
||||||
|
var pieces = strings.Split(s, "_")
|
||||||
|
for index, piece := range pieces {
|
||||||
|
pieces[index] = strings.TrimSpace(piece)
|
||||||
|
}
|
||||||
|
return pieces
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清除规则
|
||||||
|
func (this *DDoSProtectionManager) removeOldTCPRules(chain *nftables.Chain, rules []*nftables.Rule) error {
|
||||||
|
for _, rule := range rules {
|
||||||
|
var pieces = this.decodeUserData(rule.UserData())
|
||||||
|
if len(pieces) != 4 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if pieces[0] != "tcp" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch pieces[2] {
|
||||||
|
case "maxConnections", "maxConnectionsPerIP", "newConnectionsRate":
|
||||||
|
err := chain.DeleteRule(rule)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据参数检查规则是否存在
|
||||||
|
func (this *DDoSProtectionManager) existsRule(rules []*nftables.Rule, attrs []string) (exists bool) {
|
||||||
|
if len(attrs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, oldRule := range rules {
|
||||||
|
var pieces = this.decodeUserData(oldRule.UserData())
|
||||||
|
if len(attrs) != len(pieces) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var isSame = true
|
||||||
|
for index, piece := range pieces {
|
||||||
|
if strings.TrimSpace(piece) != attrs[index] {
|
||||||
|
isSame = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isSame {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取规则中的端口号
|
||||||
|
func (this *DDoSProtectionManager) getTCPPorts(rules []*nftables.Rule) []int32 {
|
||||||
|
var ports = []int32{}
|
||||||
|
for _, rule := range rules {
|
||||||
|
var pieces = this.decodeUserData(rule.UserData())
|
||||||
|
if len(pieces) != 4 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if pieces[0] != "tcp" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var port = types.Int32(pieces[1])
|
||||||
|
if port > 0 && !lists.ContainsInt32(ports, port) {
|
||||||
|
ports = append(ports, port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ports
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查端口是否一样
|
||||||
|
func (this *DDoSProtectionManager) eqPorts(ports1 []int32, ports2 []int32) bool {
|
||||||
|
if len(ports1) != len(ports2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var portMap = map[int32]bool{}
|
||||||
|
for _, port := range ports2 {
|
||||||
|
portMap[port] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, port := range ports1 {
|
||||||
|
_, ok := portMap[port]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找Table
|
||||||
|
func (this *DDoSProtectionManager) getTable(filter *nftablesTableDefinition) (*nftables.Table, error) {
|
||||||
|
var family nftables.TableFamily
|
||||||
|
if filter.IsIPv4 {
|
||||||
|
family = nftables.TableFamilyIPv4
|
||||||
|
} else if filter.IsIPv6 {
|
||||||
|
family = nftables.TableFamilyIPv6
|
||||||
|
} else {
|
||||||
|
return nil, errors.New("table '" + filter.Name + "' should be IPv4 or IPv6")
|
||||||
|
}
|
||||||
|
return nftablesInstance.conn.GetTable(filter.Name, family)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找所有规则
|
||||||
|
func (this *DDoSProtectionManager) getRules(filter *nftablesTableDefinition) (*nftables.Chain, []*nftables.Rule, error) {
|
||||||
|
table, err := this.getTable(filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.New("get table failed: " + err.Error())
|
||||||
|
}
|
||||||
|
chain, err := table.GetChain(nftablesChainName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.New("get chain failed: " + err.Error())
|
||||||
|
}
|
||||||
|
rules, err := chain.GetRules()
|
||||||
|
return chain, rules, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新白名单
|
||||||
|
func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
|
||||||
|
if nftablesInstance == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var allMap = map[string]zero.Zero{}
|
||||||
|
for _, ip := range allIPList {
|
||||||
|
allMap[ip] = zero.New()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, set := range []*nftables.Set{nftablesInstance.allowIPv4Set, nftablesInstance.allowIPv6Set} {
|
||||||
|
var isIPv4 = set == nftablesInstance.allowIPv4Set
|
||||||
|
var isIPv6 = !isIPv4
|
||||||
|
|
||||||
|
// 现有的
|
||||||
|
oldList, err := set.GetIPElements()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var oldMap = map[string]zero.Zero{} // ip=> zero
|
||||||
|
for _, ip := range oldList {
|
||||||
|
oldMap[ip] = zero.New()
|
||||||
|
|
||||||
|
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
|
||||||
|
_, ok := allMap[ip]
|
||||||
|
if !ok {
|
||||||
|
// 不存在则删除
|
||||||
|
err = set.DeleteIPElement(ip)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("delete ip element '" + ip + "' failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 新增的
|
||||||
|
for _, ip := range allIPList {
|
||||||
|
var ipObj = net.ParseIP(ip)
|
||||||
|
if ipObj == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if (utils.IsIPv4(ip) && isIPv4) || (utils.IsIPv6(ip) && isIPv6) {
|
||||||
|
_, ok := oldMap[ip]
|
||||||
|
if !ok {
|
||||||
|
// 不存在则添加
|
||||||
|
err = set.AddIPElement(ip, nil)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("add ip '" + ip + "' failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
23
internal/firewalls/ddos_protection_others.go
Normal file
23
internal/firewalls/ddos_protection_others.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
//go:build !linux
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
|
package firewalls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||||
|
)
|
||||||
|
|
||||||
|
var SharedDDoSProtectionManager = NewDDoSProtectionManager()
|
||||||
|
|
||||||
|
type DDoSProtectionManager struct {
|
||||||
|
nftPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDDoSProtectionManager() *DDoSProtectionManager {
|
||||||
|
return &DDoSProtectionManager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *DDoSProtectionManager) Apply(config *ddosconfigs.ProtectionConfig) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,6 +1,4 @@
|
|||||||
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
//go:build !plus
|
|
||||||
// +build !plus
|
|
||||||
|
|
||||||
package firewalls
|
package firewalls
|
||||||
|
|
||||||
@@ -8,9 +6,11 @@ import (
|
|||||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
var currentFirewall FirewallInterface
|
var currentFirewall FirewallInterface
|
||||||
|
var firewallLocker = &sync.Mutex{}
|
||||||
|
|
||||||
// 初始化
|
// 初始化
|
||||||
func init() {
|
func init() {
|
||||||
@@ -24,10 +24,28 @@ func init() {
|
|||||||
|
|
||||||
// Firewall 查找当前系统中最适合的防火墙
|
// Firewall 查找当前系统中最适合的防火墙
|
||||||
func Firewall() FirewallInterface {
|
func Firewall() FirewallInterface {
|
||||||
|
firewallLocker.Lock()
|
||||||
|
defer firewallLocker.Unlock()
|
||||||
if currentFirewall != nil {
|
if currentFirewall != nil {
|
||||||
return currentFirewall
|
return currentFirewall
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nftables
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
nftables, err := NewNFTablesFirewall()
|
||||||
|
if err != nil {
|
||||||
|
remotelogs.Warn("FIREWALL", "'nftables' should be installed on the system to enhance security (init failed: "+err.Error()+")")
|
||||||
|
} else {
|
||||||
|
if nftables.IsReady() {
|
||||||
|
currentFirewall = nftables
|
||||||
|
events.Notify(events.EventNFTablesReady)
|
||||||
|
return nftables
|
||||||
|
} else {
|
||||||
|
remotelogs.Warn("FIREWALL", "'nftables' should be enabled on the system to enhance security")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// firewalld
|
// firewalld
|
||||||
if runtime.GOOS == "linux" {
|
if runtime.GOOS == "linux" {
|
||||||
var firewalld = NewFirewalld()
|
var firewalld = NewFirewalld()
|
||||||
|
|||||||
373
internal/firewalls/firewall_nftables.go
Normal file
373
internal/firewalls/firewall_nftables.go
Normal file
@@ -0,0 +1,373 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package firewalls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
"net"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// check nft status, if being enabled we load it automatically
|
||||||
|
func init() {
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
var ticker = time.NewTicker(3 * time.Minute)
|
||||||
|
go func() {
|
||||||
|
for range ticker.C {
|
||||||
|
// if already ready, we break
|
||||||
|
if nftablesIsReady {
|
||||||
|
ticker.Stop()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
_, err := exec.LookPath("nft")
|
||||||
|
if err == nil {
|
||||||
|
nftablesFirewall, err := NewNFTablesFirewall()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentFirewall = nftablesFirewall
|
||||||
|
remotelogs.Println("FIREWALL", "nftables is ready")
|
||||||
|
|
||||||
|
// fire event
|
||||||
|
if nftablesFirewall.IsReady() {
|
||||||
|
events.Notify(events.EventNFTablesReady)
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker.Stop()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var nftablesInstance *NFTablesFirewall
|
||||||
|
var nftablesIsReady = false
|
||||||
|
var nftablesFilters = []*nftablesTableDefinition{
|
||||||
|
// we shorten the name for table name length restriction
|
||||||
|
{Name: "edge_dft_v4", IsIPv4: true},
|
||||||
|
{Name: "edge_dft_v6", IsIPv6: true},
|
||||||
|
}
|
||||||
|
var nftablesChainName = "input"
|
||||||
|
|
||||||
|
type nftablesTableDefinition struct {
|
||||||
|
Name string
|
||||||
|
IsIPv4 bool
|
||||||
|
IsIPv6 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *nftablesTableDefinition) protocol() string {
|
||||||
|
if this.IsIPv6 {
|
||||||
|
return "ip6"
|
||||||
|
}
|
||||||
|
return "ip"
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
|
||||||
|
var firewall = &NFTablesFirewall{
|
||||||
|
conn: nftables.NewConn(),
|
||||||
|
}
|
||||||
|
err := firewall.init()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return firewall, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type NFTablesFirewall struct {
|
||||||
|
conn *nftables.Conn
|
||||||
|
isReady bool
|
||||||
|
|
||||||
|
allowIPv4Set *nftables.Set
|
||||||
|
allowIPv6Set *nftables.Set
|
||||||
|
|
||||||
|
denyIPv4Set *nftables.Set
|
||||||
|
denyIPv6Set *nftables.Set
|
||||||
|
|
||||||
|
firewalld *Firewalld
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *NFTablesFirewall) init() error {
|
||||||
|
// check nft
|
||||||
|
_, err := exec.LookPath("nft")
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("nft not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// table
|
||||||
|
for _, tableDef := range nftablesFilters {
|
||||||
|
var family nftables.TableFamily
|
||||||
|
if tableDef.IsIPv4 {
|
||||||
|
family = nftables.TableFamilyIPv4
|
||||||
|
} else if tableDef.IsIPv6 {
|
||||||
|
family = nftables.TableFamilyIPv6
|
||||||
|
} else {
|
||||||
|
return errors.New("invalid table family: " + types.String(tableDef))
|
||||||
|
}
|
||||||
|
table, err := this.conn.GetTable(tableDef.Name, family)
|
||||||
|
if err != nil {
|
||||||
|
if nftables.IsNotFound(err) {
|
||||||
|
if tableDef.IsIPv4 {
|
||||||
|
table, err = this.conn.AddIPv4Table(tableDef.Name)
|
||||||
|
} else if tableDef.IsIPv6 {
|
||||||
|
table, err = this.conn.AddIPv6Table(tableDef.Name)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("create table '" + tableDef.Name + "' failed: " + err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return errors.New("get table '" + tableDef.Name + "' failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if table == nil {
|
||||||
|
return errors.New("can not create table '" + tableDef.Name + "'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// chain
|
||||||
|
var chainName = nftablesChainName
|
||||||
|
chain, err := table.GetChain(chainName)
|
||||||
|
if err != nil {
|
||||||
|
if nftables.IsNotFound(err) {
|
||||||
|
chain, err = table.AddAcceptChain(chainName)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("create chain '" + chainName + "' failed: " + err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return errors.New("get chain '" + chainName + "' failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if chain == nil {
|
||||||
|
return errors.New("can not create chain '" + chainName + "'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow lo
|
||||||
|
var loRuleName = []byte("lo")
|
||||||
|
_, err = chain.GetRuleWithUserData(loRuleName)
|
||||||
|
if err != nil {
|
||||||
|
if nftables.IsNotFound(err) {
|
||||||
|
_, err = chain.AddAcceptInterfaceRule("lo", loRuleName)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("add 'lo' rule failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow set
|
||||||
|
// "allow" should be always first
|
||||||
|
for _, setAction := range []string{"allow", "deny"} {
|
||||||
|
var setName = setAction + "_set"
|
||||||
|
|
||||||
|
set, err := table.GetSet(setName)
|
||||||
|
if err != nil {
|
||||||
|
if nftables.IsNotFound(err) {
|
||||||
|
var keyType nftables.SetDataType
|
||||||
|
if tableDef.IsIPv4 {
|
||||||
|
keyType = nftables.TypeIPAddr
|
||||||
|
} else if tableDef.IsIPv6 {
|
||||||
|
keyType = nftables.TypeIP6Addr
|
||||||
|
}
|
||||||
|
set, err = table.AddSet(setName, &nftables.SetOptions{
|
||||||
|
KeyType: keyType,
|
||||||
|
HasTimeout: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("create set '" + setName + "' failed: " + err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return errors.New("get set '" + setName + "' failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if set == nil {
|
||||||
|
return errors.New("can not create set '" + setName + "'")
|
||||||
|
}
|
||||||
|
if tableDef.IsIPv4 {
|
||||||
|
if setAction == "allow" {
|
||||||
|
this.allowIPv4Set = set
|
||||||
|
} else {
|
||||||
|
this.denyIPv4Set = set
|
||||||
|
}
|
||||||
|
} else if tableDef.IsIPv6 {
|
||||||
|
if setAction == "allow" {
|
||||||
|
this.allowIPv6Set = set
|
||||||
|
} else {
|
||||||
|
this.denyIPv6Set = set
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// rule
|
||||||
|
var ruleName = []byte(setAction)
|
||||||
|
rule, err := chain.GetRuleWithUserData(ruleName)
|
||||||
|
if err != nil {
|
||||||
|
if nftables.IsNotFound(err) {
|
||||||
|
if tableDef.IsIPv4 {
|
||||||
|
if setAction == "allow" {
|
||||||
|
rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName)
|
||||||
|
} else {
|
||||||
|
rule, err = chain.AddDropIPv4SetRule(setName, ruleName)
|
||||||
|
}
|
||||||
|
} else if tableDef.IsIPv6 {
|
||||||
|
if setAction == "allow" {
|
||||||
|
rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName)
|
||||||
|
} else {
|
||||||
|
rule, err = chain.AddDropIPv6SetRule(setName, ruleName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("add rule failed: " + err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return errors.New("get rule failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if rule == nil {
|
||||||
|
return errors.New("can not create rule '" + string(ruleName) + "'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this.isReady = true
|
||||||
|
nftablesIsReady = true
|
||||||
|
nftablesInstance = this
|
||||||
|
|
||||||
|
// load firewalld
|
||||||
|
var firewalld = NewFirewalld()
|
||||||
|
if firewalld.IsReady() {
|
||||||
|
this.firewalld = firewalld
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name 名称
|
||||||
|
func (this *NFTablesFirewall) Name() string {
|
||||||
|
return "nftables"
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsReady 是否已准备被调用
|
||||||
|
func (this *NFTablesFirewall) IsReady() bool {
|
||||||
|
return this.isReady
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsMock 是否为模拟
|
||||||
|
func (this *NFTablesFirewall) IsMock() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowPort 允许端口
|
||||||
|
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
|
||||||
|
if this.firewalld != nil {
|
||||||
|
return this.firewalld.AllowPort(port, protocol)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePort 删除端口
|
||||||
|
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
|
||||||
|
if this.firewalld != nil {
|
||||||
|
return this.firewalld.RemovePort(port, protocol)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowSourceIP Allow把IP加入白名单
|
||||||
|
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
|
||||||
|
var data = net.ParseIP(ip)
|
||||||
|
if data == nil {
|
||||||
|
return errors.New("invalid ip '" + ip + "'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(ip, ":") { // ipv6
|
||||||
|
if this.allowIPv6Set == nil {
|
||||||
|
return errors.New("ipv6 ip set is nil")
|
||||||
|
}
|
||||||
|
return this.allowIPv6Set.AddElement(data.To16(), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipv4
|
||||||
|
if this.allowIPv4Set == nil {
|
||||||
|
return errors.New("ipv4 ip set is nil")
|
||||||
|
}
|
||||||
|
return this.allowIPv4Set.AddElement(data.To4(), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RejectSourceIP 拒绝某个源IP连接
|
||||||
|
// we did not create set for drop ip, so we reuse DropSourceIP() method here
|
||||||
|
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||||
|
return this.DropSourceIP(ip, timeoutSeconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DropSourceIP 丢弃某个源IP数据
|
||||||
|
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||||
|
var data = net.ParseIP(ip)
|
||||||
|
if data == nil {
|
||||||
|
return errors.New("invalid ip '" + ip + "'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(ip, ":") { // ipv6
|
||||||
|
if this.denyIPv6Set == nil {
|
||||||
|
return errors.New("ipv6 ip set is nil")
|
||||||
|
}
|
||||||
|
return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
|
||||||
|
Timeout: time.Duration(timeoutSeconds) * time.Second,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipv4
|
||||||
|
if this.denyIPv4Set == nil {
|
||||||
|
return errors.New("ipv4 ip set is nil")
|
||||||
|
}
|
||||||
|
return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
|
||||||
|
Timeout: time.Duration(timeoutSeconds) * time.Second,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveSourceIP 删除某个源IP
|
||||||
|
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
|
||||||
|
var data = net.ParseIP(ip)
|
||||||
|
if data == nil {
|
||||||
|
return errors.New("invalid ip '" + ip + "'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(ip, ":") { // ipv6
|
||||||
|
if this.denyIPv6Set != nil {
|
||||||
|
err := this.denyIPv6Set.DeleteElement(data.To16())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if this.allowIPv6Set != nil {
|
||||||
|
err := this.allowIPv6Set.DeleteElement(data.To16())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipv4
|
||||||
|
if this.allowIPv4Set != nil {
|
||||||
|
err := this.denyIPv4Set.DeleteElement(data.To4())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = this.allowIPv4Set.DeleteElement(data.To4())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
61
internal/firewalls/firewall_nftables_others.go
Normal file
61
internal/firewalls/firewall_nftables_others.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
//go:build !linux
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
|
package firewalls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
type NFTablesFirewall struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name 名称
|
||||||
|
func (this *NFTablesFirewall) Name() string {
|
||||||
|
return "nftables"
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsReady 是否已准备被调用
|
||||||
|
func (this *NFTablesFirewall) IsReady() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsMock 是否为模拟
|
||||||
|
func (this *NFTablesFirewall) IsMock() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowPort 允许端口
|
||||||
|
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePort 删除端口
|
||||||
|
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowSourceIP Allow把IP加入白名单
|
||||||
|
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RejectSourceIP 拒绝某个源IP连接
|
||||||
|
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DropSourceIP 丢弃某个源IP数据
|
||||||
|
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveSourceIP 删除某个源IP
|
||||||
|
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
1
internal/firewalls/nftables/.gitignore
vendored
Normal file
1
internal/firewalls/nftables/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
build_remote.sh
|
||||||
370
internal/firewalls/nftables/chain.go
Normal file
370
internal/firewalls/nftables/chain.go
Normal file
@@ -0,0 +1,370 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
nft "github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/expr"
|
||||||
|
)
|
||||||
|
|
||||||
|
const MaxChainNameLength = 31
|
||||||
|
|
||||||
|
type RuleOptions struct {
|
||||||
|
Exprs []expr.Any
|
||||||
|
UserData []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chain chain object in table
|
||||||
|
type Chain struct {
|
||||||
|
conn *Conn
|
||||||
|
rawTable *nft.Table
|
||||||
|
rawChain *nft.Chain
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChain(conn *Conn, rawTable *nft.Table, rawChain *nft.Chain) *Chain {
|
||||||
|
return &Chain{
|
||||||
|
conn: conn,
|
||||||
|
rawTable: rawTable,
|
||||||
|
rawChain: rawChain,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) Raw() *nft.Chain {
|
||||||
|
return this.rawChain
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) Name() string {
|
||||||
|
return this.rawChain.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddRule(options *RuleOptions) (*Rule, error) {
|
||||||
|
var rawRule = this.conn.Raw().AddRule(&nft.Rule{
|
||||||
|
Table: this.rawTable,
|
||||||
|
Chain: this.rawChain,
|
||||||
|
Exprs: options.Exprs,
|
||||||
|
UserData: options.UserData,
|
||||||
|
})
|
||||||
|
err := this.conn.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewRule(rawRule), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddAcceptIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||||
|
return this.AddRule(&RuleOptions{
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 12,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ip,
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddAcceptIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||||
|
return this.AddRule(&RuleOptions{
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 8,
|
||||||
|
Len: 16,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ip,
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddDropIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||||
|
return this.AddRule(&RuleOptions{
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 12,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ip,
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictDrop,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddDropIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||||
|
return this.AddRule(&RuleOptions{
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 8,
|
||||||
|
Len: 16,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ip,
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictDrop,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddRejectIPv4Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||||
|
return this.AddRule(&RuleOptions{
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 12,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ip,
|
||||||
|
},
|
||||||
|
&expr.Reject{},
|
||||||
|
},
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddRejectIPv6Rule(ip []byte, userData []byte) (*Rule, error) {
|
||||||
|
return this.AddRule(&RuleOptions{
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 8,
|
||||||
|
Len: 16,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ip,
|
||||||
|
},
|
||||||
|
&expr.Reject{},
|
||||||
|
},
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddAcceptIPv4SetRule(setName string, userData []byte) (*Rule, error) {
|
||||||
|
return this.AddRule(&RuleOptions{
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 12,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: setName,
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddAcceptIPv6SetRule(setName string, userData []byte) (*Rule, error) {
|
||||||
|
return this.AddRule(&RuleOptions{
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 8,
|
||||||
|
Len: 16,
|
||||||
|
},
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: setName,
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddDropIPv4SetRule(setName string, userData []byte) (*Rule, error) {
|
||||||
|
return this.AddRule(&RuleOptions{
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 12,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: setName,
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictDrop,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddDropIPv6SetRule(setName string, userData []byte) (*Rule, error) {
|
||||||
|
return this.AddRule(&RuleOptions{
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 8,
|
||||||
|
Len: 16,
|
||||||
|
},
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: setName,
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictDrop,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddRejectIPv4SetRule(setName string, userData []byte) (*Rule, error) {
|
||||||
|
return this.AddRule(&RuleOptions{
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 12,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: setName,
|
||||||
|
},
|
||||||
|
&expr.Reject{},
|
||||||
|
},
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddRejectIPv6SetRule(setName string, userData []byte) (*Rule, error) {
|
||||||
|
return this.AddRule(&RuleOptions{
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 8,
|
||||||
|
Len: 16,
|
||||||
|
},
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: setName,
|
||||||
|
},
|
||||||
|
&expr.Reject{},
|
||||||
|
},
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) AddAcceptInterfaceRule(interfaceName string, userData []byte) (*Rule, error) {
|
||||||
|
if len(interfaceName) >= 16 {
|
||||||
|
return nil, errors.New("invalid interface name '" + interfaceName + "'")
|
||||||
|
}
|
||||||
|
var ifname = make([]byte, 16)
|
||||||
|
copy(ifname, interfaceName+"\x00")
|
||||||
|
|
||||||
|
return this.AddRule(&RuleOptions{
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname,
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) GetRuleWithUserData(userData []byte) (*Rule, error) {
|
||||||
|
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, rawRule := range rawRules {
|
||||||
|
if bytes.Compare(rawRule.UserData, userData) == 0 {
|
||||||
|
return NewRule(rawRule), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ErrRuleNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) GetRules() ([]*Rule, error) {
|
||||||
|
rawRules, err := this.conn.Raw().GetRule(this.rawTable, this.rawChain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var result = []*Rule{}
|
||||||
|
for _, rawRule := range rawRules {
|
||||||
|
result = append(result, NewRule(rawRule))
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) DeleteRule(rule *Rule) error {
|
||||||
|
err := this.conn.Raw().DelRule(rule.Raw())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return this.conn.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Chain) Flush() error {
|
||||||
|
this.conn.Raw().FlushChain(this.rawChain)
|
||||||
|
return this.conn.Commit()
|
||||||
|
}
|
||||||
13
internal/firewalls/nftables/chain_policy.go
Normal file
13
internal/firewalls/nftables/chain_policy.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import nft "github.com/google/nftables"
|
||||||
|
|
||||||
|
type ChainPolicy = nft.ChainPolicy
|
||||||
|
|
||||||
|
// Possible ChainPolicy values.
|
||||||
|
const (
|
||||||
|
ChainPolicyDrop = nft.ChainPolicyDrop
|
||||||
|
ChainPolicyAccept = nft.ChainPolicyAccept
|
||||||
|
)
|
||||||
130
internal/firewalls/nftables/chain_test.go
Normal file
130
internal/firewalls/nftables/chain_test.go
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package nftables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getIPv4Chain(t *testing.T) *nftables.Chain {
|
||||||
|
var conn = nftables.NewConn()
|
||||||
|
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
if err == nftables.ErrTableNotFound {
|
||||||
|
table, err = conn.AddIPv4Table("test_ipv4")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
chain, err := table.GetChain("test_chain")
|
||||||
|
if err != nil {
|
||||||
|
if err == nftables.ErrChainNotFound {
|
||||||
|
chain, err = table.AddAcceptChain("test_chain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return chain
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_AddAcceptIPRule(t *testing.T) {
|
||||||
|
var chain = getIPv4Chain(t)
|
||||||
|
_, err := chain.AddAcceptIPv4Rule(net.ParseIP("192.168.2.40").To4(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_AddDropIPRule(t *testing.T) {
|
||||||
|
var chain = getIPv4Chain(t)
|
||||||
|
_, err := chain.AddDropIPv4Rule(net.ParseIP("192.168.2.31").To4(), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_AddAcceptSetRule(t *testing.T) {
|
||||||
|
var chain = getIPv4Chain(t)
|
||||||
|
_, err := chain.AddAcceptIPv4SetRule("ipv4_black_set", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_AddDropSetRule(t *testing.T) {
|
||||||
|
var chain = getIPv4Chain(t)
|
||||||
|
_, err := chain.AddDropIPv4SetRule("ipv4_black_set", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_AddRejectSetRule(t *testing.T) {
|
||||||
|
var chain = getIPv4Chain(t)
|
||||||
|
_, err := chain.AddRejectIPv4SetRule("ipv4_black_set", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_GetRuleWithUserData(t *testing.T) {
|
||||||
|
var chain = getIPv4Chain(t)
|
||||||
|
rule, err := chain.GetRuleWithUserData([]byte("test"))
|
||||||
|
if err != nil {
|
||||||
|
if err == nftables.ErrRuleNotFound {
|
||||||
|
t.Log("rule not found")
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Log("rule:", rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_GetRules(t *testing.T) {
|
||||||
|
var chain = getIPv4Chain(t)
|
||||||
|
rules, err := chain.GetRules()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, rule := range rules {
|
||||||
|
t.Log("handle:", rule.Handle(), "set name:", rule.LookupSetName(),
|
||||||
|
"verdict:", rule.VerDict(), "user data:", string(rule.UserData()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_DeleteRule(t *testing.T) {
|
||||||
|
var chain = getIPv4Chain(t)
|
||||||
|
rule, err := chain.GetRuleWithUserData([]byte("test"))
|
||||||
|
if err != nil {
|
||||||
|
if err == nftables.ErrRuleNotFound {
|
||||||
|
t.Log("rule not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
err = chain.DeleteRule(rule)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain_Flush(t *testing.T) {
|
||||||
|
var chain = getIPv4Chain(t)
|
||||||
|
err := chain.Flush()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("ok")
|
||||||
|
}
|
||||||
84
internal/firewalls/nftables/conn.go
Normal file
84
internal/firewalls/nftables/conn.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
nft "github.com/google/nftables"
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const MaxTableNameLength = 27
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
rawConn *nft.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConn() *Conn {
|
||||||
|
return &Conn{
|
||||||
|
rawConn: &nft.Conn{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Conn) Raw() *nft.Conn {
|
||||||
|
return this.rawConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Conn) GetTable(name string, family TableFamily) (*Table, error) {
|
||||||
|
rawTables, err := this.rawConn.ListTables()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rawTable := range rawTables {
|
||||||
|
if rawTable.Name == name && rawTable.Family == family {
|
||||||
|
return NewTable(this, rawTable), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ErrTableNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Conn) AddTable(name string, family TableFamily) (*Table, error) {
|
||||||
|
if len(name) > MaxTableNameLength {
|
||||||
|
return nil, errors.New("table name too long (max " + types.String(MaxTableNameLength) + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
var rawTable = this.rawConn.AddTable(&nft.Table{
|
||||||
|
Family: family,
|
||||||
|
Name: name,
|
||||||
|
})
|
||||||
|
|
||||||
|
err := this.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewTable(this, rawTable), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Conn) AddIPv4Table(name string) (*Table, error) {
|
||||||
|
return this.AddTable(name, TableFamilyIPv4)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Conn) AddIPv6Table(name string) (*Table, error) {
|
||||||
|
return this.AddTable(name, TableFamilyIPv6)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Conn) DeleteTable(name string, family TableFamily) error {
|
||||||
|
table, err := this.GetTable(name, family)
|
||||||
|
if err != nil {
|
||||||
|
if err == ErrTableNotFound {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
this.rawConn.DelTable(table.Raw())
|
||||||
|
return this.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Conn) Commit() error {
|
||||||
|
return this.rawConn.Flush()
|
||||||
|
}
|
||||||
78
internal/firewalls/nftables/conn_test.go
Normal file
78
internal/firewalls/nftables/conn_test.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package nftables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||||
|
"os/exec"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConn_Test(t *testing.T) {
|
||||||
|
_, err := exec.LookPath("nft")
|
||||||
|
if err == nil {
|
||||||
|
t.Log("ok")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConn_GetTable_NotFound(t *testing.T) {
|
||||||
|
var conn = nftables.NewConn()
|
||||||
|
|
||||||
|
table, err := conn.GetTable("a", nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
if err == nftables.ErrTableNotFound {
|
||||||
|
t.Log("table not found")
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Log("table:", table)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConn_GetTable(t *testing.T) {
|
||||||
|
var conn = nftables.NewConn()
|
||||||
|
|
||||||
|
table, err := conn.GetTable("myFilter", nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
if err == nftables.ErrTableNotFound {
|
||||||
|
t.Log("table not found")
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Log("table:", table)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConn_AddTable(t *testing.T) {
|
||||||
|
var conn = nftables.NewConn()
|
||||||
|
|
||||||
|
{
|
||||||
|
table, err := conn.AddIPv4Table("test_ipv4")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(table.Name())
|
||||||
|
}
|
||||||
|
{
|
||||||
|
table, err := conn.AddIPv6Table("test_ipv6")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(table.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConn_DeleteTable(t *testing.T) {
|
||||||
|
var conn = nftables.NewConn()
|
||||||
|
err := conn.DeleteTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("ok")
|
||||||
|
}
|
||||||
8
internal/firewalls/nftables/element.go
Normal file
8
internal/firewalls/nftables/element.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
type Element struct {
|
||||||
|
}
|
||||||
19
internal/firewalls/nftables/errors.go
Normal file
19
internal/firewalls/nftables/errors.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var ErrTableNotFound = errors.New("table not found")
|
||||||
|
var ErrChainNotFound = errors.New("chain not found")
|
||||||
|
var ErrSetNotFound = errors.New("set not found")
|
||||||
|
var ErrRuleNotFound = errors.New("rule not found")
|
||||||
|
|
||||||
|
func IsNotFound(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return err == ErrTableNotFound || err == ErrChainNotFound || err == ErrSetNotFound || err == ErrRuleNotFound
|
||||||
|
}
|
||||||
18
internal/firewalls/nftables/family.go
Normal file
18
internal/firewalls/nftables/family.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
nft "github.com/google/nftables"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TableFamily = nft.TableFamily
|
||||||
|
|
||||||
|
const (
|
||||||
|
TableFamilyINet TableFamily = nft.TableFamilyINet
|
||||||
|
TableFamilyIPv4 TableFamily = nft.TableFamilyIPv4
|
||||||
|
TableFamilyIPv6 TableFamily = nft.TableFamilyIPv6
|
||||||
|
TableFamilyARP TableFamily = nft.TableFamilyARP
|
||||||
|
TableFamilyNetdev TableFamily = nft.TableFamilyNetdev
|
||||||
|
TableFamilyBridge TableFamily = nft.TableFamilyBridge
|
||||||
|
)
|
||||||
51
internal/firewalls/nftables/rule.go
Normal file
51
internal/firewalls/nftables/rule.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
nft "github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/expr"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Rule struct {
|
||||||
|
rawRule *nft.Rule
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRule(rawRule *nft.Rule) *Rule {
|
||||||
|
return &Rule{
|
||||||
|
rawRule: rawRule,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) Raw() *nft.Rule {
|
||||||
|
return this.rawRule
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) LookupSetName() string {
|
||||||
|
for _, e := range this.rawRule.Exprs {
|
||||||
|
exp, ok := e.(*expr.Lookup)
|
||||||
|
if ok {
|
||||||
|
return exp.SetName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) VerDict() expr.VerdictKind {
|
||||||
|
for _, e := range this.rawRule.Exprs {
|
||||||
|
exp, ok := e.(*expr.Verdict)
|
||||||
|
if ok {
|
||||||
|
return exp.Kind
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return -100
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) Handle() uint64 {
|
||||||
|
return this.rawRule.Handle
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Rule) UserData() []byte {
|
||||||
|
return this.rawRule.UserData
|
||||||
|
}
|
||||||
161
internal/firewalls/nftables/set.go
Normal file
161
internal/firewalls/nftables/set.go
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||||
|
nft "github.com/google/nftables"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const MaxSetNameLength = 15
|
||||||
|
|
||||||
|
type SetOptions struct {
|
||||||
|
Id uint32
|
||||||
|
HasTimeout bool
|
||||||
|
Timeout time.Duration
|
||||||
|
KeyType SetDataType
|
||||||
|
DataType SetDataType
|
||||||
|
Constant bool
|
||||||
|
Interval bool
|
||||||
|
Anonymous bool
|
||||||
|
IsMap bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type ElementOptions struct {
|
||||||
|
Timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type Set struct {
|
||||||
|
conn *Conn
|
||||||
|
rawSet *nft.Set
|
||||||
|
batch *SetBatch
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSet(conn *Conn, rawSet *nft.Set) *Set {
|
||||||
|
return &Set{
|
||||||
|
conn: conn,
|
||||||
|
rawSet: rawSet,
|
||||||
|
batch: &SetBatch{
|
||||||
|
conn: conn,
|
||||||
|
rawSet: rawSet,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Set) Raw() *nft.Set {
|
||||||
|
return this.rawSet
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Set) Name() string {
|
||||||
|
return this.rawSet.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Set) AddElement(key []byte, options *ElementOptions) error {
|
||||||
|
var rawElement = nft.SetElement{
|
||||||
|
Key: key,
|
||||||
|
}
|
||||||
|
if options != nil {
|
||||||
|
rawElement.Timeout = options.Timeout
|
||||||
|
}
|
||||||
|
err := this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
|
||||||
|
rawElement,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = this.conn.Commit()
|
||||||
|
if err != nil {
|
||||||
|
// retry if exists
|
||||||
|
if strings.Contains(err.Error(), "file exists") {
|
||||||
|
deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||||
|
{
|
||||||
|
Key: key,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if deleteErr == nil {
|
||||||
|
err = this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
|
||||||
|
rawElement,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
err = this.conn.Commit()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Set) AddIPElement(ip string, options *ElementOptions) error {
|
||||||
|
var ipObj = net.ParseIP(ip)
|
||||||
|
if ipObj == nil {
|
||||||
|
return errors.New("invalid ip '" + ip + "'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if utils.IsIPv4(ip) {
|
||||||
|
return this.AddElement(ipObj.To4(), options)
|
||||||
|
} else {
|
||||||
|
return this.AddElement(ipObj.To16(), options)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Set) DeleteElement(key []byte) error {
|
||||||
|
err := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||||
|
{
|
||||||
|
Key: key,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = this.conn.Commit()
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "no such file or directory") {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Set) DeleteIPElement(ip string) error {
|
||||||
|
var ipObj = net.ParseIP(ip)
|
||||||
|
if ipObj == nil {
|
||||||
|
return errors.New("invalid ip '" + ip + "'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if utils.IsIPv4(ip) {
|
||||||
|
return this.DeleteElement(ipObj.To4())
|
||||||
|
} else {
|
||||||
|
return this.DeleteElement(ipObj.To16())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Set) Batch() *SetBatch {
|
||||||
|
return this.batch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Set) GetIPElements() ([]string, error) {
|
||||||
|
elements, err := this.conn.Raw().GetSetElements(this.rawSet)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var result = []string{}
|
||||||
|
for _, element := range elements {
|
||||||
|
result = append(result, net.IP(element.Key).String())
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// not work current time
|
||||||
|
/**func (this *Set) Flush() error {
|
||||||
|
this.conn.Raw().FlushSet(this.rawSet)
|
||||||
|
return this.conn.Commit()
|
||||||
|
}**/
|
||||||
36
internal/firewalls/nftables/set_batch.go
Normal file
36
internal/firewalls/nftables/set_batch.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
nft "github.com/google/nftables"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SetBatch struct {
|
||||||
|
conn *Conn
|
||||||
|
rawSet *nft.Set
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *SetBatch) AddElement(key []byte, options *ElementOptions) error {
|
||||||
|
var rawElement = nft.SetElement{
|
||||||
|
Key: key,
|
||||||
|
}
|
||||||
|
if options != nil {
|
||||||
|
rawElement.Timeout = options.Timeout
|
||||||
|
}
|
||||||
|
return this.conn.Raw().SetAddElements(this.rawSet, []nft.SetElement{
|
||||||
|
rawElement,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *SetBatch) DeleteElement(key []byte) error {
|
||||||
|
return this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
|
||||||
|
{
|
||||||
|
Key: key,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *SetBatch) Commit() error {
|
||||||
|
return this.conn.Commit()
|
||||||
|
}
|
||||||
57
internal/firewalls/nftables/set_data_type.go
Normal file
57
internal/firewalls/nftables/set_data_type.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import nft "github.com/google/nftables"
|
||||||
|
|
||||||
|
type SetDataType = nft.SetDatatype
|
||||||
|
|
||||||
|
var (
|
||||||
|
TypeInvalid = nft.TypeInvalid
|
||||||
|
TypeVerdict = nft.TypeVerdict
|
||||||
|
TypeNFProto = nft.TypeNFProto
|
||||||
|
TypeBitmask = nft.TypeBitmask
|
||||||
|
TypeInteger = nft.TypeInteger
|
||||||
|
TypeString = nft.TypeString
|
||||||
|
TypeLLAddr = nft.TypeLLAddr
|
||||||
|
TypeIPAddr = nft.TypeIPAddr
|
||||||
|
TypeIP6Addr = nft.TypeIP6Addr
|
||||||
|
TypeEtherAddr = nft.TypeEtherAddr
|
||||||
|
TypeEtherType = nft.TypeEtherType
|
||||||
|
TypeARPOp = nft.TypeARPOp
|
||||||
|
TypeInetProto = nft.TypeInetProto
|
||||||
|
TypeInetService = nft.TypeInetService
|
||||||
|
TypeICMPType = nft.TypeICMPType
|
||||||
|
TypeTCPFlag = nft.TypeTCPFlag
|
||||||
|
TypeDCCPPktType = nft.TypeDCCPPktType
|
||||||
|
TypeMHType = nft.TypeMHType
|
||||||
|
TypeTime = nft.TypeTime
|
||||||
|
TypeMark = nft.TypeMark
|
||||||
|
TypeIFIndex = nft.TypeIFIndex
|
||||||
|
TypeARPHRD = nft.TypeARPHRD
|
||||||
|
TypeRealm = nft.TypeRealm
|
||||||
|
TypeClassID = nft.TypeClassID
|
||||||
|
TypeUID = nft.TypeUID
|
||||||
|
TypeGID = nft.TypeGID
|
||||||
|
TypeCTState = nft.TypeCTState
|
||||||
|
TypeCTDir = nft.TypeCTDir
|
||||||
|
TypeCTStatus = nft.TypeCTStatus
|
||||||
|
TypeICMP6Type = nft.TypeICMP6Type
|
||||||
|
TypeCTLabel = nft.TypeCTLabel
|
||||||
|
TypePktType = nft.TypePktType
|
||||||
|
TypeICMPCode = nft.TypeICMPCode
|
||||||
|
TypeICMPV6Code = nft.TypeICMPV6Code
|
||||||
|
TypeICMPXCode = nft.TypeICMPXCode
|
||||||
|
TypeDevGroup = nft.TypeDevGroup
|
||||||
|
TypeDSCP = nft.TypeDSCP
|
||||||
|
TypeECN = nft.TypeECN
|
||||||
|
TypeFIBAddr = nft.TypeFIBAddr
|
||||||
|
TypeBoolean = nft.TypeBoolean
|
||||||
|
TypeCTEventBit = nft.TypeCTEventBit
|
||||||
|
TypeIFName = nft.TypeIFName
|
||||||
|
TypeIGMPType = nft.TypeIGMPType
|
||||||
|
TypeTimeDate = nft.TypeTimeDate
|
||||||
|
TypeTimeHour = nft.TypeTimeHour
|
||||||
|
TypeTimeDay = nft.TypeTimeDay
|
||||||
|
TypeCGroupV2 = nft.TypeCGroupV2
|
||||||
|
)
|
||||||
110
internal/firewalls/nftables/set_test.go
Normal file
110
internal/firewalls/nftables/set_test.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
|
||||||
|
package nftables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
"github.com/mdlayher/netlink"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getIPv4Set(t *testing.T) *nftables.Set {
|
||||||
|
var table = getIPv4Table(t)
|
||||||
|
set, err := table.GetSet("test_ipv4_set")
|
||||||
|
if err != nil {
|
||||||
|
if err == nftables.ErrSetNotFound {
|
||||||
|
set, err = table.AddSet("test_ipv4_set", &nftables.SetOptions{
|
||||||
|
KeyType: nftables.TypeIPAddr,
|
||||||
|
HasTimeout: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return set
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSet_AddElement(t *testing.T) {
|
||||||
|
var set = getIPv4Set(t)
|
||||||
|
err := set.AddElement(net.ParseIP("192.168.2.31").To4(), &nftables.ElementOptions{Timeout: 86400 * time.Second})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSet_DeleteElement(t *testing.T) {
|
||||||
|
var set = getIPv4Set(t)
|
||||||
|
err := set.DeleteElement(net.ParseIP("192.168.2.31").To4())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSet_Batch(t *testing.T) {
|
||||||
|
var batch = getIPv4Set(t).Batch()
|
||||||
|
|
||||||
|
for _, ip := range []string{"192.168.2.30", "192.168.2.31", "192.168.2.32", "192.168.2.33", "192.168.2.34"} {
|
||||||
|
var ipData = net.ParseIP(ip).To4()
|
||||||
|
//err := batch.DeleteElement(ipData)
|
||||||
|
//if err != nil {
|
||||||
|
// t.Fatal(err)
|
||||||
|
//}
|
||||||
|
err := batch.AddElement(ipData, &nftables.ElementOptions{Timeout: 10 * time.Second})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := batch.Commit()
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("%#v", errors.Unwrap(err).(*netlink.OpError))
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSet_Add_Many(t *testing.T) {
|
||||||
|
var set = getIPv4Set(t)
|
||||||
|
|
||||||
|
for i := 0; i < 255; i++ {
|
||||||
|
t.Log(i)
|
||||||
|
for j := 0; j < 255; j++ {
|
||||||
|
var ip = "192.167." + types.String(i) + "." + types.String(j)
|
||||||
|
var ipData = net.ParseIP(ip).To4()
|
||||||
|
err := set.Batch().AddElement(ipData, &nftables.ElementOptions{Timeout: 3600 * time.Second})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if j%10 == 0 {
|
||||||
|
err = set.Batch().Commit()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := set.Batch().Commit()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Log("ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**func TestSet_Flush(t *testing.T) {
|
||||||
|
var set = getIPv4Set(t)
|
||||||
|
err := set.Flush()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("ok")
|
||||||
|
}**/
|
||||||
157
internal/firewalls/nftables/table.go
Normal file
157
internal/firewalls/nftables/table.go
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
nft "github.com/google/nftables"
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Table struct {
|
||||||
|
conn *Conn
|
||||||
|
rawTable *nft.Table
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTable(conn *Conn, rawTable *nft.Table) *Table {
|
||||||
|
return &Table{
|
||||||
|
conn: conn,
|
||||||
|
rawTable: rawTable,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Table) Raw() *nft.Table {
|
||||||
|
return this.rawTable
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Table) Name() string {
|
||||||
|
return this.rawTable.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Table) Family() TableFamily {
|
||||||
|
return this.rawTable.Family
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Table) GetChain(name string) (*Chain, error) {
|
||||||
|
rawChains, err := this.conn.Raw().ListChains()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, rawChain := range rawChains {
|
||||||
|
// must compare table name
|
||||||
|
if rawChain.Name == name && rawChain.Table.Name == this.rawTable.Name {
|
||||||
|
return NewChain(this.conn, this.rawTable, rawChain), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ErrChainNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Table) AddChain(name string, chainPolicy *ChainPolicy) (*Chain, error) {
|
||||||
|
if len(name) > MaxChainNameLength {
|
||||||
|
return nil, errors.New("chain name too long (max " + types.String(MaxChainNameLength) + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
var rawChain = this.conn.Raw().AddChain(&nft.Chain{
|
||||||
|
Name: name,
|
||||||
|
Table: this.rawTable,
|
||||||
|
Hooknum: nft.ChainHookInput,
|
||||||
|
Priority: nft.ChainPriorityFilter,
|
||||||
|
Type: nft.ChainTypeFilter,
|
||||||
|
Policy: chainPolicy,
|
||||||
|
})
|
||||||
|
|
||||||
|
err := this.conn.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewChain(this.conn, this.rawTable, rawChain), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Table) AddAcceptChain(name string) (*Chain, error) {
|
||||||
|
var policy = ChainPolicyAccept
|
||||||
|
return this.AddChain(name, &policy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Table) AddDropChain(name string) (*Chain, error) {
|
||||||
|
var policy = ChainPolicyDrop
|
||||||
|
return this.AddChain(name, &policy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Table) DeleteChain(name string) error {
|
||||||
|
chain, err := this.GetChain(name)
|
||||||
|
if err != nil {
|
||||||
|
if err == ErrChainNotFound {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
this.conn.Raw().DelChain(chain.Raw())
|
||||||
|
return this.conn.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Table) GetSet(name string) (*Set, error) {
|
||||||
|
rawSet, err := this.conn.Raw().GetSetByName(this.rawTable, name)
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "no such file or directory") {
|
||||||
|
return nil, ErrSetNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewSet(this.conn, rawSet), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Table) AddSet(name string, options *SetOptions) (*Set, error) {
|
||||||
|
if len(name) > MaxSetNameLength {
|
||||||
|
return nil, errors.New("set name too long (max " + types.String(MaxSetNameLength) + ")")
|
||||||
|
}
|
||||||
|
|
||||||
|
if options == nil {
|
||||||
|
options = &SetOptions{}
|
||||||
|
}
|
||||||
|
var rawSet = &nft.Set{
|
||||||
|
Table: this.rawTable,
|
||||||
|
ID: options.Id,
|
||||||
|
Name: name,
|
||||||
|
Anonymous: options.Anonymous,
|
||||||
|
Constant: options.Constant,
|
||||||
|
Interval: options.Interval,
|
||||||
|
IsMap: options.IsMap,
|
||||||
|
HasTimeout: options.HasTimeout,
|
||||||
|
Timeout: options.Timeout,
|
||||||
|
KeyType: options.KeyType,
|
||||||
|
DataType: options.DataType,
|
||||||
|
}
|
||||||
|
err := this.conn.Raw().AddSet(rawSet, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = this.conn.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewSet(this.conn, rawSet), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Table) DeleteSet(name string) error {
|
||||||
|
set, err := this.GetSet(name)
|
||||||
|
if err != nil {
|
||||||
|
if err == ErrSetNotFound {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
this.conn.Raw().DelSet(set.Raw())
|
||||||
|
return this.conn.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Table) Flush() error {
|
||||||
|
this.conn.Raw().FlushTable(this.rawTable)
|
||||||
|
return this.conn.Commit()
|
||||||
|
}
|
||||||
140
internal/firewalls/nftables/table_test.go
Normal file
140
internal/firewalls/nftables/table_test.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package nftables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getIPv4Table(t *testing.T) *nftables.Table {
|
||||||
|
var conn = nftables.NewConn()
|
||||||
|
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
if err == nftables.ErrTableNotFound {
|
||||||
|
table, err = conn.AddIPv4Table("test_ipv4")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return table
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTable_AddChain(t *testing.T) {
|
||||||
|
var table = getIPv4Table(t)
|
||||||
|
|
||||||
|
{
|
||||||
|
chain, err := table.AddChain("test_default_chain", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("created:", chain.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
chain, err := table.AddAcceptChain("test_accept_chain")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("created:", chain.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do not test drop chain before adding accept rule, you will drop yourself!!!!!!!
|
||||||
|
/**{
|
||||||
|
chain, err := table.AddDropChain("test_drop_chain")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("created:", chain.Name())
|
||||||
|
}**/
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTable_GetChain(t *testing.T) {
|
||||||
|
var table = getIPv4Table(t)
|
||||||
|
for _, chainName := range []string{"not_found_chain", "test_default_chain"} {
|
||||||
|
chain, err := table.GetChain(chainName)
|
||||||
|
if err != nil {
|
||||||
|
if err == nftables.ErrChainNotFound {
|
||||||
|
t.Log(chainName, ":", "not found")
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Log(chainName, ":", chain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTable_DeleteChain(t *testing.T) {
|
||||||
|
var table = getIPv4Table(t)
|
||||||
|
err := table.DeleteChain("test_default_chain")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTable_AddSet(t *testing.T) {
|
||||||
|
var table = getIPv4Table(t)
|
||||||
|
{
|
||||||
|
set, err := table.AddSet("ipv4_black_set", &nftables.SetOptions{
|
||||||
|
HasTimeout: false,
|
||||||
|
KeyType: nftables.TypeIPAddr,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(set.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
set, err := table.AddSet("ipv6_black_set", &nftables.SetOptions{
|
||||||
|
HasTimeout: true,
|
||||||
|
//Timeout: 3600 * time.Second,
|
||||||
|
KeyType: nftables.TypeIP6Addr,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log(set.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTable_GetSet(t *testing.T) {
|
||||||
|
var table = getIPv4Table(t)
|
||||||
|
for _, setName := range []string{"not_found_set", "ipv4_black_set"} {
|
||||||
|
set, err := table.GetSet(setName)
|
||||||
|
if err != nil {
|
||||||
|
if err == nftables.ErrSetNotFound {
|
||||||
|
t.Log(setName, ": not found")
|
||||||
|
} else {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Log(setName, ":", set)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTable_DeleteSet(t *testing.T) {
|
||||||
|
var table = getIPv4Table(t)
|
||||||
|
err := table.DeleteSet("ipv4_black_set")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTable_Flush(t *testing.T) {
|
||||||
|
var table = getIPv4Table(t)
|
||||||
|
err := table.Flush()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Log("ok")
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package nodes
|
package nodes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -14,16 +15,20 @@ import (
|
|||||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/errors"
|
"github.com/TeaOSLab/EdgeNode/internal/errors"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/events"
|
"github.com/TeaOSLab/EdgeNode/internal/events"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||||
"github.com/iwind/TeaGo/Tea"
|
"github.com/iwind/TeaGo/Tea"
|
||||||
|
"github.com/iwind/TeaGo/maps"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -119,6 +124,8 @@ func (this *APIStream) loop() error {
|
|||||||
err = this.handleNewNodeTask(message)
|
err = this.handleNewNodeTask(message)
|
||||||
case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务
|
case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务
|
||||||
err = this.handleCheckSystemdService(message)
|
err = this.handleCheckSystemdService(message)
|
||||||
|
case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
|
||||||
|
err = this.handleCheckLocalFirewall(message)
|
||||||
case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址
|
case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址
|
||||||
err = this.handleChangeAPINode(message)
|
err = this.handleChangeAPINode(message)
|
||||||
default:
|
default:
|
||||||
@@ -569,7 +576,7 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := utils.NewCommandExecutor()
|
var cmd = utils.NewCommandExecutor()
|
||||||
shortName := teaconst.SystemdServiceName
|
shortName := teaconst.SystemdServiceName
|
||||||
cmd.Add(systemctl, "is-enabled", shortName)
|
cmd.Add(systemctl, "is-enabled", shortName)
|
||||||
output, err := cmd.Run()
|
output, err := cmd.Run()
|
||||||
@@ -585,6 +592,63 @@ func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查本地防火墙
|
||||||
|
func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) error {
|
||||||
|
var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
|
||||||
|
err := json.Unmarshal(message.DataJSON, dataMessage)
|
||||||
|
if err != nil {
|
||||||
|
this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nft
|
||||||
|
if dataMessage.Name == "nftables" {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
this.replyFail(message.RequestId, "not Linux system")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nft, err := exec.LookPath("nft")
|
||||||
|
if err != nil {
|
||||||
|
this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd = exec.Command(nft, "--version")
|
||||||
|
var output = &bytes.Buffer{}
|
||||||
|
cmd.Stdout = output
|
||||||
|
err = cmd.Run()
|
||||||
|
if err != nil {
|
||||||
|
this.replyFail(message.RequestId, "get version failed: "+err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputString = output.String()
|
||||||
|
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
|
||||||
|
if len(versionMatches) <= 1 {
|
||||||
|
this.replyFail(message.RequestId, "can not get nft version")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var version = versionMatches[1]
|
||||||
|
|
||||||
|
var result = maps.Map{
|
||||||
|
"version": version,
|
||||||
|
}
|
||||||
|
|
||||||
|
var protectionConfig = sharedNodeConfig.DDOSProtection
|
||||||
|
err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
|
||||||
|
if err != nil {
|
||||||
|
this.replyFail(message.RequestId, dataMessage.Name+"was installed, but apply DDoS protection config failed: "+err.Error())
|
||||||
|
} else {
|
||||||
|
this.replyOk(message.RequestId, string(result.AsJSON()))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// 修改API地址
|
// 修改API地址
|
||||||
func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error {
|
func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error {
|
||||||
config, err := configs.LoadAPIConfig()
|
config, err := configs.LoadAPIConfig()
|
||||||
@@ -660,6 +724,11 @@ func (this *APIStream) replyOk(requestId int64, message string) {
|
|||||||
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
|
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 回复成功并包含数据
|
||||||
|
func (this *APIStream) replyOkData(requestId int64, message string, dataJSON []byte) {
|
||||||
|
_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message, DataJSON: dataJSON})
|
||||||
|
}
|
||||||
|
|
||||||
// 获取缓存存取对象
|
// 获取缓存存取对象
|
||||||
func (this *APIStream) cacheStorage(message *pb.NodeStreamMessage, cachePolicyJSON []byte) (storage caches.StorageInterface, shouldStop bool, err error) {
|
func (this *APIStream) cacheStorage(message *pb.NodeStreamMessage, cachePolicyJSON []byte) (storage caches.StorageInterface, shouldStop bool, err error) {
|
||||||
cachePolicy := &serverconfigs.HTTPCachePolicy{}
|
cachePolicy := &serverconfigs.HTTPCachePolicy{}
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
|
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
|
"github.com/TeaOSLab/EdgeNode/internal/ttlcache"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||||
@@ -21,8 +20,7 @@ import (
|
|||||||
|
|
||||||
// ClientConn 客户端连接
|
// ClientConn 客户端连接
|
||||||
type ClientConn struct {
|
type ClientConn struct {
|
||||||
once sync.Once
|
once sync.Once
|
||||||
globalLimiter *ratelimit.Counter
|
|
||||||
|
|
||||||
isTLS bool
|
isTLS bool
|
||||||
hasDeadline bool
|
hasDeadline bool
|
||||||
@@ -33,7 +31,7 @@ type ClientConn struct {
|
|||||||
BaseClientConn
|
BaseClientConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ratelimit.Counter) net.Conn {
|
func NewClientConn(conn net.Conn, isTLS bool, quickClose bool) net.Conn {
|
||||||
if quickClose {
|
if quickClose {
|
||||||
// TCP
|
// TCP
|
||||||
tcpConn, ok := conn.(*net.TCPConn)
|
tcpConn, ok := conn.(*net.TCPConn)
|
||||||
@@ -43,7 +41,7 @@ func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ra
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS, globalLimiter: globalLimiter}
|
return &ClientConn{BaseClientConn: BaseClientConn{rawConn: conn}, isTLS: isTLS}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (this *ClientConn) Read(b []byte) (n int, err error) {
|
func (this *ClientConn) Read(b []byte) (n int, err error) {
|
||||||
@@ -96,13 +94,6 @@ func (this *ClientConn) Close() error {
|
|||||||
|
|
||||||
err := this.rawConn.Close()
|
err := this.rawConn.Close()
|
||||||
|
|
||||||
// 全局并发数限制
|
|
||||||
this.once.Do(func() {
|
|
||||||
if this.globalLimiter != nil {
|
|
||||||
this.globalLimiter.Release()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// 单个服务并发数限制
|
// 单个服务并发数限制
|
||||||
sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
|
sharedClientConnLimiter.Remove(this.rawConn.RemoteAddr().String())
|
||||||
|
|
||||||
|
|||||||
@@ -3,16 +3,12 @@
|
|||||||
package nodes
|
package nodes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
|
||||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
|
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
var sharedConnectionsLimiter = ratelimit.NewCounter(nodeconfigs.DefaultTCPMaxConnections)
|
|
||||||
|
|
||||||
// ClientListener 客户端网络监听
|
// ClientListener 客户端网络监听
|
||||||
type ClientListener struct {
|
type ClientListener struct {
|
||||||
rawListener net.Listener
|
rawListener net.Listener
|
||||||
@@ -36,13 +32,8 @@ func (this *ClientListener) IsTLS() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (this *ClientListener) Accept() (net.Conn, error) {
|
func (this *ClientListener) Accept() (net.Conn, error) {
|
||||||
// 限制并发连接数
|
|
||||||
var limiter = sharedConnectionsLimiter
|
|
||||||
limiter.Ack()
|
|
||||||
|
|
||||||
conn, err := this.rawListener.Accept()
|
conn, err := this.rawListener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
limiter.Release()
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,12 +51,11 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
limiter.Release()
|
|
||||||
return this.Accept()
|
return this.Accept()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewClientConn(conn, this.isTLS, this.quickClose, limiter), nil
|
return NewClientConn(conn, this.isTLS, this.quickClose), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (this *ClientListener) Close() error {
|
func (this *ClientListener) Close() error {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||||
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
|
||||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ddosconfigs"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
"github.com/TeaOSLab/EdgeNode/internal/caches"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
"github.com/TeaOSLab/EdgeNode/internal/configs"
|
||||||
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
|
||||||
@@ -15,7 +16,6 @@ import (
|
|||||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/metrics"
|
"github.com/TeaOSLab/EdgeNode/internal/metrics"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/ratelimit"
|
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
"github.com/TeaOSLab/EdgeNode/internal/rpc"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
"github.com/TeaOSLab/EdgeNode/internal/stats"
|
||||||
@@ -368,6 +368,38 @@ func (this *Node) loop() error {
|
|||||||
}
|
}
|
||||||
sharedNodeConfig.ParentNodes = parentNodes
|
sharedNodeConfig.ParentNodes = parentNodes
|
||||||
|
|
||||||
|
// 修改为已同步
|
||||||
|
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||||
|
NodeTaskId: task.Id,
|
||||||
|
IsOk: true,
|
||||||
|
Error: "",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case "ddosProtectionChanged":
|
||||||
|
resp, err := rpcClient.NodeRPC().FindNodeDDoSProtection(nodeCtx, &pb.FindNodeDDoSProtectionRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(resp.DdosProtectionJSON) == 0 {
|
||||||
|
if sharedNodeConfig != nil {
|
||||||
|
sharedNodeConfig.DDOSProtection = nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var ddosProtectionConfig = &ddosconfigs.ProtectionConfig{}
|
||||||
|
err = json.Unmarshal(resp.DdosProtectionJSON, ddosProtectionConfig)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("decode DDoS protection config failed: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = firewalls.SharedDDoSProtectionManager.Apply(ddosProtectionConfig)
|
||||||
|
if err != nil {
|
||||||
|
// 不阻塞
|
||||||
|
remotelogs.Error("NODE", "apply DDoS protection failed: "+err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 修改为已同步
|
// 修改为已同步
|
||||||
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
_, err = rpcClient.NodeTaskRPC().ReportNodeTaskDone(nodeCtx, &pb.ReportNodeTaskDoneRequest{
|
||||||
NodeTaskId: task.Id,
|
NodeTaskId: task.Id,
|
||||||
@@ -730,7 +762,6 @@ func (this *Node) listenSock() error {
|
|||||||
"ipConns": ipConns,
|
"ipConns": ipConns,
|
||||||
"serverConns": serverConns,
|
"serverConns": serverConns,
|
||||||
"total": sharedListenerManager.TotalActiveConnections(),
|
"total": sharedListenerManager.TotalActiveConnections(),
|
||||||
"limiter": sharedConnectionsLimiter.Len(),
|
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
case "dropIP":
|
case "dropIP":
|
||||||
@@ -854,17 +885,6 @@ func (this *Node) onReload(config *nodeconfigs.NodeConfig) {
|
|||||||
this.maxThreads = config.MaxThreads
|
this.maxThreads = config.MaxThreads
|
||||||
}
|
}
|
||||||
|
|
||||||
// max tcp connections
|
|
||||||
if config.TCPMaxConnections <= 0 {
|
|
||||||
config.TCPMaxConnections = nodeconfigs.DefaultTCPMaxConnections
|
|
||||||
}
|
|
||||||
if config.TCPMaxConnections != sharedConnectionsLimiter.Count() {
|
|
||||||
remotelogs.Println("NODE", "[TCP]changed tcp max connections to '"+types.String(config.TCPMaxConnections)+"'")
|
|
||||||
|
|
||||||
sharedConnectionsLimiter.Close()
|
|
||||||
sharedConnectionsLimiter = ratelimit.NewCounter(config.TCPMaxConnections)
|
|
||||||
}
|
|
||||||
|
|
||||||
// timezone
|
// timezone
|
||||||
var timeZone = config.TimeZone
|
var timeZone = config.TimeZone
|
||||||
if len(timeZone) == 0 {
|
if len(timeZone) == 0 {
|
||||||
|
|||||||
@@ -5,9 +5,12 @@ import (
|
|||||||
"github.com/cespare/xxhash"
|
"github.com/cespare/xxhash"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ipv4Reg = regexp.MustCompile(`\d+\.`)
|
||||||
|
|
||||||
// IP2Long 将IP转换为整型
|
// IP2Long 将IP转换为整型
|
||||||
// 注意IPv6没有顺序
|
// 注意IPv6没有顺序
|
||||||
func IP2Long(ip string) uint64 {
|
func IP2Long(ip string) uint64 {
|
||||||
@@ -54,3 +57,24 @@ func IsLocalIP(ipString string) bool {
|
|||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsIPv4 是否为IPv4
|
||||||
|
func IsIPv4(ip string) bool {
|
||||||
|
var data = net.ParseIP(ip)
|
||||||
|
if data == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.Contains(ip, ":") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return data.To4() != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsIPv6 是否为IPv6
|
||||||
|
func IsIPv6(ip string) bool {
|
||||||
|
var data = net.ParseIP(ip)
|
||||||
|
if data == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !IsIPv4(ip)
|
||||||
|
}
|
||||||
|
|||||||
@@ -26,3 +26,26 @@ func TestIsLocalIP(t *testing.T) {
|
|||||||
a.IsFalse(IsLocalIP("::1:2:3"))
|
a.IsFalse(IsLocalIP("::1:2:3"))
|
||||||
a.IsFalse(IsLocalIP("8.8.8.8"))
|
a.IsFalse(IsLocalIP("8.8.8.8"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsIPv4(t *testing.T) {
|
||||||
|
var a = assert.NewAssertion(t)
|
||||||
|
a.IsTrue(IsIPv4("192.168.1.1"))
|
||||||
|
a.IsTrue(IsIPv4("0.0.0.0"))
|
||||||
|
a.IsFalse(IsIPv4("192.168.1.256"))
|
||||||
|
a.IsFalse(IsIPv4("192.168.1"))
|
||||||
|
a.IsFalse(IsIPv4("::1"))
|
||||||
|
a.IsFalse(IsIPv4("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
|
||||||
|
a.IsFalse(IsIPv4("::ffff:192.168.0.1"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsIPv6(t *testing.T) {
|
||||||
|
var a = assert.NewAssertion(t)
|
||||||
|
a.IsFalse(IsIPv6("192.168.1.1"))
|
||||||
|
a.IsFloat32(IsIPv6("0.0.0.0"))
|
||||||
|
a.IsFalse(IsIPv6("192.168.1.256"))
|
||||||
|
a.IsFalse(IsIPv6("192.168.1"))
|
||||||
|
a.IsTrue(IsIPv6("::1"))
|
||||||
|
a.IsTrue(IsIPv6("2001:0db8:85a3:0000:0000:8a2e:0370:7334"))
|
||||||
|
a.IsTrue(IsIPv4("::ffff:192.168.0.1"))
|
||||||
|
a.IsTrue(IsIPv6("::ffff:192.168.0.1"))
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
@@ -36,6 +37,22 @@ func FormatAddressList(addrList []string) []string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToValidUTF8string 去除字符串中的非UTF-8字符
|
||||||
func ToValidUTF8string(v string) string {
|
func ToValidUTF8string(v string) string {
|
||||||
return strings.ToValidUTF8(v, "")
|
return strings.ToValidUTF8(v, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ContainsSameStrings 检查两个字符串slice内容是否一致
|
||||||
|
func ContainsSameStrings(s1 []string, s2 []string) bool {
|
||||||
|
if len(s1) != len(s2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
sort.Strings(s1)
|
||||||
|
sort.Strings(s2)
|
||||||
|
for index, v1 := range s1 {
|
||||||
|
if v1 != s2[index] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,56 +1,67 @@
|
|||||||
package utils
|
package utils_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||||
|
"github.com/iwind/TeaGo/assert"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBytesToString(t *testing.T) {
|
func TestBytesToString(t *testing.T) {
|
||||||
t.Log(UnsafeBytesToString([]byte("Hello,World")))
|
t.Log(utils.UnsafeBytesToString([]byte("Hello,World")))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStringToBytes(t *testing.T) {
|
func TestStringToBytes(t *testing.T) {
|
||||||
t.Log(string(UnsafeStringToBytes("Hello,World")))
|
t.Log(string(utils.UnsafeStringToBytes("Hello,World")))
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkBytesToString(b *testing.B) {
|
func BenchmarkBytesToString(b *testing.B) {
|
||||||
data := []byte("Hello,World")
|
var data = []byte("Hello,World")
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_ = UnsafeBytesToString(data)
|
_ = utils.UnsafeBytesToString(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkBytesToString2(b *testing.B) {
|
func BenchmarkBytesToString2(b *testing.B) {
|
||||||
data := []byte("Hello,World")
|
var data = []byte("Hello,World")
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_ = string(data)
|
_ = string(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkStringToBytes(b *testing.B) {
|
func BenchmarkStringToBytes(b *testing.B) {
|
||||||
s := strings.Repeat("Hello,World", 1024)
|
var s = strings.Repeat("Hello,World", 1024)
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_ = UnsafeStringToBytes(s)
|
_ = utils.UnsafeStringToBytes(s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkStringToBytes2(b *testing.B) {
|
func BenchmarkStringToBytes2(b *testing.B) {
|
||||||
s := strings.Repeat("Hello,World", 1024)
|
var s = strings.Repeat("Hello,World", 1024)
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_ = []byte(s)
|
_ = []byte(s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFormatAddress(t *testing.T) {
|
func TestFormatAddress(t *testing.T) {
|
||||||
t.Log(FormatAddress("127.0.0.1:1234"))
|
t.Log(utils.FormatAddress("127.0.0.1:1234"))
|
||||||
t.Log(FormatAddress("127.0.0.1 : 1234"))
|
t.Log(utils.FormatAddress("127.0.0.1 : 1234"))
|
||||||
t.Log(FormatAddress("127.0.0.1:1234"))
|
t.Log(utils.FormatAddress("127.0.0.1:1234"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFormatAddressList(t *testing.T) {
|
func TestFormatAddressList(t *testing.T) {
|
||||||
t.Log(FormatAddressList([]string{
|
t.Log(utils.FormatAddressList([]string{
|
||||||
"127.0.0.1:1234",
|
"127.0.0.1:1234",
|
||||||
"127.0.0.1 : 1234",
|
"127.0.0.1 : 1234",
|
||||||
"127.0.0.1:1234",
|
"127.0.0.1:1234",
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContainsSameStrings(t *testing.T) {
|
||||||
|
var a = assert.NewAssertion(t)
|
||||||
|
a.IsFalse(utils.ContainsSameStrings([]string{"a"}, []string{"b"}))
|
||||||
|
a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b"}))
|
||||||
|
a.IsFalse(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b", "c"}))
|
||||||
|
a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"a", "b"}))
|
||||||
|
a.IsTrue(utils.ContainsSameStrings([]string{"a", "b"}, []string{"b", "a"}))
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user