mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-03 23:20:25 +08:00
396 lines
9.2 KiB
Go
396 lines
9.2 KiB
Go
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
|
//go:build linux
|
|
// +build linux
|
|
|
|
package firewalls
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"github.com/TeaOSLab/EdgeNode/internal/events"
|
|
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
|
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
|
"github.com/iwind/TeaGo/types"
|
|
"net"
|
|
"os/exec"
|
|
"regexp"
|
|
"runtime"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// check nft status, if being enabled we load it automatically
|
|
func init() {
|
|
if runtime.GOOS == "linux" {
|
|
var ticker = time.NewTicker(3 * time.Minute)
|
|
go func() {
|
|
for range ticker.C {
|
|
// if already ready, we break
|
|
if nftablesIsReady {
|
|
ticker.Stop()
|
|
break
|
|
}
|
|
_, err := exec.LookPath("nft")
|
|
if err == nil {
|
|
nftablesFirewall, err := NewNFTablesFirewall()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
currentFirewall = nftablesFirewall
|
|
remotelogs.Println("FIREWALL", "nftables is ready")
|
|
|
|
// fire event
|
|
if nftablesFirewall.IsReady() {
|
|
events.Notify(events.EventNFTablesReady)
|
|
}
|
|
|
|
ticker.Stop()
|
|
break
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
|
|
var nftablesInstance *NFTablesFirewall
|
|
var nftablesIsReady = false
|
|
var nftablesFilters = []*nftablesTableDefinition{
|
|
// we shorten the name for table name length restriction
|
|
{Name: "edge_dft_v4", IsIPv4: true},
|
|
{Name: "edge_dft_v6", IsIPv6: true},
|
|
}
|
|
var nftablesChainName = "input"
|
|
|
|
type nftablesTableDefinition struct {
|
|
Name string
|
|
IsIPv4 bool
|
|
IsIPv6 bool
|
|
}
|
|
|
|
func (this *nftablesTableDefinition) protocol() string {
|
|
if this.IsIPv6 {
|
|
return "ip6"
|
|
}
|
|
return "ip"
|
|
}
|
|
|
|
func NewNFTablesFirewall() (*NFTablesFirewall, error) {
|
|
var firewall = &NFTablesFirewall{
|
|
conn: nftables.NewConn(),
|
|
}
|
|
err := firewall.init()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return firewall, nil
|
|
}
|
|
|
|
type NFTablesFirewall struct {
|
|
conn *nftables.Conn
|
|
isReady bool
|
|
version string
|
|
|
|
allowIPv4Set *nftables.Set
|
|
allowIPv6Set *nftables.Set
|
|
|
|
denyIPv4Set *nftables.Set
|
|
denyIPv6Set *nftables.Set
|
|
|
|
firewalld *Firewalld
|
|
}
|
|
|
|
func (this *NFTablesFirewall) init() error {
|
|
// check nft
|
|
nftPath, err := exec.LookPath("nft")
|
|
if err != nil {
|
|
return errors.New("nft not found")
|
|
}
|
|
this.version = this.readVersion(nftPath)
|
|
|
|
// table
|
|
for _, tableDef := range nftablesFilters {
|
|
var family nftables.TableFamily
|
|
if tableDef.IsIPv4 {
|
|
family = nftables.TableFamilyIPv4
|
|
} else if tableDef.IsIPv6 {
|
|
family = nftables.TableFamilyIPv6
|
|
} else {
|
|
return errors.New("invalid table family: " + types.String(tableDef))
|
|
}
|
|
table, err := this.conn.GetTable(tableDef.Name, family)
|
|
if err != nil {
|
|
if nftables.IsNotFound(err) {
|
|
if tableDef.IsIPv4 {
|
|
table, err = this.conn.AddIPv4Table(tableDef.Name)
|
|
} else if tableDef.IsIPv6 {
|
|
table, err = this.conn.AddIPv6Table(tableDef.Name)
|
|
}
|
|
if err != nil {
|
|
return errors.New("create table '" + tableDef.Name + "' failed: " + err.Error())
|
|
}
|
|
} else {
|
|
return errors.New("get table '" + tableDef.Name + "' failed: " + err.Error())
|
|
}
|
|
}
|
|
if table == nil {
|
|
return errors.New("can not create table '" + tableDef.Name + "'")
|
|
}
|
|
|
|
// chain
|
|
var chainName = nftablesChainName
|
|
chain, err := table.GetChain(chainName)
|
|
if err != nil {
|
|
if nftables.IsNotFound(err) {
|
|
chain, err = table.AddAcceptChain(chainName)
|
|
if err != nil {
|
|
return errors.New("create chain '" + chainName + "' failed: " + err.Error())
|
|
}
|
|
} else {
|
|
return errors.New("get chain '" + chainName + "' failed: " + err.Error())
|
|
}
|
|
}
|
|
if chain == nil {
|
|
return errors.New("can not create chain '" + chainName + "'")
|
|
}
|
|
|
|
// allow lo
|
|
var loRuleName = []byte("lo")
|
|
_, err = chain.GetRuleWithUserData(loRuleName)
|
|
if err != nil {
|
|
if nftables.IsNotFound(err) {
|
|
_, err = chain.AddAcceptInterfaceRule("lo", loRuleName)
|
|
}
|
|
if err != nil {
|
|
return errors.New("add 'lo' rule failed: " + err.Error())
|
|
}
|
|
}
|
|
|
|
// allow set
|
|
// "allow" should be always first
|
|
for _, setAction := range []string{"allow", "deny"} {
|
|
var setName = setAction + "_set"
|
|
|
|
set, err := table.GetSet(setName)
|
|
if err != nil {
|
|
if nftables.IsNotFound(err) {
|
|
var keyType nftables.SetDataType
|
|
if tableDef.IsIPv4 {
|
|
keyType = nftables.TypeIPAddr
|
|
} else if tableDef.IsIPv6 {
|
|
keyType = nftables.TypeIP6Addr
|
|
}
|
|
set, err = table.AddSet(setName, &nftables.SetOptions{
|
|
KeyType: keyType,
|
|
HasTimeout: true,
|
|
})
|
|
if err != nil {
|
|
return errors.New("create set '" + setName + "' failed: " + err.Error())
|
|
}
|
|
} else {
|
|
return errors.New("get set '" + setName + "' failed: " + err.Error())
|
|
}
|
|
}
|
|
if set == nil {
|
|
return errors.New("can not create set '" + setName + "'")
|
|
}
|
|
if tableDef.IsIPv4 {
|
|
if setAction == "allow" {
|
|
this.allowIPv4Set = set
|
|
} else {
|
|
this.denyIPv4Set = set
|
|
}
|
|
} else if tableDef.IsIPv6 {
|
|
if setAction == "allow" {
|
|
this.allowIPv6Set = set
|
|
} else {
|
|
this.denyIPv6Set = set
|
|
}
|
|
}
|
|
|
|
// rule
|
|
var ruleName = []byte(setAction)
|
|
rule, err := chain.GetRuleWithUserData(ruleName)
|
|
if err != nil {
|
|
if nftables.IsNotFound(err) {
|
|
if tableDef.IsIPv4 {
|
|
if setAction == "allow" {
|
|
rule, err = chain.AddAcceptIPv4SetRule(setName, ruleName)
|
|
} else {
|
|
rule, err = chain.AddDropIPv4SetRule(setName, ruleName)
|
|
}
|
|
} else if tableDef.IsIPv6 {
|
|
if setAction == "allow" {
|
|
rule, err = chain.AddAcceptIPv6SetRule(setName, ruleName)
|
|
} else {
|
|
rule, err = chain.AddDropIPv6SetRule(setName, ruleName)
|
|
}
|
|
}
|
|
if err != nil {
|
|
return errors.New("add rule failed: " + err.Error())
|
|
}
|
|
} else {
|
|
return errors.New("get rule failed: " + err.Error())
|
|
}
|
|
}
|
|
if rule == nil {
|
|
return errors.New("can not create rule '" + string(ruleName) + "'")
|
|
}
|
|
}
|
|
}
|
|
|
|
this.isReady = true
|
|
nftablesIsReady = true
|
|
nftablesInstance = this
|
|
|
|
// load firewalld
|
|
var firewalld = NewFirewalld()
|
|
if firewalld.IsReady() {
|
|
this.firewalld = firewalld
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Name 名称
|
|
func (this *NFTablesFirewall) Name() string {
|
|
return "nftables"
|
|
}
|
|
|
|
// IsReady 是否已准备被调用
|
|
func (this *NFTablesFirewall) IsReady() bool {
|
|
return this.isReady
|
|
}
|
|
|
|
// IsMock 是否为模拟
|
|
func (this *NFTablesFirewall) IsMock() bool {
|
|
return false
|
|
}
|
|
|
|
// AllowPort 允许端口
|
|
func (this *NFTablesFirewall) AllowPort(port int, protocol string) error {
|
|
if this.firewalld != nil {
|
|
return this.firewalld.AllowPort(port, protocol)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RemovePort 删除端口
|
|
func (this *NFTablesFirewall) RemovePort(port int, protocol string) error {
|
|
if this.firewalld != nil {
|
|
return this.firewalld.RemovePort(port, protocol)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// AllowSourceIP Allow把IP加入白名单
|
|
func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
|
|
var data = net.ParseIP(ip)
|
|
if data == nil {
|
|
return errors.New("invalid ip '" + ip + "'")
|
|
}
|
|
|
|
if strings.Contains(ip, ":") { // ipv6
|
|
if this.allowIPv6Set == nil {
|
|
return errors.New("ipv6 ip set is nil")
|
|
}
|
|
return this.allowIPv6Set.AddElement(data.To16(), nil)
|
|
}
|
|
|
|
// ipv4
|
|
if this.allowIPv4Set == nil {
|
|
return errors.New("ipv4 ip set is nil")
|
|
}
|
|
return this.allowIPv4Set.AddElement(data.To4(), nil)
|
|
}
|
|
|
|
// RejectSourceIP 拒绝某个源IP连接
|
|
// we did not create set for drop ip, so we reuse DropSourceIP() method here
|
|
func (this *NFTablesFirewall) RejectSourceIP(ip string, timeoutSeconds int) error {
|
|
return this.DropSourceIP(ip, timeoutSeconds)
|
|
}
|
|
|
|
// DropSourceIP 丢弃某个源IP数据
|
|
func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int) error {
|
|
var data = net.ParseIP(ip)
|
|
if data == nil {
|
|
return errors.New("invalid ip '" + ip + "'")
|
|
}
|
|
|
|
if strings.Contains(ip, ":") { // ipv6
|
|
if this.denyIPv6Set == nil {
|
|
return errors.New("ipv6 ip set is nil")
|
|
}
|
|
return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
|
|
Timeout: time.Duration(timeoutSeconds) * time.Second,
|
|
})
|
|
}
|
|
|
|
// ipv4
|
|
if this.denyIPv4Set == nil {
|
|
return errors.New("ipv4 ip set is nil")
|
|
}
|
|
return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
|
|
Timeout: time.Duration(timeoutSeconds) * time.Second,
|
|
})
|
|
}
|
|
|
|
// RemoveSourceIP 删除某个源IP
|
|
func (this *NFTablesFirewall) RemoveSourceIP(ip string) error {
|
|
var data = net.ParseIP(ip)
|
|
if data == nil {
|
|
return errors.New("invalid ip '" + ip + "'")
|
|
}
|
|
|
|
if strings.Contains(ip, ":") { // ipv6
|
|
if this.denyIPv6Set != nil {
|
|
err := this.denyIPv6Set.DeleteElement(data.To16())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if this.allowIPv6Set != nil {
|
|
err := this.allowIPv6Set.DeleteElement(data.To16())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ipv4
|
|
if this.allowIPv4Set != nil {
|
|
err := this.denyIPv4Set.DeleteElement(data.To4())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = this.allowIPv4Set.DeleteElement(data.To4())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// 读取版本号
|
|
func (this *NFTablesFirewall) readVersion(nftPath string) string {
|
|
var cmd = exec.Command(nftPath, "--version")
|
|
var output = &bytes.Buffer{}
|
|
cmd.Stdout = output
|
|
err := cmd.Run()
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
var outputString = output.String()
|
|
var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
|
|
if len(versionMatches) <= 1 {
|
|
return ""
|
|
}
|
|
return versionMatches[1]
|
|
}
|