mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-02 14:00:25 +08:00
134 lines
2.7 KiB
Go
134 lines
2.7 KiB
Go
// Copyright 2022 GoEdge goedge.cdn@gmail.com. All rights reserved.
|
|
//go:build linux
|
|
|
|
package nftables_test
|
|
|
|
import (
|
|
"net"
|
|
"testing"
|
|
|
|
"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
|
|
)
|
|
|
|
func getIPv4Chain(t *testing.T) *nftables.Chain {
|
|
conn, err := nftables.NewConn()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
|
|
if err != nil {
|
|
if err == nftables.ErrTableNotFound {
|
|
table, err = conn.AddIPv4Table("test_ipv4")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
} else {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
chain, err := table.GetChain("test_chain")
|
|
if err != nil {
|
|
if err == nftables.ErrChainNotFound {
|
|
chain, err = table.AddAcceptChain("test_chain")
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return chain
|
|
}
|
|
|
|
func TestChain_AddAcceptIPRule(t *testing.T) {
|
|
var chain = getIPv4Chain(t)
|
|
_, err := chain.AddAcceptIPv4Rule(net.ParseIP("192.168.2.40").To4(), nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestChain_AddDropIPRule(t *testing.T) {
|
|
var chain = getIPv4Chain(t)
|
|
_, err := chain.AddDropIPv4Rule(net.ParseIP("192.168.2.31").To4(), nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestChain_AddAcceptSetRule(t *testing.T) {
|
|
var chain = getIPv4Chain(t)
|
|
_, err := chain.AddAcceptIPv4SetRule("ipv4_black_set", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestChain_AddDropSetRule(t *testing.T) {
|
|
var chain = getIPv4Chain(t)
|
|
_, err := chain.AddDropIPv4SetRule("ipv4_black_set", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestChain_AddRejectSetRule(t *testing.T) {
|
|
var chain = getIPv4Chain(t)
|
|
_, err := chain.AddRejectIPv4SetRule("ipv4_black_set", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestChain_GetRuleWithUserData(t *testing.T) {
|
|
var chain = getIPv4Chain(t)
|
|
rule, err := chain.GetRuleWithUserData([]byte("test"))
|
|
if err != nil {
|
|
if err == nftables.ErrRuleNotFound {
|
|
t.Log("rule not found")
|
|
return
|
|
} else {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
t.Log("rule:", rule)
|
|
}
|
|
|
|
func TestChain_GetRules(t *testing.T) {
|
|
var chain = getIPv4Chain(t)
|
|
rules, err := chain.GetRules()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for _, rule := range rules {
|
|
t.Log("handle:", rule.Handle(), "set name:", rule.LookupSetName(),
|
|
"verdict:", rule.VerDict(), "user data:", string(rule.UserData()))
|
|
}
|
|
}
|
|
|
|
func TestChain_DeleteRule(t *testing.T) {
|
|
var chain = getIPv4Chain(t)
|
|
rule, err := chain.GetRuleWithUserData([]byte("test"))
|
|
if err != nil {
|
|
if err == nftables.ErrRuleNotFound {
|
|
t.Log("rule not found")
|
|
return
|
|
}
|
|
t.Fatal(err)
|
|
}
|
|
err = chain.DeleteRule(rule)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func TestChain_Flush(t *testing.T) {
|
|
var chain = getIPv4Chain(t)
|
|
err := chain.Flush()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Log("ok")
|
|
}
|