From c7947a5f38432ce666084a60b04e7551e5b18e30 Mon Sep 17 00:00:00 2001 From: Chris Selzo Date: Wed, 25 Feb 2026 20:13:29 -0800 Subject: [PATCH] Add nftables-checker binary to integration tests Previously we had assumed we could use the nftables command directly, but the stemcell does not have that binary. Instead, take a page from the agent itself and use the github.com/google/nftables library to verify that the NATs firewall rule is in place ai-assisted=yes [TNZ-60576] --- ci/tasks/test-integration.sh | 5 + integration/nats_firewall_test.go | 22 ++- integration/nftables-checker/main.go | 225 +++++++++++++++++++++++++++ 3 files changed, 239 insertions(+), 13 deletions(-) create mode 100644 integration/nftables-checker/main.go diff --git a/ci/tasks/test-integration.sh b/ci/tasks/test-integration.sh index 78c48d995..9ed9c497b 100755 --- a/ci/tasks/test-integration.sh +++ b/ci/tasks/test-integration.sh @@ -75,6 +75,11 @@ pushd "${bosh_agent_dir}/integration/fake-blobstore" copy_to_remote_host ./fake-blobstore /home/agent_test_user/fake-blobstore popd +echo -e "\n Installing nftables-checker..." +pushd "${bosh_agent_dir}/integration/nftables-checker" + GOARCH=amd64 GOOS=linux CGO_ENABLED=0 go build . + copy_to_remote_host ./nftables-checker /home/agent_test_user/nftables-checker +popd echo -e "\n Setup assets" pushd "${bosh_agent_dir}/integration/assets" diff --git a/integration/nats_firewall_test.go b/integration/nats_firewall_test.go index fd9b32c0f..5ebdec1f5 100644 --- a/integration/nats_firewall_test.go +++ b/integration/nats_firewall_test.go @@ -22,26 +22,24 @@ var _ = Describe("nats firewall", func() { It("sets up the outgoing nats firewall", func() { format.MaxLength = 0 - // Wait a maximum of 300 seconds Eventually(func() string { logs, _ := testEnvironment.RunCommand("sudo cat /var/vcap/bosh/log/current") //nolint:errcheck return logs }, 300).Should(ContainSubstring("Updated NATS firewall rules")) - output, err := testEnvironment.RunCommand("sudo nft list chain inet bosh_agent nats_access") - Expect(err).To(BeNil()) - boshEnv := os.Getenv("BOSH_ENVIRONMENT") - Expect(output).To(MatchRegexp(`meta skuid 0 ip daddr %s tcp dport 4222 accept`, boshEnv)) - Expect(output).To(MatchRegexp(`ip daddr %s tcp dport 4222 drop`, boshEnv)) + output, err := testEnvironment.RunCommand("sudo /home/agent_test_user/nftables-checker --table bosh_agent --chain nats_access") + Expect(err).To(BeNil()) + Expect(output).To(MatchRegexp(`ACCEPT uid=0 dst=%s dport=4222`, boshEnv)) + Expect(output).To(MatchRegexp(`DROP dst=%s dport=4222`, boshEnv)) // check that non-root cannot access the director nats, -w2 == timeout 2 seconds out, err := testEnvironment.RunCommand(fmt.Sprintf("nc %v 4222 -w2 -v", boshEnv)) Expect(err).NotTo(BeNil()) Expect(out).To(ContainSubstring("port 4222 (tcp) timed out")) - // root (UID 0) should be allowed through the firewall + // root (UID 0) is allowed through the nftables firewall out, err = testEnvironment.RunCommand(fmt.Sprintf("sudo nc %v 4222 -w2 -v", boshEnv)) Expect(out).To(MatchRegexp("INFO.*server_id.*version.*host.*")) Expect(err).To(BeNil()) @@ -75,24 +73,22 @@ var _ = Describe("nats firewall", func() { AfterEach(func() { err := testEnvironment.DetachDevice("/dev/sdh") Expect(err).ToNot(HaveOccurred()) - _, err = testEnvironment.RunCommand("sudo nft flush chain inet bosh_agent nats_access") + _, err = testEnvironment.RunCommand("sudo /home/agent_test_user/nftables-checker --table bosh_agent --chain nats_access --flush") Expect(err).To(BeNil()) }) It("sets up the outgoing nats for firewall ipv6", func() { format.MaxLength = 0 - // Wait a maximum of 300 seconds Eventually(func() string { logs, _ := testEnvironment.RunCommand("sudo cat /var/vcap/bosh/log/current") //nolint:errcheck return logs }, 300).Should(ContainSubstring("Updated NATS firewall rules")) - output, err := testEnvironment.RunCommand("sudo nft list chain inet bosh_agent nats_access") + output, err := testEnvironment.RunCommand("sudo /home/agent_test_user/nftables-checker --table bosh_agent --chain nats_access") Expect(err).To(BeNil()) - - Expect(output).To(MatchRegexp(`meta skuid 0 ip6 daddr 2001:db8::1 tcp dport 4222 accept`)) - Expect(output).To(MatchRegexp(`ip6 daddr 2001:db8::1 tcp dport 4222 drop`)) + Expect(output).To(ContainSubstring("ACCEPT uid=0 dst=2001:db8::1 dport=4222")) + Expect(output).To(ContainSubstring("DROP dst=2001:db8::1 dport=4222")) }) }) }) diff --git a/integration/nftables-checker/main.go b/integration/nftables-checker/main.go new file mode 100644 index 000000000..df98044be --- /dev/null +++ b/integration/nftables-checker/main.go @@ -0,0 +1,225 @@ +//go:build linux + +package main + +import ( + "encoding/binary" + "flag" + "fmt" + "net" + "os" + + "github.com/google/nftables" + "github.com/google/nftables/expr" + "golang.org/x/sys/unix" +) + +func main() { + tableName := flag.String("table", "", "nftables table name") + chainName := flag.String("chain", "", "nftables chain name") + flush := flag.Bool("flush", false, "flush (delete all rules from) the chain instead of listing") + flag.Parse() + + if *tableName == "" || *chainName == "" { + fmt.Fprintln(os.Stderr, "usage: nftables-checker --table TABLE --chain CHAIN [--flush]") + os.Exit(2) + } + + conn, err := nftables.New() + if err != nil { + fmt.Fprintf(os.Stderr, "failed to open nftables connection: %v\n", err) + os.Exit(1) + } + + table, chain, err := findChain(conn, *tableName, *chainName) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + + if *flush { + conn.FlushChain(chain) + if err := conn.Flush(); err != nil { + fmt.Fprintf(os.Stderr, "failed to flush chain: %v\n", err) + os.Exit(1) + } + fmt.Printf("flushed chain %s in table %s\n", *chainName, *tableName) + return + } + + rules, err := conn.GetRules(table, chain) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to get rules: %v\n", err) + os.Exit(1) + } + + if len(rules) == 0 { + fmt.Println("no rules found") + os.Exit(0) + } + + for _, rule := range rules { + fmt.Println(formatRule(rule)) + } +} + +func findChain(conn *nftables.Conn, tableName, chainName string) (*nftables.Table, *nftables.Chain, error) { + tables, err := conn.ListTables() + if err != nil { + return nil, nil, fmt.Errorf("failed to list tables: %w", err) + } + + var table *nftables.Table + for _, t := range tables { + if t.Name == tableName && t.Family == nftables.TableFamilyINet { + table = t + break + } + } + if table == nil { + return nil, nil, fmt.Errorf("table inet %s not found", tableName) + } + + chains, err := conn.ListChains() + if err != nil { + return nil, nil, fmt.Errorf("failed to list chains: %w", err) + } + + for _, c := range chains { + if c.Table.Name == tableName && c.Name == chainName { + return table, c, nil + } + } + return nil, nil, fmt.Errorf("chain %s not found in table %s", chainName, tableName) +} + +func formatRule(rule *nftables.Rule) string { + var verdict string + var uid = -1 + var dstIP net.IP + + for _, e := range rule.Exprs { + switch v := e.(type) { + case *expr.Verdict: + switch v.Kind { + case expr.VerdictAccept: + verdict = "ACCEPT" + case expr.VerdictDrop: + verdict = "DROP" + default: + verdict = fmt.Sprintf("verdict=%d", v.Kind) + } + case *expr.Meta: + if v.Key == expr.MetaKeySKUID { + uid = extractUID(rule) + } + } + } + + dstIP = extractDestIP(rule) + dport := extractDestPort(rule) + + result := verdict + if uid >= 0 { + result += fmt.Sprintf(" uid=%d", uid) + } + if dstIP != nil { + result += fmt.Sprintf(" dst=%s", dstIP) + } + if dport >= 0 { + result += fmt.Sprintf(" dport=%d", dport) + } + return result +} + +func extractUID(rule *nftables.Rule) int { + foundSKUID := false + for _, e := range rule.Exprs { + switch v := e.(type) { + case *expr.Meta: + if v.Key == expr.MetaKeySKUID { + foundSKUID = true + } + case *expr.Cmp: + if foundSKUID && len(v.Data) == 4 { + return int(binary.NativeEndian.Uint32(v.Data)) + } + foundSKUID = false + default: + foundSKUID = false + } + } + return -1 +} + +func extractDestIP(rule *nftables.Rule) net.IP { + foundNFProto := false + var proto byte + + for _, e := range rule.Exprs { + switch v := e.(type) { + case *expr.Meta: + if v.Key == expr.MetaKeyNFPROTO { + foundNFProto = true + } + case *expr.Cmp: + if foundNFProto && len(v.Data) == 1 { + proto = v.Data[0] + foundNFProto = false + continue + } + // After a Payload load, the next Cmp holds the IP + if proto == unix.NFPROTO_IPV4 && len(v.Data) == 4 { + return net.IP(v.Data) + } + if proto == unix.NFPROTO_IPV6 && len(v.Data) == 16 { + return net.IP(v.Data) + } + case *expr.Payload: + // IPv4 dst offset=16 len=4, IPv6 dst offset=24 len=16 + if v.Base == expr.PayloadBaseNetworkHeader && + ((v.Offset == 16 && v.Len == 4) || (v.Offset == 24 && v.Len == 16)) { + continue + } + proto = 0 + default: + foundNFProto = false + } + } + return nil +} + +func extractDestPort(rule *nftables.Rule) int { + foundTCP := false + foundPayload := false + + for _, e := range rule.Exprs { + switch v := e.(type) { + case *expr.Meta: + if v.Key == expr.MetaKeyL4PROTO { + foundTCP = false + foundPayload = false + } + case *expr.Cmp: + if !foundTCP && len(v.Data) == 1 && v.Data[0] == unix.IPPROTO_TCP { + foundTCP = true + continue + } + if foundTCP && foundPayload && len(v.Data) == 2 { + return int(binary.BigEndian.Uint16(v.Data)) + } + foundPayload = false + case *expr.Payload: + if foundTCP && v.Base == expr.PayloadBaseTransportHeader && v.Offset == 2 && v.Len == 2 { + foundPayload = true + continue + } + foundPayload = false + default: + if foundTCP { + continue + } + } + } + return -1 +}