Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ci/tasks/test-integration.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ pushd "${bosh_agent_dir}/integration/fake-blobstore"
copy_to_remote_host ./fake-blobstore /home/agent_test_user/fake-blobstore
popd

echo -e "\n Installing nftables-checker..."
pushd "${bosh_agent_dir}/integration/nftables-checker"
GOARCH=amd64 GOOS=linux CGO_ENABLED=0 go build .
copy_to_remote_host ./nftables-checker /home/agent_test_user/nftables-checker
popd

echo -e "\n Setup assets"
pushd "${bosh_agent_dir}/integration/assets"
Expand Down
22 changes: 9 additions & 13 deletions integration/nats_firewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,24 @@ var _ = Describe("nats firewall", func() {
It("sets up the outgoing nats firewall", func() {
format.MaxLength = 0

// Wait a maximum of 300 seconds
Eventually(func() string {
logs, _ := testEnvironment.RunCommand("sudo cat /var/vcap/bosh/log/current") //nolint:errcheck
return logs
}, 300).Should(ContainSubstring("Updated NATS firewall rules"))

output, err := testEnvironment.RunCommand("sudo nft list chain inet bosh_agent nats_access")
Expect(err).To(BeNil())

boshEnv := os.Getenv("BOSH_ENVIRONMENT")

Expect(output).To(MatchRegexp(`meta skuid 0 ip daddr %s tcp dport 4222 accept`, boshEnv))
Expect(output).To(MatchRegexp(`ip daddr %s tcp dport 4222 drop`, boshEnv))
output, err := testEnvironment.RunCommand("sudo /home/agent_test_user/nftables-checker --table bosh_agent --chain nats_access")
Expect(err).To(BeNil())
Expect(output).To(MatchRegexp(`ACCEPT uid=0 dst=%s dport=4222`, boshEnv))
Expect(output).To(MatchRegexp(`DROP dst=%s dport=4222`, boshEnv))

// check that non-root cannot access the director nats, -w2 == timeout 2 seconds
out, err := testEnvironment.RunCommand(fmt.Sprintf("nc %v 4222 -w2 -v", boshEnv))
Expect(err).NotTo(BeNil())
Expect(out).To(ContainSubstring("port 4222 (tcp) timed out"))

// root (UID 0) should be allowed through the firewall
// root (UID 0) is allowed through the nftables firewall
out, err = testEnvironment.RunCommand(fmt.Sprintf("sudo nc %v 4222 -w2 -v", boshEnv))
Expect(out).To(MatchRegexp("INFO.*server_id.*version.*host.*"))
Expect(err).To(BeNil())
Expand Down Expand Up @@ -75,24 +73,22 @@ var _ = Describe("nats firewall", func() {
AfterEach(func() {
err := testEnvironment.DetachDevice("/dev/sdh")
Expect(err).ToNot(HaveOccurred())
_, err = testEnvironment.RunCommand("sudo nft flush chain inet bosh_agent nats_access")
_, err = testEnvironment.RunCommand("sudo /home/agent_test_user/nftables-checker --table bosh_agent --chain nats_access --flush")
Expect(err).To(BeNil())
})

It("sets up the outgoing nats for firewall ipv6", func() {
format.MaxLength = 0

// Wait a maximum of 300 seconds
Eventually(func() string {
logs, _ := testEnvironment.RunCommand("sudo cat /var/vcap/bosh/log/current") //nolint:errcheck
return logs
}, 300).Should(ContainSubstring("Updated NATS firewall rules"))

output, err := testEnvironment.RunCommand("sudo nft list chain inet bosh_agent nats_access")
output, err := testEnvironment.RunCommand("sudo /home/agent_test_user/nftables-checker --table bosh_agent --chain nats_access")
Expect(err).To(BeNil())

Expect(output).To(MatchRegexp(`meta skuid 0 ip6 daddr 2001:db8::1 tcp dport 4222 accept`))
Expect(output).To(MatchRegexp(`ip6 daddr 2001:db8::1 tcp dport 4222 drop`))
Expect(output).To(ContainSubstring("ACCEPT uid=0 dst=2001:db8::1 dport=4222"))
Expect(output).To(ContainSubstring("DROP dst=2001:db8::1 dport=4222"))
})
})
})
225 changes: 225 additions & 0 deletions integration/nftables-checker/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
//go:build linux

package main

import (
"encoding/binary"
"flag"
"fmt"
"net"
"os"

"github.com/google/nftables"
"github.com/google/nftables/expr"
"golang.org/x/sys/unix"
)

func main() {
tableName := flag.String("table", "", "nftables table name")
chainName := flag.String("chain", "", "nftables chain name")
flush := flag.Bool("flush", false, "flush (delete all rules from) the chain instead of listing")
flag.Parse()

if *tableName == "" || *chainName == "" {
fmt.Fprintln(os.Stderr, "usage: nftables-checker --table TABLE --chain CHAIN [--flush]")
os.Exit(2)
}

conn, err := nftables.New()
if err != nil {
fmt.Fprintf(os.Stderr, "failed to open nftables connection: %v\n", err)
os.Exit(1)
}

table, chain, err := findChain(conn, *tableName, *chainName)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}

if *flush {
conn.FlushChain(chain)
if err := conn.Flush(); err != nil {
fmt.Fprintf(os.Stderr, "failed to flush chain: %v\n", err)
os.Exit(1)
}
fmt.Printf("flushed chain %s in table %s\n", *chainName, *tableName)
return
}

rules, err := conn.GetRules(table, chain)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to get rules: %v\n", err)
os.Exit(1)
}

if len(rules) == 0 {
fmt.Println("no rules found")
os.Exit(0)
}

for _, rule := range rules {
fmt.Println(formatRule(rule))
}
}

