Files
EdgeNode/internal/firewalls/firewall_nftables.go

501 lines
12 KiB
Go
Raw Normal View History

2024-05-17 18:30:33 +08:00
// Copyright 2022 GoEdge goedge.cdn@gmail.com. All rights reserved.
2022-05-18 21:03:51 +08:00
//go:build linux
package firewalls
import (
"errors"
2023-08-11 14:38:00 +08:00
"fmt"
"github.com/TeaOSLab/EdgeCommon/pkg/iputils"
"github.com/TeaOSLab/EdgeNode/internal/conns"
2022-07-26 09:41:43 +08:00
teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
2022-05-18 21:03:51 +08:00
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
2022-09-15 11:14:33 +08:00
executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
2024-05-11 09:23:54 +08:00
"github.com/TeaOSLab/EdgeNode/internal/utils/goman"
"github.com/google/nftables/expr"
2022-05-18 21:03:51 +08:00
"github.com/iwind/TeaGo/types"
"net"
2022-06-09 19:12:10 +08:00
"regexp"
2022-05-18 21:03:51 +08:00
"runtime"
"strings"
"time"
)
// check nft status, if being enabled we load it automatically
func init() {
if !teaconst.IsMain {
2022-07-26 09:41:43 +08:00
return
}
2022-05-18 21:03:51 +08:00
if runtime.GOOS == "linux" {
var ticker = time.NewTicker(3 * time.Minute)
2022-08-04 11:01:16 +08:00
goman.New(func() {
2022-05-18 21:03:51 +08:00
for range ticker.C {
// if already ready, we break
if nftablesIsReady {
ticker.Stop()
break
}
2023-04-05 09:33:03 +08:00
var nftExe = nftables.NftExePath()
if len(nftExe) > 0 {
2022-05-18 21:03:51 +08:00
nftablesFirewall, err := NewNFTablesFirewall()
if err != nil {
continue
}
currentFirewall = nftablesFirewall
remotelogs.Println("FIREWALL", "nftables is ready")
// fire event
if nftablesFirewall.IsReady() {
events.Notify(events.EventNFTablesReady)
}
ticker.Stop()
break
}
}
2022-08-04 11:01:16 +08:00
})
2022-05-18 21:03:51 +08:00
}
}
var nftablesInstance *NFTablesFirewall
var nftablesIsReady = false
var nftablesFilters = []*nftablesTableDefinition{
// we shorten the name for table name length restriction
{Name: "edge_dft_v4", IsIPv4: true},
{Name: "edge_dft_v6", IsIPv6: true},
}
var nftablesChainName = "input"
type nftablesTableDefinition struct {
Name string
IsIPv4 bool
IsIPv6 bool
}
func (this *nftablesTableDefinition) protocol() string {
if this.IsIPv6 {
return "ip6"
}
return "ip"
}
2022-08-04 11:01:16 +08:00
type blockIPItem struct {
action string
ip string
timeoutSeconds int
}
2022-05-18 21:03:51 +08:00
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
2023-04-19 12:01:02 +08:00
conn, err := nftables.NewConn()
if err != nil {
return nil, err
}
2022-05-18 21:03:51 +08:00
var firewall = &NFTablesFirewall{
2023-04-19 12:01:02 +08:00
conn: conn,
2022-08-04 11:01:16 +08:00
dropIPQueue: make(chan *blockIPItem, 4096),
2022-05-18 21:03:51 +08:00
}
2023-04-19 12:01:02 +08:00
err = firewall.init()
2022-05-18 21:03:51 +08:00
if err != nil {
return nil, err
}
return firewall, nil
}
type NFTablesFirewall struct {
BaseFirewall
2022-05-18 21:03:51 +08:00
conn *nftables.Conn
isReady bool
2022-06-09 19:12:10 +08:00
version string
2022-05-18 21:03:51 +08:00
allowIPv4Set *nftables.Set
allowIPv6Set *nftables.Set
2023-04-02 20:32:36 +08:00
denyIPv4Sets []*nftables.Set
denyIPv6Sets []*nftables.Set
2022-05-18 21:03:51 +08:00
firewalld *Firewalld
2022-08-04 11:01:16 +08:00
dropIPQueue chan *blockIPItem
2022-05-18 21:03:51 +08:00
}
func (this *NFTablesFirewall) init() error {
// check nft
2023-04-05 09:33:03 +08:00
var nftPath = nftables.NftExePath()
if len(nftPath) == 0 {
return errors.New("'nft' not found")
2022-05-18 21:03:51 +08:00
}
2022-06-09 19:12:10 +08:00
this.version = this.readVersion(nftPath)
2022-05-18 21:03:51 +08:00
// table
for _, tableDef := range nftablesFilters {
var family nftables.TableFamily
if tableDef.IsIPv4 {
family = nftables.TableFamilyIPv4
} else if tableDef.IsIPv6 {
family = nftables.TableFamilyIPv6
} else {
return errors.New("invalid table family: " + types.String(tableDef))
}
table, err := this.conn.GetTable(tableDef.Name, family)
if err != nil {
if nftables.IsNotFound(err) {
if tableDef.IsIPv4 {
table, err = this.conn.AddIPv4Table(tableDef.Name)
} else if tableDef.IsIPv6 {
table, err = this.conn.AddIPv6Table(tableDef.Name)
}
if err != nil {
2023-08-11 14:38:00 +08:00
return fmt.Errorf("create table '%s' failed: %w", tableDef.Name, err)
2022-05-18 21:03:51 +08:00
}
} else {
2023-08-11 14:38:00 +08:00
return fmt.Errorf("get table '%s' failed: %w", tableDef.Name, err)
2022-05-18 21:03:51 +08:00
}
}
if table == nil {
return errors.New("can not create table '" + tableDef.Name + "'")
}
// chain
var chainName = nftablesChainName
chain, err := table.GetChain(chainName)
if err != nil {
if nftables.IsNotFound(err) {
chain, err = table.AddAcceptChain(chainName)
if err != nil {
2023-08-11 14:38:00 +08:00
return fmt.Errorf("create chain '%s' failed: %w", chainName, err)
2022-05-18 21:03:51 +08:00
}
} else {
2023-08-11 14:38:00 +08:00
return fmt.Errorf("get chain '%s' failed: %w", chainName, err)
2022-05-18 21:03:51 +08:00
}
}
if chain == nil {
return errors.New("can not create chain '" + chainName + "'")
}
// allow lo
var loRuleName = []byte("lo")
_, err = chain.GetRuleWithUserData(loRuleName)
if err != nil {
if nftables.IsNotFound(err) {
_, err = chain.AddAcceptInterfaceRule("lo", loRuleName)
}
if err != nil {
2023-08-11 14:38:00 +08:00
return fmt.Errorf("add 'lo' rule failed: %w", err)
2022-05-18 21:03:51 +08:00
}
}
// allow set
// "allow" should be always first
2023-04-02 20:32:36 +08:00
for _, setAction := range []string{"allow", "deny", "deny1", "deny2", "deny3", "deny4"} {
2022-05-18 21:03:51 +08:00
var setName = setAction + "_set"
set, err := table.GetSet(setName)
if err != nil {
if nftables.IsNotFound(err) {
var keyType nftables.SetDataType
if tableDef.IsIPv4 {
keyType = nftables.TypeIPAddr
} else if tableDef.IsIPv6 {
keyType = nftables.TypeIP6Addr
}
set, err = table.AddSet(setName, &nftables.SetOptions{
KeyType: keyType,
HasTimeout: true,
})
if err != nil {
2023-08-11 14:38:00 +08:00
return fmt.Errorf("create set '%s' failed: %w", setName, err)
2022-05-18 21:03:51 +08:00
}
} else {
2023-08-11 14:38:00 +08:00
return fmt.Errorf("get set '%s' failed: %w", setName, err)
2022-05-18 21:03:51 +08:00
}
}
if set == nil {
return errors.New("can not create set '" + setName + "'")
}
if tableDef.IsIPv4 {
if setAction == "allow" {
this.allowIPv4Set = set
} else {
2023-04-02 20:32:36 +08:00
this.denyIPv4Sets = append(this.denyIPv4Sets, set)
2022-05-18 21:03:51 +08:00
}
} else if tableDef.IsIPv6 {
if setAction == "allow" {
this.allowIPv6Set = set
} else {
2023-04-02 20:32:36 +08:00
this.denyIPv6Sets = append(this.denyIPv6Sets, set)
2022-05-18 21:03:51 +08:00
}
}
// rule
var ruleName = []byte(setAction)
rule, err := chain.GetRuleWithUserData(ruleName)
// 将以前的drop规则删掉替换成后面的reject
if err == nil && setAction != "allow" && rule != nil && rule.VerDict() == expr.VerdictDrop {
deleteErr := chain.DeleteRule(rule)
if deleteErr == nil {
err = nftables.ErrRuleNotFound
rule = nil
}
}
2022-05-18 21:03:51 +08:00
if err != nil {
if nftables.IsNotFound(err) {
if tableDef.IsIPv4 {
if setAction == "allow" {
rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName)
} else {
rule, err = chain.AddRejectIPv4SetRule(setName, ruleName)
2022-05-18 21:03:51 +08:00
}
} else if tableDef.IsIPv6 {
if setAction == "allow" {
rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName)
} else {
rule, err = chain.AddRejectIPv6SetRule(setName, ruleName)
2022-05-18 21:03:51 +08:00
}
}
if err != nil {
2023-08-11 14:38:00 +08:00
return fmt.Errorf("add rule failed: %w", err)
2022-05-18 21:03:51 +08:00
}
} else {
2023-08-11 14:38:00 +08:00
return fmt.Errorf("get rule failed: %w", err)
2022-05-18 21:03:51 +08:00
}
}
if rule == nil {
return errors.New("can not create rule '" + string(ruleName) + "'")
}
}
}
this.isReady = true
nftablesIsReady = true
nftablesInstance = this
2022-08-04 11:01:16 +08:00
goman.New(func() {
for ipItem := range this.dropIPQueue {
switch ipItem.action {
case "drop":
err := this.DropSourceIP(ipItem.ip, ipItem.timeoutSeconds, false)
2022-08-04 11:01:16 +08:00
if err != nil {
remotelogs.Warn("NFTABLES", "drop ip '"+ipItem.ip+"' failed: "+err.Error())
}
}
}
})
2022-05-18 21:03:51 +08:00
// load firewalld
var firewalld = NewFirewalld()
if firewalld.IsReady() {
this.firewalld = firewalld
}
return nil
}
// Name 名称
func (this *NFTablesFirewall) Name() string {
return "nftables"
}
// IsReady 是否已准备被调用
func (this *NFTablesFirewall) IsReady() bool {
return this.isReady
}
// IsMock 是否为模拟
func (this *NFTablesFirewall) IsMock() bool {
return false
}
// AllowPort 允许端口
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
if this.firewalld != nil {
return this.firewalld.AllowPort(port, protocol)
}
return nil
}
// RemovePort 删除端口
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
if this.firewalld != nil {
return this.firewalld.RemovePort(port, protocol)
}
return nil
}
// AllowSourceIP Allow把IP加入白名单
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
var data = net.ParseIP(ip)
if data == nil {
return errors.New("invalid ip '" + ip + "'")
}
if strings.Contains(ip, ":") { // ipv6
if this.allowIPv6Set == nil {
return errors.New("ipv6 ip set is nil")
}
2023-04-01 20:51:49 +08:00
return this.allowIPv6Set.AddElement(data.To16(), nil, false)
2022-05-18 21:03:51 +08:00
}
// ipv4
if this.allowIPv4Set == nil {
return errors.New("ipv4 ip set is nil")
}
2023-04-01 20:51:49 +08:00
return this.allowIPv4Set.AddElement(data.To4(), nil, false)
2022-05-18 21:03:51 +08:00
}
// RejectSourceIP 拒绝某个源IP连接
// we did not create set for drop ip, so we reuse DropSourceIP() method here
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
2022-08-04 11:01:16 +08:00
return this.DropSourceIP(ip, timeoutSeconds, true)
2022-05-18 21:03:51 +08:00
}
// DropSourceIP 丢弃某个源IP数据
2022-08-04 11:01:16 +08:00
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async bool) error {
2022-05-18 21:03:51 +08:00
var data = net.ParseIP(ip)
if data == nil {
return errors.New("invalid ip '" + ip + "'")
}
// 尝试关闭连接
conns.SharedMap.CloseIPConns(ip)
// 避免短时间内重复添加
if async && this.checkLatestIP(ip) {
return nil
}
2022-08-04 11:01:16 +08:00
if async {
select {
case this.dropIPQueue <- &blockIPItem{
action: "drop",
ip: ip,
timeoutSeconds: timeoutSeconds,
}:
default:
return errors.New("drop ip queue is full")
}
return nil
}
// 再次尝试关闭连接
defer conns.SharedMap.CloseIPConns(ip)
2022-05-18 21:03:51 +08:00
if strings.Contains(ip, ":") { // ipv6
2023-04-02 20:32:36 +08:00
if len(this.denyIPv6Sets) == 0 {
return errors.New("ipv6 ip set not found")
2022-05-18 21:03:51 +08:00
}
var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv6Sets))
return this.denyIPv6Sets[setIndex].AddElement(data.To16(), &nftables.ElementOptions{
2022-05-18 21:03:51 +08:00
Timeout: time.Duration(timeoutSeconds) * time.Second,
2023-04-01 20:51:49 +08:00
}, false)
2022-05-18 21:03:51 +08:00
}
// ipv4
2023-04-02 20:32:36 +08:00
if len(this.denyIPv4Sets) == 0 {
return errors.New("ipv4 ip set not found")
2022-05-18 21:03:51 +08:00
}
var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv4Sets))
return this.denyIPv4Sets[setIndex].AddElement(data.To4(), &nftables.ElementOptions{
2022-05-18 21:03:51 +08:00
Timeout: time.Duration(timeoutSeconds) * time.Second,
2023-04-01 20:51:49 +08:00
}, false)
2022-05-18 21:03:51 +08:00
}
// RemoveSourceIP 删除某个源IP
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
var data = net.ParseIP(ip)
if data == nil {
return errors.New("invalid ip '" + ip + "'")
}
if strings.Contains(ip, ":") { // ipv6
var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv6Sets))
2023-04-02 20:32:36 +08:00
if len(this.denyIPv6Sets) > 0 {
err := this.denyIPv6Sets[setIndex].DeleteElement(data.To16())
2022-05-18 21:03:51 +08:00
if err != nil {
return err
}
}
if this.allowIPv6Set != nil {
err := this.allowIPv6Set.DeleteElement(data.To16())
if err != nil {
return err
}
}
return nil
}
// ipv4
2023-04-02 20:32:36 +08:00
if len(this.denyIPv4Sets) > 0 {
var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv4Sets))
err := this.denyIPv4Sets[setIndex].DeleteElement(data.To4())
2022-05-18 21:03:51 +08:00
if err != nil {
return err
}
2023-04-02 20:32:36 +08:00
}
if this.allowIPv4Set != nil {
err := this.allowIPv4Set.DeleteElement(data.To4())
2022-05-18 21:03:51 +08:00
if err != nil {
return err
}
}
return nil
}
2022-06-09 19:12:10 +08:00
// 读取版本号
func (this *NFTablesFirewall) readVersion(nftPath string) string {
2022-09-15 11:14:33 +08:00
var cmd = executils.NewTimeoutCmd(10*time.Second, nftPath, "--version")
cmd.WithStdout()
2022-06-09 19:12:10 +08:00
err := cmd.Run()
if err != nil {
return ""
}
2022-09-15 11:14:33 +08:00
var outputString = cmd.Stdout()
2022-06-09 19:12:10 +08:00
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
if len(versionMatches) <= 1 {
return ""
}
return versionMatches[1]
}
// 检查是否在最近添加过
func (this *NFTablesFirewall) existLatestIP(ip string) bool {
this.locker.Lock()
defer this.locker.Unlock()
var expiredIndex = -1
for index, ipTime := range this.latestIPTimes {
var pieces = strings.Split(ipTime, "@")
var oldIP = pieces[0]
var oldTimestamp = pieces[1]
if types.Int64(oldTimestamp) < time.Now().Unix()-3 /** 3秒外表示过期 **/ {
expiredIndex = index
continue
}
if oldIP == ip {
return true
}
}
if expiredIndex > -1 {
this.latestIPTimes = this.latestIPTimes[expiredIndex+1:]
}
this.latestIPTimes = append(this.latestIPTimes, ip+"@"+types.String(time.Now().Unix()))
const maxLen = 128
if len(this.latestIPTimes) > maxLen {
this.latestIPTimes = this.latestIPTimes[1:]
}
return false
}