Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion main/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ func main() {

switch binaryName(os.Args[0]) {
case "bosh-enable-monit-access":
firewall.EnableMonitAccess(logger, "enable-monit-access")
uidArg := ""
if len(os.Args) > 1 {
uidArg = os.Args[1]
}
firewall.EnableMonitAccess(logger, "enable-monit-access", uidArg)
return
}

Expand Down
6 changes: 4 additions & 2 deletions platform/firewall/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ type Manager interface {
SetupMonitFirewall() error

// EnableMonitAccess enables monit access by adding firewall rules.
// It first tries to use cgroup-based matching, then falls back to UID-based matching.
EnableMonitAccess() error
// Cgroup-based matching is always tried first for better isolation.
// If cgroup matching fails and uid is provided, a UID-based rule is added
// as a fallback. If uid is nil and cgroup matching fails, an error is returned.
EnableMonitAccess(uid *uint32) error

// SetupNATSFirewall creates firewall rules to protect NATS.
// Only root (UID 0) is allowed to connect to the resolved NATS address.
Expand Down
21 changes: 15 additions & 6 deletions platform/firewall/firewallfakes/fake_manager.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 20 additions & 4 deletions platform/firewall/monit_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
//
// Usage:
//
// bosh-agent enable-monit-access # Add firewall rule (cgroup preferred, UID fallback)
// bosh-enable-monit-access [<uid>]
//
// When a UID is provided, a UID-based rule is added for that user.
// When no UID is provided, cgroup-based matching is tried first,
// falling back to a UID-based rule for the current user.
//
// This binary serves as a replacement for the complex bash firewall setup logic
// that was previously in job service scripts.
Expand All @@ -12,11 +16,12 @@ package firewall
import (
"errors"
"os"
"strconv"

boshlog "github.com/cloudfoundry/bosh-utils/logger"
)

func EnableMonitAccess(logger boshlog.Logger, command string) {
func EnableMonitAccess(logger boshlog.Logger, command string, uidArg string) {
logger.UseTags([]boshlog.LogTag{{Name: "monit-access", LogLevel: boshlog.LevelDebug}})

mgr, err := NewNftablesFirewall(logger)
Expand All @@ -30,10 +35,21 @@ func EnableMonitAccess(logger boshlog.Logger, command string) {
}
defer mgr.Cleanup() //nolint:errcheck

// Setup mode: add firewall rule
logger.Info(command, "Setting up monit firewall rule")

err = mgr.EnableMonitAccess()
var uid *uint32
if uidArg != "" {
parsed, err := strconv.ParseUint(uidArg, 10, 32)
if err != nil {
logger.Error(command, "Invalid UID argument %q: %v. Usage: %s [<uid>]", uidArg, err, command)
os.Exit(1)
}
u := uint32(parsed)
uid = &u
logger.Info(command, "UID argument provided: %d", *uid)
}

err = mgr.EnableMonitAccess(uid)
if err != nil {
logger.Error(command, "Failed to enable monit access: %v", err)
os.Exit(1)
Expand Down
14 changes: 8 additions & 6 deletions platform/firewall/nftables_firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"net"
gonetURL "net/url"
"os"
"strconv"
"strings"

Expand Down Expand Up @@ -172,7 +171,7 @@ func (f *NftablesFirewall) SetupMonitFirewall() error {
return nil
}

func (f *NftablesFirewall) EnableMonitAccess() error {
func (f *NftablesFirewall) EnableMonitAccess(uid *uint32) error {
// 1. Check if jobs chain exists
err := f.getMonitJobsChainAndTable()
if err != nil {
Expand Down Expand Up @@ -200,11 +199,14 @@ func (f *NftablesFirewall) EnableMonitAccess() error {
f.logger.Error(f.logTag, "Could not detect cgroup: %v", err)
}

// 3. Fallback to UID-based rule
uid := uint32(os.Getuid())
f.logger.Info(f.logTag, "Falling back to UID rule for UID: %d", uid)
// 3. Fallback to UID-based rule if a UID was provided
if uid == nil {
return fmt.Errorf("cgroup matching failed and no UID was provided")
}

f.logger.Info(f.logTag, "Falling back to UID rule for UID: %d", *uid)

return f.addUIDRule(uid)
return f.addUIDRule(*uid)
}

// SetupNATSFirewall creates firewall rules to protect NATS.
Expand Down
60 changes: 45 additions & 15 deletions platform/firewall/nftables_firewall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/binary"
"errors"
"net"
"os"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -172,7 +171,10 @@ var _ = Describe("NftablesFirewall", func() {
})

Context("when the jobs chain exists", func() {
var vcapUID uint32

BeforeEach(func() {
vcapUID = uint32(1000)
fakeConn.ListTablesReturns([]*nftables.Table{boshTable}, nil)
fakeConn.ListChainsReturns([]*nftables.Chain{
{
Expand All @@ -183,15 +185,15 @@ var _ = Describe("NftablesFirewall", func() {
})

It("adds a rule and flushes", func() {
err := manager.EnableMonitAccess()
err := manager.EnableMonitAccess(&vcapUID)
Expect(err).NotTo(HaveOccurred())

Expect(fakeConn.AddRuleCallCount()).To(Equal(1))
Expect(fakeConn.FlushCallCount()).To(Equal(1))
})

It("adds the rule to the monit_access_jobs chain", func() {
err := manager.EnableMonitAccess()
err := manager.EnableMonitAccess(&vcapUID)
Expect(err).NotTo(HaveOccurred())

Expect(fakeConn.AddRuleCallCount()).To(Equal(1))
Expand All @@ -200,7 +202,7 @@ var _ = Describe("NftablesFirewall", func() {
})

It("adds a rule targeting loopback and monit port", func() {
err := manager.EnableMonitAccess()
err := manager.EnableMonitAccess(&vcapUID)
Expect(err).NotTo(HaveOccurred())

rule := fakeConn.AddRuleArgsForCall(0)
Expand All @@ -215,10 +217,41 @@ var _ = Describe("NftablesFirewall", func() {
Expect(hasAcceptVerdict).To(BeTrue(), "rule should have an accept verdict")
})

It("adds a UID rule matching the provided UID", func() {
uid := uint32(1001)
err := manager.EnableMonitAccess(&uid)
Expect(err).NotTo(HaveOccurred())

Expect(fakeConn.AddRuleCallCount()).To(Equal(1))
rule := fakeConn.AddRuleArgsForCall(0)

uidBytes := make([]byte, 4)
binary.NativeEndian.PutUint32(uidBytes, uid)

hasUIDMatch := false
for _, e := range rule.Exprs {
if cmpExpr, ok := e.(*expr.Cmp); ok {
if cmpExpr.Op == expr.CmpOpEq && len(cmpExpr.Data) == 4 {
if binary.NativeEndian.Uint32(cmpExpr.Data) == uid {
hasUIDMatch = true
}
}
}
}
Expect(hasUIDMatch).To(BeTrue(), "rule should match on the provided UID")
})

Context("when no UID is provided and cgroup detection fails", func() {
It("returns an error", func() {
err := manager.EnableMonitAccess(nil)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no UID was provided"))
})
})

Context("when the rule already exists (idempotency)", func() {
It("does not add a duplicate UID rule", func() {
// Simulate an existing UID rule matching the current UID
uid := uint32(os.Getuid())
uid := uint32(1000)
uidBytes := make([]byte, 4)
binary.NativeEndian.PutUint32(uidBytes, uid)

Expand All @@ -229,12 +262,9 @@ var _ = Describe("NftablesFirewall", func() {
},
}

// First call to GetRules is from cleanupStaleJobRules (cgroup path),
// subsequent calls check for existing rules.
// The UID fallback path calls GetRules to check for duplicates.
fakeConn.GetRulesReturns([]*nftables.Rule{existingRule}, nil)

err := manager.EnableMonitAccess()
err := manager.EnableMonitAccess(&uid)
Expect(err).NotTo(HaveOccurred())
Expect(fakeConn.AddRuleCallCount()).To(Equal(0))
Expect(fakeConn.FlushCallCount()).To(Equal(0))
Expand All @@ -245,7 +275,7 @@ var _ = Describe("NftablesFirewall", func() {
It("returns an error", func() {
fakeConn.FlushReturns(errors.New("flush failed"))

err := manager.EnableMonitAccess()
err := manager.EnableMonitAccess(&vcapUID)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("flush"))
})
Expand All @@ -258,7 +288,7 @@ var _ = Describe("NftablesFirewall", func() {
})

It("returns an error", func() {
err := manager.EnableMonitAccess()
err := manager.EnableMonitAccess(nil)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("Failed to check if jobs chain exists"))
})
Expand All @@ -272,7 +302,7 @@ var _ = Describe("NftablesFirewall", func() {
})

It("returns bosh_agent table not found error", func() {
err := manager.EnableMonitAccess()
err := manager.EnableMonitAccess(nil)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("bosh_agent table not found"))
Expect(fakeConn.AddRuleCallCount()).To(Equal(0))
Expand All @@ -286,7 +316,7 @@ var _ = Describe("NftablesFirewall", func() {
})

It("returns an error", func() {
err := manager.EnableMonitAccess()
err := manager.EnableMonitAccess(nil)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("Failed to check if jobs chain exists"))
Expect(err.Error()).To(ContainSubstring("listing chains"))
Expand All @@ -300,7 +330,7 @@ var _ = Describe("NftablesFirewall", func() {
})

It("returns monit_access_jobs chain not found error", func() {
err := manager.EnableMonitAccess()
err := manager.EnableMonitAccess(nil)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("monit_access_jobs chain not found"))
})
Expand Down