From 96862aaa5f9b9475a528845e56485cfcb506eb36 Mon Sep 17 00:00:00 2001 From: Maria Shaldybin Date: Fri, 27 Feb 2026 23:29:55 +0000 Subject: [PATCH] Add optional UID argument to bosh-enable-monit-access command --- main/agent.go | 6 +- platform/firewall/firewall.go | 6 +- .../firewall/firewallfakes/fake_manager.go | 21 +++++-- platform/firewall/monit_access.go | 24 ++++++-- platform/firewall/nftables_firewall.go | 14 +++-- platform/firewall/nftables_firewall_test.go | 60 ++++++++++++++----- 6 files changed, 97 insertions(+), 34 deletions(-) diff --git a/main/agent.go b/main/agent.go index c74f0df8a..48fc768ab 100644 --- a/main/agent.go +++ b/main/agent.go @@ -83,7 +83,11 @@ func main() { switch binaryName(os.Args[0]) { case "bosh-enable-monit-access": - firewall.EnableMonitAccess(logger, "enable-monit-access") + uidArg := "" + if len(os.Args) > 1 { + uidArg = os.Args[1] + } + firewall.EnableMonitAccess(logger, "enable-monit-access", uidArg) return } diff --git a/platform/firewall/firewall.go b/platform/firewall/firewall.go index 66517768c..0844d0123 100644 --- a/platform/firewall/firewall.go +++ b/platform/firewall/firewall.go @@ -39,8 +39,10 @@ type Manager interface { SetupMonitFirewall() error // EnableMonitAccess enables monit access by adding firewall rules. - // It first tries to use cgroup-based matching, then falls back to UID-based matching. - EnableMonitAccess() error + // Cgroup-based matching is always tried first for better isolation. + // If cgroup matching fails and uid is provided, a UID-based rule is added + // as a fallback. If uid is nil and cgroup matching fails, an error is returned. + EnableMonitAccess(uid *uint32) error // SetupNATSFirewall creates firewall rules to protect NATS. // Only root (UID 0) is allowed to connect to the resolved NATS address. diff --git a/platform/firewall/firewallfakes/fake_manager.go b/platform/firewall/firewallfakes/fake_manager.go index 2cd07221a..36d04afb1 100644 --- a/platform/firewall/firewallfakes/fake_manager.go +++ b/platform/firewall/firewallfakes/fake_manager.go @@ -18,9 +18,10 @@ type FakeManager struct { cleanupReturnsOnCall map[int]struct { result1 error } - EnableMonitAccessStub func() error + EnableMonitAccessStub func(*uint32) error enableMonitAccessMutex sync.RWMutex enableMonitAccessArgsForCall []struct { + arg1 *uint32 } enableMonitAccessReturns struct { result1 error @@ -106,17 +107,18 @@ func (fake *FakeManager) CleanupReturnsOnCall(i int, result1 error) { }{result1} } -func (fake *FakeManager) EnableMonitAccess() error { +func (fake *FakeManager) EnableMonitAccess(arg1 *uint32) error { fake.enableMonitAccessMutex.Lock() ret, specificReturn := fake.enableMonitAccessReturnsOnCall[len(fake.enableMonitAccessArgsForCall)] fake.enableMonitAccessArgsForCall = append(fake.enableMonitAccessArgsForCall, struct { - }{}) + arg1 *uint32 + }{arg1}) stub := fake.EnableMonitAccessStub fakeReturns := fake.enableMonitAccessReturns - fake.recordInvocation("EnableMonitAccess", []interface{}{}) + fake.recordInvocation("EnableMonitAccess", []interface{}{arg1}) fake.enableMonitAccessMutex.Unlock() if stub != nil { - return stub() + return stub(arg1) } if specificReturn { return ret.result1 @@ -130,12 +132,19 @@ func (fake *FakeManager) EnableMonitAccessCallCount() int { return len(fake.enableMonitAccessArgsForCall) } -func (fake *FakeManager) EnableMonitAccessCalls(stub func() error) { +func (fake *FakeManager) EnableMonitAccessCalls(stub func(*uint32) error) { fake.enableMonitAccessMutex.Lock() defer fake.enableMonitAccessMutex.Unlock() fake.EnableMonitAccessStub = stub } +func (fake *FakeManager) EnableMonitAccessArgsForCall(i int) *uint32 { + fake.enableMonitAccessMutex.RLock() + defer fake.enableMonitAccessMutex.RUnlock() + argsForCall := fake.enableMonitAccessArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeManager) EnableMonitAccessReturns(result1 error) { fake.enableMonitAccessMutex.Lock() defer fake.enableMonitAccessMutex.Unlock() diff --git a/platform/firewall/monit_access.go b/platform/firewall/monit_access.go index d648c717c..cbb619819 100644 --- a/platform/firewall/monit_access.go +++ b/platform/firewall/monit_access.go @@ -3,7 +3,11 @@ // // Usage: // -// bosh-agent enable-monit-access # Add firewall rule (cgroup preferred, UID fallback) +// bosh-enable-monit-access [] +// +// When a UID is provided, a UID-based rule is added for that user. +// When no UID is provided, cgroup-based matching is tried first, +// falling back to a UID-based rule for the current user. // // This binary serves as a replacement for the complex bash firewall setup logic // that was previously in job service scripts. @@ -12,11 +16,12 @@ package firewall import ( "errors" "os" + "strconv" boshlog "github.com/cloudfoundry/bosh-utils/logger" ) -func EnableMonitAccess(logger boshlog.Logger, command string) { +func EnableMonitAccess(logger boshlog.Logger, command string, uidArg string) { logger.UseTags([]boshlog.LogTag{{Name: "monit-access", LogLevel: boshlog.LevelDebug}}) mgr, err := NewNftablesFirewall(logger) @@ -30,10 +35,21 @@ func EnableMonitAccess(logger boshlog.Logger, command string) { } defer mgr.Cleanup() //nolint:errcheck - // Setup mode: add firewall rule logger.Info(command, "Setting up monit firewall rule") - err = mgr.EnableMonitAccess() + var uid *uint32 + if uidArg != "" { + parsed, err := strconv.ParseUint(uidArg, 10, 32) + if err != nil { + logger.Error(command, "Invalid UID argument %q: %v. Usage: %s []", uidArg, err, command) + os.Exit(1) + } + u := uint32(parsed) + uid = &u + logger.Info(command, "UID argument provided: %d", *uid) + } + + err = mgr.EnableMonitAccess(uid) if err != nil { logger.Error(command, "Failed to enable monit access: %v", err) os.Exit(1) diff --git a/platform/firewall/nftables_firewall.go b/platform/firewall/nftables_firewall.go index c78fb49ae..c5cfe3c65 100644 --- a/platform/firewall/nftables_firewall.go +++ b/platform/firewall/nftables_firewall.go @@ -7,7 +7,6 @@ import ( "fmt" "net" gonetURL "net/url" - "os" "strconv" "strings" @@ -172,7 +171,7 @@ func (f *NftablesFirewall) SetupMonitFirewall() error { return nil } -func (f *NftablesFirewall) EnableMonitAccess() error { +func (f *NftablesFirewall) EnableMonitAccess(uid *uint32) error { // 1. Check if jobs chain exists err := f.getMonitJobsChainAndTable() if err != nil { @@ -200,11 +199,14 @@ func (f *NftablesFirewall) EnableMonitAccess() error { f.logger.Error(f.logTag, "Could not detect cgroup: %v", err) } - // 3. Fallback to UID-based rule - uid := uint32(os.Getuid()) - f.logger.Info(f.logTag, "Falling back to UID rule for UID: %d", uid) + // 3. Fallback to UID-based rule if a UID was provided + if uid == nil { + return fmt.Errorf("cgroup matching failed and no UID was provided") + } + + f.logger.Info(f.logTag, "Falling back to UID rule for UID: %d", *uid) - return f.addUIDRule(uid) + return f.addUIDRule(*uid) } // SetupNATSFirewall creates firewall rules to protect NATS. diff --git a/platform/firewall/nftables_firewall_test.go b/platform/firewall/nftables_firewall_test.go index a02c4221b..7c2095139 100644 --- a/platform/firewall/nftables_firewall_test.go +++ b/platform/firewall/nftables_firewall_test.go @@ -6,7 +6,6 @@ import ( "encoding/binary" "errors" "net" - "os" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -172,7 +171,10 @@ var _ = Describe("NftablesFirewall", func() { }) Context("when the jobs chain exists", func() { + var vcapUID uint32 + BeforeEach(func() { + vcapUID = uint32(1000) fakeConn.ListTablesReturns([]*nftables.Table{boshTable}, nil) fakeConn.ListChainsReturns([]*nftables.Chain{ { @@ -183,7 +185,7 @@ var _ = Describe("NftablesFirewall", func() { }) It("adds a rule and flushes", func() { - err := manager.EnableMonitAccess() + err := manager.EnableMonitAccess(&vcapUID) Expect(err).NotTo(HaveOccurred()) Expect(fakeConn.AddRuleCallCount()).To(Equal(1)) @@ -191,7 +193,7 @@ var _ = Describe("NftablesFirewall", func() { }) It("adds the rule to the monit_access_jobs chain", func() { - err := manager.EnableMonitAccess() + err := manager.EnableMonitAccess(&vcapUID) Expect(err).NotTo(HaveOccurred()) Expect(fakeConn.AddRuleCallCount()).To(Equal(1)) @@ -200,7 +202,7 @@ var _ = Describe("NftablesFirewall", func() { }) It("adds a rule targeting loopback and monit port", func() { - err := manager.EnableMonitAccess() + err := manager.EnableMonitAccess(&vcapUID) Expect(err).NotTo(HaveOccurred()) rule := fakeConn.AddRuleArgsForCall(0) @@ -215,10 +217,41 @@ var _ = Describe("NftablesFirewall", func() { Expect(hasAcceptVerdict).To(BeTrue(), "rule should have an accept verdict") }) + It("adds a UID rule matching the provided UID", func() { + uid := uint32(1001) + err := manager.EnableMonitAccess(&uid) + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeConn.AddRuleCallCount()).To(Equal(1)) + rule := fakeConn.AddRuleArgsForCall(0) + + uidBytes := make([]byte, 4) + binary.NativeEndian.PutUint32(uidBytes, uid) + + hasUIDMatch := false + for _, e := range rule.Exprs { + if cmpExpr, ok := e.(*expr.Cmp); ok { + if cmpExpr.Op == expr.CmpOpEq && len(cmpExpr.Data) == 4 { + if binary.NativeEndian.Uint32(cmpExpr.Data) == uid { + hasUIDMatch = true + } + } + } + } + Expect(hasUIDMatch).To(BeTrue(), "rule should match on the provided UID") + }) + + Context("when no UID is provided and cgroup detection fails", func() { + It("returns an error", func() { + err := manager.EnableMonitAccess(nil) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("no UID was provided")) + }) + }) + Context("when the rule already exists (idempotency)", func() { It("does not add a duplicate UID rule", func() { - // Simulate an existing UID rule matching the current UID - uid := uint32(os.Getuid()) + uid := uint32(1000) uidBytes := make([]byte, 4) binary.NativeEndian.PutUint32(uidBytes, uid) @@ -229,12 +262,9 @@ var _ = Describe("NftablesFirewall", func() { }, } - // First call to GetRules is from cleanupStaleJobRules (cgroup path), - // subsequent calls check for existing rules. - // The UID fallback path calls GetRules to check for duplicates. fakeConn.GetRulesReturns([]*nftables.Rule{existingRule}, nil) - err := manager.EnableMonitAccess() + err := manager.EnableMonitAccess(&uid) Expect(err).NotTo(HaveOccurred()) Expect(fakeConn.AddRuleCallCount()).To(Equal(0)) Expect(fakeConn.FlushCallCount()).To(Equal(0)) @@ -245,7 +275,7 @@ var _ = Describe("NftablesFirewall", func() { It("returns an error", func() { fakeConn.FlushReturns(errors.New("flush failed")) - err := manager.EnableMonitAccess() + err := manager.EnableMonitAccess(&vcapUID) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("flush")) }) @@ -258,7 +288,7 @@ var _ = Describe("NftablesFirewall", func() { }) It("returns an error", func() { - err := manager.EnableMonitAccess() + err := manager.EnableMonitAccess(nil) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("Failed to check if jobs chain exists")) }) @@ -272,7 +302,7 @@ var _ = Describe("NftablesFirewall", func() { }) It("returns bosh_agent table not found error", func() { - err := manager.EnableMonitAccess() + err := manager.EnableMonitAccess(nil) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("bosh_agent table not found")) Expect(fakeConn.AddRuleCallCount()).To(Equal(0)) @@ -286,7 +316,7 @@ var _ = Describe("NftablesFirewall", func() { }) It("returns an error", func() { - err := manager.EnableMonitAccess() + err := manager.EnableMonitAccess(nil) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("Failed to check if jobs chain exists")) Expect(err.Error()).To(ContainSubstring("listing chains")) @@ -300,7 +330,7 @@ var _ = Describe("NftablesFirewall", func() { }) It("returns monit_access_jobs chain not found error", func() { - err := manager.EnableMonitAccess() + err := manager.EnableMonitAccess(nil) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("monit_access_jobs chain not found")) })