func findChain(conn *nftables.Conn, tableName, chainName string) (*nftables.Table, *nftables.Chain, error) {
tables, err := conn.ListTables()
if err != nil {
return nil, nil, fmt.Errorf("failed to list tables: %w", err)
}

var table *nftables.Table
for _, t := range tables {
if t.Name == tableName && t.Family == nftables.TableFamilyINet {
table = t
break
}
}
if table == nil {
return nil, nil, fmt.Errorf("table inet %s not found", tableName)
}

chains, err := conn.ListChains()
if err != nil {
return nil, nil, fmt.Errorf("failed to list chains: %w", err)
}

for _, c := range chains {
if c.Table.Name == tableName && c.Name == chainName {
return table, c, nil
}
}
return nil, nil, fmt.Errorf("chain %s not found in table %s", chainName, tableName)
}

func formatRule(rule *nftables.Rule) string {
var verdict string
var uid = -1
var dstIP net.IP

for _, e := range rule.Exprs {
switch v := e.(type) {
case *expr.Verdict:
switch v.Kind {
case expr.VerdictAccept:
verdict = "ACCEPT"
case expr.VerdictDrop:
verdict = "DROP"
default:
verdict = fmt.Sprintf("verdict=%d", v.Kind)
}
case *expr.Meta:
if v.Key == expr.MetaKeySKUID {
uid = extractUID(rule)
}
}
}

dstIP = extractDestIP(rule)
dport := extractDestPort(rule)

result := verdict
if uid >= 0 {
result += fmt.Sprintf(" uid=%d", uid)
}
if dstIP != nil {
result += fmt.Sprintf(" dst=%s", dstIP)
}
if dport >= 0 {
result += fmt.Sprintf(" dport=%d", dport)
}
return result
}

func extractUID(rule *nftables.Rule) int {
foundSKUID := false
for _, e := range rule.Exprs {
switch v := e.(type) {
case *expr.Meta:
if v.Key == expr.MetaKeySKUID {
foundSKUID = true
}
case *expr.Cmp:
if foundSKUID && len(v.Data) == 4 {
return int(binary.NativeEndian.Uint32(v.Data))
}
foundSKUID = false
default:
foundSKUID = false
}
}
return -1
}

func extractDestIP(rule *nftables.Rule) net.IP {
foundNFProto := false
var proto byte

for _, e := range rule.Exprs {
switch v := e.(type) {
case *expr.Meta:
if v.Key == expr.MetaKeyNFPROTO {
foundNFProto = true
}
case *expr.Cmp:
if foundNFProto && len(v.Data) == 1 {
proto = v.Data[0]
foundNFProto = false
continue
}
// After a Payload load, the next Cmp holds the IP
if proto == unix.NFPROTO_IPV4 && len(v.Data) == 4 {
return net.IP(v.Data)
}
if proto == unix.NFPROTO_IPV6 && len(v.Data) == 16 {
return net.IP(v.Data)
}
case *expr.Payload:
// IPv4 dst offset=16 len=4, IPv6 dst offset=24 len=16
if v.Base == expr.PayloadBaseNetworkHeader &&
((v.Offset == 16 && v.Len == 4) || (v.Offset == 24 && v.Len == 16)) {
continue
}
proto = 0
default:
foundNFProto = false
}
}
return nil
}

func extractDestPort(rule *nftables.Rule) int {
foundTCP := false
foundPayload := false

for _, e := range rule.Exprs {
switch v := e.(type) {
case *expr.Meta:
if v.Key == expr.MetaKeyL4PROTO {
foundTCP = false
foundPayload = false
}
case *expr.Cmp:
if !foundTCP && len(v.Data) == 1 && v.Data[0] == unix.IPPROTO_TCP {
foundTCP = true
continue
}
if foundTCP && foundPayload && len(v.Data) == 2 {
return int(binary.BigEndian.Uint16(v.Data))
}
foundPayload = false
case *expr.Payload:
if foundTCP && v.Base == expr.PayloadBaseTransportHeader && v.Offset == 2 && v.Len == 2 {
foundPayload = true
continue
}
foundPayload = false
default:
if foundTCP {
continue
}
}
}
return -1
}