diff --git a/agent/bootstrap.go b/agent/bootstrap.go index c2830187e..9d80097b5 100644 --- a/agent/bootstrap.go +++ b/agent/bootstrap.go @@ -183,6 +183,10 @@ func (boot bootstrap) Run() (err error) { //nolint:gocyclo } } + if err = boot.platform.SetupFirewall(); err != nil { + return bosherr.WrapError(err, "Setting up firewall") + } + if err = boot.platform.SetupMonitUser(); err != nil { return bosherr.WrapError(err, "Setting up monit user") } diff --git a/go.mod b/go.mod index 743dc5831..8d5bd92ca 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/coreos/go-iptables v0.8.0 github.com/gofrs/uuid v4.4.0+incompatible github.com/golang/mock v1.6.0 + github.com/google/nftables v0.3.0 github.com/google/uuid v1.6.0 github.com/kevinburke/ssh_config v1.4.0 github.com/masterzen/winrm v0.0.0-20250927112105-5f8e6c707321 @@ -70,6 +71,8 @@ require ( github.com/jpillora/backoff v1.0.0 // indirect github.com/klauspost/compress v1.18.3 // indirect github.com/masterzen/simplexml v0.0.0-20190410153822-31eea3082786 // indirect + github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect + github.com/mdlayher/socket v0.5.0 // indirect github.com/moby/sys/userns v0.1.0 // indirect github.com/nats-io/nkeys v0.4.15 // indirect github.com/nats-io/nuid v1.0.1 // indirect diff --git a/go.sum b/go.sum index 6aa4e27d7..2f6665c1d 100644 --- a/go.sum +++ b/go.sum @@ -109,6 +109,8 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/nftables v0.3.0 h1:bkyZ0cbpVeMHXOrtlFc8ISmfVqq5gPJukoYieyVmITg= +github.com/google/nftables v0.3.0/go.mod h1:BCp9FsrbF1Fn/Yu6CLUc9GGZFw/+hsxfluNXXmxBfRM= github.com/google/pprof v0.0.0-20260202012954-cb029daf43ef h1:xpF9fUHpoIrrjX24DURVKiwHcFpw19ndIs+FwTSMbno= github.com/google/pprof v0.0.0-20260202012954-cb029daf43ef/go.mod h1:MxpfABSjhmINe3F1It9d+8exIHFvUqtLIRCdOGNXqiI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -157,6 +159,10 @@ github.com/masterzen/winrm v0.0.0-20250927112105-5f8e6c707321 h1:AKIJL2PfBX2uie0 github.com/masterzen/winrm v0.0.0-20250927112105-5f8e6c707321/go.mod h1:JajVhkiG2bYSNYYPYuWG7WZHr42CTjMTcCjfInRNCqc= github.com/maxbrunsfeld/counterfeiter/v6 v6.12.1 h1:D4O2wLxB384TS3ohBJMfolnxb4qGmoZ1PnWNtit8LYo= github.com/maxbrunsfeld/counterfeiter/v6 v6.12.1/go.mod h1:RuJdxo0oI6dClIaMzdl3hewq3a065RH65dofJP03h8I= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= +github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= +github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI= github.com/mfridman/tparse v0.18.0 h1:wh6dzOKaIwkUGyKgOntDW4liXSo37qg5AXbIhkMV3vE= github.com/mfridman/tparse v0.18.0/go.mod h1:gEvqZTuCgEhPbYk/2lS3Kcxg1GmTxxU7kTC8DvP0i/A= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= @@ -220,6 +226,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde h1:AMNpJRc7P+GTwVbl8DkK2I9I8BBUzNiHuH/tlxrpan0= github.com/tidwall/transform v0.0.0-20201103190739-32f242e2dbde/go.mod h1:MvrEmduDUz4ST5pGZ7CABCnOU5f3ZiOAZzT6b1A6nX8= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/mbus/nats_handler.go b/mbus/nats_handler.go index f08553157..dd080df2a 100644 --- a/mbus/nats_handler.go +++ b/mbus/nats_handler.go @@ -99,14 +99,29 @@ func NewNatsHandler( func (h *natsHandler) arpClean() { connectionInfo, err := h.getConnectionInfo() if err != nil { - h.logger.Error(h.logTag, "%v", bosherr.WrapError(err, "Getting connection info")) + h.logger.Error(h.logTag, "Failed to get connection info for ARP clean: %v", err) + return } - err = h.platform.DeleteARPEntryWithIP(connectionInfo.IP) - if err != nil { + if err := h.platform.DeleteARPEntryWithIP(connectionInfo.IP); err != nil { h.logger.Error(h.logTag, "Cleaning ip-mac address cache for: %s. Error: %v", connectionInfo.IP, err) } +} - h.logger.Debug(h.logTag, "Cleaned ip-mac address cache for: %s.", connectionInfo.IP) +// updateFirewallForNATS calls the firewall hook to update NATS rules before connection/reconnection. +// This allows DNS to be re-resolved, supporting HA failover where the director may have moved. +func (h *natsHandler) updateFirewallForNATS() { + hook := h.platform.GetNatsFirewallHook() + if hook == nil { + return + } + + settings := h.settingsService.GetSettings() + mbusURL := settings.GetMbusURL() + + if err := hook.BeforeConnect(mbusURL); err != nil { + // Log but don't fail - firewall update failure shouldn't prevent connection attempt + h.logger.Warn(h.logTag, "Failed to update NATS firewall rules: %v", err) + } } func (h *natsHandler) Run(handlerFunc boshhandler.Func) error { @@ -131,11 +146,21 @@ func (h *natsHandler) Start(handlerFunc boshhandler.Func) error { if net.ParseIP(connectionInfo.IP) != nil { h.arpClean() } + + // Update firewall rules before initial connection + h.updateFirewallForNATS() + var natsOptions = []nats.Option{ nats.RetryOnFailedConnect(true), nats.DisconnectErrHandler(func(c *nats.Conn, err error) { - h.logger.Debug(natsHandlerLogTag, "Nats disconnected with Error: %v", err.Error()) + if err != nil { + h.logger.Debug(natsHandlerLogTag, "Nats disconnected with Error: %v", err.Error()) + } else { + h.logger.Debug(natsHandlerLogTag, "Nats disconnected") + } h.logger.Debug(natsHandlerLogTag, "Attempting to reconnect: %v", c.IsReconnecting()) + // Update firewall rules before reconnection attempts (allows DNS re-resolution) + h.updateFirewallForNATS() for c.IsReconnecting() { h.arpClean() h.logger.Debug(natsHandlerLogTag, "Waiting to reconnect to nats.. Current attempt: %v, Connected: %v", c.Reconnects, c.IsConnected()) @@ -146,7 +171,11 @@ func (h *natsHandler) Start(handlerFunc boshhandler.Func) error { h.logger.Debug(natsHandlerLogTag, "Reconnected to %v", c.ConnectedAddr()) }), nats.ClosedHandler(func(c *nats.Conn) { - h.logger.Debug(natsHandlerLogTag, "Connection Closed with: %v", c.LastError().Error()) + if err := c.LastError(); err != nil { + h.logger.Debug(natsHandlerLogTag, "Connection Closed with: %v", err.Error()) + } else { + h.logger.Debug(natsHandlerLogTag, "Connection Closed") + } }), nats.ErrorHandler(func(c *nats.Conn, s *nats.Subscription, err error) { h.logger.Debug(natsHandlerLogTag, err.Error()) diff --git a/mbus/nats_handler_test.go b/mbus/nats_handler_test.go index c68b552df..c1ad397c8 100644 --- a/mbus/nats_handler_test.go +++ b/mbus/nats_handler_test.go @@ -18,6 +18,7 @@ import ( boshhandler "github.com/cloudfoundry/bosh-agent/v2/handler" "github.com/cloudfoundry/bosh-agent/v2/mbus" "github.com/cloudfoundry/bosh-agent/v2/mbus/mbusfakes" + "github.com/cloudfoundry/bosh-agent/v2/platform/firewall/firewallfakes" "github.com/cloudfoundry/bosh-agent/v2/platform/platformfakes" boshsettings "github.com/cloudfoundry/bosh-agent/v2/settings" fakesettings "github.com/cloudfoundry/bosh-agent/v2/settings/fakes" @@ -407,6 +408,54 @@ func init() { //nolint:funlen,gochecknoinits }) }) }) + + Context("Firewall hook", func() { + var fakeFirewallHook *firewallfakes.FakeNatsFirewallHook + + BeforeEach(func() { + fakeFirewallHook = &firewallfakes.FakeNatsFirewallHook{} + platform.GetNatsFirewallHookReturns(fakeFirewallHook) + }) + + It("calls GetNatsFirewallHook on Start", func() { + err := handler.Start(func(req boshhandler.Request) (res boshhandler.Response) { return }) + Expect(err).NotTo(HaveOccurred()) + defer handler.Stop() + + Expect(platform.GetNatsFirewallHookCallCount()).To(BeNumerically(">=", 1)) + }) + + It("calls BeforeConnect with the mbus URL before initial connection", func() { + err := handler.Start(func(req boshhandler.Request) (res boshhandler.Response) { return }) + Expect(err).NotTo(HaveOccurred()) + defer handler.Stop() + + Expect(fakeFirewallHook.BeforeConnectCallCount()).To(Equal(1)) + mbusURL := fakeFirewallHook.BeforeConnectArgsForCall(0) + Expect(mbusURL).To(Equal("nats://fake-username:fake-password@127.0.0.1:1234")) + }) + + It("does not fail if hook returns nil", func() { + platform.GetNatsFirewallHookReturns(nil) + + err := handler.Start(func(req boshhandler.Request) (res boshhandler.Response) { return }) + Expect(err).NotTo(HaveOccurred()) + defer handler.Stop() + }) + + It("logs warning but does not fail if BeforeConnect returns error", func() { + fakeFirewallHook.BeforeConnectReturns(errors.New("firewall update failed")) + loggerOutBuf = bytes.NewBufferString("") + logger = boshlog.NewWriterLogger(boshlog.LevelWarn, loggerOutBuf) + handler = mbus.NewNatsHandler(settingsService, connector, logger, platform) + + err := handler.Start(func(req boshhandler.Request) (res boshhandler.Response) { return }) + Expect(err).NotTo(HaveOccurred()) + defer handler.Stop() + + Expect(loggerOutBuf.String()).To(ContainSubstring("Failed to update NATS firewall rules")) + }) + }) }) Describe("Send", func() { diff --git a/platform/dummy_platform.go b/platform/dummy_platform.go index c4a5632c0..8c2eaa00a 100644 --- a/platform/dummy_platform.go +++ b/platform/dummy_platform.go @@ -15,6 +15,7 @@ import ( boshlogstarprovider "github.com/cloudfoundry/bosh-agent/v2/agent/logstarprovider" boshdpresolv "github.com/cloudfoundry/bosh-agent/v2/infrastructure/devicepathresolver" boshcert "github.com/cloudfoundry/bosh-agent/v2/platform/cert" + "github.com/cloudfoundry/bosh-agent/v2/platform/firewall" boship "github.com/cloudfoundry/bosh-agent/v2/platform/net/ip" boshstats "github.com/cloudfoundry/bosh-agent/v2/platform/stats" boshvitals "github.com/cloudfoundry/bosh-agent/v2/platform/vitals" @@ -562,6 +563,14 @@ func (p dummyPlatform) SetupRecordsJSONPermission(path string) error { return nil } +func (p dummyPlatform) SetupFirewall() error { + return nil +} + func (p dummyPlatform) Shutdown() error { return nil } + +func (p dummyPlatform) GetNatsFirewallHook() firewall.NatsFirewallHook { + return nil +} diff --git a/platform/firewall/firewall.go b/platform/firewall/firewall.go new file mode 100644 index 000000000..abfe20204 --- /dev/null +++ b/platform/firewall/firewall.go @@ -0,0 +1,39 @@ +// Package firewall provides nftables-based firewall management for the BOSH agent. +// +// The firewall protects access to: +// - Monit (port 2822 on localhost): Used by the agent to manage job processes +// - NATS (director's message bus): Used for agent-director communication +// +// Security Model: +// The firewall uses UID-based matching (meta skuid 0) to allow only root processes +// to access these services. This blocks non-root BOSH job workloads (vcap user) +// while allowing the agent and operators to access monit/NATS. +// +// This approach is simpler and more reliable than cgroup-based matching, which +// fails in nested container environments due to cgroup filesystem bind-mount issues. +package firewall + +// Manager handles firewall setup +// +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate +//counterfeiter:generate . Manager +type Manager interface { + // SetupMonitFirewall creates firewall rules to protect monit (port 2822). + // Only root (UID 0) is allowed to connect. + SetupMonitFirewall() error + + // SetupNATSFirewall creates firewall rules to protect NATS. + // Only root (UID 0) is allowed to connect to the resolved NATS address. + // This method resolves DNS and should be called before each connection attempt. + SetupNATSFirewall(mbusURL string) error +} + +// NatsFirewallHook is called by the NATS handler before connection/reconnection. +// This allows DNS to be re-resolved, supporting HA failover scenarios. +// +//counterfeiter:generate . NatsFirewallHook +type NatsFirewallHook interface { + // BeforeConnect is called before each NATS connection/reconnection attempt. + // It resolves the NATS URL and updates firewall rules with the resolved IP. + BeforeConnect(mbusURL string) error +} diff --git a/platform/firewall/firewall_suite_test.go b/platform/firewall/firewall_suite_test.go new file mode 100644 index 000000000..678fbdf47 --- /dev/null +++ b/platform/firewall/firewall_suite_test.go @@ -0,0 +1,15 @@ +//go:build linux + +package firewall_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestFirewall(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Firewall Suite") +} diff --git a/platform/firewall/firewallfakes/fake_dnsresolver.go b/platform/firewall/firewallfakes/fake_dnsresolver.go new file mode 100644 index 000000000..deea519e1 --- /dev/null +++ b/platform/firewall/firewallfakes/fake_dnsresolver.go @@ -0,0 +1,117 @@ +//go:build linux + +// Code generated by counterfeiter. DO NOT EDIT. +package firewallfakes + +import ( + "net" + "sync" + + "github.com/cloudfoundry/bosh-agent/v2/platform/firewall" +) + +type FakeDNSResolver struct { + LookupIPStub func(string) ([]net.IP, error) + lookupIPMutex sync.RWMutex + lookupIPArgsForCall []struct { + arg1 string + } + lookupIPReturns struct { + result1 []net.IP + result2 error + } + lookupIPReturnsOnCall map[int]struct { + result1 []net.IP + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeDNSResolver) LookupIP(arg1 string) ([]net.IP, error) { + fake.lookupIPMutex.Lock() + ret, specificReturn := fake.lookupIPReturnsOnCall[len(fake.lookupIPArgsForCall)] + fake.lookupIPArgsForCall = append(fake.lookupIPArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.LookupIPStub + fakeReturns := fake.lookupIPReturns + fake.recordInvocation("LookupIP", []interface{}{arg1}) + fake.lookupIPMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeDNSResolver) LookupIPCallCount() int { + fake.lookupIPMutex.RLock() + defer fake.lookupIPMutex.RUnlock() + return len(fake.lookupIPArgsForCall) +} + +func (fake *FakeDNSResolver) LookupIPCalls(stub func(string) ([]net.IP, error)) { + fake.lookupIPMutex.Lock() + defer fake.lookupIPMutex.Unlock() + fake.LookupIPStub = stub +} + +func (fake *FakeDNSResolver) LookupIPArgsForCall(i int) string { + fake.lookupIPMutex.RLock() + defer fake.lookupIPMutex.RUnlock() + argsForCall := fake.lookupIPArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeDNSResolver) LookupIPReturns(result1 []net.IP, result2 error) { + fake.lookupIPMutex.Lock() + defer fake.lookupIPMutex.Unlock() + fake.LookupIPStub = nil + fake.lookupIPReturns = struct { + result1 []net.IP + result2 error + }{result1, result2} +} + +func (fake *FakeDNSResolver) LookupIPReturnsOnCall(i int, result1 []net.IP, result2 error) { + fake.lookupIPMutex.Lock() + defer fake.lookupIPMutex.Unlock() + fake.LookupIPStub = nil + if fake.lookupIPReturnsOnCall == nil { + fake.lookupIPReturnsOnCall = make(map[int]struct { + result1 []net.IP + result2 error + }) + } + fake.lookupIPReturnsOnCall[i] = struct { + result1 []net.IP + result2 error + }{result1, result2} +} + +func (fake *FakeDNSResolver) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeDNSResolver) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ firewall.DNSResolver = new(FakeDNSResolver) diff --git a/platform/firewall/firewallfakes/fake_manager.go b/platform/firewall/firewallfakes/fake_manager.go new file mode 100644 index 000000000..64177995a --- /dev/null +++ b/platform/firewall/firewallfakes/fake_manager.go @@ -0,0 +1,172 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package firewallfakes + +import ( + "sync" + + "github.com/cloudfoundry/bosh-agent/v2/platform/firewall" +) + +type FakeManager struct { + SetupMonitFirewallStub func() error + setupMonitFirewallMutex sync.RWMutex + setupMonitFirewallArgsForCall []struct { + } + setupMonitFirewallReturns struct { + result1 error + } + setupMonitFirewallReturnsOnCall map[int]struct { + result1 error + } + SetupNATSFirewallStub func(string) error + setupNATSFirewallMutex sync.RWMutex + setupNATSFirewallArgsForCall []struct { + arg1 string + } + setupNATSFirewallReturns struct { + result1 error + } + setupNATSFirewallReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeManager) SetupMonitFirewall() error { + fake.setupMonitFirewallMutex.Lock() + ret, specificReturn := fake.setupMonitFirewallReturnsOnCall[len(fake.setupMonitFirewallArgsForCall)] + fake.setupMonitFirewallArgsForCall = append(fake.setupMonitFirewallArgsForCall, struct { + }{}) + stub := fake.SetupMonitFirewallStub + fakeReturns := fake.setupMonitFirewallReturns + fake.recordInvocation("SetupMonitFirewall", []interface{}{}) + fake.setupMonitFirewallMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeManager) SetupMonitFirewallCallCount() int { + fake.setupMonitFirewallMutex.RLock() + defer fake.setupMonitFirewallMutex.RUnlock() + return len(fake.setupMonitFirewallArgsForCall) +} + +func (fake *FakeManager) SetupMonitFirewallCalls(stub func() error) { + fake.setupMonitFirewallMutex.Lock() + defer fake.setupMonitFirewallMutex.Unlock() + fake.SetupMonitFirewallStub = stub +} + +func (fake *FakeManager) SetupMonitFirewallReturns(result1 error) { + fake.setupMonitFirewallMutex.Lock() + defer fake.setupMonitFirewallMutex.Unlock() + fake.SetupMonitFirewallStub = nil + fake.setupMonitFirewallReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeManager) SetupMonitFirewallReturnsOnCall(i int, result1 error) { + fake.setupMonitFirewallMutex.Lock() + defer fake.setupMonitFirewallMutex.Unlock() + fake.SetupMonitFirewallStub = nil + if fake.setupMonitFirewallReturnsOnCall == nil { + fake.setupMonitFirewallReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.setupMonitFirewallReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeManager) SetupNATSFirewall(arg1 string) error { + fake.setupNATSFirewallMutex.Lock() + ret, specificReturn := fake.setupNATSFirewallReturnsOnCall[len(fake.setupNATSFirewallArgsForCall)] + fake.setupNATSFirewallArgsForCall = append(fake.setupNATSFirewallArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.SetupNATSFirewallStub + fakeReturns := fake.setupNATSFirewallReturns + fake.recordInvocation("SetupNATSFirewall", []interface{}{arg1}) + fake.setupNATSFirewallMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeManager) SetupNATSFirewallCallCount() int { + fake.setupNATSFirewallMutex.RLock() + defer fake.setupNATSFirewallMutex.RUnlock() + return len(fake.setupNATSFirewallArgsForCall) +} + +func (fake *FakeManager) SetupNATSFirewallCalls(stub func(string) error) { + fake.setupNATSFirewallMutex.Lock() + defer fake.setupNATSFirewallMutex.Unlock() + fake.SetupNATSFirewallStub = stub +} + +func (fake *FakeManager) SetupNATSFirewallArgsForCall(i int) string { + fake.setupNATSFirewallMutex.RLock() + defer fake.setupNATSFirewallMutex.RUnlock() + argsForCall := fake.setupNATSFirewallArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeManager) SetupNATSFirewallReturns(result1 error) { + fake.setupNATSFirewallMutex.Lock() + defer fake.setupNATSFirewallMutex.Unlock() + fake.SetupNATSFirewallStub = nil + fake.setupNATSFirewallReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeManager) SetupNATSFirewallReturnsOnCall(i int, result1 error) { + fake.setupNATSFirewallMutex.Lock() + defer fake.setupNATSFirewallMutex.Unlock() + fake.SetupNATSFirewallStub = nil + if fake.setupNATSFirewallReturnsOnCall == nil { + fake.setupNATSFirewallReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.setupNATSFirewallReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeManager) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeManager) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ firewall.Manager = new(FakeManager) diff --git a/platform/firewall/firewallfakes/fake_nats_firewall_hook.go b/platform/firewall/firewallfakes/fake_nats_firewall_hook.go new file mode 100644 index 000000000..6d59eb72d --- /dev/null +++ b/platform/firewall/firewallfakes/fake_nats_firewall_hook.go @@ -0,0 +1,109 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package firewallfakes + +import ( + "sync" + + "github.com/cloudfoundry/bosh-agent/v2/platform/firewall" +) + +type FakeNatsFirewallHook struct { + BeforeConnectStub func(string) error + beforeConnectMutex sync.RWMutex + beforeConnectArgsForCall []struct { + arg1 string + } + beforeConnectReturns struct { + result1 error + } + beforeConnectReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeNatsFirewallHook) BeforeConnect(arg1 string) error { + fake.beforeConnectMutex.Lock() + ret, specificReturn := fake.beforeConnectReturnsOnCall[len(fake.beforeConnectArgsForCall)] + fake.beforeConnectArgsForCall = append(fake.beforeConnectArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.BeforeConnectStub + fakeReturns := fake.beforeConnectReturns + fake.recordInvocation("BeforeConnect", []interface{}{arg1}) + fake.beforeConnectMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeNatsFirewallHook) BeforeConnectCallCount() int { + fake.beforeConnectMutex.RLock() + defer fake.beforeConnectMutex.RUnlock() + return len(fake.beforeConnectArgsForCall) +} + +func (fake *FakeNatsFirewallHook) BeforeConnectCalls(stub func(string) error) { + fake.beforeConnectMutex.Lock() + defer fake.beforeConnectMutex.Unlock() + fake.BeforeConnectStub = stub +} + +func (fake *FakeNatsFirewallHook) BeforeConnectArgsForCall(i int) string { + fake.beforeConnectMutex.RLock() + defer fake.beforeConnectMutex.RUnlock() + argsForCall := fake.beforeConnectArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeNatsFirewallHook) BeforeConnectReturns(result1 error) { + fake.beforeConnectMutex.Lock() + defer fake.beforeConnectMutex.Unlock() + fake.BeforeConnectStub = nil + fake.beforeConnectReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeNatsFirewallHook) BeforeConnectReturnsOnCall(i int, result1 error) { + fake.beforeConnectMutex.Lock() + defer fake.beforeConnectMutex.Unlock() + fake.BeforeConnectStub = nil + if fake.beforeConnectReturnsOnCall == nil { + fake.beforeConnectReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.beforeConnectReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeNatsFirewallHook) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeNatsFirewallHook) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ firewall.NatsFirewallHook = new(FakeNatsFirewallHook) diff --git a/platform/firewall/firewallfakes/fake_nftables_conn.go b/platform/firewall/firewallfakes/fake_nftables_conn.go new file mode 100644 index 000000000..739e88a0b --- /dev/null +++ b/platform/firewall/firewallfakes/fake_nftables_conn.go @@ -0,0 +1,356 @@ +//go:build linux + +// Code generated by counterfeiter. DO NOT EDIT. +package firewallfakes + +import ( + "sync" + + "github.com/cloudfoundry/bosh-agent/v2/platform/firewall" + "github.com/google/nftables" +) + +type FakeNftablesConn struct { + AddChainStub func(*nftables.Chain) *nftables.Chain + addChainMutex sync.RWMutex + addChainArgsForCall []struct { + arg1 *nftables.Chain + } + addChainReturns struct { + result1 *nftables.Chain + } + addChainReturnsOnCall map[int]struct { + result1 *nftables.Chain + } + AddRuleStub func(*nftables.Rule) *nftables.Rule + addRuleMutex sync.RWMutex + addRuleArgsForCall []struct { + arg1 *nftables.Rule + } + addRuleReturns struct { + result1 *nftables.Rule + } + addRuleReturnsOnCall map[int]struct { + result1 *nftables.Rule + } + AddTableStub func(*nftables.Table) *nftables.Table + addTableMutex sync.RWMutex + addTableArgsForCall []struct { + arg1 *nftables.Table + } + addTableReturns struct { + result1 *nftables.Table + } + addTableReturnsOnCall map[int]struct { + result1 *nftables.Table + } + FlushStub func() error + flushMutex sync.RWMutex + flushArgsForCall []struct { + } + flushReturns struct { + result1 error + } + flushReturnsOnCall map[int]struct { + result1 error + } + FlushChainStub func(*nftables.Chain) + flushChainMutex sync.RWMutex + flushChainArgsForCall []struct { + arg1 *nftables.Chain + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeNftablesConn) AddChain(arg1 *nftables.Chain) *nftables.Chain { + fake.addChainMutex.Lock() + ret, specificReturn := fake.addChainReturnsOnCall[len(fake.addChainArgsForCall)] + fake.addChainArgsForCall = append(fake.addChainArgsForCall, struct { + arg1 *nftables.Chain + }{arg1}) + stub := fake.AddChainStub + fakeReturns := fake.addChainReturns + fake.recordInvocation("AddChain", []interface{}{arg1}) + fake.addChainMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeNftablesConn) AddChainCallCount() int { + fake.addChainMutex.RLock() + defer fake.addChainMutex.RUnlock() + return len(fake.addChainArgsForCall) +} + +func (fake *FakeNftablesConn) AddChainCalls(stub func(*nftables.Chain) *nftables.Chain) { + fake.addChainMutex.Lock() + defer fake.addChainMutex.Unlock() + fake.AddChainStub = stub +} + +func (fake *FakeNftablesConn) AddChainArgsForCall(i int) *nftables.Chain { + fake.addChainMutex.RLock() + defer fake.addChainMutex.RUnlock() + argsForCall := fake.addChainArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeNftablesConn) AddChainReturns(result1 *nftables.Chain) { + fake.addChainMutex.Lock() + defer fake.addChainMutex.Unlock() + fake.AddChainStub = nil + fake.addChainReturns = struct { + result1 *nftables.Chain + }{result1} +} + +func (fake *FakeNftablesConn) AddChainReturnsOnCall(i int, result1 *nftables.Chain) { + fake.addChainMutex.Lock() + defer fake.addChainMutex.Unlock() + fake.AddChainStub = nil + if fake.addChainReturnsOnCall == nil { + fake.addChainReturnsOnCall = make(map[int]struct { + result1 *nftables.Chain + }) + } + fake.addChainReturnsOnCall[i] = struct { + result1 *nftables.Chain + }{result1} +} + +func (fake *FakeNftablesConn) AddRule(arg1 *nftables.Rule) *nftables.Rule { + fake.addRuleMutex.Lock() + ret, specificReturn := fake.addRuleReturnsOnCall[len(fake.addRuleArgsForCall)] + fake.addRuleArgsForCall = append(fake.addRuleArgsForCall, struct { + arg1 *nftables.Rule + }{arg1}) + stub := fake.AddRuleStub + fakeReturns := fake.addRuleReturns + fake.recordInvocation("AddRule", []interface{}{arg1}) + fake.addRuleMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeNftablesConn) AddRuleCallCount() int { + fake.addRuleMutex.RLock() + defer fake.addRuleMutex.RUnlock() + return len(fake.addRuleArgsForCall) +} + +func (fake *FakeNftablesConn) AddRuleCalls(stub func(*nftables.Rule) *nftables.Rule) { + fake.addRuleMutex.Lock() + defer fake.addRuleMutex.Unlock() + fake.AddRuleStub = stub +} + +func (fake *FakeNftablesConn) AddRuleArgsForCall(i int) *nftables.Rule { + fake.addRuleMutex.RLock() + defer fake.addRuleMutex.RUnlock() + argsForCall := fake.addRuleArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeNftablesConn) AddRuleReturns(result1 *nftables.Rule) { + fake.addRuleMutex.Lock() + defer fake.addRuleMutex.Unlock() + fake.AddRuleStub = nil + fake.addRuleReturns = struct { + result1 *nftables.Rule + }{result1} +} + +func (fake *FakeNftablesConn) AddRuleReturnsOnCall(i int, result1 *nftables.Rule) { + fake.addRuleMutex.Lock() + defer fake.addRuleMutex.Unlock() + fake.AddRuleStub = nil + if fake.addRuleReturnsOnCall == nil { + fake.addRuleReturnsOnCall = make(map[int]struct { + result1 *nftables.Rule + }) + } + fake.addRuleReturnsOnCall[i] = struct { + result1 *nftables.Rule + }{result1} +} + +func (fake *FakeNftablesConn) AddTable(arg1 *nftables.Table) *nftables.Table { + fake.addTableMutex.Lock() + ret, specificReturn := fake.addTableReturnsOnCall[len(fake.addTableArgsForCall)] + fake.addTableArgsForCall = append(fake.addTableArgsForCall, struct { + arg1 *nftables.Table + }{arg1}) + stub := fake.AddTableStub + fakeReturns := fake.addTableReturns + fake.recordInvocation("AddTable", []interface{}{arg1}) + fake.addTableMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeNftablesConn) AddTableCallCount() int { + fake.addTableMutex.RLock() + defer fake.addTableMutex.RUnlock() + return len(fake.addTableArgsForCall) +} + +func (fake *FakeNftablesConn) AddTableCalls(stub func(*nftables.Table) *nftables.Table) { + fake.addTableMutex.Lock() + defer fake.addTableMutex.Unlock() + fake.AddTableStub = stub +} + +func (fake *FakeNftablesConn) AddTableArgsForCall(i int) *nftables.Table { + fake.addTableMutex.RLock() + defer fake.addTableMutex.RUnlock() + argsForCall := fake.addTableArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeNftablesConn) AddTableReturns(result1 *nftables.Table) { + fake.addTableMutex.Lock() + defer fake.addTableMutex.Unlock() + fake.AddTableStub = nil + fake.addTableReturns = struct { + result1 *nftables.Table + }{result1} +} + +func (fake *FakeNftablesConn) AddTableReturnsOnCall(i int, result1 *nftables.Table) { + fake.addTableMutex.Lock() + defer fake.addTableMutex.Unlock() + fake.AddTableStub = nil + if fake.addTableReturnsOnCall == nil { + fake.addTableReturnsOnCall = make(map[int]struct { + result1 *nftables.Table + }) + } + fake.addTableReturnsOnCall[i] = struct { + result1 *nftables.Table + }{result1} +} + +func (fake *FakeNftablesConn) Flush() error { + fake.flushMutex.Lock() + ret, specificReturn := fake.flushReturnsOnCall[len(fake.flushArgsForCall)] + fake.flushArgsForCall = append(fake.flushArgsForCall, struct { + }{}) + stub := fake.FlushStub + fakeReturns := fake.flushReturns + fake.recordInvocation("Flush", []interface{}{}) + fake.flushMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeNftablesConn) FlushCallCount() int { + fake.flushMutex.RLock() + defer fake.flushMutex.RUnlock() + return len(fake.flushArgsForCall) +} + +func (fake *FakeNftablesConn) FlushCalls(stub func() error) { + fake.flushMutex.Lock() + defer fake.flushMutex.Unlock() + fake.FlushStub = stub +} + +func (fake *FakeNftablesConn) FlushReturns(result1 error) { + fake.flushMutex.Lock() + defer fake.flushMutex.Unlock() + fake.FlushStub = nil + fake.flushReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeNftablesConn) FlushReturnsOnCall(i int, result1 error) { + fake.flushMutex.Lock() + defer fake.flushMutex.Unlock() + fake.FlushStub = nil + if fake.flushReturnsOnCall == nil { + fake.flushReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.flushReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeNftablesConn) FlushChain(arg1 *nftables.Chain) { + fake.flushChainMutex.Lock() + fake.flushChainArgsForCall = append(fake.flushChainArgsForCall, struct { + arg1 *nftables.Chain + }{arg1}) + stub := fake.FlushChainStub + fake.recordInvocation("FlushChain", []interface{}{arg1}) + fake.flushChainMutex.Unlock() + if stub != nil { + fake.FlushChainStub(arg1) + } +} + +func (fake *FakeNftablesConn) FlushChainCallCount() int { + fake.flushChainMutex.RLock() + defer fake.flushChainMutex.RUnlock() + return len(fake.flushChainArgsForCall) +} + +func (fake *FakeNftablesConn) FlushChainCalls(stub func(*nftables.Chain)) { + fake.flushChainMutex.Lock() + defer fake.flushChainMutex.Unlock() + fake.FlushChainStub = stub +} + +func (fake *FakeNftablesConn) FlushChainArgsForCall(i int) *nftables.Chain { + fake.flushChainMutex.RLock() + defer fake.flushChainMutex.RUnlock() + argsForCall := fake.flushChainArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeNftablesConn) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeNftablesConn) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ firewall.NftablesConn = new(FakeNftablesConn) diff --git a/platform/firewall/firewallfakes/linux_build_constraint.txt b/platform/firewall/firewallfakes/linux_build_constraint.txt new file mode 100644 index 000000000..88452a54f --- /dev/null +++ b/platform/firewall/firewallfakes/linux_build_constraint.txt @@ -0,0 +1,2 @@ +//go:build linux + diff --git a/platform/firewall/nftables_firewall.go b/platform/firewall/nftables_firewall.go new file mode 100644 index 000000000..b74388b73 --- /dev/null +++ b/platform/firewall/nftables_firewall.go @@ -0,0 +1,490 @@ +//go:build linux + +package firewall + +import ( + "encoding/binary" + "fmt" + "net" + gonetURL "net/url" + "strconv" + "strings" + + bosherr "github.com/cloudfoundry/bosh-utils/errors" + boshlog "github.com/cloudfoundry/bosh-utils/logger" + "github.com/google/nftables" + "github.com/google/nftables/expr" + "golang.org/x/sys/unix" +) + +const ( + TableName = "bosh_agent" + MonitChainName = "monit_access" + MonitJobsChainName = "monit_access_jobs" + NATSChainName = "nats_access" + MonitPort = 2822 +) + +// NftablesConn abstracts the nftables connection for testing +// +//counterfeiter:generate -header ./firewallfakes/linux_build_constraint.txt . NftablesConn +type NftablesConn interface { + AddTable(t *nftables.Table) *nftables.Table + AddChain(c *nftables.Chain) *nftables.Chain + AddRule(r *nftables.Rule) *nftables.Rule + FlushChain(c *nftables.Chain) + Flush() error +} + +// DNSResolver abstracts DNS resolution for testing +// +//counterfeiter:generate -header ./firewallfakes/linux_build_constraint.txt . DNSResolver +type DNSResolver interface { + LookupIP(host string) ([]net.IP, error) +} + +// realDNSResolver uses the standard library for DNS resolution +type realDNSResolver struct{} + +func (r *realDNSResolver) LookupIP(host string) ([]net.IP, error) { + return net.LookupIP(host) +} + +// realNftablesConn wraps the actual nftables.Conn +type realNftablesConn struct { + conn *nftables.Conn +} + +func (r *realNftablesConn) AddTable(t *nftables.Table) *nftables.Table { + return r.conn.AddTable(t) +} + +func (r *realNftablesConn) AddChain(c *nftables.Chain) *nftables.Chain { + return r.conn.AddChain(c) +} + +func (r *realNftablesConn) AddRule(rule *nftables.Rule) *nftables.Rule { + return r.conn.AddRule(rule) +} + +func (r *realNftablesConn) FlushChain(c *nftables.Chain) { + r.conn.FlushChain(c) +} + +func (r *realNftablesConn) Flush() error { + return r.conn.Flush() +} + +// NftablesFirewall implements Manager and NatsFirewallHook using nftables with UID-based matching +type NftablesFirewall struct { + conn NftablesConn + resolver DNSResolver + logger boshlog.Logger + logTag string + table *nftables.Table + monitChain *nftables.Chain + monitJobsChain *nftables.Chain + natsChain *nftables.Chain +} + +// NewNftablesFirewall creates a new nftables-based firewall manager +func NewNftablesFirewall(logger boshlog.Logger) (Manager, error) { + conn, err := nftables.New() + if err != nil { + return nil, bosherr.WrapError(err, "Creating nftables connection") + } + + return NewNftablesFirewallWithDeps( + &realNftablesConn{conn: conn}, + &realDNSResolver{}, + logger, + ), nil +} + +// NewNftablesFirewallWithDeps creates a firewall manager with injected dependencies (for testing) +func NewNftablesFirewallWithDeps(conn NftablesConn, resolver DNSResolver, logger boshlog.Logger) Manager { + return &NftablesFirewall{ + conn: conn, + resolver: resolver, + logger: logger, + logTag: "NftablesFirewall", + } +} + +// SetupMonitFirewall creates firewall rules to protect monit (port 2822). +// Only root (UID 0) is allowed to connect by default. +// Jobs can add their own access rules to the monit_access_jobs chain. +// +// Architecture: +// - monit_access_jobs: Regular chain for job-managed rules (never flushed by agent) +// - monit_access: Base chain with hook that jumps to jobs chain, then applies agent rules +// +// This allows job rules to persist across agent restarts while ensuring +// agent rules are always up-to-date. +func (f *NftablesFirewall) SetupMonitFirewall() error { + f.logger.Info(f.logTag, "Setting up monit firewall rules (UID-based matching)") + + // Create or get our table + f.ensureTable() + + // Create jobs chain if it doesn't exist (never flush it - job rules persist) + f.ensureMonitJobsChain() + + // Create monit chain + f.ensureMonitChain() + + // Flush existing agent rules to ensure idempotency on restart + f.conn.FlushChain(f.monitChain) + + // Add jump to jobs chain first (so job rules are checked before agent rules) + f.addJumpToJobsChain() + + // Add allow rule for root (UID 0) + f.addMonitAllowRule() + + // Add block rule for everyone else + f.addMonitBlockRule() + + // Commit all rules + if err := f.conn.Flush(); err != nil { + return bosherr.WrapError(err, "Flushing nftables rules") + } + + f.logger.Info(f.logTag, "Successfully set up monit firewall rules") + return nil +} + +// SetupNATSFirewall creates firewall rules to protect NATS. +// This resolves DNS and should be called before each connection attempt. +func (f *NftablesFirewall) SetupNATSFirewall(mbusURL string) error { + // Parse URL to get host and port + host, port, err := parseNATSURL(mbusURL) + if err != nil { + // Not an error for https URLs or empty URLs (create-env case) + f.logger.Info(f.logTag, "Skipping NATS firewall: %s", err) + return nil + } + + // Resolve host to IP addresses + var addrs []net.IP + if ip := net.ParseIP(host); ip != nil { + addrs = []net.IP{ip} + } else { + addrs, err = f.resolver.LookupIP(host) + if err != nil { + f.logger.Warn(f.logTag, "DNS resolution failed for %s: %s", host, err) + return nil + } + } + + f.logger.Debug(f.logTag, "Setting up NATS firewall for %s:%d (resolved to %v)", host, port, addrs) + + // Ensure table exists + f.ensureTable() + + // Ensure NATS chain exists + f.ensureNATSChain() + + // Flush NATS chain (removes old rules for previous IPs) + f.conn.FlushChain(f.natsChain) + + // Add rules for each resolved IP + for _, addr := range addrs { + f.addNATSAllowRule(addr, port) + f.addNATSBlockRule(addr, port) + } + + // Commit + if err := f.conn.Flush(); err != nil { + return bosherr.WrapError(err, "Flushing nftables rules") + } + + f.logger.Info(f.logTag, "Updated NATS firewall rules for %s:%d", host, port) + return nil +} + +// BeforeConnect implements NatsFirewallHook. Called before each NATS connection attempt. +func (f *NftablesFirewall) BeforeConnect(mbusURL string) error { + return f.SetupNATSFirewall(mbusURL) +} + +func (f *NftablesFirewall) ensureTable() { + f.table = &nftables.Table{ + Family: nftables.TableFamilyINet, + Name: TableName, + } + f.conn.AddTable(f.table) +} + +func (f *NftablesFirewall) ensureMonitChain() { + priority := nftables.ChainPriority(*nftables.ChainPriorityFilter - 1) + + f.monitChain = &nftables.Chain{ + Name: MonitChainName, + Table: f.table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookOutput, + Priority: &priority, + Policy: policyPtr(nftables.ChainPolicyAccept), + } + f.conn.AddChain(f.monitChain) +} + +// ensureMonitJobsChain creates a regular chain (no hook) for job-managed rules. +// This chain is never flushed by the agent, allowing job rules to persist across agent restarts. +// Jobs can add rules to this chain via pre-start scripts using the nft CLI. +func (f *NftablesFirewall) ensureMonitJobsChain() { + f.monitJobsChain = &nftables.Chain{ + Name: MonitJobsChainName, + Table: f.table, + // No Type, Hooknum, Priority, or Policy - this is a regular chain + // that can only be reached via jump from monit_access + } + f.conn.AddChain(f.monitJobsChain) +} + +func (f *NftablesFirewall) ensureNATSChain() { + priority := nftables.ChainPriority(*nftables.ChainPriorityFilter - 1) + + f.natsChain = &nftables.Chain{ + Name: NATSChainName, + Table: f.table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookOutput, + Priority: &priority, + Policy: policyPtr(nftables.ChainPolicyAccept), + } + f.conn.AddChain(f.natsChain) +} + +func (f *NftablesFirewall) addMonitAllowRule() { + // Rule: meta skuid 0 ip daddr 127.0.0.1 tcp dport 2822 accept + exprs := f.buildUIDMatchExprs(0) + exprs = append(exprs, f.buildLoopbackDestExprs()...) + exprs = append(exprs, f.buildTCPDestPortExprs(MonitPort)...) + exprs = append(exprs, &expr.Verdict{Kind: expr.VerdictAccept}) + + f.conn.AddRule(&nftables.Rule{ + Table: f.table, + Chain: f.monitChain, + Exprs: exprs, + }) +} + +func (f *NftablesFirewall) addMonitBlockRule() { + // Rule: ip daddr 127.0.0.1 tcp dport 2822 drop + exprs := f.buildLoopbackDestExprs() + exprs = append(exprs, f.buildTCPDestPortExprs(MonitPort)...) + exprs = append(exprs, &expr.Verdict{Kind: expr.VerdictDrop}) + + f.conn.AddRule(&nftables.Rule{ + Table: f.table, + Chain: f.monitChain, + Exprs: exprs, + }) +} + +// addJumpToJobsChain adds a jump rule to the monit_access_jobs chain. +// This must be the first rule in monit_access so job rules are evaluated first. +func (f *NftablesFirewall) addJumpToJobsChain() { + f.conn.AddRule(&nftables.Rule{ + Table: f.table, + Chain: f.monitChain, + Exprs: []expr.Any{ + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: MonitJobsChainName, + }, + }, + }) +} + +func (f *NftablesFirewall) addNATSAllowRule(addr net.IP, port int) { + // Rule: meta skuid 0 ip daddr tcp dport accept + exprs := f.buildUIDMatchExprs(0) + exprs = append(exprs, f.buildDestIPExprs(addr)...) + exprs = append(exprs, f.buildTCPDestPortExprs(port)...) + exprs = append(exprs, &expr.Verdict{Kind: expr.VerdictAccept}) + + f.conn.AddRule(&nftables.Rule{ + Table: f.table, + Chain: f.natsChain, + Exprs: exprs, + }) +} + +func (f *NftablesFirewall) addNATSBlockRule(addr net.IP, port int) { + // Rule: ip daddr tcp dport drop + exprs := f.buildDestIPExprs(addr) + exprs = append(exprs, f.buildTCPDestPortExprs(port)...) + exprs = append(exprs, &expr.Verdict{Kind: expr.VerdictDrop}) + + f.conn.AddRule(&nftables.Rule{ + Table: f.table, + Chain: f.natsChain, + Exprs: exprs, + }) +} + +// buildUIDMatchExprs creates expressions for matching socket UID +func (f *NftablesFirewall) buildUIDMatchExprs(uid uint32) []expr.Any { + uidBytes := make([]byte, 4) + binary.NativeEndian.PutUint32(uidBytes, uid) + + return []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeySKUID, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: uidBytes, + }, + } +} + +// buildLoopbackDestExprs creates expressions for matching IPv4 loopback destination. +// Note: IPv6 loopback (::1) is intentionally not protected because monit only +// binds to 127.0.0.1:2822 (see jobsupervisor/monit/provider.go). +func (f *NftablesFirewall) buildLoopbackDestExprs() []expr.Any { + return []expr.Any{ + // Check this is IPv4 + &expr.Meta{ + Key: expr.MetaKeyNFPROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.NFPROTO_IPV4}, + }, + // Load destination IP + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 16, // Destination IP offset in IPv4 header + Len: 4, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: net.ParseIP("127.0.0.1").To4(), + }, + } +} + +func (f *NftablesFirewall) buildDestIPExprs(ip net.IP) []expr.Any { + if ip4 := ip.To4(); ip4 != nil { + return []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyNFPROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.NFPROTO_IPV4}, + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 16, + Len: 4, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ip4, + }, + } + } + + // IPv6 + return []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyNFPROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.NFPROTO_IPV6}, + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 24, // Destination IP offset in IPv6 header + Len: 16, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ip.To16(), + }, + } +} + +func (f *NftablesFirewall) buildTCPDestPortExprs(port int) []expr.Any { + portBytes := make([]byte, 2) + binary.BigEndian.PutUint16(portBytes, uint16(port)) + + return []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyL4PROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_TCP}, + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, // Destination port offset in TCP header + Len: 2, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: portBytes, + }, + } +} + +func policyPtr(p nftables.ChainPolicy) *nftables.ChainPolicy { + return &p +} + +func parseNATSURL(mbusURL string) (string, int, error) { + if mbusURL == "" || strings.HasPrefix(mbusURL, "https://") { + return "", 0, fmt.Errorf("skipping URL: %s", mbusURL) + } + + u, err := gonetURL.Parse(mbusURL) + if err != nil { + return "", 0, err + } + + if u.Hostname() == "" { + return "", 0, fmt.Errorf("empty hostname in URL") + } + + host, portStr, err := net.SplitHostPort(u.Host) + if err != nil { + host = u.Hostname() + portStr = "4222" + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return "", 0, fmt.Errorf("parsing port: %w", err) + } + + if port < 1 || port > 65535 { + return "", 0, fmt.Errorf("port %d out of valid range (1-65535)", port) + } + + return host, port, nil +} diff --git a/platform/firewall/nftables_firewall_other.go b/platform/firewall/nftables_firewall_other.go new file mode 100644 index 000000000..39baede89 --- /dev/null +++ b/platform/firewall/nftables_firewall_other.go @@ -0,0 +1,14 @@ +//go:build !linux + +package firewall + +import ( + "errors" + + boshlog "github.com/cloudfoundry/bosh-utils/logger" +) + +// NewNftablesFirewall returns an error on non-Linux platforms +func NewNftablesFirewall(logger boshlog.Logger) (Manager, error) { + return nil, errors.New("nftables firewall is only supported on Linux") +} diff --git a/platform/firewall/nftables_firewall_test.go b/platform/firewall/nftables_firewall_test.go new file mode 100644 index 000000000..ea11238d2 --- /dev/null +++ b/platform/firewall/nftables_firewall_test.go @@ -0,0 +1,321 @@ +//go:build linux + +package firewall_test + +import ( + "errors" + "net" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + boshlog "github.com/cloudfoundry/bosh-utils/logger" + "github.com/google/nftables" + "github.com/google/nftables/expr" + + "github.com/cloudfoundry/bosh-agent/v2/platform/firewall" + "github.com/cloudfoundry/bosh-agent/v2/platform/firewall/firewallfakes" +) + +var _ = Describe("NftablesFirewall", func() { + var ( + fakeConn *firewallfakes.FakeNftablesConn + fakeResolver *firewallfakes.FakeDNSResolver + logger boshlog.Logger + manager firewall.Manager + ) + + BeforeEach(func() { + fakeConn = &firewallfakes.FakeNftablesConn{} + fakeResolver = &firewallfakes.FakeDNSResolver{} + logger = boshlog.NewWriterLogger(boshlog.LevelDebug, GinkgoWriter) + manager = firewall.NewNftablesFirewallWithDeps(fakeConn, fakeResolver, logger) + }) + + Describe("SetupMonitFirewall", func() { + It("creates table, chains, and rules successfully", func() { + err := manager.SetupMonitFirewall() + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeConn.AddTableCallCount()).To(Equal(1)) + Expect(fakeConn.AddChainCallCount()).To(Equal(2)) // jobs chain + monit chain + Expect(fakeConn.FlushChainCallCount()).To(Equal(1)) + Expect(fakeConn.AddRuleCallCount()).To(Equal(3)) // jump + allow + block + Expect(fakeConn.FlushCallCount()).To(Equal(1)) + }) + + It("creates table with correct configuration", func() { + err := manager.SetupMonitFirewall() + Expect(err).NotTo(HaveOccurred()) + + table := fakeConn.AddTableArgsForCall(0) + Expect(table.Name).To(Equal("bosh_agent")) + Expect(table.Family).To(Equal(nftables.TableFamilyINet)) + }) + + It("creates jobs chain as regular chain (no hook)", func() { + err := manager.SetupMonitFirewall() + Expect(err).NotTo(HaveOccurred()) + + // First chain is the jobs chain + jobsChain := fakeConn.AddChainArgsForCall(0) + Expect(jobsChain.Name).To(Equal("monit_access_jobs")) + Expect(jobsChain.Type).To(Equal(nftables.ChainType(""))) // Regular chain has no type + Expect(jobsChain.Hooknum).To(BeNil()) // Regular chain has no hook + Expect(jobsChain.Priority).To(BeNil()) // Regular chain has no priority + }) + + It("creates monit chain with correct configuration", func() { + err := manager.SetupMonitFirewall() + Expect(err).NotTo(HaveOccurred()) + + // Second chain is the monit chain (base chain with hook) + monitChain := fakeConn.AddChainArgsForCall(1) + Expect(monitChain.Name).To(Equal("monit_access")) + Expect(monitChain.Type).To(Equal(nftables.ChainTypeFilter)) + Expect(monitChain.Hooknum).NotTo(BeNil()) + Expect(*monitChain.Hooknum).To(Equal(*nftables.ChainHookOutput)) + }) + + It("adds jump to jobs chain as first rule", func() { + err := manager.SetupMonitFirewall() + Expect(err).NotTo(HaveOccurred()) + + // First rule should be the jump rule + jumpRule := fakeConn.AddRuleArgsForCall(0) + Expect(jumpRule.Chain.Name).To(Equal("monit_access")) + Expect(jumpRule.Exprs).To(HaveLen(1)) + + verdict, ok := jumpRule.Exprs[0].(*expr.Verdict) + Expect(ok).To(BeTrue()) + Expect(verdict.Kind).To(Equal(expr.VerdictJump)) + Expect(verdict.Chain).To(Equal("monit_access_jobs")) + }) + + It("adds allow rule for UID 0 after jump rule", func() { + err := manager.SetupMonitFirewall() + Expect(err).NotTo(HaveOccurred()) + + // Second rule should be the allow rule (has UID match expressions) + allowRule := fakeConn.AddRuleArgsForCall(1) + Expect(allowRule.Chain.Name).To(Equal("monit_access")) + // The allow rule has more expressions (UID match + loopback + port + accept) + // Block rule has fewer (loopback + port + drop) + blockRule := fakeConn.AddRuleArgsForCall(2) + Expect(len(allowRule.Exprs)).To(BeNumerically(">", len(blockRule.Exprs))) + }) + + It("flushes monit chain before adding rules", func() { + err := manager.SetupMonitFirewall() + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeConn.FlushChainCallCount()).To(Equal(1)) + flushedChain := fakeConn.FlushChainArgsForCall(0) + Expect(flushedChain.Name).To(Equal("monit_access")) + }) + + It("never flushes jobs chain to preserve job-managed rules", func() { + // Call SetupMonitFirewall multiple times to simulate agent restarts + for i := 0; i < 3; i++ { + err := manager.SetupMonitFirewall() + Expect(err).NotTo(HaveOccurred()) + } + + // Verify that all FlushChain calls were on monit_access, never on monit_access_jobs + flushCount := fakeConn.FlushChainCallCount() + Expect(flushCount).To(Equal(3)) // Once per call + + for i := 0; i < flushCount; i++ { + flushedChain := fakeConn.FlushChainArgsForCall(i) + Expect(flushedChain.Name).To(Equal("monit_access"), + "FlushChain should only be called on monit_access, not monit_access_jobs") + } + }) + + Context("when called multiple times", func() { + It("flushes monit chain each time to prevent duplicate rules", func() { + err := manager.SetupMonitFirewall() + Expect(err).NotTo(HaveOccurred()) + Expect(fakeConn.FlushChainCallCount()).To(Equal(1)) + + err = manager.SetupMonitFirewall() + Expect(err).NotTo(HaveOccurred()) + Expect(fakeConn.FlushChainCallCount()).To(Equal(2)) + + // Both flush calls should be on the monit chain, not the jobs chain + Expect(fakeConn.FlushChainArgsForCall(0).Name).To(Equal("monit_access")) + Expect(fakeConn.FlushChainArgsForCall(1).Name).To(Equal("monit_access")) + }) + }) + + Context("when Flush fails", func() { + It("returns an error", func() { + fakeConn.FlushReturns(errors.New("flush failed")) + + err := manager.SetupMonitFirewall() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("Flushing nftables rules")) + }) + }) + }) + + Describe("SetupNATSFirewall", func() { + Context("with an IPv4 address URL", func() { + It("creates rules for the IPv4 address", func() { + err := manager.SetupNATSFirewall("nats://user:pass@192.168.1.100:4222") + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeConn.AddTableCallCount()).To(Equal(1)) + Expect(fakeConn.AddChainCallCount()).To(Equal(1)) + Expect(fakeConn.FlushChainCallCount()).To(Equal(1)) + // One allow rule + one block rule + Expect(fakeConn.AddRuleCallCount()).To(Equal(2)) + Expect(fakeConn.FlushCallCount()).To(Equal(1)) + }) + + It("creates chain with correct configuration", func() { + err := manager.SetupNATSFirewall("nats://192.168.1.100:4222") + Expect(err).NotTo(HaveOccurred()) + + chain := fakeConn.AddChainArgsForCall(0) + Expect(chain.Name).To(Equal("nats_access")) + Expect(chain.Type).To(Equal(nftables.ChainTypeFilter)) + Expect(chain.Hooknum).To(Equal(nftables.ChainHookOutput)) + }) + }) + + Context("with an IPv6 address URL", func() { + It("creates rules for the IPv6 address", func() { + err := manager.SetupNATSFirewall("nats://user:pass@[2001:db8::1]:4222") + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeConn.AddRuleCallCount()).To(Equal(2)) + Expect(fakeConn.FlushCallCount()).To(Equal(1)) + }) + }) + + Context("with a hostname URL", func() { + It("resolves DNS and creates rules for resolved IPs", func() { + fakeResolver.LookupIPReturns([]net.IP{ + net.ParseIP("10.0.0.1"), + net.ParseIP("10.0.0.2"), + }, nil) + + err := manager.SetupNATSFirewall("nats://user:pass@nats.example.com:4222") + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeResolver.LookupIPCallCount()).To(Equal(1)) + Expect(fakeResolver.LookupIPArgsForCall(0)).To(Equal("nats.example.com")) + + // Two IPs * 2 rules each = 4 rules + Expect(fakeConn.AddRuleCallCount()).To(Equal(4)) + }) + + It("handles DNS resolution failure gracefully", func() { + fakeResolver.LookupIPReturns(nil, errors.New("dns lookup failed")) + + err := manager.SetupNATSFirewall("nats://user:pass@nats.example.com:4222") + Expect(err).NotTo(HaveOccurred()) // Should not return error, just log warning + + Expect(fakeResolver.LookupIPCallCount()).To(Equal(1)) + Expect(fakeConn.AddRuleCallCount()).To(Equal(0)) // No rules added + }) + }) + + Context("with default port", func() { + It("uses port 4222 when not specified", func() { + err := manager.SetupNATSFirewall("nats://192.168.1.100") + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeConn.AddRuleCallCount()).To(Equal(2)) + }) + }) + + Context("with custom port", func() { + It("uses the specified port", func() { + err := manager.SetupNATSFirewall("nats://192.168.1.100:5222") + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeConn.AddRuleCallCount()).To(Equal(2)) + }) + }) + + Context("with https URL", func() { + It("skips setup and returns nil", func() { + err := manager.SetupNATSFirewall("https://director.example.com:25555") + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeConn.AddTableCallCount()).To(Equal(0)) + Expect(fakeConn.AddRuleCallCount()).To(Equal(0)) + }) + }) + + Context("with empty URL", func() { + It("skips setup and returns nil", func() { + err := manager.SetupNATSFirewall("") + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeConn.AddTableCallCount()).To(Equal(0)) + Expect(fakeConn.AddRuleCallCount()).To(Equal(0)) + }) + }) + + Context("when called multiple times", func() { + It("flushes existing NATS chain before adding new rules", func() { + err := manager.SetupNATSFirewall("nats://192.168.1.100:4222") + Expect(err).NotTo(HaveOccurred()) + Expect(fakeConn.FlushChainCallCount()).To(Equal(1)) + + err = manager.SetupNATSFirewall("nats://192.168.1.200:4222") + Expect(err).NotTo(HaveOccurred()) + Expect(fakeConn.FlushChainCallCount()).To(Equal(2)) + }) + }) + + Context("when Flush fails", func() { + It("returns an error", func() { + fakeConn.FlushReturns(errors.New("flush failed")) + + err := manager.SetupNATSFirewall("nats://192.168.1.100:4222") + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("Flushing nftables rules")) + }) + }) + }) + + Describe("BeforeConnect", func() { + var hook firewall.NatsFirewallHook + + BeforeEach(func() { + hook = manager.(firewall.NatsFirewallHook) + }) + + It("delegates to SetupNATSFirewall", func() { + err := hook.BeforeConnect("nats://192.168.1.100:4222") + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeConn.AddTableCallCount()).To(Equal(1)) + Expect(fakeConn.AddChainCallCount()).To(Equal(1)) + Expect(fakeConn.AddRuleCallCount()).To(Equal(2)) + }) + + It("returns nil on success", func() { + err := hook.BeforeConnect("nats://192.168.1.100:4222") + Expect(err).NotTo(HaveOccurred()) + }) + + It("returns error when SetupNATSFirewall fails", func() { + fakeConn.FlushReturns(errors.New("flush failed")) + + err := hook.BeforeConnect("nats://192.168.1.100:4222") + Expect(err).To(HaveOccurred()) + }) + }) + + Describe("Manager interface implementation", func() { + It("implements NatsFirewallHook interface", func() { + hook := manager.(firewall.NatsFirewallHook) + Expect(hook).NotTo(BeNil()) + }) + }) +}) diff --git a/platform/linux_platform.go b/platform/linux_platform.go index f9c33400a..045802333 100644 --- a/platform/linux_platform.go +++ b/platform/linux_platform.go @@ -23,6 +23,7 @@ import ( "github.com/cloudfoundry/bosh-agent/v2/platform/cdrom" boshcert "github.com/cloudfoundry/bosh-agent/v2/platform/cert" boshdisk "github.com/cloudfoundry/bosh-agent/v2/platform/disk" + "github.com/cloudfoundry/bosh-agent/v2/platform/firewall" boshnet "github.com/cloudfoundry/bosh-agent/v2/platform/net" boship "github.com/cloudfoundry/bosh-agent/v2/platform/net/ip" boshstats "github.com/cloudfoundry/bosh-agent/v2/platform/stats" @@ -113,6 +114,7 @@ type linux struct { auditLogger AuditLogger logsTarProvider boshlogstarprovider.LogsTarProvider serviceManager servicemanager.ServiceManager + firewallManager firewall.Manager } func NewLinuxPlatform( @@ -1800,3 +1802,38 @@ func prepareDiskLabelPrefix(labelPrefix string) string { return labelPrefix } + +// SetupFirewall initializes the nftables-based firewall for protecting monit and NATS. +// This should be called during platform setup, before NATS connections are attempted. +func (p *linux) SetupFirewall() error { + mgr, err := firewall.NewNftablesFirewall(p.logger) + if err != nil { + p.logger.Warn(logTag, "Failed to create firewall manager: %s", err) + // Not a fatal error - continue without firewall protection + return nil + } + p.firewallManager = mgr + + // Set up monit firewall rules + if err := mgr.SetupMonitFirewall(); err != nil { + p.logger.Warn(logTag, "Failed to set up monit firewall: %s", err) + // Not a fatal error - continue without monit firewall protection + } + + return nil +} + +// GetNatsFirewallHook returns the firewall hook for NATS connection management. +// The hook is called before each NATS connection/reconnection to update firewall +// rules with resolved DNS addresses. Returns nil if firewall was not set up. +func (p *linux) GetNatsFirewallHook() firewall.NatsFirewallHook { + if p.firewallManager == nil { + return nil + } + + // The firewall manager implements NatsFirewallHook + if hook, ok := p.firewallManager.(firewall.NatsFirewallHook); ok { + return hook + } + return nil +} diff --git a/platform/net/firewall_provider.go b/platform/net/firewall_provider.go index 32a6c3e4b..9b9996724 100644 --- a/platform/net/firewall_provider.go +++ b/platform/net/firewall_provider.go @@ -1,10 +1,10 @@ -//go:build !windows && !linux +//go:build !windows package net -// SetupNatsFirewall is does nothing, except on Linux and Windows +// SetupNatsFirewall is a no-op on non-Windows platforms. +// On Linux, the nftables-based firewall in platform/firewall/ is used instead. +// On Windows, a Windows Firewall rule is set up. func SetupNatsFirewall(mbus string) error { - // NOTE: If we return a "not supported" err here, unit tests would fail. - //return errors.New("not supported") return nil } diff --git a/platform/net/firewall_provider_linux.go b/platform/net/firewall_provider_linux.go deleted file mode 100644 index 71aed0226..000000000 --- a/platform/net/firewall_provider_linux.go +++ /dev/null @@ -1,158 +0,0 @@ -//go:build linux - -package net - -import ( - "errors" - "fmt" - "net" - gonetURL "net/url" - "os" - "strings" - - bosherr "github.com/cloudfoundry/bosh-utils/errors" - cgroups "github.com/containerd/cgroups/v3" - "github.com/containerd/cgroups/v3/cgroup1" - "github.com/coreos/go-iptables/iptables" - "github.com/opencontainers/runtime-spec/specs-go" -) - -const ( - /* "natsIsolationClassID" This is the integer value of the argument "0xb0540002", which is - b054:0002 . The major number (the left-hand side) is "BOSH", leet-ified. - The minor number (the right-hand side) is 2, indicating that this is the - second thing in our "BOSH" classid namespace. - - _Hopefully_ noone uses a major number of "b054", and we avoid collisions _forever_! - If you need to select new classids for firewall rules or traffic control rules, keep - the major number "b054" for bosh stuff, unless there's a good reason to not. - - The net_cls.classid structure is described in more detail here: - https://www.kernel.org/doc/Documentation/cgroup-v1/net_cls.txt - */ - natsIsolationClassID uint32 = 2958295042 -) - -// SetupNatsFirewall will setup the outgoing cgroup based rule that prevents everything except the agent to open connections to the nats api -func SetupNatsFirewall(mbus string) error { - // We have decided to remove the NATS firewall starting with Noble because we have - // ephemeral NATS credentials implemented in the Bosh Director which is a better solution - // to the problem. This allows us to remove all of this code after Jammy support ends - if cgroups.Mode() == cgroups.Unified { - return nil - } - - // return early if - // we get a https url for mbus. case for create-env - // we get an empty string. case for http_metadata_service (responsible to extract the agent-settings.json from the metadata endpoint) - // we find that v1cgroups are not mounted (warden stemcells) - if mbus == "" || strings.HasPrefix(mbus, "https://") { - return nil - } - - mbusURL, err := gonetURL.Parse(mbus) - if err != nil || mbusURL.Hostname() == "" { - return bosherr.WrapError(err, "Error parsing MbusURL") - } - - host, port, err := net.SplitHostPort(mbusURL.Host) - if err != nil { - return bosherr.WrapError(err, "Error getting Port") - } - - // Run the lookup for Host as it could be potentially a Hostname | IPv4 | IPv6 - // the return for LookupIP will be a list of IP Addr and in case of the Input being an IP Addr, - // it will only contain one element with the Input IP - addr_array, err := net.LookupIP(host) - if err != nil { - return bosherr.WrapError(err, fmt.Sprintf("Error resolving mbus host: %v", host)) - } - - return SetupIptables(host, port, addr_array) -} - -func SetupIptables(host, port string, addr_array []net.IP) error { - _, err := cgroup1.Default() - if err != nil { - if errors.Is(err, cgroup1.ErrMountPointNotExist) { - return nil // v1cgroups are not mounted (warden stemcells) - } - return bosherr.WrapError(err, "Error retrieving cgroups mount point") - } - - ipt, err := iptables.New() - if err != nil { - return bosherr.WrapError(err, "Creating Iptables Error") - } - // Even on a V6 VM, Monit will listen to only V4 loopback - // First create Monit V4 rules for natsIsolationClassID - exists, err := ipt.Exists("mangle", "POSTROUTING", - "-d", "127.0.0.1", - "-p", "tcp", - "--dport", "2822", - "-m", "cgroup", - "--cgroup", fmt.Sprintf("%v", natsIsolationClassID), - "-j", "ACCEPT", - ) - if err != nil { - return bosherr.WrapError(err, "Iptables Error checking for monit rule") - } - if !exists { - err = ipt.Insert("mangle", "POSTROUTING", 1, - "-d", "127.0.0.1", - "-p", "tcp", - "--dport", "2822", - "-m", "cgroup", - "--cgroup", fmt.Sprintf("%v", natsIsolationClassID), - "-j", "ACCEPT", - ) - if err != nil { - return bosherr.WrapError(err, "Iptables Error inserting for monit rule") - } - } - - // For nats iptables rules we default to V4 unless below dns resolution gives us a V6 target - ipVersion := iptables.ProtocolIPv4 - // Check if we're dealing with a V4 Target - if addr_array[0].To4() == nil { - ipVersion = iptables.ProtocolIPv6 - } - ipt, err = iptables.NewWithProtocol(ipVersion) - if err != nil { - return bosherr.WrapError(err, "Creating Iptables Error") - } - - err = ipt.AppendUnique("mangle", "POSTROUTING", - "-d", host, - "-p", "tcp", - "--dport", port, - "-m", "cgroup", - "--cgroup", fmt.Sprintf("%v", natsIsolationClassID), - "-j", "ACCEPT", - ) - if err != nil { - return bosherr.WrapError(err, "Iptables Error inserting for agent ACCEPT rule") - } - err = ipt.AppendUnique("mangle", "POSTROUTING", - "-d", host, - "-p", "tcp", - "--dport", port, - "-j", "DROP", - ) - if err != nil { - return bosherr.WrapError(err, "Iptables Error inserting for non-agent DROP rule") - } - - var isolationClassID = natsIsolationClassID - natsAPICgroup, err := cgroup1.New(cgroup1.StaticPath("/nats-api-access"), &specs.LinuxResources{ - Network: &specs.LinuxNetwork{ - ClassID: &isolationClassID, - }, - }) - if err != nil { - return bosherr.WrapError(err, "Error setting up cgroups for nats api access") - } - - err = natsAPICgroup.AddProc(uint64(os.Getpid()), cgroup1.NetCLS) - return err -} diff --git a/platform/net/firewall_provider_test.go b/platform/net/firewall_provider_test.go deleted file mode 100644 index afe77cc3c..000000000 --- a/platform/net/firewall_provider_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package net - -import ( - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("SetupFirewall Linux", func() { - // covers the case for http_metadata_service where on some IaaSs we cannot yet know the contents of - // agent-settings.json since http_metadata_service is responsible for pulling the data. - When("mbus url is empty", func() { - It("returns early without an error", func() { - err := SetupNatsFirewall("") - Expect(err).ToNot(HaveOccurred()) - }) - }) - // create no rule on a create-env - When("mbus url starts with https://", func() { - It("returns early without an error", func() { - err := SetupNatsFirewall("https://") - Expect(err).ToNot(HaveOccurred()) - }) - }) -}) diff --git a/platform/net/ubuntu_net_manager.go b/platform/net/ubuntu_net_manager.go index 1a38d0823..99a1632c3 100644 --- a/platform/net/ubuntu_net_manager.go +++ b/platform/net/ubuntu_net_manager.go @@ -116,11 +116,6 @@ func (net UbuntuNetManager) SetupNetworking(networks boshsettings.Networks, mbus if err != nil { return err } - err = SetupNatsFirewall(mbus) - if err != nil { - return bosherr.WrapError(err, "Setting up Nats Firewall") - } - net.logger.Info(UbuntuNetManagerLogTag, "Successfully set up outgoing nats api firewall") return nil } staticConfigs, dhcpConfigs, dnsServers, err := net.ComputeNetworkConfig(networks) @@ -184,11 +179,6 @@ func (net UbuntuNetManager) SetupNetworking(networks boshsettings.Networks, mbus } go net.addressBroadcaster.BroadcastMACAddresses(append(staticAddressesWithoutVirtual, dynamicAddresses...)) - err = SetupNatsFirewall(mbus) - if err != nil { - return bosherr.WrapError(err, "Setting up nats firewall") - } - net.logger.Info(UbuntuNetManagerLogTag, "Successfully set up outgoing nats api firewall") return nil } func (net UbuntuNetManager) ComputeNetworkConfig(networks boshsettings.Networks) ([]StaticInterfaceConfiguration, []DHCPInterfaceConfiguration, []string, error) { diff --git a/platform/platform_interface.go b/platform/platform_interface.go index e4cc77c38..84e12bd18 100644 --- a/platform/platform_interface.go +++ b/platform/platform_interface.go @@ -4,6 +4,7 @@ import ( "log" "github.com/cloudfoundry/bosh-agent/v2/platform/cert" + "github.com/cloudfoundry/bosh-agent/v2/platform/firewall" boshcmd "github.com/cloudfoundry/bosh-utils/fileutil" boshsys "github.com/cloudfoundry/bosh-utils/system" @@ -77,6 +78,7 @@ type Platform interface { SetupLoggingAndAuditing() (err error) SetupOptDir() (err error) SetupRecordsJSONPermission(path string) error + SetupFirewall() error // Disk management AdjustPersistentDiskPartitioning(diskSettings boshsettings.DiskSettings, mountPoint string) error @@ -110,4 +112,9 @@ type Platform interface { RemoveStaticLibraries(packageFileListPath string) error Shutdown() error + + // Firewall management + // GetNatsFirewallHook returns a hook that is called before NATS connection/reconnection + // to update firewall rules with resolved DNS. Returns nil if firewall is not supported. + GetNatsFirewallHook() firewall.NatsFirewallHook } diff --git a/platform/platformfakes/fake_platform.go b/platform/platformfakes/fake_platform.go index e7fa3ff80..730f67ff6 100644 --- a/platform/platformfakes/fake_platform.go +++ b/platform/platformfakes/fake_platform.go @@ -8,6 +8,7 @@ import ( "github.com/cloudfoundry/bosh-agent/v2/infrastructure/devicepathresolver" "github.com/cloudfoundry/bosh-agent/v2/platform" "github.com/cloudfoundry/bosh-agent/v2/platform/cert" + "github.com/cloudfoundry/bosh-agent/v2/platform/firewall" "github.com/cloudfoundry/bosh-agent/v2/platform/net/ip" "github.com/cloudfoundry/bosh-agent/v2/platform/vitals" "github.com/cloudfoundry/bosh-agent/v2/servicemanager" @@ -270,6 +271,16 @@ type FakePlatform struct { result2 string result3 error } + GetNatsFirewallHookStub func() firewall.NatsFirewallHook + getNatsFirewallHookMutex sync.RWMutex + getNatsFirewallHookArgsForCall []struct { + } + getNatsFirewallHookReturns struct { + result1 firewall.NatsFirewallHook + } + getNatsFirewallHookReturnsOnCall map[int]struct { + result1 firewall.NatsFirewallHook + } GetPersistentDiskSettingsPathStub func(bool) string getPersistentDiskSettingsPathMutex sync.RWMutex getPersistentDiskSettingsPathArgsForCall []struct { @@ -509,6 +520,16 @@ type FakePlatform struct { setupEphemeralDiskWithPathReturnsOnCall map[int]struct { result1 error } + SetupFirewallStub func() error + setupFirewallMutex sync.RWMutex + setupFirewallArgsForCall []struct { + } + setupFirewallReturns struct { + result1 error + } + setupFirewallReturnsOnCall map[int]struct { + result1 error + } SetupHomeDirStub func() error setupHomeDirMutex sync.RWMutex setupHomeDirArgsForCall []struct { @@ -2011,6 +2032,59 @@ func (fake *FakePlatform) GetMonitCredentialsReturnsOnCall(i int, result1 string }{result1, result2, result3} } +func (fake *FakePlatform) GetNatsFirewallHook() firewall.NatsFirewallHook { + fake.getNatsFirewallHookMutex.Lock() + ret, specificReturn := fake.getNatsFirewallHookReturnsOnCall[len(fake.getNatsFirewallHookArgsForCall)] + fake.getNatsFirewallHookArgsForCall = append(fake.getNatsFirewallHookArgsForCall, struct { + }{}) + stub := fake.GetNatsFirewallHookStub + fakeReturns := fake.getNatsFirewallHookReturns + fake.recordInvocation("GetNatsFirewallHook", []interface{}{}) + fake.getNatsFirewallHookMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakePlatform) GetNatsFirewallHookCallCount() int { + fake.getNatsFirewallHookMutex.RLock() + defer fake.getNatsFirewallHookMutex.RUnlock() + return len(fake.getNatsFirewallHookArgsForCall) +} + +func (fake *FakePlatform) GetNatsFirewallHookCalls(stub func() firewall.NatsFirewallHook) { + fake.getNatsFirewallHookMutex.Lock() + defer fake.getNatsFirewallHookMutex.Unlock() + fake.GetNatsFirewallHookStub = stub +} + +func (fake *FakePlatform) GetNatsFirewallHookReturns(result1 firewall.NatsFirewallHook) { + fake.getNatsFirewallHookMutex.Lock() + defer fake.getNatsFirewallHookMutex.Unlock() + fake.GetNatsFirewallHookStub = nil + fake.getNatsFirewallHookReturns = struct { + result1 firewall.NatsFirewallHook + }{result1} +} + +func (fake *FakePlatform) GetNatsFirewallHookReturnsOnCall(i int, result1 firewall.NatsFirewallHook) { + fake.getNatsFirewallHookMutex.Lock() + defer fake.getNatsFirewallHookMutex.Unlock() + fake.GetNatsFirewallHookStub = nil + if fake.getNatsFirewallHookReturnsOnCall == nil { + fake.getNatsFirewallHookReturnsOnCall = make(map[int]struct { + result1 firewall.NatsFirewallHook + }) + } + fake.getNatsFirewallHookReturnsOnCall[i] = struct { + result1 firewall.NatsFirewallHook + }{result1} +} + func (fake *FakePlatform) GetPersistentDiskSettingsPath(arg1 bool) string { fake.getPersistentDiskSettingsPathMutex.Lock() ret, specificReturn := fake.getPersistentDiskSettingsPathReturnsOnCall[len(fake.getPersistentDiskSettingsPathArgsForCall)] @@ -3260,6 +3334,59 @@ func (fake *FakePlatform) SetupEphemeralDiskWithPathReturnsOnCall(i int, result1 }{result1} } +func (fake *FakePlatform) SetupFirewall() error { + fake.setupFirewallMutex.Lock() + ret, specificReturn := fake.setupFirewallReturnsOnCall[len(fake.setupFirewallArgsForCall)] + fake.setupFirewallArgsForCall = append(fake.setupFirewallArgsForCall, struct { + }{}) + stub := fake.SetupFirewallStub + fakeReturns := fake.setupFirewallReturns + fake.recordInvocation("SetupFirewall", []interface{}{}) + fake.setupFirewallMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakePlatform) SetupFirewallCallCount() int { + fake.setupFirewallMutex.RLock() + defer fake.setupFirewallMutex.RUnlock() + return len(fake.setupFirewallArgsForCall) +} + +func (fake *FakePlatform) SetupFirewallCalls(stub func() error) { + fake.setupFirewallMutex.Lock() + defer fake.setupFirewallMutex.Unlock() + fake.SetupFirewallStub = stub +} + +func (fake *FakePlatform) SetupFirewallReturns(result1 error) { + fake.setupFirewallMutex.Lock() + defer fake.setupFirewallMutex.Unlock() + fake.SetupFirewallStub = nil + fake.setupFirewallReturns = struct { + result1 error + }{result1} +} + +func (fake *FakePlatform) SetupFirewallReturnsOnCall(i int, result1 error) { + fake.setupFirewallMutex.Lock() + defer fake.setupFirewallMutex.Unlock() + fake.SetupFirewallStub = nil + if fake.setupFirewallReturnsOnCall == nil { + fake.setupFirewallReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.setupFirewallReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakePlatform) SetupHomeDir() error { fake.setupHomeDirMutex.Lock() ret, specificReturn := fake.setupHomeDirReturnsOnCall[len(fake.setupHomeDirArgsForCall)] diff --git a/platform/windows_platform.go b/platform/windows_platform.go index b0a20b46e..47d2dd567 100644 --- a/platform/windows_platform.go +++ b/platform/windows_platform.go @@ -20,6 +20,7 @@ import ( boshlogstarprovider "github.com/cloudfoundry/bosh-agent/v2/agent/logstarprovider" boshdpresolv "github.com/cloudfoundry/bosh-agent/v2/infrastructure/devicepathresolver" boshcert "github.com/cloudfoundry/bosh-agent/v2/platform/cert" + "github.com/cloudfoundry/bosh-agent/v2/platform/firewall" boshnet "github.com/cloudfoundry/bosh-agent/v2/platform/net" boship "github.com/cloudfoundry/bosh-agent/v2/platform/net/ip" boshstats "github.com/cloudfoundry/bosh-agent/v2/platform/stats" @@ -772,6 +773,14 @@ func (p WindowsPlatform) SetupRecordsJSONPermission(path string) error { return nil } +func (p WindowsPlatform) SetupFirewall() error { + return nil +} + func (p WindowsPlatform) Shutdown() error { return nil } + +func (p WindowsPlatform) GetNatsFirewallHook() firewall.NatsFirewallHook { + return nil +} diff --git a/vendor/github.com/google/nftables/CONTRIBUTING.md b/vendor/github.com/google/nftables/CONTRIBUTING.md new file mode 100644 index 000000000..ae319c70a --- /dev/null +++ b/vendor/github.com/google/nftables/CONTRIBUTING.md @@ -0,0 +1,23 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution, +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. diff --git a/vendor/github.com/google/nftables/LICENSE b/vendor/github.com/google/nftables/LICENSE new file mode 100644 index 000000000..d64569567 --- /dev/null +++ b/vendor/github.com/google/nftables/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/google/nftables/README.md b/vendor/github.com/google/nftables/README.md new file mode 100644 index 000000000..cb633c718 --- /dev/null +++ b/vendor/github.com/google/nftables/README.md @@ -0,0 +1,24 @@ +[![Build Status](https://github.com/google/nftables/actions/workflows/push.yml/badge.svg)](https://github.com/google/nftables/actions/workflows/push.yml) +[![GoDoc](https://godoc.org/github.com/google/nftables?status.svg)](https://godoc.org/github.com/google/nftables) + +**This is not the correct repository for issues with the Linux nftables +project!** This repository contains a third-party Go package to programmatically +interact with nftables. Find the official nftables website at +https://wiki.nftables.org/ + +This package manipulates Linux nftables (the iptables successor). It is +implemented in pure Go, i.e. does not wrap libnftnl. + +This is not an official Google product. + +## Breaking changes + +This package is in very early stages, and only contains enough data types and +functions to install very basic nftables rules. It is likely that mistakes with +the data types/API will be identified as more functionality is added. + +## Contributions + +Contributions are very welcome! + + diff --git a/vendor/github.com/google/nftables/alignedbuff/alignedbuff.go b/vendor/github.com/google/nftables/alignedbuff/alignedbuff.go new file mode 100644 index 000000000..a97214649 --- /dev/null +++ b/vendor/github.com/google/nftables/alignedbuff/alignedbuff.go @@ -0,0 +1,300 @@ +// Package alignedbuff implements encoding and decoding aligned data elements +// to/from buffers in native endianess. +// +// # Note +// +// The alignment/padding as implemented in this package must match that of +// kernel's and user space C implementations for a particular architecture (bit +// size). Please see also the "dummy structure" _xt_align +// (https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/x_tables.h#L93) +// as well as the associated XT_ALIGN C preprocessor macro. +// +// In particular, we rely on the Go compiler to follow the same architecture +// alignments as the C compiler(s) on Linux. +package alignedbuff + +import ( + "bytes" + "errors" + "fmt" + "unsafe" + + "github.com/google/nftables/binaryutil" +) + +// ErrEOF signals trying to read beyond the available payload information. +var ErrEOF = errors.New("not enough data left") + +// AlignedBuff implements marshalling and unmarshalling information in +// platform/architecture native endianess and data type alignment. It +// additionally covers some of the nftables-xtables translation-specific +// idiosyncracies to the extend needed in order to properly marshal and +// unmarshal Match and Target expressions, and their Info payload in particular. +type AlignedBuff struct { + data []byte + pos int +} + +// New returns a new AlignedBuff for marshalling aligned data in native +// endianess. +func New() AlignedBuff { + return AlignedBuff{} +} + +// NewWithData returns a new AlignedBuff for unmarshalling the passed data in +// native endianess. +func NewWithData(data []byte) AlignedBuff { + return AlignedBuff{data: data} +} + +// Data returns the properly padded info payload data written before by calling +// the various Uint8, Uint16, ... marshalling functions. +func (a *AlignedBuff) Data() []byte { + // The Linux kernel expects payloads to be padded to the next uint64 + // alignment. + a.alignWrite(uint64AlignMask) + return a.data +} + +// BytesAligned32 unmarshals the given amount of bytes starting with the native +// alignment for uint32 data types. It returns ErrEOF when trying to read beyond +// the payload. +// +// BytesAligned32 is used to unmarshal IP addresses for different IP versions, +// which are always aligned the same way as the native alignment for uint32. +func (a *AlignedBuff) BytesAligned32(size int) ([]byte, error) { + if err := a.alignCheckedRead(uint32AlignMask); err != nil { + return nil, err + } + if a.pos > len(a.data)-size { + return nil, ErrEOF + } + data := a.data[a.pos : a.pos+size] + a.pos += size + return data, nil +} + +// Uint8 unmarshals an uint8 in native endianess and alignment. It returns +// ErrEOF when trying to read beyond the payload. +func (a *AlignedBuff) Uint8() (uint8, error) { + if a.pos >= len(a.data) { + return 0, ErrEOF + } + v := a.data[a.pos] + a.pos++ + return v, nil +} + +// Uint16 unmarshals an uint16 in native endianess and alignment. It returns +// ErrEOF when trying to read beyond the payload. +func (a *AlignedBuff) Uint16() (uint16, error) { + if err := a.alignCheckedRead(uint16AlignMask); err != nil { + return 0, err + } + v := binaryutil.NativeEndian.Uint16(a.data[a.pos : a.pos+2]) + a.pos += 2 + return v, nil +} + +// Uint16BE unmarshals an uint16 in "network" (=big endian) endianess and native +// uint16 alignment. It returns ErrEOF when trying to read beyond the payload. +func (a *AlignedBuff) Uint16BE() (uint16, error) { + if err := a.alignCheckedRead(uint16AlignMask); err != nil { + return 0, err + } + v := binaryutil.BigEndian.Uint16(a.data[a.pos : a.pos+2]) + a.pos += 2 + return v, nil +} + +// Uint32 unmarshals an uint32 in native endianess and alignment. It returns +// ErrEOF when trying to read beyond the payload. +func (a *AlignedBuff) Uint32() (uint32, error) { + if err := a.alignCheckedRead(uint32AlignMask); err != nil { + return 0, err + } + v := binaryutil.NativeEndian.Uint32(a.data[a.pos : a.pos+4]) + a.pos += 4 + return v, nil +} + +// Uint64 unmarshals an uint64 in native endianess and alignment. It returns +// ErrEOF when trying to read beyond the payload. +func (a *AlignedBuff) Uint64() (uint64, error) { + if err := a.alignCheckedRead(uint64AlignMask); err != nil { + return 0, err + } + v := binaryutil.NativeEndian.Uint64(a.data[a.pos : a.pos+8]) + a.pos += 8 + return v, nil +} + +// Int32 unmarshals an int32 in native endianess and alignment. It returns +// ErrEOF when trying to read beyond the payload. +func (a *AlignedBuff) Int32() (int32, error) { + if err := a.alignCheckedRead(int32AlignMask); err != nil { + return 0, err + } + v := binaryutil.Int32(a.data[a.pos : a.pos+4]) + a.pos += 4 + return v, nil +} + +// String unmarshals a null terminated string +func (a *AlignedBuff) String() (string, error) { + len := 0 + for { + if a.data[a.pos+len] == 0x00 { + break + } + len++ + } + + v := binaryutil.String(a.data[a.pos : a.pos+len]) + a.pos += len + return v, nil +} + +// StringWithLength unmarshals a string of a given length (for non-null +// terminated strings) +func (a *AlignedBuff) StringWithLength(len int) (string, error) { + v := binaryutil.String(a.data[a.pos : a.pos+len]) + a.pos += len + return v, nil +} + +// Uint unmarshals an uint in native endianess and alignment for the C "unsigned +// int" type. It returns ErrEOF when trying to read beyond the payload. Please +// note that on 64bit platforms, the size and alignment of C's and Go's unsigned +// integer data types differ, so we encapsulate this difference here. +func (a *AlignedBuff) Uint() (uint, error) { + switch uintSize { + case 2: + v, err := a.Uint16() + return uint(v), err + case 4: + v, err := a.Uint32() + return uint(v), err + case 8: + v, err := a.Uint64() + return uint(v), err + default: + panic(fmt.Sprintf("unsupported uint size %d", uintSize)) + } +} + +// PutBytesAligned32 marshals the given bytes starting with the native alignment +// for uint32 data types. It additionaly adds padding to reach the specified +// size. +// +// PutBytesAligned32 is used to marshal IP addresses for different IP versions, +// which are always aligned the same way as the native alignment for uint32. +func (a *AlignedBuff) PutBytesAligned32(data []byte, size int) { + a.alignWrite(uint32AlignMask) + a.data = append(a.data, data...) + a.pos += len(data) + if len(data) < size { + padding := size - len(data) + a.data = append(a.data, bytes.Repeat([]byte{0}, padding)...) + a.pos += padding + } +} + +// PutUint8 marshals an uint8 in native endianess and alignment. +func (a *AlignedBuff) PutUint8(v uint8) { + a.data = append(a.data, v) + a.pos++ +} + +// PutUint16 marshals an uint16 in native endianess and alignment. +func (a *AlignedBuff) PutUint16(v uint16) { + a.alignWrite(uint16AlignMask) + a.data = append(a.data, binaryutil.NativeEndian.PutUint16(v)...) + a.pos += 2 +} + +// PutUint16BE marshals an uint16 in "network" (=big endian) endianess and +// native uint16 alignment. +func (a *AlignedBuff) PutUint16BE(v uint16) { + a.alignWrite(uint16AlignMask) + a.data = append(a.data, binaryutil.BigEndian.PutUint16(v)...) + a.pos += 2 +} + +// PutUint32 marshals an uint32 in native endianess and alignment. +func (a *AlignedBuff) PutUint32(v uint32) { + a.alignWrite(uint32AlignMask) + a.data = append(a.data, binaryutil.NativeEndian.PutUint32(v)...) + a.pos += 4 +} + +// PutUint64 marshals an uint64 in native endianess and alignment. +func (a *AlignedBuff) PutUint64(v uint64) { + a.alignWrite(uint64AlignMask) + a.data = append(a.data, binaryutil.NativeEndian.PutUint64(v)...) + a.pos += 8 +} + +// PutInt32 marshals an int32 in native endianess and alignment. +func (a *AlignedBuff) PutInt32(v int32) { + a.alignWrite(int32AlignMask) + a.data = append(a.data, binaryutil.PutInt32(v)...) + a.pos += 4 +} + +// PutString marshals a string. +func (a *AlignedBuff) PutString(v string) { + a.data = append(a.data, binaryutil.PutString(v)...) + a.pos += len(v) +} + +// PutUint marshals an uint in native endianess and alignment for the C +// "unsigned int" type. Please note that on 64bit platforms, the size and +// alignment of C's and Go's unsigned integer data types differ, so we +// encapsulate this difference here. +func (a *AlignedBuff) PutUint(v uint) { + switch uintSize { + case 2: + a.PutUint16(uint16(v)) + case 4: + a.PutUint32(uint32(v)) + case 8: + a.PutUint64(uint64(v)) + default: + panic(fmt.Sprintf("unsupported uint size %d", uintSize)) + } +} + +// alignCheckedRead aligns the (read) position if necessary and suitable +// according to the specified alignment mask. alignCheckedRead returns an error +// if after any necessary alignment there isn't enough data left to be read into +// a value of the size corresponding to the specified alignment mask. +func (a *AlignedBuff) alignCheckedRead(m int) error { + a.pos = (a.pos + m) & ^m + if a.pos > len(a.data)-(m+1) { + return ErrEOF + } + return nil +} + +// alignWrite aligns the (write) position if necessary and suitable according to +// the specified alignment mask. It doubles as final payload padding helpmate in +// order to keep the kernel happy. +func (a *AlignedBuff) alignWrite(m int) { + pos := (a.pos + m) & ^m + if pos != a.pos { + a.data = append(a.data, padding[:pos-a.pos]...) + a.pos = pos + } +} + +// This is ... ugly. +var uint16AlignMask = int(unsafe.Alignof(uint16(0)) - 1) +var uint32AlignMask = int(unsafe.Alignof(uint32(0)) - 1) +var uint64AlignMask = int(unsafe.Alignof(uint64(0)) - 1) +var padding = bytes.Repeat([]byte{0}, uint64AlignMask) + +var int32AlignMask = int(unsafe.Alignof(int32(0)) - 1) + +// And this even worse. +var uintSize = unsafe.Sizeof(uint32(0)) diff --git a/vendor/github.com/google/nftables/binaryutil/binaryutil.go b/vendor/github.com/google/nftables/binaryutil/binaryutil.go new file mode 100644 index 000000000..e61973f07 --- /dev/null +++ b/vendor/github.com/google/nftables/binaryutil/binaryutil.go @@ -0,0 +1,125 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package binaryutil contains convenience wrappers around encoding/binary. +package binaryutil + +import ( + "bytes" + "encoding/binary" + "unsafe" +) + +// ByteOrder is like binary.ByteOrder, but allocates memory and returns byte +// slices, for convenience. +type ByteOrder interface { + PutUint16(v uint16) []byte + PutUint32(v uint32) []byte + PutUint64(v uint64) []byte + Uint16(b []byte) uint16 + Uint32(b []byte) uint32 + Uint64(b []byte) uint64 +} + +// NativeEndian is either little endian or big endian, depending on the native +// endian-ness, and allocates memory and returns byte slices, for convenience. +var NativeEndian ByteOrder = &nativeEndian{} + +type nativeEndian struct{} + +func (nativeEndian) PutUint16(v uint16) []byte { + buf := make([]byte, 2) + *(*uint16)(unsafe.Pointer(&buf[0])) = v + return buf +} + +func (nativeEndian) PutUint32(v uint32) []byte { + buf := make([]byte, 4) + *(*uint32)(unsafe.Pointer(&buf[0])) = v + return buf +} + +func (nativeEndian) PutUint64(v uint64) []byte { + buf := make([]byte, 8) + *(*uint64)(unsafe.Pointer(&buf[0])) = v + return buf +} + +func (nativeEndian) Uint16(b []byte) uint16 { + return *(*uint16)(unsafe.Pointer(&b[0])) +} + +func (nativeEndian) Uint32(b []byte) uint32 { + return *(*uint32)(unsafe.Pointer(&b[0])) +} + +func (nativeEndian) Uint64(b []byte) uint64 { + return *(*uint64)(unsafe.Pointer(&b[0])) +} + +// BigEndian is like binary.BigEndian, but allocates memory and returns byte +// slices, for convenience. +var BigEndian ByteOrder = &bigEndian{} + +type bigEndian struct{} + +func (bigEndian) PutUint16(v uint16) []byte { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, v) + return buf +} + +func (bigEndian) PutUint32(v uint32) []byte { + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, v) + return buf +} + +func (bigEndian) PutUint64(v uint64) []byte { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, v) + return buf +} + +func (bigEndian) Uint16(b []byte) uint16 { + return binary.BigEndian.Uint16(b) +} + +func (bigEndian) Uint32(b []byte) uint32 { + return binary.BigEndian.Uint32(b) +} + +func (bigEndian) Uint64(b []byte) uint64 { + return binary.BigEndian.Uint64(b) +} + +// For dealing with types not supported by the encoding/binary interface + +func PutInt32(v int32) []byte { + buf := make([]byte, 4) + *(*int32)(unsafe.Pointer(&buf[0])) = v + return buf +} + +func Int32(b []byte) int32 { + return *(*int32)(unsafe.Pointer(&b[0])) +} + +func PutString(s string) []byte { + return []byte(s) +} + +func String(b []byte) string { + return string(bytes.TrimRight(b, "\x00")) +} diff --git a/vendor/github.com/google/nftables/chain.go b/vendor/github.com/google/nftables/chain.go new file mode 100644 index 000000000..4f4c0a532 --- /dev/null +++ b/vendor/github.com/google/nftables/chain.go @@ -0,0 +1,328 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "encoding/binary" + "fmt" + "math" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// ChainHook specifies at which step in packet processing the Chain should be +// executed. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_hooks +type ChainHook uint32 + +// Possible ChainHook values. +var ( + ChainHookPrerouting *ChainHook = ChainHookRef(unix.NF_INET_PRE_ROUTING) + ChainHookInput *ChainHook = ChainHookRef(unix.NF_INET_LOCAL_IN) + ChainHookForward *ChainHook = ChainHookRef(unix.NF_INET_FORWARD) + ChainHookOutput *ChainHook = ChainHookRef(unix.NF_INET_LOCAL_OUT) + ChainHookPostrouting *ChainHook = ChainHookRef(unix.NF_INET_POST_ROUTING) + ChainHookIngress *ChainHook = ChainHookRef(unix.NF_NETDEV_INGRESS) + ChainHookEgress *ChainHook = ChainHookRef(unix.NF_NETDEV_EGRESS) +) + +// ChainHookRef returns a pointer to a ChainHookRef value. +func ChainHookRef(h ChainHook) *ChainHook { + return &h +} + +// ChainPriority orders the chain relative to Netfilter internal operations. See +// also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_priority +type ChainPriority int32 + +// Possible ChainPriority values. +var ( // from /usr/include/linux/netfilter_ipv4.h + ChainPriorityFirst *ChainPriority = ChainPriorityRef(math.MinInt32) + ChainPriorityConntrackDefrag *ChainPriority = ChainPriorityRef(-400) + ChainPriorityRaw *ChainPriority = ChainPriorityRef(-300) + ChainPrioritySELinuxFirst *ChainPriority = ChainPriorityRef(-225) + ChainPriorityConntrack *ChainPriority = ChainPriorityRef(-200) + ChainPriorityMangle *ChainPriority = ChainPriorityRef(-150) + ChainPriorityNATDest *ChainPriority = ChainPriorityRef(-100) + ChainPriorityFilter *ChainPriority = ChainPriorityRef(0) + ChainPrioritySecurity *ChainPriority = ChainPriorityRef(50) + ChainPriorityNATSource *ChainPriority = ChainPriorityRef(100) + ChainPrioritySELinuxLast *ChainPriority = ChainPriorityRef(225) + ChainPriorityConntrackHelper *ChainPriority = ChainPriorityRef(300) + ChainPriorityConntrackConfirm *ChainPriority = ChainPriorityRef(math.MaxInt32) + ChainPriorityLast *ChainPriority = ChainPriorityRef(math.MaxInt32) +) + +// ChainPriorityRef returns a pointer to a ChainPriority value. +func ChainPriorityRef(p ChainPriority) *ChainPriority { + return &p +} + +// ChainType defines what this chain will be used for. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_types +type ChainType string + +// Possible ChainType values. +const ( + ChainTypeFilter ChainType = "filter" + ChainTypeRoute ChainType = "route" + ChainTypeNAT ChainType = "nat" +) + +// ChainPolicy defines what this chain default policy will be. +type ChainPolicy uint32 + +// Possible ChainPolicy values. +const ( + ChainPolicyDrop ChainPolicy = iota + ChainPolicyAccept +) + +// A Chain contains Rules. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains +type Chain struct { + Name string + Table *Table + Hooknum *ChainHook + Priority *ChainPriority + Type ChainType + Policy *ChainPolicy + Device string +} + +// AddChain adds the specified Chain. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Adding_base_chains +func (cc *Conn) AddChain(c *Chain) *Chain { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_CHAIN_TABLE, Data: []byte(c.Table.Name + "\x00")}, + {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, + }) + + if c.Hooknum != nil && c.Priority != nil { + hookAttr := []netlink.Attribute{ + {Type: unix.NFTA_HOOK_HOOKNUM, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Hooknum))}, + {Type: unix.NFTA_HOOK_PRIORITY, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Priority))}, + } + + if c.Device != "" { + hookAttr = append(hookAttr, netlink.Attribute{Type: unix.NFTA_HOOK_DEV, Data: []byte(c.Device + "\x00")}) + } + + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NLA_F_NESTED | unix.NFTA_CHAIN_HOOK, Data: cc.marshalAttr(hookAttr)}, + })...) + } + + if c.Policy != nil { + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_CHAIN_POLICY, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Policy))}, + })...) + } + if c.Type != "" { + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")}, + })...) + } + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(c.Table.Family), 0), data...), + }) + + return c +} + +// DelChain deletes the specified Chain. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Deleting_chains +func (cc *Conn) DelChain(c *Chain) { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_CHAIN_TABLE, Data: []byte(c.Table.Name + "\x00")}, + {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, + }) + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(c.Table.Family), 0), data...), + }) +} + +// FlushChain removes all rules within the specified Chain. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Flushing_chain +func (cc *Conn) FlushChain(c *Chain) { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")}, + {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, + }) + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(c.Table.Family), 0), data...), + }) +} + +// ListChains returns currently configured chains in the kernel +func (cc *Conn) ListChains() ([]*Chain, error) { + return cc.ListChainsOfTableFamily(TableFamilyUnspecified) +} + +// ListChain returns a single chain configured in the specified table +func (cc *Conn) ListChain(table *Table, chain string) (*Chain, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + attrs := []netlink.Attribute{ + {Type: unix.NFTA_TABLE_NAME, Data: []byte(table.Name + "\x00")}, + {Type: unix.NFTA_CHAIN_NAME, Data: []byte(chain + "\x00")}, + } + msg := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETCHAIN), + Flags: netlink.Request, + }, + Data: append(extraHeader(uint8(table.Family), 0), cc.marshalAttr(attrs)...), + } + + response, err := conn.Execute(msg) + if err != nil { + return nil, fmt.Errorf("conn.Execute failed: %v", err) + } + + if got, want := len(response), 1; got != want { + return nil, fmt.Errorf("expected %d response message for chain, got %d", want, got) + } + + ch, err := chainFromMsg(response[0]) + if err != nil { + return nil, err + } + + return ch, nil +} + +// ListChainsOfTableFamily returns currently configured chains for the specified +// family in the kernel. It lists all chains ins all tables if family is +// TableFamilyUnspecified. +func (cc *Conn) ListChainsOfTableFamily(family TableFamily) ([]*Chain, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + msg := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETCHAIN), + Flags: netlink.Request | netlink.Dump, + }, + Data: extraHeader(uint8(family), 0), + } + + response, err := conn.Execute(msg) + if err != nil { + return nil, err + } + + var chains []*Chain + for _, m := range response { + c, err := chainFromMsg(m) + if err != nil { + return nil, err + } + + chains = append(chains, c) + } + + return chains, nil +} + +func chainFromMsg(msg netlink.Message) (*Chain, error) { + newChainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN) + delChainHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN) + if got, want1, want2 := msg.Header.Type, newChainHeaderType, delChainHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) + } + + var c Chain + + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + + for ad.Next() { + switch ad.Type() { + case unix.NFTA_CHAIN_NAME: + c.Name = ad.String() + case unix.NFTA_TABLE_NAME: + c.Table = &Table{Name: ad.String()} + // msg[0] carries TableFamily byte indicating whether it is IPv4, IPv6 or something else + c.Table.Family = TableFamily(msg.Data[0]) + case unix.NFTA_CHAIN_TYPE: + c.Type = ChainType(ad.String()) + case unix.NFTA_CHAIN_POLICY: + policy := ChainPolicy(binaryutil.BigEndian.Uint32(ad.Bytes())) + c.Policy = &policy + case unix.NFTA_CHAIN_HOOK: + ad.Do(func(b []byte) error { + c.Hooknum, c.Priority, err = hookFromMsg(b) + return err + }) + } + } + + return &c, nil +} + +func hookFromMsg(b []byte) (*ChainHook, *ChainPriority, error) { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return nil, nil, err + } + + ad.ByteOrder = binary.BigEndian + + var hooknum ChainHook + var prio ChainPriority + + for ad.Next() { + switch ad.Type() { + case unix.NFTA_HOOK_HOOKNUM: + hooknum = ChainHook(ad.Uint32()) + case unix.NFTA_HOOK_PRIORITY: + prio = ChainPriority(ad.Uint32()) + } + } + + return &hooknum, &prio, nil +} diff --git a/vendor/github.com/google/nftables/compat_policy.go b/vendor/github.com/google/nftables/compat_policy.go new file mode 100644 index 000000000..c1f390855 --- /dev/null +++ b/vendor/github.com/google/nftables/compat_policy.go @@ -0,0 +1,89 @@ +package nftables + +import ( + "fmt" + + "github.com/google/nftables/expr" + "golang.org/x/sys/unix" +) + +const nft_RULE_COMPAT_F_INV uint32 = (1 << 1) +const nft_RULE_COMPAT_F_MASK uint32 = nft_RULE_COMPAT_F_INV + +// Used by xt match or target like xt_tcpudp to set compat policy between xtables and nftables +// https://elixir.bootlin.com/linux/v5.12/source/net/netfilter/nft_compat.c#L187 +type compatPolicy struct { + Proto uint32 + Flag uint32 +} + +var xtMatchCompatMap map[string]*compatPolicy = map[string]*compatPolicy{ + "tcp": { + Proto: unix.IPPROTO_TCP, + }, + "udp": { + Proto: unix.IPPROTO_UDP, + }, + "udplite": { + Proto: unix.IPPROTO_UDPLITE, + }, + "tcpmss": { + Proto: unix.IPPROTO_TCP, + }, + "sctp": { + Proto: unix.IPPROTO_SCTP, + }, + "osf": { + Proto: unix.IPPROTO_TCP, + }, + "ipcomp": { + Proto: unix.IPPROTO_COMP, + }, + "esp": { + Proto: unix.IPPROTO_ESP, + }, +} + +var xtTargetCompatMap map[string]*compatPolicy = map[string]*compatPolicy{ + "TCPOPTSTRIP": { + Proto: unix.IPPROTO_TCP, + }, + "TCPMSS": { + Proto: unix.IPPROTO_TCP, + }, +} + +func getCompatPolicy(exprs []expr.Any) (*compatPolicy, error) { + var exprItem expr.Any + var compat *compatPolicy + + for _, iter := range exprs { + var tmpExprItem expr.Any + var tmpCompat *compatPolicy + switch item := iter.(type) { + case *expr.Match: + if compat, ok := xtMatchCompatMap[item.Name]; ok { + tmpCompat = compat + tmpExprItem = item + } else { + continue + } + case *expr.Target: + if compat, ok := xtTargetCompatMap[item.Name]; ok { + tmpCompat = compat + tmpExprItem = item + } else { + continue + } + default: + continue + } + if compat == nil { + compat = tmpCompat + exprItem = tmpExprItem + } else if *compat != *tmpCompat { + return nil, fmt.Errorf("%#v and %#v has conflict compat policy %#v vs %#v", exprItem, tmpExprItem, compat, tmpCompat) + } + } + return compat, nil +} diff --git a/vendor/github.com/google/nftables/conn.go b/vendor/github.com/google/nftables/conn.go new file mode 100644 index 000000000..fef9c2a4e --- /dev/null +++ b/vendor/github.com/google/nftables/conn.go @@ -0,0 +1,371 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "errors" + "fmt" + "os" + "sync" + "syscall" + + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + "github.com/mdlayher/netlink" + "github.com/mdlayher/netlink/nltest" + "golang.org/x/sys/unix" +) + +// A Conn represents a netlink connection of the nftables family. +// +// All methods return their input, so that variables can be defined from string +// literals when desired. +// +// Commands are buffered. Flush sends all buffered commands in a single batch. +type Conn struct { + TestDial nltest.Func // for testing only; passed to nltest.Dial + NetNS int // fd referencing the network namespace netlink will interact with. + + lasting bool // establish a lasting connection to be used across multiple netlink operations. + mu sync.Mutex // protects the following state + messages []netlink.Message + err error + nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol. + sockOptions []SockOption +} + +// ConnOption is an option to change the behavior of the nftables Conn returned by Open. +type ConnOption func(*Conn) + +// SockOption is an option to change the behavior of the netlink socket used by the nftables Conn. +type SockOption func(*netlink.Conn) error + +// New returns a netlink connection for querying and modifying nftables. Some +// aspects of the new netlink connection can be configured using the options +// WithNetNSFd, WithTestDial, and AsLasting. +// +// A lasting netlink connection should be closed by calling CloseLasting() to +// close the underlying lasting netlink connection, cancelling all pending +// operations using this connection. +func New(opts ...ConnOption) (*Conn, error) { + cc := &Conn{} + for _, opt := range opts { + opt(cc) + } + + if !cc.lasting { + return cc, nil + } + + nlconn, err := cc.dialNetlink() + if err != nil { + return nil, err + } + cc.nlconn = nlconn + return cc, nil +} + +// AsLasting creates the new netlink connection as a lasting connection that is +// reused across multiple netlink operations, instead of opening and closing the +// underlying netlink connection only for the duration of a single netlink +// operation. +func AsLasting() ConnOption { + return func(cc *Conn) { + // We cannot create the underlying connection yet, as we are called + // anywhere in the option processing chain and there might be later + // options still modifying connection behavior. + cc.lasting = true + } +} + +// WithNetNSFd sets the network namespace to create a new netlink connection to: +// the fd must reference a network namespace. +func WithNetNSFd(fd int) ConnOption { + return func(cc *Conn) { + cc.NetNS = fd + } +} + +// WithTestDial sets the specified nltest.Func when creating a new netlink +// connection. +func WithTestDial(f nltest.Func) ConnOption { + return func(cc *Conn) { + cc.TestDial = f + } +} + +// WithSockOptions sets the specified socket options when creating a new netlink +// connection. +func WithSockOptions(opts ...SockOption) ConnOption { + return func(cc *Conn) { + cc.sockOptions = append(cc.sockOptions, opts...) + } +} + +// netlinkCloser is returned by netlinkConn(UnderLock) and must be called after +// being done with the returned netlink connection in order to properly close +// this connection, if necessary. +type netlinkCloser func() error + +// netlinkConn returns a netlink connection together with a netlinkCloser that +// later must be called by the caller when it doesn't need the returned netlink +// connection anymore. The netlinkCloser will close the netlink connection when +// necessary. If New has been told to create a lasting connection, then this +// lasting netlink connection will be returned, otherwise a new "transient" +// netlink connection will be opened and returned instead. netlinkConn must not +// be called while the Conn.mu lock is currently helt (this will cause a +// deadlock). Use netlinkConnUnderLock instead in such situations. +func (cc *Conn) netlinkConn() (*netlink.Conn, netlinkCloser, error) { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.netlinkConnUnderLock() +} + +// netlinkConnUnderLock works like netlinkConn but must be called while holding +// the Conn.mu lock. +func (cc *Conn) netlinkConnUnderLock() (*netlink.Conn, netlinkCloser, error) { + if cc.nlconn != nil { + return cc.nlconn, func() error { return nil }, nil + } + nlconn, err := cc.dialNetlink() + if err != nil { + return nil, nil, err + } + return nlconn, func() error { return nlconn.Close() }, nil +} + +func receiveAckAware(nlconn *netlink.Conn, sentMsgFlags netlink.HeaderFlags) ([]netlink.Message, error) { + if nlconn == nil { + return nil, errors.New("netlink conn is not initialized") + } + + // first receive will be the message that we expect + reply, err := nlconn.Receive() + if err != nil { + return nil, err + } + + if (sentMsgFlags & netlink.Acknowledge) == 0 { + // we did not request an ack + return reply, nil + } + + if (sentMsgFlags & netlink.Dump) == netlink.Dump { + // sent message has Dump flag set, there will be no acks + // https://github.com/torvalds/linux/blob/7e062cda7d90543ac8c7700fc7c5527d0c0f22ad/net/netlink/af_netlink.c#L2387-L2390 + return reply, nil + } + + if len(reply) != 0 { + last := reply[len(reply)-1] + for re := last.Header.Type; (re&netlink.Overrun) == netlink.Overrun && (re&netlink.Done) != netlink.Done; re = last.Header.Type { + // we are not finished, the message is overrun + r, err := nlconn.Receive() + if err != nil { + return nil, err + } + reply = append(reply, r...) + last = reply[len(reply)-1] + } + + if last.Header.Type == netlink.Error && binaryutil.BigEndian.Uint32(last.Data[:4]) == 0 { + // we have already collected an ack + return reply, nil + } + } + + // Now we expect an ack + ack, err := nlconn.Receive() + if err != nil { + return nil, err + } + + if len(ack) == 0 { + // received an empty ack? + return reply, nil + } + + msg := ack[0] + if msg.Header.Type != netlink.Error { + // acks should be delivered as NLMSG_ERROR + return nil, fmt.Errorf("expected header %v, but got %v", netlink.Error, msg.Header.Type) + } + + if binaryutil.BigEndian.Uint32(msg.Data[:4]) != 0 { + // if errno field is not set to 0 (success), this is an error + return nil, fmt.Errorf("error delivered in message: %v", msg.Data) + } + + return reply, nil +} + +// CloseLasting closes the lasting netlink connection that has been opened using +// AsLasting option when creating this connection. If either no lasting netlink +// connection has been opened or the lasting connection is already in the +// process of closing or has been closed, CloseLasting will immediately return +// without any error. +// +// CloseLasting will terminate all pending netlink operations using the lasting +// connection. +// +// After closing a lasting connection, the connection will revert to using +// on-demand transient netlink connections when calling further netlink +// operations (such as GetTables). +func (cc *Conn) CloseLasting() error { + // Don't acquire the lock for the whole duration of the CloseLasting + // operation, but instead only so long as to make sure to only run the + // netlink socket close on the first time with a lasting netlink socket. As + // there is only the New() constructor, but no Open() method, it's + // impossible to reopen a lasting connection. + cc.mu.Lock() + nlconn := cc.nlconn + cc.nlconn = nil + cc.mu.Unlock() + if nlconn != nil { + return nlconn.Close() + } + return nil +} + +// Flush sends all buffered commands in a single batch to nftables. +func (cc *Conn) Flush() error { + cc.mu.Lock() + defer func() { + cc.messages = nil + cc.mu.Unlock() + }() + if len(cc.messages) == 0 { + // Messages were already programmed, returning nil + return nil + } + if cc.err != nil { + return cc.err // serialization error + } + conn, closer, err := cc.netlinkConnUnderLock() + if err != nil { + return err + } + defer func() { _ = closer() }() + + if _, err := conn.SendMessages(batch(cc.messages)); err != nil { + return fmt.Errorf("SendMessages: %w", err) + } + + var errs error + // Fetch the requested acknowledgement for each message we sent. + for _, msg := range cc.messages { + if _, err := receiveAckAware(conn, msg.Header.Flags); err != nil { + if errors.Is(err, os.ErrPermission) || errors.Is(err, syscall.ENOBUFS) { + // Kernel will only send one error to user space. + return err + } + errs = errors.Join(errs, err) + } + } + + if errs != nil { + return fmt.Errorf("conn.Receive: %w", errs) + } + + return nil +} + +// FlushRuleset flushes the entire ruleset. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Operations_at_ruleset_level +func (cc *Conn) FlushRuleset() { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: extraHeader(0, 0), + }) +} + +func (cc *Conn) dialNetlink() (*netlink.Conn, error) { + var ( + conn *netlink.Conn + err error + ) + + if cc.TestDial != nil { + conn = nltest.Dial(cc.TestDial) + } else { + conn, err = netlink.Dial(unix.NETLINK_NETFILTER, &netlink.Config{NetNS: cc.NetNS}) + } + + if err != nil { + return nil, err + } + + for _, opt := range cc.sockOptions { + if err := opt(conn); err != nil { + return nil, err + } + } + + return conn, nil +} + +func (cc *Conn) setErr(err error) { + if cc.err != nil { + return + } + cc.err = err +} + +func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte { + b, err := netlink.MarshalAttributes(attrs) + if err != nil { + cc.setErr(err) + return nil + } + return b +} + +func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte { + b, err := expr.Marshal(fam, e) + if err != nil { + cc.setErr(err) + return nil + } + return b +} + +func batch(messages []netlink.Message) []netlink.Message { + batch := []netlink.Message{ + { + Header: netlink.Header{ + Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), + Flags: netlink.Request, + }, + Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), + }, + } + + batch = append(batch, messages...) + + batch = append(batch, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), + Flags: netlink.Request, + }, + Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), + }) + + return batch +} diff --git a/vendor/github.com/google/nftables/counter.go b/vendor/github.com/google/nftables/counter.go new file mode 100644 index 000000000..d18fc49c5 --- /dev/null +++ b/vendor/github.com/google/nftables/counter.go @@ -0,0 +1,63 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "github.com/google/nftables/expr" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type CounterObj struct { + Table *Table + Name string // e.g. “fwded” + + Bytes uint64 + Packets uint64 +} + +func (c *CounterObj) unmarshal(ad *netlink.AttributeDecoder) error { + for ad.Next() { + switch ad.Type() { + case unix.NFTA_COUNTER_BYTES: + c.Bytes = ad.Uint64() + case unix.NFTA_COUNTER_PACKETS: + c.Packets = ad.Uint64() + } + } + return ad.Err() +} + +func (c *CounterObj) data() expr.Any { + return &expr.Counter{ + Bytes: c.Bytes, + Packets: c.Packets, + } +} + +func (c *CounterObj) name() string { + return c.Name +} +func (c *CounterObj) objType() ObjType { + return ObjTypeCounter +} + +func (c *CounterObj) table() *Table { + return c.Table +} + +func (c *CounterObj) family() TableFamily { + return c.Table.Family +} diff --git a/vendor/github.com/google/nftables/doc.go b/vendor/github.com/google/nftables/doc.go new file mode 100644 index 000000000..41985b35e --- /dev/null +++ b/vendor/github.com/google/nftables/doc.go @@ -0,0 +1,16 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package nftables manipulates Linux nftables (the iptables successor). +package nftables diff --git a/vendor/github.com/google/nftables/expr/bitwise.go b/vendor/github.com/google/nftables/expr/bitwise.go new file mode 100644 index 000000000..5f3cdeac7 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/bitwise.go @@ -0,0 +1,106 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Bitwise struct { + SourceRegister uint32 + DestRegister uint32 + Len uint32 + Mask []byte + Xor []byte +} + +func (e *Bitwise) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("bitwise\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Bitwise) marshalData(fam byte) ([]byte, error) { + mask, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: e.Mask}, + }) + if err != nil { + return nil, err + } + xor, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: e.Xor}, + }) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_BITWISE_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}, + {Type: unix.NFTA_BITWISE_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}, + {Type: unix.NFTA_BITWISE_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)}, + {Type: unix.NLA_F_NESTED | unix.NFTA_BITWISE_MASK, Data: mask}, + {Type: unix.NLA_F_NESTED | unix.NFTA_BITWISE_XOR, Data: xor}, + }) +} + +func (e *Bitwise) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_BITWISE_SREG: + e.SourceRegister = ad.Uint32() + case unix.NFTA_BITWISE_DREG: + e.DestRegister = ad.Uint32() + case unix.NFTA_BITWISE_LEN: + e.Len = ad.Uint32() + case unix.NFTA_BITWISE_MASK: + // Since NFTA_BITWISE_MASK is nested, it requires additional decoding + ad.Nested(func(nad *netlink.AttributeDecoder) error { + for nad.Next() { + switch nad.Type() { + case unix.NFTA_DATA_VALUE: + e.Mask = nad.Bytes() + } + } + return nil + }) + case unix.NFTA_BITWISE_XOR: + // Since NFTA_BITWISE_XOR is nested, it requires additional decoding + ad.Nested(func(nad *netlink.AttributeDecoder) error { + for nad.Next() { + switch nad.Type() { + case unix.NFTA_DATA_VALUE: + e.Xor = nad.Bytes() + } + } + return nil + }) + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/byteorder.go b/vendor/github.com/google/nftables/expr/byteorder.go new file mode 100644 index 000000000..cf9e2fe58 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/byteorder.go @@ -0,0 +1,63 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type ByteorderOp uint32 + +const ( + ByteorderNtoh ByteorderOp = unix.NFT_BYTEORDER_NTOH + ByteorderHton ByteorderOp = unix.NFT_BYTEORDER_HTON +) + +type Byteorder struct { + SourceRegister uint32 + DestRegister uint32 + Op ByteorderOp + Len uint32 + Size uint32 +} + +func (e *Byteorder) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("byteorder\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Byteorder) marshalData(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_BYTEORDER_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}, + {Type: unix.NFTA_BYTEORDER_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}, + {Type: unix.NFTA_BYTEORDER_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))}, + {Type: unix.NFTA_BYTEORDER_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)}, + {Type: unix.NFTA_BYTEORDER_SIZE, Data: binaryutil.BigEndian.PutUint32(e.Size)}, + }) +} + +func (e *Byteorder) unmarshal(fam byte, data []byte) error { + return fmt.Errorf("not yet implemented") +} diff --git a/vendor/github.com/google/nftables/expr/connlimit.go b/vendor/github.com/google/nftables/expr/connlimit.go new file mode 100644 index 000000000..11bd07bf4 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/connlimit.go @@ -0,0 +1,74 @@ +// Copyright 2019 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +const ( + // Per https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n1167 + NFTA_CONNLIMIT_UNSPEC = iota + NFTA_CONNLIMIT_COUNT + NFTA_CONNLIMIT_FLAGS + NFT_CONNLIMIT_F_INV = 1 +) + +// Per https://git.netfilter.org/libnftnl/tree/src/expr/connlimit.c?id=84d12cfacf8ddd857a09435f3d982ab6250d250c +type Connlimit struct { + Count uint32 + Flags uint32 +} + +func (e *Connlimit) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("connlimit\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Connlimit) marshalData(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: NFTA_CONNLIMIT_COUNT, Data: binaryutil.BigEndian.PutUint32(e.Count)}, + {Type: NFTA_CONNLIMIT_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)}, + }) +} + +func (e *Connlimit) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTA_CONNLIMIT_COUNT: + e.Count = binaryutil.BigEndian.Uint32(ad.Bytes()) + case NFTA_CONNLIMIT_FLAGS: + e.Flags = binaryutil.BigEndian.Uint32(ad.Bytes()) + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/counter.go b/vendor/github.com/google/nftables/expr/counter.go new file mode 100644 index 000000000..7483ee45b --- /dev/null +++ b/vendor/github.com/google/nftables/expr/counter.go @@ -0,0 +1,64 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Counter struct { + Bytes uint64 + Packets uint64 +} + +func (e *Counter) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("counter\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Counter) marshalData(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_COUNTER_BYTES, Data: binaryutil.BigEndian.PutUint64(e.Bytes)}, + {Type: unix.NFTA_COUNTER_PACKETS, Data: binaryutil.BigEndian.PutUint64(e.Packets)}, + }) +} + +func (e *Counter) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_COUNTER_BYTES: + e.Bytes = ad.Uint64() + case unix.NFTA_COUNTER_PACKETS: + e.Packets = ad.Uint64() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/ct.go b/vendor/github.com/google/nftables/expr/ct.go new file mode 100644 index 000000000..127b6fdd0 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/ct.go @@ -0,0 +1,400 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// CtKey specifies which piece of conntrack information should be loaded. See +// also https://wiki.nftables.org/wiki-nftables/index.php/Matching_connection_tracking_stateful_metainformation +type CtKey uint32 + +// Possible CtKey values. +const ( + CtKeySTATE CtKey = unix.NFT_CT_STATE + CtKeyDIRECTION CtKey = unix.NFT_CT_DIRECTION + CtKeySTATUS CtKey = unix.NFT_CT_STATUS + CtKeyMARK CtKey = unix.NFT_CT_MARK + CtKeySECMARK CtKey = unix.NFT_CT_SECMARK + CtKeyEXPIRATION CtKey = unix.NFT_CT_EXPIRATION + CtKeyHELPER CtKey = unix.NFT_CT_HELPER + CtKeyL3PROTOCOL CtKey = unix.NFT_CT_L3PROTOCOL + CtKeySRC CtKey = unix.NFT_CT_SRC + CtKeyDST CtKey = unix.NFT_CT_DST + CtKeyPROTOCOL CtKey = unix.NFT_CT_PROTOCOL + CtKeyPROTOSRC CtKey = unix.NFT_CT_PROTO_SRC + CtKeyPROTODST CtKey = unix.NFT_CT_PROTO_DST + CtKeyLABELS CtKey = unix.NFT_CT_LABELS + CtKeyPKTS CtKey = unix.NFT_CT_PKTS + CtKeyBYTES CtKey = unix.NFT_CT_BYTES + CtKeyAVGPKT CtKey = unix.NFT_CT_AVGPKT + CtKeyZONE CtKey = unix.NFT_CT_ZONE + CtKeyEVENTMASK CtKey = unix.NFT_CT_EVENTMASK + + // https://sources.debian.org/src//nftables/0.9.8-3/src/ct.c/?hl=39#L39 + CtStateBitINVALID uint32 = 1 + CtStateBitESTABLISHED uint32 = 2 + CtStateBitRELATED uint32 = 4 + CtStateBitNEW uint32 = 8 + CtStateBitUNTRACKED uint32 = 64 +) + +// Missing ct timeout consts +// https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=be0bae0ad31b0adb506f96de083f52a2bd0d4fbf#n1592 +const ( + NFTA_CT_TIMEOUT_L3PROTO = 0x01 + NFTA_CT_TIMEOUT_L4PROTO = 0x02 + NFTA_CT_TIMEOUT_DATA = 0x03 +) + +type CtStatePolicyTimeout map[uint16]uint32 + +const ( + // https://git.netfilter.org/libnftnl/tree/src/obj/ct_timeout.c?id=116e95aa7b6358c917de8c69f6f173874030b46b#n24 + CtStateTCPSYNSENT = iota + CtStateTCPSYNRECV + CtStateTCPESTABLISHED + CtStateTCPFINWAIT + CtStateTCPCLOSEWAIT + CtStateTCPLASTACK + CtStateTCPTIMEWAIT + CtStateTCPCLOSE + CtStateTCPSYNSENT2 + CtStateTCPRETRANS + CtStateTCPUNACK +) + +// https://git.netfilter.org/libnftnl/tree/src/obj/ct_timeout.c?id=116e95aa7b6358c917de8c69f6f173874030b46b#n38 +var CtStateTCPTimeoutDefaults CtStatePolicyTimeout = map[uint16]uint32{ + CtStateTCPSYNSENT: 120, + CtStateTCPSYNRECV: 60, + CtStateTCPESTABLISHED: 43200, + CtStateTCPFINWAIT: 120, + CtStateTCPCLOSEWAIT: 60, + CtStateTCPLASTACK: 30, + CtStateTCPTIMEWAIT: 120, + CtStateTCPCLOSE: 10, + CtStateTCPSYNSENT2: 120, + CtStateTCPRETRANS: 300, + CtStateTCPUNACK: 300, +} + +const ( + // https://git.netfilter.org/libnftnl/tree/src/obj/ct_timeout.c?id=116e95aa7b6358c917de8c69f6f173874030b46b#n57 + CtStateUDPUNREPLIED = iota + CtStateUDPREPLIED +) + +// https://git.netfilter.org/libnftnl/tree/src/obj/ct_timeout.c?id=116e95aa7b6358c917de8c69f6f173874030b46b#n57 +var CtStateUDPTimeoutDefaults CtStatePolicyTimeout = map[uint16]uint32{ + CtStateUDPUNREPLIED: 30, + CtStateUDPREPLIED: 180, +} + +// Ct defines type for NFT connection tracking +type Ct struct { + Register uint32 + SourceRegister bool + Key CtKey + Direction uint32 +} + +func (e *Ct) marshal(fam byte) ([]byte, error) { + exprData, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("ct\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (e *Ct) marshalData(fam byte) ([]byte, error) { + var regData []byte + exprData, err := netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: unix.NFTA_CT_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, + }, + ) + if err != nil { + return nil, err + } + if e.SourceRegister { + regData, err = netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: unix.NFTA_CT_SREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + }, + ) + } else { + regData, err = netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: unix.NFTA_CT_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + }, + ) + } + if err != nil { + return nil, err + } + exprData = append(exprData, regData...) + + switch e.Key { + case CtKeySRC, CtKeyDST, CtKeyPROTOSRC, CtKeyPROTODST: + regData, err = netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: unix.NFTA_CT_DIRECTION, Data: binaryutil.BigEndian.PutUint32(e.Direction)}, + }, + ) + if err != nil { + return nil, err + } + exprData = append(exprData, regData...) + } + + return exprData, nil +} + +func (e *Ct) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_CT_KEY: + e.Key = CtKey(ad.Uint32()) + case unix.NFTA_CT_DREG: + e.Register = ad.Uint32() + case unix.NFTA_CT_DIRECTION: + e.Direction = ad.Uint32() + } + } + return ad.Err() +} + +type CtHelper struct { + Name string + L3Proto uint16 + L4Proto uint8 +} + +func (c *CtHelper) marshal(fam byte) ([]byte, error) { + exprData, err := c.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("cthelper\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (c *CtHelper) marshalData(fam byte) ([]byte, error) { + exprData := []netlink.Attribute{ + {Type: unix.NFTA_CT_HELPER_NAME, Data: []byte(c.Name)}, + } + + if c.L3Proto != 0 { + exprData = append(exprData, netlink.Attribute{ + Type: unix.NFTA_CT_HELPER_L3PROTO, Data: binaryutil.BigEndian.PutUint16(c.L3Proto), + }) + } + if c.L4Proto != 0 { + exprData = append(exprData, netlink.Attribute{ + Type: unix.NFTA_CT_HELPER_L4PROTO, Data: []byte{c.L4Proto}, + }) + } + + return netlink.MarshalAttributes(exprData) +} + +func (c *CtHelper) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_CT_HELPER_NAME: + c.Name = ad.String() + case unix.NFTA_CT_HELPER_L3PROTO: + c.L3Proto = ad.Uint16() + case unix.NFTA_CT_HELPER_L4PROTO: + c.L4Proto = ad.Uint8() + } + } + return ad.Err() +} + +// From https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=be0bae0ad31b0adb506f96de083f52a2bd0d4fbf#n1601 +// Currently not available in sys/unix +const ( + NFTA_CT_EXPECT_L3PROTO = 0x01 + NFTA_CT_EXPECT_L4PROTO = 0x02 + NFTA_CT_EXPECT_DPORT = 0x03 + NFTA_CT_EXPECT_TIMEOUT = 0x04 + NFTA_CT_EXPECT_SIZE = 0x05 +) + +type CtExpect struct { + L3Proto uint16 + L4Proto uint8 + DPort uint16 + Timeout uint32 + Size uint8 +} + +func (c *CtExpect) marshal(fam byte) ([]byte, error) { + exprData, err := c.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("ctexpect\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (c *CtExpect) marshalData(fam byte) ([]byte, error) { + // all elements except l3proto must be defined + // per https://git.netfilter.org/nftables/tree/doc/stateful-objects.txt?id=db70959a5ccf2952b218f51c3d529e186a5a43bb#n119 + // from man page: l3proto is derived from the table family by default + exprData := []netlink.Attribute{ + {Type: NFTA_CT_EXPECT_L4PROTO, Data: []byte{c.L4Proto}}, + {Type: NFTA_CT_EXPECT_DPORT, Data: binaryutil.BigEndian.PutUint16(c.DPort)}, + {Type: NFTA_CT_EXPECT_TIMEOUT, Data: binaryutil.BigEndian.PutUint32(c.Timeout)}, + {Type: NFTA_CT_EXPECT_SIZE, Data: []byte{c.Size}}, + } + + if c.L3Proto != 0 { + attr := netlink.Attribute{Type: NFTA_CT_EXPECT_L3PROTO, Data: binaryutil.BigEndian.PutUint16(c.L3Proto)} + exprData = append(exprData, attr) + } + return netlink.MarshalAttributes(exprData) +} + +func (c *CtExpect) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTA_CT_EXPECT_L3PROTO: + c.L3Proto = ad.Uint16() + case NFTA_CT_EXPECT_L4PROTO: + c.L4Proto = ad.Uint8() + case NFTA_CT_EXPECT_DPORT: + c.DPort = ad.Uint16() + case NFTA_CT_EXPECT_TIMEOUT: + c.Timeout = ad.Uint32() + case NFTA_CT_EXPECT_SIZE: + c.Size = ad.Uint8() + } + } + return ad.Err() +} + +type CtTimeout struct { + L3Proto uint16 + L4Proto uint8 + Policy CtStatePolicyTimeout +} + +func (c *CtTimeout) marshal(fam byte) ([]byte, error) { + exprData, err := c.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("cttimeout\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (c *CtTimeout) marshalData(fam byte) ([]byte, error) { + var policy CtStatePolicyTimeout + switch c.L4Proto { + case unix.IPPROTO_UDP: + policy = CtStateUDPTimeoutDefaults + default: + policy = CtStateTCPTimeoutDefaults + } + + for k, v := range c.Policy { + policy[k] = v + } + + var policyAttrs []netlink.Attribute + for k, v := range policy { + policyAttrs = append(policyAttrs, netlink.Attribute{Type: k + 1, Data: binaryutil.BigEndian.PutUint32(v)}) + } + policyData, err := netlink.MarshalAttributes(policyAttrs) + if err != nil { + return nil, err + } + + exprData := []netlink.Attribute{ + {Type: NFTA_CT_TIMEOUT_L3PROTO, Data: binaryutil.BigEndian.PutUint16(c.L3Proto)}, + {Type: NFTA_CT_TIMEOUT_L4PROTO, Data: []byte{c.L4Proto}}, + {Type: unix.NLA_F_NESTED | NFTA_CT_TIMEOUT_DATA, Data: policyData}, + } + + return netlink.MarshalAttributes(exprData) +} + +func (c *CtTimeout) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTA_CT_TIMEOUT_L3PROTO: + c.L3Proto = ad.Uint16() + case NFTA_CT_TIMEOUT_L4PROTO: + c.L4Proto = ad.Uint8() + case NFTA_CT_TIMEOUT_DATA: + decoder, err := netlink.NewAttributeDecoder(ad.Bytes()) + decoder.ByteOrder = binary.BigEndian + if err != nil { + return err + } + for decoder.Next() { + switch c.L4Proto { + case unix.IPPROTO_UDP: + c.Policy = CtStateUDPTimeoutDefaults + default: + c.Policy = CtStateTCPTimeoutDefaults + } + c.Policy[decoder.Type()-1] = decoder.Uint32() + } + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/dup.go b/vendor/github.com/google/nftables/expr/dup.go new file mode 100644 index 000000000..9012fdaca --- /dev/null +++ b/vendor/github.com/google/nftables/expr/dup.go @@ -0,0 +1,70 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Dup struct { + RegAddr uint32 + RegDev uint32 + IsRegDevSet bool +} + +func (e *Dup) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("dup\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Dup) marshalData(fam byte) ([]byte, error) { + attrs := []netlink.Attribute{ + {Type: unix.NFTA_DUP_SREG_ADDR, Data: binaryutil.BigEndian.PutUint32(e.RegAddr)}, + } + + if e.IsRegDevSet { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_DUP_SREG_DEV, Data: binaryutil.BigEndian.PutUint32(e.RegDev)}) + } + + return netlink.MarshalAttributes(attrs) +} + +func (e *Dup) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_DUP_SREG_ADDR: + e.RegAddr = ad.Uint32() + case unix.NFTA_DUP_SREG_DEV: + e.RegDev = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/dynset.go b/vendor/github.com/google/nftables/expr/dynset.go new file mode 100644 index 000000000..0b3d9bf13 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/dynset.go @@ -0,0 +1,153 @@ +// Copyright 2020 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + "time" + + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/internal/parseexprfunc" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// Not yet supported by unix package +// https://cs.opensource.google/go/x/sys/+/c6bc011c:unix/ztypes_linux.go;l=2027-2036 +const ( + NFTA_DYNSET_EXPRESSIONS = 0xa + NFT_DYNSET_F_EXPR = (1 << 1) +) + +// Dynset represent a rule dynamically adding or updating a set or a map based on an incoming packet. +type Dynset struct { + SrcRegKey uint32 + SrcRegData uint32 + SetID uint32 + SetName string + Operation uint32 + Timeout time.Duration + Invert bool + Exprs []Any +} + +func (e *Dynset) marshal(fam byte) ([]byte, error) { + opData, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("dynset\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: opData}, + }) +} + +func (e *Dynset) marshalData(fam byte) ([]byte, error) { + // See: https://git.netfilter.org/libnftnl/tree/src/expr/dynset.c + var opAttrs []netlink.Attribute + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_KEY, Data: binaryutil.BigEndian.PutUint32(e.SrcRegKey)}) + if e.SrcRegData != 0 { + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_DATA, Data: binaryutil.BigEndian.PutUint32(e.SrcRegData)}) + } + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_OP, Data: binaryutil.BigEndian.PutUint32(e.Operation)}) + if e.Timeout != 0 { + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(e.Timeout.Milliseconds()))}) + } + var flags uint32 + if e.Invert { + flags |= unix.NFT_DYNSET_F_INV + } + + opAttrs = append(opAttrs, + netlink.Attribute{Type: unix.NFTA_DYNSET_SET_NAME, Data: []byte(e.SetName + "\x00")}, + netlink.Attribute{Type: unix.NFTA_DYNSET_SET_ID, Data: binaryutil.BigEndian.PutUint32(e.SetID)}) + + // Per https://git.netfilter.org/libnftnl/tree/src/expr/dynset.c?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n170 + if len(e.Exprs) > 0 { + switch len(e.Exprs) { + case 1: + exprData, err := Marshal(fam, e.Exprs[0]) + if err != nil { + return nil, err + } + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_EXPR, Data: exprData}) + default: + flags |= NFT_DYNSET_F_EXPR + var elemAttrs []netlink.Attribute + for _, ex := range e.Exprs { + exprData, err := Marshal(fam, ex) + if err != nil { + return nil, err + } + elemAttrs = append(elemAttrs, netlink.Attribute{Type: unix.NFTA_LIST_ELEM, Data: exprData}) + } + elemData, err := netlink.MarshalAttributes(elemAttrs) + if err != nil { + return nil, err + } + opAttrs = append(opAttrs, netlink.Attribute{Type: NFTA_DYNSET_EXPRESSIONS, Data: elemData}) + } + } + + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}) + return netlink.MarshalAttributes(opAttrs) +} + +func (e *Dynset) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_DYNSET_SET_NAME: + e.SetName = ad.String() + case unix.NFTA_DYNSET_SET_ID: + e.SetID = ad.Uint32() + case unix.NFTA_DYNSET_SREG_KEY: + e.SrcRegKey = ad.Uint32() + case unix.NFTA_DYNSET_SREG_DATA: + e.SrcRegData = ad.Uint32() + case unix.NFTA_DYNSET_OP: + e.Operation = ad.Uint32() + case unix.NFTA_DYNSET_TIMEOUT: + e.Timeout = time.Duration(time.Millisecond * time.Duration(ad.Uint64())) + case unix.NFTA_DYNSET_FLAGS: + e.Invert = (ad.Uint32() & unix.NFT_DYNSET_F_INV) != 0 + case unix.NFTA_DYNSET_EXPR: + exprs, err := parseexprfunc.ParseExprBytesFunc(fam, ad) + if err != nil { + return err + } + e.setInterfaceExprs(exprs) + case NFTA_DYNSET_EXPRESSIONS: + exprs, err := parseexprfunc.ParseExprMsgFunc(fam, ad.Bytes()) + if err != nil { + return err + } + e.setInterfaceExprs(exprs) + } + } + return ad.Err() +} + +func (e *Dynset) setInterfaceExprs(exprs []interface{}) { + e.Exprs = make([]Any, len(exprs)) + for i := range exprs { + e.Exprs[i] = exprs[i].(Any) + } +} diff --git a/vendor/github.com/google/nftables/expr/expr.go b/vendor/github.com/google/nftables/expr/expr.go new file mode 100644 index 000000000..0a363233b --- /dev/null +++ b/vendor/github.com/google/nftables/expr/expr.go @@ -0,0 +1,502 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package expr provides nftables rule expressions. +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/internal/parseexprfunc" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +func init() { + parseExprBytesCommonFunc := func(exprsFromBytesFunc func() ([]Any, error)) ([]interface{}, error) { + exprs, err := exprsFromBytesFunc() + if err != nil { + return nil, err + } + result := make([]interface{}, len(exprs)) + for idx, expr := range exprs { + result[idx] = expr + } + return result, nil + } + + parseexprfunc.ParseExprBytesFromNameFunc = func(fam byte, ad *netlink.AttributeDecoder, exprName string) ([]interface{}, error) { + return parseExprBytesCommonFunc(func() ([]Any, error) { + return exprsBytesFromName(fam, ad, exprName) + }) + } + parseexprfunc.ParseExprBytesFunc = func(fam byte, ad *netlink.AttributeDecoder) ([]interface{}, error) { + return parseExprBytesCommonFunc(func() ([]Any, error) { + return exprsFromBytes(fam, ad) + }) + } + parseexprfunc.ParseExprMsgFunc = func(fam byte, b []byte) ([]interface{}, error) { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + var exprs []interface{} + for ad.Next() { + e, err := parseexprfunc.ParseExprBytesFunc(fam, ad) + if err != nil { + return e, err + } + exprs = append(exprs, e...) + } + return exprs, ad.Err() + } +} + +// Marshal serializes the specified expression into a byte slice. +func Marshal(fam byte, e Any) ([]byte, error) { + return e.marshal(fam) +} + +func MarshalExprData(fam byte, e Any) ([]byte, error) { + return e.marshalData(fam) +} + +// Unmarshal fills an expression from the specified byte slice. +func Unmarshal(fam byte, data []byte, e Any) error { + return e.unmarshal(fam, data) +} + +// exprsBytesFromName parses raw expressions bytes +// based on provided expr name +func exprsBytesFromName(fam byte, ad *netlink.AttributeDecoder, name string) ([]Any, error) { + var exprs []Any + e := exprFromName(name) + ad.Do(func(b []byte) error { + if err := Unmarshal(fam, b, e); err != nil { + return err + } + exprs = append(exprs, e) + return nil + }) + return exprs, ad.Err() +} + +// exprsFromBytes parses nested raw expressions bytes +// to construct nftables expressions +func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder) ([]Any, error) { + var exprs []Any + + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + var name string + for ad.Next() { + switch ad.Type() { + case unix.NFTA_EXPR_NAME: + name = ad.String() + if name == "notrack" { + e := &Notrack{} + exprs = append(exprs, e) + } + case unix.NFTA_EXPR_DATA: + e := exprFromName(name) + if e == nil { + // TODO: introduce an opaque expression type so that users know + // something is here. + continue // unsupported expression type + } + ad.Do(func(b []byte) error { + if err := Unmarshal(fam, b, e); err != nil { + return err + } + // Verdict expressions are a special-case of immediate expressions, so + // if the expression is an immediate writing nothing into the verdict + // register (invalid), re-parse it as a verdict expression. + if imm, isImmediate := e.(*Immediate); isImmediate && imm.Register == unix.NFT_REG_VERDICT && len(imm.Data) == 0 { + e = &Verdict{} + if err := Unmarshal(fam, b, e); err != nil { + return err + } + } + exprs = append(exprs, e) + return nil + }) + } + } + return ad.Err() + }) + return exprs, ad.Err() +} + +func exprFromName(name string) Any { + var e Any + switch name { + case "ct": + e = &Ct{} + case "range": + e = &Range{} + case "meta": + e = &Meta{} + case "cmp": + e = &Cmp{} + case "counter": + e = &Counter{} + case "objref": + e = &Objref{} + case "payload": + e = &Payload{} + case "lookup": + e = &Lookup{} + case "immediate": + e = &Immediate{} + case "bitwise": + e = &Bitwise{} + case "redir": + e = &Redir{} + case "nat": + e = &NAT{} + case "limit": + e = &Limit{} + case "quota": + e = &Quota{} + case "dynset": + e = &Dynset{} + case "log": + e = &Log{} + case "exthdr": + e = &Exthdr{} + case "match": + e = &Match{} + case "target": + e = &Target{} + case "connlimit": + e = &Connlimit{} + case "queue": + e = &Queue{} + case "flow_offload": + e = &FlowOffload{} + case "reject": + e = &Reject{} + case "masq": + e = &Masq{} + case "hash": + e = &Hash{} + case "cthelper": + e = &CtHelper{} + case "synproxy": + e = &SynProxy{} + case "ctexpect": + e = &CtExpect{} + case "secmark": + e = &SecMark{} + case "cttimeout": + e = &CtTimeout{} + case "fib": + e = &Fib{} + case "numgen": + e = &Numgen{} + } + return e +} + +// Any is an interface implemented by any expression type. +type Any interface { + marshal(fam byte) ([]byte, error) + marshalData(fam byte) ([]byte, error) + unmarshal(fam byte, data []byte) error +} + +// MetaKey specifies which piece of meta information should be loaded. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Matching_packet_metainformation +type MetaKey uint32 + +// Possible MetaKey values. +const ( + MetaKeyLEN MetaKey = unix.NFT_META_LEN + MetaKeyPROTOCOL MetaKey = unix.NFT_META_PROTOCOL + MetaKeyPRIORITY MetaKey = unix.NFT_META_PRIORITY + MetaKeyMARK MetaKey = unix.NFT_META_MARK + MetaKeyIIF MetaKey = unix.NFT_META_IIF + MetaKeyOIF MetaKey = unix.NFT_META_OIF + MetaKeyIIFNAME MetaKey = unix.NFT_META_IIFNAME + MetaKeyOIFNAME MetaKey = unix.NFT_META_OIFNAME + MetaKeyIIFTYPE MetaKey = unix.NFT_META_IIFTYPE + MetaKeyOIFTYPE MetaKey = unix.NFT_META_OIFTYPE + MetaKeySKUID MetaKey = unix.NFT_META_SKUID + MetaKeySKGID MetaKey = unix.NFT_META_SKGID + MetaKeyNFTRACE MetaKey = unix.NFT_META_NFTRACE + MetaKeyRTCLASSID MetaKey = unix.NFT_META_RTCLASSID + MetaKeySECMARK MetaKey = unix.NFT_META_SECMARK + MetaKeyNFPROTO MetaKey = unix.NFT_META_NFPROTO + MetaKeyL4PROTO MetaKey = unix.NFT_META_L4PROTO + MetaKeyBRIIIFNAME MetaKey = unix.NFT_META_BRI_IIFNAME + MetaKeyBRIOIFNAME MetaKey = unix.NFT_META_BRI_OIFNAME + MetaKeyPKTTYPE MetaKey = unix.NFT_META_PKTTYPE + MetaKeyCPU MetaKey = unix.NFT_META_CPU + MetaKeyIIFGROUP MetaKey = unix.NFT_META_IIFGROUP + MetaKeyOIFGROUP MetaKey = unix.NFT_META_OIFGROUP + MetaKeyCGROUP MetaKey = unix.NFT_META_CGROUP + MetaKeyPRANDOM MetaKey = unix.NFT_META_PRANDOM +) + +// Meta loads packet meta information for later comparisons. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Matching_packet_metainformation +type Meta struct { + Key MetaKey + SourceRegister bool + Register uint32 +} + +func (e *Meta) marshal(fam byte) ([]byte, error) { + exprData, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("meta\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (e *Meta) marshalData(fam byte) ([]byte, error) { + var regData []byte + exprData, err := netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: unix.NFTA_META_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, + }, + ) + if err != nil { + return nil, err + } + if e.SourceRegister { + regData, err = netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: unix.NFTA_META_SREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + }, + ) + } else { + regData, err = netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: unix.NFTA_META_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + }, + ) + } + if err != nil { + return nil, err + } + exprData = append(exprData, regData...) + return exprData, nil +} + +func (e *Meta) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_META_SREG: + e.Register = ad.Uint32() + e.SourceRegister = true + case unix.NFTA_META_DREG: + e.Register = ad.Uint32() + case unix.NFTA_META_KEY: + e.Key = MetaKey(ad.Uint32()) + } + } + return ad.Err() +} + +// Masq (Masquerade) is a special case of SNAT, where the source address is +// automagically set to the address of the output interface. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Performing_Network_Address_Translation_(NAT)#Masquerading +type Masq struct { + Random bool + FullyRandom bool + Persistent bool + ToPorts bool + RegProtoMin uint32 + RegProtoMax uint32 +} + +const ( + // NF_NAT_RANGE_PROTO_RANDOM defines flag for a random masquerade + NF_NAT_RANGE_PROTO_RANDOM = unix.NF_NAT_RANGE_PROTO_RANDOM + // NF_NAT_RANGE_PROTO_RANDOM_FULLY defines flag for a fully random masquerade + NF_NAT_RANGE_PROTO_RANDOM_FULLY = unix.NF_NAT_RANGE_PROTO_RANDOM_FULLY + // NF_NAT_RANGE_PERSISTENT defines flag for a persistent masquerade + NF_NAT_RANGE_PERSISTENT = unix.NF_NAT_RANGE_PERSISTENT + // NF_NAT_RANGE_PREFIX defines flag for a prefix masquerade + NF_NAT_RANGE_PREFIX = unix.NF_NAT_RANGE_NETMAP + // NF_NAT_RANGE_PROTO_SPECIFIED defines flag for a specified range + NF_NAT_RANGE_PROTO_SPECIFIED = unix.NF_NAT_RANGE_PROTO_SPECIFIED +) + +func (e *Masq) marshal(fam byte) ([]byte, error) { + msgData, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("masq\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: msgData}, + }) +} + +func (e *Masq) marshalData(fam byte) ([]byte, error) { + msgData := []byte{} + if !e.ToPorts { + flags := uint32(0) + if e.Random { + flags |= NF_NAT_RANGE_PROTO_RANDOM + } + if e.FullyRandom { + flags |= NF_NAT_RANGE_PROTO_RANDOM_FULLY + } + if e.Persistent { + flags |= NF_NAT_RANGE_PERSISTENT + } + if flags != 0 { + flagsData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_MASQ_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}}) + if err != nil { + return nil, err + } + msgData = append(msgData, flagsData...) + } + } else { + regsData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_MASQ_REG_PROTO_MIN, Data: binaryutil.BigEndian.PutUint32(e.RegProtoMin)}}) + if err != nil { + return nil, err + } + msgData = append(msgData, regsData...) + if e.RegProtoMax != 0 { + regsData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_MASQ_REG_PROTO_MAX, Data: binaryutil.BigEndian.PutUint32(e.RegProtoMax)}}) + if err != nil { + return nil, err + } + msgData = append(msgData, regsData...) + } + } + return msgData, nil +} + +func (e *Masq) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_MASQ_REG_PROTO_MIN: + e.ToPorts = true + e.RegProtoMin = ad.Uint32() + case unix.NFTA_MASQ_REG_PROTO_MAX: + e.RegProtoMax = ad.Uint32() + case unix.NFTA_MASQ_FLAGS: + flags := ad.Uint32() + e.Persistent = (flags & NF_NAT_RANGE_PERSISTENT) != 0 + e.Random = (flags & NF_NAT_RANGE_PROTO_RANDOM) != 0 + e.FullyRandom = (flags & NF_NAT_RANGE_PROTO_RANDOM_FULLY) != 0 + } + } + return ad.Err() +} + +// CmpOp specifies which type of comparison should be performed. +type CmpOp uint32 + +// Possible CmpOp values. +const ( + CmpOpEq CmpOp = unix.NFT_CMP_EQ + CmpOpNeq CmpOp = unix.NFT_CMP_NEQ + CmpOpLt CmpOp = unix.NFT_CMP_LT + CmpOpLte CmpOp = unix.NFT_CMP_LTE + CmpOpGt CmpOp = unix.NFT_CMP_GT + CmpOpGte CmpOp = unix.NFT_CMP_GTE +) + +// Cmp compares a register with the specified data. +type Cmp struct { + Op CmpOp + Register uint32 + Data []byte +} + +func (e *Cmp) marshal(fam byte) ([]byte, error) { + exprData, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("cmp\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (e *Cmp) marshalData(fam byte) ([]byte, error) { + cmpData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: e.Data}, + }) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_CMP_SREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + {Type: unix.NFTA_CMP_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))}, + {Type: unix.NLA_F_NESTED | unix.NFTA_CMP_DATA, Data: cmpData}, + }) +} + +func (e *Cmp) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_CMP_SREG: + e.Register = ad.Uint32() + case unix.NFTA_CMP_OP: + e.Op = CmpOp(ad.Uint32()) + case unix.NFTA_CMP_DATA: + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + if ad.Next() && ad.Type() == unix.NFTA_DATA_VALUE { + ad.Do(func(b []byte) error { + e.Data = b + return nil + }) + } + return ad.Err() + }) + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/exthdr.go b/vendor/github.com/google/nftables/expr/exthdr.go new file mode 100644 index 000000000..0a9d9fcde --- /dev/null +++ b/vendor/github.com/google/nftables/expr/exthdr.go @@ -0,0 +1,106 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type ExthdrOp uint32 + +const ( + ExthdrOpIpv6 ExthdrOp = unix.NFT_EXTHDR_OP_IPV6 + ExthdrOpTcpopt ExthdrOp = unix.NFT_EXTHDR_OP_TCPOPT +) + +type Exthdr struct { + DestRegister uint32 + Type uint8 + Offset uint32 + Len uint32 + Flags uint32 + Op ExthdrOp + SourceRegister uint32 +} + +func (e *Exthdr) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("exthdr\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Exthdr) marshalData(fam byte) ([]byte, error) { + var attr []netlink.Attribute + + // Operations are differentiated by the Op and whether the SourceRegister + // or DestRegister is set. Mixing them results in EOPNOTSUPP. + if e.SourceRegister != 0 { + attr = []netlink.Attribute{ + {Type: unix.NFTA_EXTHDR_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}} + } else { + attr = []netlink.Attribute{ + {Type: unix.NFTA_EXTHDR_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}} + } + + attr = append(attr, + netlink.Attribute{Type: unix.NFTA_EXTHDR_TYPE, Data: []byte{e.Type}}, + netlink.Attribute{Type: unix.NFTA_EXTHDR_OFFSET, Data: binaryutil.BigEndian.PutUint32(e.Offset)}, + netlink.Attribute{Type: unix.NFTA_EXTHDR_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)}, + netlink.Attribute{Type: unix.NFTA_EXTHDR_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))}) + + // Flags is only set if DREG is set + if e.DestRegister != 0 { + attr = append(attr, + netlink.Attribute{Type: unix.NFTA_EXTHDR_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)}) + } + + return netlink.MarshalAttributes(attr) +} + +func (e *Exthdr) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_EXTHDR_DREG: + e.DestRegister = ad.Uint32() + case unix.NFTA_EXTHDR_TYPE: + e.Type = ad.Uint8() + case unix.NFTA_EXTHDR_OFFSET: + e.Offset = ad.Uint32() + case unix.NFTA_EXTHDR_LEN: + e.Len = ad.Uint32() + case unix.NFTA_EXTHDR_FLAGS: + e.Flags = ad.Uint32() + case unix.NFTA_EXTHDR_OP: + e.Op = ExthdrOp(ad.Uint32()) + case unix.NFTA_EXTHDR_SREG: + e.SourceRegister = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/fib.go b/vendor/github.com/google/nftables/expr/fib.go new file mode 100644 index 000000000..ea6c059b5 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/fib.go @@ -0,0 +1,140 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// Fib defines fib expression structure +type Fib struct { + Register uint32 + ResultOIF bool + ResultOIFNAME bool + ResultADDRTYPE bool + FlagSADDR bool + FlagDADDR bool + FlagMARK bool + FlagIIF bool + FlagOIF bool + FlagPRESENT bool +} + +func (e *Fib) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("fib\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Fib) marshalData(fam byte) ([]byte, error) { + data := []byte{} + reg, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_FIB_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + }) + if err != nil { + return nil, err + } + data = append(data, reg...) + flags := uint32(0) + if e.FlagSADDR { + flags |= unix.NFTA_FIB_F_SADDR + } + if e.FlagDADDR { + flags |= unix.NFTA_FIB_F_DADDR + } + if e.FlagMARK { + flags |= unix.NFTA_FIB_F_MARK + } + if e.FlagIIF { + flags |= unix.NFTA_FIB_F_IIF + } + if e.FlagOIF { + flags |= unix.NFTA_FIB_F_OIF + } + if e.FlagPRESENT { + flags |= unix.NFTA_FIB_F_PRESENT + } + if flags != 0 { + flg, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_FIB_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}, + }) + if err != nil { + return nil, err + } + data = append(data, flg...) + } + results := uint32(0) + if e.ResultOIF { + results |= unix.NFT_FIB_RESULT_OIF + } + if e.ResultOIFNAME { + results |= unix.NFT_FIB_RESULT_OIFNAME + } + if e.ResultADDRTYPE { + results |= unix.NFT_FIB_RESULT_ADDRTYPE + } + if results != 0 { + rslt, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_FIB_RESULT, Data: binaryutil.BigEndian.PutUint32(results)}, + }) + if err != nil { + return nil, err + } + data = append(data, rslt...) + } + return data, nil +} + +func (e *Fib) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_FIB_DREG: + e.Register = ad.Uint32() + case unix.NFTA_FIB_RESULT: + result := ad.Uint32() + switch result { + case unix.NFT_FIB_RESULT_OIF: + e.ResultOIF = true + case unix.NFT_FIB_RESULT_OIFNAME: + e.ResultOIFNAME = true + case unix.NFT_FIB_RESULT_ADDRTYPE: + e.ResultADDRTYPE = true + } + case unix.NFTA_FIB_FLAGS: + flags := ad.Uint32() + e.FlagSADDR = (flags & unix.NFTA_FIB_F_SADDR) != 0 + e.FlagDADDR = (flags & unix.NFTA_FIB_F_DADDR) != 0 + e.FlagMARK = (flags & unix.NFTA_FIB_F_MARK) != 0 + e.FlagIIF = (flags & unix.NFTA_FIB_F_IIF) != 0 + e.FlagOIF = (flags & unix.NFTA_FIB_F_OIF) != 0 + e.FlagPRESENT = (flags & unix.NFTA_FIB_F_PRESENT) != 0 + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/flow_offload.go b/vendor/github.com/google/nftables/expr/flow_offload.go new file mode 100644 index 000000000..de4949a6d --- /dev/null +++ b/vendor/github.com/google/nftables/expr/flow_offload.go @@ -0,0 +1,63 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +const NFTNL_EXPR_FLOW_TABLE_NAME = 1 + +type FlowOffload struct { + Name string +} + +func (e *FlowOffload) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("flow_offload\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *FlowOffload) marshalData(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: NFTNL_EXPR_FLOW_TABLE_NAME, Data: []byte(e.Name)}, + }) +} + +func (e *FlowOffload) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTNL_EXPR_FLOW_TABLE_NAME: + e.Name = ad.String() + } + } + + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/hash.go b/vendor/github.com/google/nftables/expr/hash.go new file mode 100644 index 000000000..92b9eea34 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/hash.go @@ -0,0 +1,98 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type HashType uint32 + +const ( + HashTypeJenkins HashType = unix.NFT_HASH_JENKINS + HashTypeSym HashType = unix.NFT_HASH_SYM +) + +// Hash defines type for nftables internal hashing functions +type Hash struct { + SourceRegister uint32 + DestRegister uint32 + Length uint32 + Modulus uint32 + Seed uint32 + Offset uint32 + Type HashType +} + +func (e *Hash) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("hash\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Hash) marshalData(fam byte) ([]byte, error) { + hashAttrs := []netlink.Attribute{ + {Type: unix.NFTA_HASH_SREG, Data: binaryutil.BigEndian.PutUint32(uint32(e.SourceRegister))}, + {Type: unix.NFTA_HASH_DREG, Data: binaryutil.BigEndian.PutUint32(uint32(e.DestRegister))}, + {Type: unix.NFTA_HASH_LEN, Data: binaryutil.BigEndian.PutUint32(uint32(e.Length))}, + {Type: unix.NFTA_HASH_MODULUS, Data: binaryutil.BigEndian.PutUint32(uint32(e.Modulus))}, + } + if e.Seed != 0 { + hashAttrs = append(hashAttrs, netlink.Attribute{ + Type: unix.NFTA_HASH_SEED, Data: binaryutil.BigEndian.PutUint32(uint32(e.Seed)), + }) + } + hashAttrs = append(hashAttrs, []netlink.Attribute{ + {Type: unix.NFTA_HASH_OFFSET, Data: binaryutil.BigEndian.PutUint32(uint32(e.Offset))}, + {Type: unix.NFTA_HASH_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Type))}, + }...) + return netlink.MarshalAttributes(hashAttrs) +} + +func (e *Hash) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_HASH_SREG: + e.SourceRegister = ad.Uint32() + case unix.NFTA_HASH_DREG: + e.DestRegister = ad.Uint32() + case unix.NFTA_HASH_LEN: + e.Length = ad.Uint32() + case unix.NFTA_HASH_MODULUS: + e.Modulus = ad.Uint32() + case unix.NFTA_HASH_SEED: + e.Seed = ad.Uint32() + case unix.NFTA_HASH_OFFSET: + e.Offset = ad.Uint32() + case unix.NFTA_HASH_TYPE: + e.Type = HashType(ad.Uint32()) + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/immediate.go b/vendor/github.com/google/nftables/expr/immediate.go new file mode 100644 index 000000000..19eea4480 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/immediate.go @@ -0,0 +1,83 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Immediate struct { + Register uint32 + Data []byte +} + +func (e *Immediate) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("immediate\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Immediate) marshalData(fam byte) ([]byte, error) { + immData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: e.Data}, + }) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_IMMEDIATE_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + {Type: unix.NLA_F_NESTED | unix.NFTA_IMMEDIATE_DATA, Data: immData}, + }) +} + +func (e *Immediate) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_IMMEDIATE_DREG: + e.Register = ad.Uint32() + case unix.NFTA_IMMEDIATE_DATA: + nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes()) + if err != nil { + return fmt.Errorf("nested NewAttributeDecoder() failed: %v", err) + } + for nestedAD.Next() { + switch nestedAD.Type() { + case unix.NFTA_DATA_VALUE: + e.Data = nestedAD.Bytes() + } + } + if nestedAD.Err() != nil { + return fmt.Errorf("decoding immediate: %v", nestedAD.Err()) + } + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/limit.go b/vendor/github.com/google/nftables/expr/limit.go new file mode 100644 index 000000000..1e170ac31 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/limit.go @@ -0,0 +1,132 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + "errors" + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// LimitType represents the type of the limit expression. +type LimitType uint32 + +// Imported from the nft_limit_type enum in netfilter/nf_tables.h. +const ( + LimitTypePkts LimitType = unix.NFT_LIMIT_PKTS + LimitTypePktBytes LimitType = unix.NFT_LIMIT_PKT_BYTES +) + +// LimitTime represents the limit unit. +type LimitTime uint64 + +// Possible limit unit values. +const ( + LimitTimeSecond LimitTime = 1 + LimitTimeMinute LimitTime = 60 + LimitTimeHour LimitTime = 60 * 60 + LimitTimeDay LimitTime = 60 * 60 * 24 + LimitTimeWeek LimitTime = 60 * 60 * 24 * 7 +) + +func limitTime(value uint64) (LimitTime, error) { + switch LimitTime(value) { + case LimitTimeSecond: + return LimitTimeSecond, nil + case LimitTimeMinute: + return LimitTimeMinute, nil + case LimitTimeHour: + return LimitTimeHour, nil + case LimitTimeDay: + return LimitTimeDay, nil + case LimitTimeWeek: + return LimitTimeWeek, nil + default: + return 0, fmt.Errorf("expr: invalid limit unit value %d", value) + } +} + +// Limit represents a rate limit expression. +type Limit struct { + Type LimitType + Rate uint64 + Over bool + Unit LimitTime + Burst uint32 +} + +func (l *Limit) marshal(fam byte) ([]byte, error) { + data, err := l.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("limit\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (l *Limit) marshalData(fam byte) ([]byte, error) { + var flags uint32 + if l.Over { + flags = unix.NFT_LIMIT_F_INV + } + attrs := []netlink.Attribute{ + {Type: unix.NFTA_LIMIT_RATE, Data: binaryutil.BigEndian.PutUint64(l.Rate)}, + {Type: unix.NFTA_LIMIT_UNIT, Data: binaryutil.BigEndian.PutUint64(uint64(l.Unit))}, + {Type: unix.NFTA_LIMIT_BURST, Data: binaryutil.BigEndian.PutUint32(l.Burst)}, + {Type: unix.NFTA_LIMIT_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(l.Type))}, + {Type: unix.NFTA_LIMIT_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}, + } + + return netlink.MarshalAttributes(attrs) +} + +func (l *Limit) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_LIMIT_RATE: + l.Rate = ad.Uint64() + case unix.NFTA_LIMIT_UNIT: + l.Unit, err = limitTime(ad.Uint64()) + if err != nil { + return err + } + case unix.NFTA_LIMIT_BURST: + l.Burst = ad.Uint32() + case unix.NFTA_LIMIT_TYPE: + l.Type = LimitType(ad.Uint32()) + if l.Type != LimitTypePkts && l.Type != LimitTypePktBytes { + return fmt.Errorf("expr: invalid limit type %d", l.Type) + } + case unix.NFTA_LIMIT_FLAGS: + l.Over = (ad.Uint32() & unix.NFT_LIMIT_F_INV) != 0 + default: + return errors.New("expr: unhandled limit netlink attribute") + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/log.go b/vendor/github.com/google/nftables/expr/log.go new file mode 100644 index 000000000..eaa057a63 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/log.go @@ -0,0 +1,154 @@ +// Copyright 2019 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type LogLevel uint32 + +const ( + // See https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=5b364657a35f4e4cd5d220ba2a45303d729c8eca#n1226 + LogLevelEmerg LogLevel = iota + LogLevelAlert + LogLevelCrit + LogLevelErr + LogLevelWarning + LogLevelNotice + LogLevelInfo + LogLevelDebug + LogLevelAudit +) + +type LogFlags uint32 + +const ( + // See https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_log.h?id=5b364657a35f4e4cd5d220ba2a45303d729c8eca + LogFlagsTCPSeq LogFlags = 0x01 << iota + LogFlagsTCPOpt + LogFlagsIPOpt + LogFlagsUID + LogFlagsNFLog + LogFlagsMACDecode + LogFlagsMask LogFlags = 0x2f +) + +// Log defines type for NFT logging +// See https://git.netfilter.org/libnftnl/tree/src/expr/log.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n25 +type Log struct { + Level LogLevel + // Refers to log flags (flags all, flags ip options, ...) + Flags LogFlags + // Equivalent to expression flags. + // Indicates that an option is set by setting a bit + // on index referred by the NFTA_LOG_* value. + // See https://cs.opensource.google/go/x/sys/+/3681064d:unix/ztypes_linux.go;l=2126;drc=3681064d51587c1db0324b3d5c23c2ddbcff6e8f + Key uint32 + Snaplen uint32 + Group uint16 + QThreshold uint16 + // Log prefix string content + Data []byte +} + +func (e *Log) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("log\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Log) marshalData(fam byte) ([]byte, error) { + // Per https://git.netfilter.org/libnftnl/tree/src/expr/log.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n129 + attrs := make([]netlink.Attribute, 0) + if e.Key&(1<= /* sic! */ XTablesExtensionNameMaxLen { + name = name[:XTablesExtensionNameMaxLen-1] // leave room for trailing \x00. + } + // Marshalling assumes that the correct Info type for the particular table + // family and Match revision has been set. + info, err := xt.Marshal(xt.TableFamily(fam), e.Rev, e.Info) + if err != nil { + return nil, err + } + attrs := []netlink.Attribute{ + {Type: unix.NFTA_MATCH_NAME, Data: []byte(name + "\x00")}, + {Type: unix.NFTA_MATCH_REV, Data: binaryutil.BigEndian.PutUint32(e.Rev)}, + {Type: unix.NFTA_MATCH_INFO, Data: info}, + } + return netlink.MarshalAttributes(attrs) +} + +func (e *Match) unmarshal(fam byte, data []byte) error { + // Per https://git.netfilter.org/libnftnl/tree/src/expr/match.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n65 + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + var info []byte + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_MATCH_NAME: + // We are forgiving here, accepting any length and even missing terminating \x00. + e.Name = string(bytes.TrimRight(ad.Bytes(), "\x00")) + case unix.NFTA_MATCH_REV: + e.Rev = ad.Uint32() + case unix.NFTA_MATCH_INFO: + info = ad.Bytes() + } + } + if err = ad.Err(); err != nil { + return err + } + e.Info, err = xt.Unmarshal(e.Name, xt.TableFamily(fam), e.Rev, info) + return err +} diff --git a/vendor/github.com/google/nftables/expr/nat.go b/vendor/github.com/google/nftables/expr/nat.go new file mode 100644 index 000000000..3f28967be --- /dev/null +++ b/vendor/github.com/google/nftables/expr/nat.go @@ -0,0 +1,141 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type NATType uint32 + +// Possible NATType values. +const ( + NATTypeSourceNAT NATType = unix.NFT_NAT_SNAT + NATTypeDestNAT NATType = unix.NFT_NAT_DNAT +) + +type NAT struct { + Type NATType + Family uint32 // TODO: typed const + RegAddrMin uint32 + RegAddrMax uint32 + RegProtoMin uint32 + RegProtoMax uint32 + Random bool + FullyRandom bool + Persistent bool + Prefix bool + Specified bool +} + +// |00048|N-|00001| |len |flags| type| +// |00008|--|00001| |len |flags| type| +// | 6e 61 74 00 | | data | n a t +// |00036|N-|00002| |len |flags| type| +// |00008|--|00001| |len |flags| type| NFTA_NAT_TYPE +// | 00 00 00 01 | | data | NFT_NAT_DNAT +// |00008|--|00002| |len |flags| type| NFTA_NAT_FAMILY +// | 00 00 00 02 | | data | NFPROTO_IPV4 +// |00008|--|00003| |len |flags| type| NFTA_NAT_REG_ADDR_MIN +// | 00 00 00 01 | | data | reg 1 +// |00008|--|00005| |len |flags| type| NFTA_NAT_REG_PROTO_MIN +// | 00 00 00 02 | | data | reg 2 + +func (e *NAT) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("nat\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *NAT) marshalData(fam byte) ([]byte, error) { + attrs := []netlink.Attribute{ + {Type: unix.NFTA_NAT_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Type))}, + {Type: unix.NFTA_NAT_FAMILY, Data: binaryutil.BigEndian.PutUint32(e.Family)}, + } + if e.RegAddrMin != 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_NAT_REG_ADDR_MIN, Data: binaryutil.BigEndian.PutUint32(e.RegAddrMin)}) + if e.RegAddrMax != 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_NAT_REG_ADDR_MAX, Data: binaryutil.BigEndian.PutUint32(e.RegAddrMax)}) + } + } + if e.RegProtoMin != 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_NAT_REG_PROTO_MIN, Data: binaryutil.BigEndian.PutUint32(e.RegProtoMin)}) + if e.RegProtoMax != 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_NAT_REG_PROTO_MAX, Data: binaryutil.BigEndian.PutUint32(e.RegProtoMax)}) + } + } + flags := uint32(0) + if e.Random { + flags |= NF_NAT_RANGE_PROTO_RANDOM + } + if e.FullyRandom { + flags |= NF_NAT_RANGE_PROTO_RANDOM_FULLY + } + if e.Persistent { + flags |= NF_NAT_RANGE_PERSISTENT + } + if e.Prefix { + flags |= NF_NAT_RANGE_PREFIX + } + if e.Specified { + flags |= NF_NAT_RANGE_PROTO_SPECIFIED + } + if flags != 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_NAT_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}) + } + + return netlink.MarshalAttributes(attrs) +} + +func (e *NAT) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_NAT_TYPE: + e.Type = NATType(ad.Uint32()) + case unix.NFTA_NAT_FAMILY: + e.Family = ad.Uint32() + case unix.NFTA_NAT_REG_ADDR_MIN: + e.RegAddrMin = ad.Uint32() + case unix.NFTA_NAT_REG_ADDR_MAX: + e.RegAddrMax = ad.Uint32() + case unix.NFTA_NAT_REG_PROTO_MIN: + e.RegProtoMin = ad.Uint32() + case unix.NFTA_NAT_REG_PROTO_MAX: + e.RegProtoMax = ad.Uint32() + case unix.NFTA_NAT_FLAGS: + flags := ad.Uint32() + e.Persistent = (flags & NF_NAT_RANGE_PERSISTENT) != 0 + e.Random = (flags & NF_NAT_RANGE_PROTO_RANDOM) != 0 + e.FullyRandom = (flags & NF_NAT_RANGE_PROTO_RANDOM_FULLY) != 0 + e.Prefix = (flags & NF_NAT_RANGE_PREFIX) != 0 + e.Specified = (flags & NF_NAT_RANGE_PROTO_SPECIFIED) != 0 + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/notrack.go b/vendor/github.com/google/nftables/expr/notrack.go new file mode 100644 index 000000000..c17d7b586 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/notrack.go @@ -0,0 +1,42 @@ +// Copyright 2019 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Notrack struct{} + +func (e *Notrack) marshal(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("notrack\x00")}, + }) +} + +func (e *Notrack) marshalData(fam byte) ([]byte, error) { + return []byte("notrack\x00"), nil +} + +func (e *Notrack) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + + if err != nil { + return err + } + + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/numgen.go b/vendor/github.com/google/nftables/expr/numgen.go new file mode 100644 index 000000000..362ca0682 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/numgen.go @@ -0,0 +1,82 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// Numgen defines Numgen expression structure +type Numgen struct { + Register uint32 + Modulus uint32 + Type uint32 + Offset uint32 +} + +func (e *Numgen) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("numgen\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Numgen) marshalData(fam byte) ([]byte, error) { + // Currently only two types are supported, failing if Type is not of two known types + switch e.Type { + case unix.NFT_NG_INCREMENTAL: + case unix.NFT_NG_RANDOM: + default: + return nil, fmt.Errorf("unsupported numgen type %d", e.Type) + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_NG_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + {Type: unix.NFTA_NG_MODULUS, Data: binaryutil.BigEndian.PutUint32(e.Modulus)}, + {Type: unix.NFTA_NG_TYPE, Data: binaryutil.BigEndian.PutUint32(e.Type)}, + {Type: unix.NFTA_NG_OFFSET, Data: binaryutil.BigEndian.PutUint32(e.Offset)}, + }) +} + +func (e *Numgen) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_NG_DREG: + e.Register = ad.Uint32() + case unix.NFTA_NG_MODULUS: + e.Modulus = ad.Uint32() + case unix.NFTA_NG_TYPE: + e.Type = ad.Uint32() + case unix.NFTA_NG_OFFSET: + e.Offset = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/objref.go b/vendor/github.com/google/nftables/expr/objref.go new file mode 100644 index 000000000..e6d59749a --- /dev/null +++ b/vendor/github.com/google/nftables/expr/objref.go @@ -0,0 +1,64 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Objref struct { + Type int // TODO: enum + Name string +} + +func (e *Objref) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("objref\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Objref) marshalData(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_OBJREF_IMM_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Type))}, + {Type: unix.NFTA_OBJREF_IMM_NAME, Data: []byte(e.Name)}, // NOT \x00-terminated?! + }) +} + +func (e *Objref) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_OBJREF_IMM_TYPE: + e.Type = int(ad.Uint32()) + case unix.NFTA_OBJREF_IMM_NAME: + e.Name = ad.String() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/payload.go b/vendor/github.com/google/nftables/expr/payload.go new file mode 100644 index 000000000..b25128b55 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/payload.go @@ -0,0 +1,134 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type PayloadBase uint32 +type PayloadCsumType uint32 +type PayloadOperationType uint32 + +// Possible PayloadBase values. +const ( + PayloadBaseLLHeader PayloadBase = unix.NFT_PAYLOAD_LL_HEADER + PayloadBaseNetworkHeader PayloadBase = unix.NFT_PAYLOAD_NETWORK_HEADER + PayloadBaseTransportHeader PayloadBase = unix.NFT_PAYLOAD_TRANSPORT_HEADER +) + +// Possible PayloadCsumType values. +const ( + CsumTypeNone PayloadCsumType = unix.NFT_PAYLOAD_CSUM_NONE + CsumTypeInet PayloadCsumType = unix.NFT_PAYLOAD_CSUM_INET +) + +// Possible PayloadOperationType values. +const ( + PayloadLoad PayloadOperationType = iota + PayloadWrite +) + +type Payload struct { + OperationType PayloadOperationType + DestRegister uint32 + SourceRegister uint32 + Base PayloadBase + Offset uint32 + Len uint32 + CsumType PayloadCsumType + CsumOffset uint32 + CsumFlags uint32 +} + +func (e *Payload) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("payload\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Payload) marshalData(fam byte) ([]byte, error) { + var attrs []netlink.Attribute + + if e.OperationType == PayloadWrite { + attrs = []netlink.Attribute{ + {Type: unix.NFTA_PAYLOAD_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}, + } + } else { + attrs = []netlink.Attribute{ + {Type: unix.NFTA_PAYLOAD_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}, + } + } + + attrs = append(attrs, + netlink.Attribute{Type: unix.NFTA_PAYLOAD_BASE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Base))}, + netlink.Attribute{Type: unix.NFTA_PAYLOAD_OFFSET, Data: binaryutil.BigEndian.PutUint32(e.Offset)}, + netlink.Attribute{Type: unix.NFTA_PAYLOAD_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)}, + ) + + if e.CsumType > 0 { + attrs = append(attrs, + netlink.Attribute{Type: unix.NFTA_PAYLOAD_CSUM_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(e.CsumType))}, + netlink.Attribute{Type: unix.NFTA_PAYLOAD_CSUM_OFFSET, Data: binaryutil.BigEndian.PutUint32(uint32(e.CsumOffset))}, + ) + if e.CsumFlags > 0 { + attrs = append(attrs, + netlink.Attribute{Type: unix.NFTA_PAYLOAD_CSUM_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.CsumFlags)}, + ) + } + } + + return netlink.MarshalAttributes(attrs) +} + +func (e *Payload) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_PAYLOAD_DREG: + e.DestRegister = ad.Uint32() + case unix.NFTA_PAYLOAD_SREG: + e.SourceRegister = ad.Uint32() + e.OperationType = PayloadWrite + case unix.NFTA_PAYLOAD_BASE: + e.Base = PayloadBase(ad.Uint32()) + case unix.NFTA_PAYLOAD_OFFSET: + e.Offset = ad.Uint32() + case unix.NFTA_PAYLOAD_LEN: + e.Len = ad.Uint32() + case unix.NFTA_PAYLOAD_CSUM_TYPE: + e.CsumType = PayloadCsumType(ad.Uint32()) + case unix.NFTA_PAYLOAD_CSUM_OFFSET: + e.CsumOffset = ad.Uint32() + case unix.NFTA_PAYLOAD_CSUM_FLAGS: + e.CsumFlags = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/queue.go b/vendor/github.com/google/nftables/expr/queue.go new file mode 100644 index 000000000..0e9d8f43e --- /dev/null +++ b/vendor/github.com/google/nftables/expr/queue.go @@ -0,0 +1,85 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type QueueAttribute uint16 + +type QueueFlag uint16 + +// Possible QueueAttribute values +const ( + QueueNum QueueAttribute = unix.NFTA_QUEUE_NUM + QueueTotal QueueAttribute = unix.NFTA_QUEUE_TOTAL + QueueFlags QueueAttribute = unix.NFTA_QUEUE_FLAGS + + QueueFlagBypass QueueFlag = unix.NFT_QUEUE_FLAG_BYPASS + QueueFlagFanout QueueFlag = unix.NFT_QUEUE_FLAG_CPU_FANOUT + QueueFlagMask QueueFlag = unix.NFT_QUEUE_FLAG_MASK +) + +type Queue struct { + Num uint16 + Total uint16 + Flag QueueFlag +} + +func (e *Queue) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("queue\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Queue) marshalData(fam byte) ([]byte, error) { + if e.Total == 0 { + e.Total = 1 // The total default value is 1 + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_QUEUE_NUM, Data: binaryutil.BigEndian.PutUint16(e.Num)}, + {Type: unix.NFTA_QUEUE_TOTAL, Data: binaryutil.BigEndian.PutUint16(e.Total)}, + {Type: unix.NFTA_QUEUE_FLAGS, Data: binaryutil.BigEndian.PutUint16(uint16(e.Flag))}, + }) +} + +func (e *Queue) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_QUEUE_NUM: + e.Num = ad.Uint16() + case unix.NFTA_QUEUE_TOTAL: + e.Total = ad.Uint16() + case unix.NFTA_QUEUE_FLAGS: + e.Flag = QueueFlag(ad.Uint16()) + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/quota.go b/vendor/github.com/google/nftables/expr/quota.go new file mode 100644 index 000000000..ca55f6c91 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/quota.go @@ -0,0 +1,80 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// Quota defines a threshold against a number of bytes. +type Quota struct { + Bytes uint64 + Consumed uint64 + Over bool +} + +func (q *Quota) marshal(fam byte) ([]byte, error) { + data, err := q.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("quota\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (q *Quota) marshalData(fam byte) ([]byte, error) { + attrs := []netlink.Attribute{ + {Type: unix.NFTA_QUOTA_BYTES, Data: binaryutil.BigEndian.PutUint64(q.Bytes)}, + {Type: unix.NFTA_QUOTA_CONSUMED, Data: binaryutil.BigEndian.PutUint64(q.Consumed)}, + } + + flags := uint32(0) + if q.Over { + flags = unix.NFT_QUOTA_F_INV + } + attrs = append(attrs, netlink.Attribute{ + Type: unix.NFTA_QUOTA_FLAGS, + Data: binaryutil.BigEndian.PutUint32(flags), + }) + + return netlink.MarshalAttributes(attrs) +} + +func (q *Quota) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_QUOTA_BYTES: + q.Bytes = ad.Uint64() + case unix.NFTA_QUOTA_CONSUMED: + q.Consumed = ad.Uint64() + case unix.NFTA_QUOTA_FLAGS: + q.Over = (ad.Uint32() & unix.NFT_QUOTA_F_INV) != 0 + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/range.go b/vendor/github.com/google/nftables/expr/range.go new file mode 100644 index 000000000..bd6972aaf --- /dev/null +++ b/vendor/github.com/google/nftables/expr/range.go @@ -0,0 +1,132 @@ +// Copyright 2019 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// Range implements range expression +type Range struct { + Op CmpOp + Register uint32 + FromData []byte + ToData []byte +} + +func (e *Range) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("range\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Range) marshalData(fam byte) ([]byte, error) { + var attrs []netlink.Attribute + var err error + var rangeFromData, rangeToData []byte + + if e.Register > 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_RANGE_SREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}) + } + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_RANGE_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))}) + if len(e.FromData) > 0 { + rangeFromData, err = nestedAttr(e.FromData, unix.NFTA_RANGE_FROM_DATA) + if err != nil { + return nil, err + } + } + if len(e.ToData) > 0 { + rangeToData, err = nestedAttr(e.ToData, unix.NFTA_RANGE_TO_DATA) + if err != nil { + return nil, err + } + } + data, err := netlink.MarshalAttributes(attrs) + if err != nil { + return nil, err + } + data = append(data, rangeFromData...) + data = append(data, rangeToData...) + return data, nil +} + +func (e *Range) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_RANGE_OP: + e.Op = CmpOp(ad.Uint32()) + case unix.NFTA_RANGE_SREG: + e.Register = ad.Uint32() + case unix.NFTA_RANGE_FROM_DATA: + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + if ad.Next() && ad.Type() == unix.NFTA_DATA_VALUE { + ad.Do(func(b []byte) error { + e.FromData = b + return nil + }) + } + return ad.Err() + }) + case unix.NFTA_RANGE_TO_DATA: + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + if ad.Next() && ad.Type() == unix.NFTA_DATA_VALUE { + ad.Do(func(b []byte) error { + e.ToData = b + return nil + }) + } + return ad.Err() + }) + } + } + return ad.Err() +} + +func nestedAttr(data []byte, attrType uint16) ([]byte, error) { + ae := netlink.NewAttributeEncoder() + ae.Do(unix.NLA_F_NESTED|attrType, func() ([]byte, error) { + nae := netlink.NewAttributeEncoder() + nae.ByteOrder = binary.BigEndian + nae.Bytes(unix.NFTA_DATA_VALUE, data) + + return nae.Encode() + }) + return ae.Encode() +} diff --git a/vendor/github.com/google/nftables/expr/redirect.go b/vendor/github.com/google/nftables/expr/redirect.go new file mode 100644 index 000000000..ea535f3c7 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/redirect.go @@ -0,0 +1,75 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Redir struct { + RegisterProtoMin uint32 + RegisterProtoMax uint32 + Flags uint32 +} + +func (e *Redir) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("redir\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Redir) marshalData(fam byte) ([]byte, error) { + var attrs []netlink.Attribute + if e.RegisterProtoMin > 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_REDIR_REG_PROTO_MIN, Data: binaryutil.BigEndian.PutUint32(e.RegisterProtoMin)}) + } + if e.RegisterProtoMax > 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_REDIR_REG_PROTO_MAX, Data: binaryutil.BigEndian.PutUint32(e.RegisterProtoMax)}) + } + if e.Flags > 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_REDIR_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)}) + } + + return netlink.MarshalAttributes(attrs) +} + +func (e *Redir) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_REDIR_REG_PROTO_MIN: + e.RegisterProtoMin = ad.Uint32() + case unix.NFTA_REDIR_REG_PROTO_MAX: + e.RegisterProtoMax = ad.Uint32() + case unix.NFTA_REDIR_FLAGS: + e.Flags = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/reject.go b/vendor/github.com/google/nftables/expr/reject.go new file mode 100644 index 000000000..7fe216db2 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/reject.go @@ -0,0 +1,63 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type Reject struct { + Type uint32 + Code uint8 +} + +func (e *Reject) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("reject\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Reject) marshalData(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_REJECT_TYPE, Data: binaryutil.BigEndian.PutUint32(e.Type)}, + {Type: unix.NFTA_REJECT_ICMP_CODE, Data: []byte{e.Code}}, + }) +} + +func (e *Reject) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_REJECT_TYPE: + e.Type = ad.Uint32() + case unix.NFTA_REJECT_ICMP_CODE: + e.Code = ad.Uint8() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/rt.go b/vendor/github.com/google/nftables/expr/rt.go new file mode 100644 index 000000000..21c3a637e --- /dev/null +++ b/vendor/github.com/google/nftables/expr/rt.go @@ -0,0 +1,59 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type RtKey uint32 + +const ( + RtClassid RtKey = unix.NFT_RT_CLASSID + RtNexthop4 RtKey = unix.NFT_RT_NEXTHOP4 + RtNexthop6 RtKey = unix.NFT_RT_NEXTHOP6 + RtTCPMSS RtKey = unix.NFT_RT_TCPMSS +) + +type Rt struct { + Register uint32 + Key RtKey +} + +func (e *Rt) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("rt\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Rt) marshalData(fam byte) ([]byte, error) { + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RT_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, + {Type: unix.NFTA_RT_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + }) +} + +func (e *Rt) unmarshal(fam byte, data []byte) error { + return fmt.Errorf("not yet implemented") +} diff --git a/vendor/github.com/google/nftables/expr/secmark.go b/vendor/github.com/google/nftables/expr/secmark.go new file mode 100644 index 000000000..3faf87f5c --- /dev/null +++ b/vendor/github.com/google/nftables/expr/secmark.go @@ -0,0 +1,64 @@ +// Copyright 2024 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// From https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=be0bae0ad31b0adb506f96de083f52a2bd0d4fbf#n1338 +const ( + NFTA_SECMARK_CTX = 0x01 +) + +type SecMark struct { + Ctx string +} + +func (e *SecMark) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("secmark\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *SecMark) marshalData(fam byte) ([]byte, error) { + attrs := []netlink.Attribute{ + {Type: NFTA_SECMARK_CTX, Data: []byte(e.Ctx)}, + } + return netlink.MarshalAttributes(attrs) +} + +func (e *SecMark) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTA_SECMARK_CTX: + e.Ctx = ad.String() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/socket.go b/vendor/github.com/google/nftables/expr/socket.go new file mode 100644 index 000000000..e3843ccde --- /dev/null +++ b/vendor/github.com/google/nftables/expr/socket.go @@ -0,0 +1,92 @@ +// Copyright 2023 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "golang.org/x/sys/unix" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" +) + +type Socket struct { + Key SocketKey + Level uint32 + Register uint32 +} + +type SocketKey uint32 + +const ( + // TODO, Once the constants below are available in golang.org/x/sys/unix, switch to use those. + NFTA_SOCKET_KEY = 1 + NFTA_SOCKET_DREG = 2 + NFTA_SOCKET_LEVEL = 3 + + NFT_SOCKET_TRANSPARENT = 0 + NFT_SOCKET_MARK = 1 + NFT_SOCKET_WILDCARD = 2 + NFT_SOCKET_CGROUPV2 = 3 + + SocketKeyTransparent SocketKey = NFT_SOCKET_TRANSPARENT + SocketKeyMark SocketKey = NFT_SOCKET_MARK + SocketKeyWildcard SocketKey = NFT_SOCKET_WILDCARD + SocketKeyCgroupv2 SocketKey = NFT_SOCKET_CGROUPV2 +) + +func (e *Socket) marshal(fam byte) ([]byte, error) { + exprData, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("socket\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, + }) +} + +func (e *Socket) marshalData(fam byte) ([]byte, error) { + // NOTE: Socket.Level is only used when Socket.Key == SocketKeyCgroupv2. But `nft` always encoding it. Check link below: + // http://git.netfilter.org/nftables/tree/src/netlink_linearize.c?id=0583bac241ea18c9d7f61cb20ca04faa1e043b78#n319 + return netlink.MarshalAttributes( + []netlink.Attribute{ + {Type: NFTA_SOCKET_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, + {Type: NFTA_SOCKET_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, + {Type: NFTA_SOCKET_LEVEL, Data: binaryutil.BigEndian.PutUint32(uint32(e.Level))}, + }, + ) +} + +func (e *Socket) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTA_SOCKET_DREG: + e.Register = ad.Uint32() + case NFTA_SOCKET_KEY: + e.Key = SocketKey(ad.Uint32()) + case NFTA_SOCKET_LEVEL: + e.Level = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/synproxy.go b/vendor/github.com/google/nftables/expr/synproxy.go new file mode 100644 index 000000000..a93b8bb21 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/synproxy.go @@ -0,0 +1,118 @@ +// Copyright 2024 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type SynProxy struct { + Mss uint16 + Wscale uint8 + Timestamp bool + SackPerm bool + // Probably not expected to be set by users + // https://github.com/torvalds/linux/blob/521b1e7f4cf0b05a47995b103596978224b380a8/net/netfilter/nft_synproxy.c#L30-L31 + Ecn bool + // True when Mss is set to a value or if 0 is an intended value of Mss + MssValueSet bool + // True when Wscale is set to a value or if 0 is an intended value of Wscale + WscaleValueSet bool +} + +// From https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=be0bae0ad31b0adb506f96de083f52a2bd0d4fbf#n1723 +// Currently not available in golang.org/x/sys/unix +const ( + NFTA_SYNPROXY_MSS = 0x01 + NFTA_SYNPROXY_WSCALE = 0x02 + NFTA_SYNPROXY_FLAGS = 0x03 +) + +// From https://github.com/torvalds/linux/blob/521b1e7f4cf0b05a47995b103596978224b380a8/include/uapi/linux/netfilter/nf_synproxy.h#L7-L15 +// Currently not available in golang.org/x/sys/unix +const ( + NF_SYNPROXY_OPT_MSS = 0x01 + NF_SYNPROXY_OPT_WSCALE = 0x02 + NF_SYNPROXY_OPT_SACK_PERM = 0x04 + NF_SYNPROXY_OPT_TIMESTAMP = 0x08 + NF_SYNPROXY_OPT_ECN = 0x10 +) + +func (e *SynProxy) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("synproxy\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *SynProxy) marshalData(fam byte) ([]byte, error) { + var flags uint32 + if e.Mss != 0 || e.MssValueSet { + flags |= NF_SYNPROXY_OPT_MSS + } + if e.Wscale != 0 || e.WscaleValueSet { + flags |= NF_SYNPROXY_OPT_WSCALE + } + if e.SackPerm { + flags |= NF_SYNPROXY_OPT_SACK_PERM + } + if e.Timestamp { + flags |= NF_SYNPROXY_OPT_TIMESTAMP + } + if e.Ecn { + flags |= NF_SYNPROXY_OPT_ECN + } + attrs := []netlink.Attribute{ + {Type: NFTA_SYNPROXY_MSS, Data: binaryutil.BigEndian.PutUint16(e.Mss)}, + {Type: NFTA_SYNPROXY_WSCALE, Data: []byte{e.Wscale}}, + {Type: NFTA_SYNPROXY_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}, + } + return netlink.MarshalAttributes(attrs) +} + +func (e *SynProxy) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTA_SYNPROXY_MSS: + e.Mss = ad.Uint16() + case NFTA_SYNPROXY_WSCALE: + e.Wscale = ad.Uint8() + case NFTA_SYNPROXY_FLAGS: + flags := ad.Uint32() + checkFlag := func(flag uint32) bool { + return (flags & flag) == flag + } + e.MssValueSet = checkFlag(NF_SYNPROXY_OPT_MSS) + e.WscaleValueSet = checkFlag(NF_SYNPROXY_OPT_WSCALE) + e.SackPerm = checkFlag(NF_SYNPROXY_OPT_SACK_PERM) + e.Timestamp = checkFlag(NF_SYNPROXY_OPT_TIMESTAMP) + e.Ecn = checkFlag(NF_SYNPROXY_OPT_ECN) + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/target.go b/vendor/github.com/google/nftables/expr/target.go new file mode 100644 index 000000000..d1c800bbe --- /dev/null +++ b/vendor/github.com/google/nftables/expr/target.go @@ -0,0 +1,82 @@ +package expr + +import ( + "bytes" + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/xt" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// See https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n28 +const XTablesExtensionNameMaxLen = 29 + +// See https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n30 +type Target struct { + Name string + Rev uint32 + Info xt.InfoAny +} + +func (e *Target) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("target\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Target) marshalData(fam byte) ([]byte, error) { + // Per https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n38 + name := e.Name + // limit the extension name as (some) user-space tools do and leave room for + // trailing \x00 + if len(name) >= /* sic! */ XTablesExtensionNameMaxLen { + name = name[:XTablesExtensionNameMaxLen-1] // leave room for trailing \x00. + } + // Marshalling assumes that the correct Info type for the particular table + // family and Match revision has been set. + info, err := xt.Marshal(xt.TableFamily(fam), e.Rev, e.Info) + if err != nil { + return nil, err + } + attrs := []netlink.Attribute{ + {Type: unix.NFTA_TARGET_NAME, Data: []byte(name + "\x00")}, + {Type: unix.NFTA_TARGET_REV, Data: binaryutil.BigEndian.PutUint32(e.Rev)}, + {Type: unix.NFTA_TARGET_INFO, Data: info}, + } + + return netlink.MarshalAttributes(attrs) +} + +func (e *Target) unmarshal(fam byte, data []byte) error { + // Per https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n65 + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + var info []byte + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_TARGET_NAME: + // We are forgiving here, accepting any length and even missing terminating \x00. + e.Name = string(bytes.TrimRight(ad.Bytes(), "\x00")) + case unix.NFTA_TARGET_REV: + e.Rev = ad.Uint32() + case unix.NFTA_TARGET_INFO: + info = ad.Bytes() + } + } + if err = ad.Err(); err != nil { + return err + } + e.Info, err = xt.Unmarshal(e.Name, xt.TableFamily(fam), e.Rev, info) + return err +} diff --git a/vendor/github.com/google/nftables/expr/tproxy.go b/vendor/github.com/google/nftables/expr/tproxy.go new file mode 100644 index 000000000..142740c4d --- /dev/null +++ b/vendor/github.com/google/nftables/expr/tproxy.go @@ -0,0 +1,86 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "encoding/binary" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +const ( + // NFTA_TPROXY_FAMILY defines attribute for a table family + NFTA_TPROXY_FAMILY = 0x01 + // NFTA_TPROXY_REG_ADDR defines attribute for a register carrying redirection address value + NFTA_TPROXY_REG_ADDR = 0x02 + // NFTA_TPROXY_REG_PORT defines attribute for a register carrying redirection port value + NFTA_TPROXY_REG_PORT = 0x03 +) + +// TProxy defines struct with parameters for the transparent proxy +type TProxy struct { + Family byte + TableFamily byte + RegAddr uint32 + RegPort uint32 +} + +func (e *TProxy) marshal(fam byte) ([]byte, error) { + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("tproxy\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *TProxy) marshalData(fam byte) ([]byte, error) { + attrs := []netlink.Attribute{ + {Type: NFTA_TPROXY_FAMILY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Family))}, + {Type: NFTA_TPROXY_REG_PORT, Data: binaryutil.BigEndian.PutUint32(e.RegPort)}, + } + + if e.RegAddr != 0 { + attrs = append(attrs, netlink.Attribute{ + Type: NFTA_TPROXY_REG_ADDR, + Data: binaryutil.BigEndian.PutUint32(e.RegAddr), + }) + } + + return netlink.MarshalAttributes(attrs) +} + +func (e *TProxy) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case NFTA_TPROXY_FAMILY: + e.Family = ad.Uint8() + case NFTA_TPROXY_REG_PORT: + e.RegPort = ad.Uint32() + case NFTA_TPROXY_REG_ADDR: + e.RegAddr = ad.Uint32() + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/expr/verdict.go b/vendor/github.com/google/nftables/expr/verdict.go new file mode 100644 index 000000000..239b40878 --- /dev/null +++ b/vendor/github.com/google/nftables/expr/verdict.go @@ -0,0 +1,131 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expr + +import ( + "bytes" + "encoding/binary" + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +// This code assembles the verdict structure, as expected by the +// nftables netlink API. +// For further information, consult: +// - netfilter.h (Linux kernel) +// - net/netfilter/nf_tables_api.c (Linux kernel) +// - src/expr/data_reg.c (linbnftnl) + +type Verdict struct { + Kind VerdictKind + Chain string +} + +type VerdictKind int64 + +// Verdicts, as per netfilter.h and netfilter/nf_tables.h. +const ( + VerdictReturn VerdictKind = iota - 5 + VerdictGoto + VerdictJump + VerdictBreak + VerdictContinue + VerdictDrop + VerdictAccept + VerdictStolen + VerdictQueue + VerdictRepeat + VerdictStop +) + +func (e *Verdict) marshal(fam byte) ([]byte, error) { + // A verdict is a tree of netlink attributes structured as follows: + // NFTA_LIST_ELEM | NLA_F_NESTED { + // NFTA_EXPR_NAME { "immediate\x00" } + // NFTA_EXPR_DATA | NLA_F_NESTED { + // NFTA_IMMEDIATE_DREG { NFT_REG_VERDICT } + // NFTA_IMMEDIATE_DATA | NLA_F_NESTED { + // the verdict code + // } + // } + // } + data, err := e.marshalData(fam) + if err != nil { + return nil, err + } + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_EXPR_NAME, Data: []byte("immediate\x00")}, + {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, + }) +} + +func (e *Verdict) marshalData(fam byte) ([]byte, error) { + attrs := []netlink.Attribute{ + {Type: unix.NFTA_VERDICT_CODE, Data: binaryutil.BigEndian.PutUint32(uint32(e.Kind))}, + } + if e.Chain != "" { + attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_VERDICT_CHAIN, Data: []byte(e.Chain + "\x00")}) + } + codeData, err := netlink.MarshalAttributes(attrs) + if err != nil { + return nil, err + } + + immData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NLA_F_NESTED | unix.NFTA_DATA_VERDICT, Data: codeData}, + }) + if err != nil { + return nil, err + } + + return netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_IMMEDIATE_DREG, Data: binaryutil.BigEndian.PutUint32(unix.NFT_REG_VERDICT)}, + {Type: unix.NLA_F_NESTED | unix.NFTA_IMMEDIATE_DATA, Data: immData}, + }) +} + +func (e *Verdict) unmarshal(fam byte, data []byte) error { + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_IMMEDIATE_DATA: + nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes()) + if err != nil { + return fmt.Errorf("nested NewAttributeDecoder() failed: %v", err) + } + for nestedAD.Next() { + switch nestedAD.Type() { + case unix.NFTA_DATA_VERDICT: + e.Kind = VerdictKind(int32(binaryutil.BigEndian.Uint32(nestedAD.Bytes()[4:8]))) + if len(nestedAD.Bytes()) > 12 { + e.Chain = string(bytes.Trim(nestedAD.Bytes()[12:], "\x00")) + } + } + } + if nestedAD.Err() != nil { + return fmt.Errorf("decoding immediate: %v", nestedAD.Err()) + } + } + } + return ad.Err() +} diff --git a/vendor/github.com/google/nftables/flowtable.go b/vendor/github.com/google/nftables/flowtable.go new file mode 100644 index 000000000..93dbcb55f --- /dev/null +++ b/vendor/github.com/google/nftables/flowtable.go @@ -0,0 +1,306 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "encoding/binary" + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +const ( + // not in ztypes_linux.go, added here + // https://cs.opensource.google/go/x/sys/+/c6bc011c:unix/ztypes_linux.go;l=1870-1892 + NFT_MSG_NEWFLOWTABLE = 0x16 + NFT_MSG_GETFLOWTABLE = 0x17 + NFT_MSG_DELFLOWTABLE = 0x18 +) + +const ( + // not in ztypes_linux.go, added here + // https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n1634 + _ = iota + NFTA_FLOWTABLE_TABLE + NFTA_FLOWTABLE_NAME + NFTA_FLOWTABLE_HOOK + NFTA_FLOWTABLE_USE + NFTA_FLOWTABLE_HANDLE + NFTA_FLOWTABLE_PAD + NFTA_FLOWTABLE_FLAGS +) + +const ( + // not in ztypes_linux.go, added here + // https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n1657 + _ = iota + NFTA_FLOWTABLE_HOOK_NUM + NFTA_FLOWTABLE_PRIORITY + NFTA_FLOWTABLE_DEVS +) + +const ( + // not in ztypes_linux.go, added here, used for flowtable device name specification + // https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n1709 + NFTA_DEVICE_NAME = 1 +) + +type FlowtableFlags uint32 + +const ( + _ FlowtableFlags = iota + FlowtableFlagsHWOffload + FlowtableFlagsCounter + FlowtableFlagsMask = (FlowtableFlagsHWOffload | FlowtableFlagsCounter) +) + +type FlowtableHook uint32 + +func FlowtableHookRef(h FlowtableHook) *FlowtableHook { + return &h +} + +var ( + // Only ingress is supported + // https://github.com/torvalds/linux/blob/b72018ab8236c3ae427068adeb94bdd3f20454ec/net/netfilter/nf_tables_api.c#L7378-L7379 + FlowtableHookIngress *FlowtableHook = FlowtableHookRef(unix.NF_NETDEV_INGRESS) +) + +type FlowtablePriority int32 + +func FlowtablePriorityRef(p FlowtablePriority) *FlowtablePriority { + return &p +} + +var ( + // As per man page: + // The priority can be a signed integer or filter which stands for 0. Addition and subtraction can be used to set relative priority, e.g. filter + 5 equals to 5. + // https://git.netfilter.org/nftables/tree/doc/nft.txt?id=8c600a843b7c0c1cc275ecc0603bd1fc57773e98#n712 + FlowtablePriorityFilter *FlowtablePriority = FlowtablePriorityRef(0) +) + +type Flowtable struct { + Table *Table + Name string + Hooknum *FlowtableHook + Priority *FlowtablePriority + Devices []string + Use uint32 + // Bitmask flags, can be HW_OFFLOAD or COUNTER + // https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n1621 + Flags FlowtableFlags + Handle uint64 +} + +func (cc *Conn) AddFlowtable(f *Flowtable) *Flowtable { + cc.mu.Lock() + defer cc.mu.Unlock() + + data := cc.marshalAttr([]netlink.Attribute{ + {Type: NFTA_FLOWTABLE_TABLE, Data: []byte(f.Table.Name)}, + {Type: NFTA_FLOWTABLE_NAME, Data: []byte(f.Name)}, + {Type: NFTA_FLOWTABLE_FLAGS, Data: binaryutil.BigEndian.PutUint32(uint32(f.Flags))}, + }) + + if f.Hooknum == nil { + f.Hooknum = FlowtableHookIngress + } + + if f.Priority == nil { + f.Priority = FlowtablePriorityFilter + } + + hookAttr := []netlink.Attribute{ + {Type: NFTA_FLOWTABLE_HOOK_NUM, Data: binaryutil.BigEndian.PutUint32(uint32(*f.Hooknum))}, + {Type: NFTA_FLOWTABLE_PRIORITY, Data: binaryutil.BigEndian.PutUint32(uint32(*f.Priority))}, + } + if len(f.Devices) > 0 { + devs := make([]netlink.Attribute, len(f.Devices)) + for i, d := range f.Devices { + devs[i] = netlink.Attribute{Type: NFTA_DEVICE_NAME, Data: []byte(d)} + } + hookAttr = append(hookAttr, netlink.Attribute{ + Type: unix.NLA_F_NESTED | NFTA_FLOWTABLE_DEVS, + Data: cc.marshalAttr(devs), + }) + } + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NLA_F_NESTED | NFTA_FLOWTABLE_HOOK, Data: cc.marshalAttr(hookAttr)}, + })...) + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_NEWFLOWTABLE), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(f.Table.Family), 0), data...), + }) + + return f +} + +func (cc *Conn) DelFlowtable(f *Flowtable) { + cc.mu.Lock() + defer cc.mu.Unlock() + + data := cc.marshalAttr([]netlink.Attribute{ + {Type: NFTA_FLOWTABLE_TABLE, Data: []byte(f.Table.Name)}, + {Type: NFTA_FLOWTABLE_NAME, Data: []byte(f.Name)}, + }) + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_DELFLOWTABLE), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(f.Table.Family), 0), data...), + }) +} + +func (cc *Conn) ListFlowtables(t *Table) ([]*Flowtable, error) { + reply, err := cc.getFlowtables(t) + if err != nil { + return nil, err + } + + var fts []*Flowtable + for _, msg := range reply { + f, err := ftsFromMsg(msg) + if err != nil { + return nil, err + } + f.Table = t + fts = append(fts, f) + } + + return fts, nil +} + +func (cc *Conn) getFlowtables(t *Table) ([]netlink.Message, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + attrs := []netlink.Attribute{ + {Type: NFTA_FLOWTABLE_TABLE, Data: []byte(t.Name + "\x00")}, + } + data, err := netlink.MarshalAttributes(attrs) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_GETFLOWTABLE), + Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("receiveAckAware: %v", err) + } + + return reply, nil +} + +func ftsFromMsg(msg netlink.Message) (*Flowtable, error) { + flowHeaderType := netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | NFT_MSG_NEWFLOWTABLE) + if got, want := msg.Header.Type, flowHeaderType; got != want { + return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + } + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + + var ft Flowtable + for ad.Next() { + switch ad.Type() { + case NFTA_FLOWTABLE_NAME: + ft.Name = ad.String() + case NFTA_FLOWTABLE_USE: + ft.Use = ad.Uint32() + case NFTA_FLOWTABLE_HANDLE: + ft.Handle = ad.Uint64() + case NFTA_FLOWTABLE_FLAGS: + ft.Flags = FlowtableFlags(ad.Uint32()) + case NFTA_FLOWTABLE_HOOK: + ad.Do(func(b []byte) error { + ft.Hooknum, ft.Priority, ft.Devices, err = ftsHookFromMsg(b) + return err + }) + } + } + return &ft, nil +} + +func ftsHookFromMsg(b []byte) (*FlowtableHook, *FlowtablePriority, []string, error) { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return nil, nil, nil, err + } + + ad.ByteOrder = binary.BigEndian + + var hooknum FlowtableHook + var prio FlowtablePriority + var devices []string + + for ad.Next() { + switch ad.Type() { + case NFTA_FLOWTABLE_HOOK_NUM: + hooknum = FlowtableHook(ad.Uint32()) + case NFTA_FLOWTABLE_PRIORITY: + prio = FlowtablePriority(ad.Uint32()) + case NFTA_FLOWTABLE_DEVS: + ad.Do(func(b []byte) error { + devices, err = devsFromMsg(b) + return err + }) + } + } + + return &hooknum, &prio, devices, nil +} + +func devsFromMsg(b []byte) ([]string, error) { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return nil, err + } + + ad.ByteOrder = binary.BigEndian + + devs := make([]string, 0) + for ad.Next() { + switch ad.Type() { + case NFTA_DEVICE_NAME: + devs = append(devs, ad.String()) + } + } + + return devs, nil +} diff --git a/vendor/github.com/google/nftables/gen.go b/vendor/github.com/google/nftables/gen.go new file mode 100644 index 000000000..4ebcfbeb3 --- /dev/null +++ b/vendor/github.com/google/nftables/gen.go @@ -0,0 +1,45 @@ +package nftables + +import ( + "encoding/binary" + "fmt" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type GenMsg struct { + ID uint32 + ProcPID uint32 + ProcComm string // [16]byte - max 16bytes - kernel TASK_COMM_LEN +} + +var genHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWGEN) + +func genFromMsg(msg netlink.Message) (*GenMsg, error) { + if got, want := msg.Header.Type, genHeaderType; got != want { + return nil, fmt.Errorf("unexpected header type: got %v, want %v", got, want) + } + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + + msgOut := &GenMsg{} + for ad.Next() { + switch ad.Type() { + case unix.NFTA_GEN_ID: + msgOut.ID = ad.Uint32() + case unix.NFTA_GEN_PROC_PID: + msgOut.ProcPID = ad.Uint32() + case unix.NFTA_GEN_PROC_NAME: + msgOut.ProcComm = ad.String() + default: + return nil, fmt.Errorf("Unknown attribute: %d %v\n", ad.Type(), ad.Bytes()) + } + } + if err := ad.Err(); err != nil { + return nil, err + } + return msgOut, nil +} diff --git a/vendor/github.com/google/nftables/internal/parseexprfunc/parseexprfunc.go b/vendor/github.com/google/nftables/internal/parseexprfunc/parseexprfunc.go new file mode 100644 index 000000000..586146178 --- /dev/null +++ b/vendor/github.com/google/nftables/internal/parseexprfunc/parseexprfunc.go @@ -0,0 +1,11 @@ +package parseexprfunc + +import ( + "github.com/mdlayher/netlink" +) + +var ( + ParseExprBytesFromNameFunc func(fam byte, ad *netlink.AttributeDecoder, exprName string) ([]interface{}, error) + ParseExprBytesFunc func(fam byte, ad *netlink.AttributeDecoder) ([]interface{}, error) + ParseExprMsgFunc func(fam byte, b []byte) ([]interface{}, error) +) diff --git a/vendor/github.com/google/nftables/monitor.go b/vendor/github.com/google/nftables/monitor.go new file mode 100644 index 000000000..7a25e3b43 --- /dev/null +++ b/vendor/github.com/google/nftables/monitor.go @@ -0,0 +1,377 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "math" + "strings" + "sync" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type MonitorAction uint8 + +// Possible MonitorAction values. +const ( + MonitorActionNew MonitorAction = 1 << iota + MonitorActionDel + MonitorActionMask MonitorAction = (1 << iota) - 1 + MonitorActionAny MonitorAction = MonitorActionMask +) + +type MonitorObject uint32 + +// Possible MonitorObject values. +const ( + MonitorObjectTables MonitorObject = 1 << iota + MonitorObjectChains + MonitorObjectSets + MonitorObjectRules + MonitorObjectElements + MonitorObjectRuleset + MonitorObjectMask MonitorObject = (1 << iota) - 1 + MonitorObjectAny MonitorObject = MonitorObjectMask +) + +var ( + monitorFlags = map[MonitorAction]map[MonitorObject]uint32{ + MonitorActionAny: { + MonitorObjectAny: 0xffffffff, + MonitorObjectTables: 1<>8 != netlink.HeaderType(unix.NFNL_SUBSYS_NFTABLES) { + continue + } + msgType := msg.Header.Type & 0x00ff + if monitor.monitorFlags&1< 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: data}) + } + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(o.family()), 0), cc.marshalAttr(attrs)...), + }) + return o +} + +// DeleteObject deletes the specified Obj +func (cc *Conn) DeleteObject(o Obj) { + cc.mu.Lock() + defer cc.mu.Unlock() + attrs := []netlink.Attribute{ + {Type: unix.NFTA_OBJ_TABLE, Data: []byte(o.table().Name + "\x00")}, + {Type: unix.NFTA_OBJ_NAME, Data: []byte(o.name() + "\x00")}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(o.objType()))}, + } + data := cc.marshalAttr(attrs) + data = append(data, cc.marshalAttr([]netlink.Attribute{{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA}})...) + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(o.family()), 0), data...), + }) +} + +// GetObj is a legacy method that return all Obj that belongs +// to the same table as the given one +// This function returns the same concrete type as passed, +// e.g. QuotaObj, CounterObj or NamedObj. Prefer using the more +// generic NamedObj over the legacy QuotaObj and CounterObj types. +func (cc *Conn) GetObj(o Obj) ([]Obj, error) { + return cc.getObjWithLegacyType(nil, o.table(), unix.NFT_MSG_GETOBJ, cc.useLegacyObjType(o)) +} + +// GetObjReset is a legacy method that reset all Obj that belongs +// the same table as the given one +// This function returns the same concrete type as passed, +// e.g. QuotaObj, CounterObj or NamedObj. Prefer using the more +// generic NamedObj over the legacy QuotaObj and CounterObj types. +func (cc *Conn) GetObjReset(o Obj) ([]Obj, error) { + return cc.getObjWithLegacyType(nil, o.table(), unix.NFT_MSG_GETOBJ_RESET, cc.useLegacyObjType(o)) +} + +// GetObject gets the specified Object +// This function returns the same concrete type as passed, +// e.g. QuotaObj, CounterObj or NamedObj. Prefer using the more +// generic NamedObj over the legacy QuotaObj and CounterObj types. +func (cc *Conn) GetObject(o Obj) (Obj, error) { + objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ) + + if len(objs) == 0 { + return nil, err + } + + return objs[0], err +} + +// GetObjects get all the Obj that belongs to the given table +// This function will always return legacy QuotaObj/CounterObj +// types for backwards compatibility +func (cc *Conn) GetObjects(t *Table) ([]Obj, error) { + return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ) +} + +// GetNamedObjects get all the Obj that belongs to the given table +// This function always return NamedObj types +func (cc *Conn) GetNamedObjects(t *Table) ([]Obj, error) { + return cc.getObjWithLegacyType(nil, t, unix.NFT_MSG_GETOBJ, false) +} + +// ResetObject reset the given Obj +// This function returns the same concrete type as passed, +// e.g. QuotaObj, CounterObj or NamedObj. Prefer using the more +// generic NamedObj over the legacy QuotaObj and CounterObj types. +func (cc *Conn) ResetObject(o Obj) (Obj, error) { + objs, err := cc.getObj(o, o.table(), unix.NFT_MSG_GETOBJ_RESET) + + if len(objs) == 0 { + return nil, err + } + + return objs[0], err +} + +// ResetObjects reset all the Obj that belongs to the given table +// This function will always return legacy QuotaObj/CounterObj +// types for backwards compatibility +func (cc *Conn) ResetObjects(t *Table) ([]Obj, error) { + return cc.getObj(nil, t, unix.NFT_MSG_GETOBJ_RESET) +} + +// ResetNamedObjects reset all the Obj that belongs to the given table +// This function always return NamedObj types +func (cc *Conn) ResetNamedObjects(t *Table) ([]Obj, error) { + return cc.getObjWithLegacyType(nil, t, unix.NFT_MSG_GETOBJ_RESET, false) +} + +func objFromMsg(msg netlink.Message, returnLegacyType bool) (Obj, error) { + if got, want1, want2 := msg.Header.Type, newObjHeaderType, delObjHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) + } + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + var ( + table *Table + name string + objectType uint32 + ) + for ad.Next() { + switch ad.Type() { + case unix.NFTA_OBJ_TABLE: + table = &Table{Name: ad.String(), Family: TableFamily(msg.Data[0])} + case unix.NFTA_OBJ_NAME: + name = ad.String() + case unix.NFTA_OBJ_TYPE: + objectType = ad.Uint32() + case unix.NFTA_OBJ_DATA: + if returnLegacyType { + return objDataFromMsgLegacy(ad, table, name, objectType) + } + + o := NamedObj{ + Table: table, + Name: name, + Type: ObjType(objectType), + } + + objs, err := parseexprfunc.ParseExprBytesFromNameFunc(byte(o.family()), ad, objByObjTypeMagic[o.Type]) + if err != nil { + return nil, err + } + if len(objs) == 0 { + return nil, fmt.Errorf("objFromMsg: objs is empty for obj %v", o) + } + exprs := make([]expr.Any, len(objs)) + for i := range exprs { + exprs[i] = objs[i].(expr.Any) + } + + o.Obj = exprs[0] + return &o, ad.Err() + } + } + if err := ad.Err(); err != nil { + return nil, err + } + return nil, fmt.Errorf("malformed stateful object") +} + +func objDataFromMsgLegacy(ad *netlink.AttributeDecoder, table *Table, name string, objectType uint32) (Obj, error) { + switch objectType { + case unix.NFT_OBJECT_COUNTER: + o := CounterObj{ + Table: table, + Name: name, + } + + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + return o.unmarshal(ad) + }) + return &o, ad.Err() + case unix.NFT_OBJECT_QUOTA: + o := QuotaObj{ + Table: table, + Name: name, + } + + ad.Do(func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + return o.unmarshal(ad) + }) + return &o, ad.Err() + } + if err := ad.Err(); err != nil { + return nil, err + } + return nil, fmt.Errorf("malformed stateful object") +} + +func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { + return cc.getObjWithLegacyType(o, t, msgType, cc.useLegacyObjType(o)) +} + +func (cc *Conn) getObjWithLegacyType(o Obj, t *Table, msgType uint16, returnLegacyObjType bool) ([]Obj, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + var data []byte + var flags netlink.HeaderFlags + + if o != nil { + attrs := []netlink.Attribute{ + {Type: unix.NFTA_OBJ_TABLE, Data: []byte(o.table().Name + "\x00")}, + {Type: unix.NFTA_OBJ_NAME, Data: []byte(o.name() + "\x00")}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(o.objType()))}, + } + data = cc.marshalAttr(attrs) + } else { + flags = netlink.Dump + data, err = netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, + }) + } + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | msgType), + Flags: netlink.Request | netlink.Acknowledge | flags, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("receiveAckAware: %v", err) + } + var objs []Obj + for _, msg := range reply { + o, err := objFromMsg(msg, returnLegacyObjType) + if err != nil { + return nil, err + } + objs = append(objs, o) + } + + return objs, nil +} + +func (cc *Conn) useLegacyObjType(o Obj) bool { + useLegacyType := true + if o != nil { + switch o.(type) { + case *NamedObj: + useLegacyType = false + } + } + return useLegacyType +} diff --git a/vendor/github.com/google/nftables/quota.go b/vendor/github.com/google/nftables/quota.go new file mode 100644 index 000000000..a8be63465 --- /dev/null +++ b/vendor/github.com/google/nftables/quota.go @@ -0,0 +1,67 @@ +// Copyright 2023 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "github.com/google/nftables/expr" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +type QuotaObj struct { + Table *Table + Name string + Bytes uint64 + Consumed uint64 + Over bool +} + +func (q *QuotaObj) unmarshal(ad *netlink.AttributeDecoder) error { + for ad.Next() { + switch ad.Type() { + case unix.NFTA_QUOTA_BYTES: + q.Bytes = ad.Uint64() + case unix.NFTA_QUOTA_CONSUMED: + q.Consumed = ad.Uint64() + case unix.NFTA_QUOTA_FLAGS: + q.Over = (ad.Uint32() & unix.NFT_QUOTA_F_INV) != 0 + } + } + return nil +} + +func (q *QuotaObj) table() *Table { + return q.Table +} + +func (q *QuotaObj) family() TableFamily { + return q.Table.Family +} + +func (q *QuotaObj) data() expr.Any { + return &expr.Quota{ + Bytes: q.Bytes, + Consumed: q.Consumed, + Over: q.Over, + } +} + +func (q *QuotaObj) name() string { + return q.Name +} + +func (q *QuotaObj) objType() ObjType { + return ObjTypeQuota +} diff --git a/vendor/github.com/google/nftables/rule.go b/vendor/github.com/google/nftables/rule.go new file mode 100644 index 000000000..07068344e --- /dev/null +++ b/vendor/github.com/google/nftables/rule.go @@ -0,0 +1,270 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "encoding/binary" + "fmt" + + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + "github.com/google/nftables/internal/parseexprfunc" + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +var ( + newRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWRULE) + delRuleHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE) +) + +type ruleOperation uint32 + +// Possible PayloadOperationType values. +const ( + operationAdd ruleOperation = iota + operationInsert + operationReplace +) + +// A Rule does something with a packet. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Simple_rule_management +type Rule struct { + Table *Table + Chain *Chain + Position uint64 + Handle uint64 + // The list of possible flags are specified by nftnl_rule_attr, see + // https://git.netfilter.org/libnftnl/tree/include/libnftnl/rule.h#n21 + // Current nftables go implementation supports only + // NFTNL_RULE_POSITION flag for setting rule at position 0 + Flags uint32 + Exprs []expr.Any + UserData []byte +} + +// GetRule returns the rules in the specified table and chain. +// +// Deprecated: use GetRules instead. +func (cc *Conn) GetRule(t *Table, c *Chain) ([]*Rule, error) { + return cc.GetRules(t, c) +} + +// GetRules returns the rules in the specified table and chain. +func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, + {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, + }) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETRULE), + Flags: netlink.Request | netlink.Acknowledge | netlink.Dump | unix.NLM_F_ECHO, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("receiveAckAware: %v", err) + } + var rules []*Rule + for _, msg := range reply { + r, err := ruleFromMsg(t.Family, msg) + if err != nil { + return nil, err + } + rules = append(rules, r) + } + + return rules, nil +} + +// AddRule adds the specified Rule +func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { + cc.mu.Lock() + defer cc.mu.Unlock() + exprAttrs := make([]netlink.Attribute, len(r.Exprs)) + for idx, expr := range r.Exprs { + exprAttrs[idx] = netlink.Attribute{ + Type: unix.NLA_F_NESTED | unix.NFTA_LIST_ELEM, + Data: cc.marshalExpr(byte(r.Table.Family), expr), + } + } + + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(r.Table.Name + "\x00")}, + {Type: unix.NFTA_RULE_CHAIN, Data: []byte(r.Chain.Name + "\x00")}, + }) + + if r.Handle != 0 { + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_HANDLE, Data: binaryutil.BigEndian.PutUint64(r.Handle)}, + })...) + } + + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NLA_F_NESTED | unix.NFTA_RULE_EXPRESSIONS, Data: cc.marshalAttr(exprAttrs)}, + })...) + + if compatPolicy, err := getCompatPolicy(r.Exprs); err != nil { + cc.setErr(err) + } else if compatPolicy != nil { + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NLA_F_NESTED | unix.NFTA_RULE_COMPAT, Data: cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_COMPAT_PROTO, Data: binaryutil.BigEndian.PutUint32(compatPolicy.Proto)}, + {Type: unix.NFTA_RULE_COMPAT_FLAGS, Data: binaryutil.BigEndian.PutUint32(compatPolicy.Flag & nft_RULE_COMPAT_F_MASK)}, + })}, + })...) + } + + msgData := []byte{} + + msgData = append(msgData, data...) + var flags netlink.HeaderFlags + if r.UserData != nil { + msgData = append(msgData, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_USERDATA, Data: r.UserData}, + })...) + } + + switch op { + case operationAdd: + flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND + case operationInsert: + flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO + case operationReplace: + flags = netlink.Request | netlink.Acknowledge | netlink.Replace | unix.NLM_F_ECHO | unix.NLM_F_REPLACE + } + + if r.Position != 0 || (r.Flags&(1< 32/SetConcatTypeBits { + return SetDatatype{}, ErrTooManyTypes + } + + var magic, bytes uint32 + names := make([]string, len(types)) + for i, t := range types { + bytes += t.Bytes + // concatenated types pad the length to multiples of the register size (4 bytes) + // see https://git.netfilter.org/nftables/tree/src/datatype.c?id=488356b895024d0944b20feb1f930558726e0877#n1162 + if t.Bytes%4 != 0 { + bytes += 4 - (t.Bytes % 4) + } + names[i] = t.Name + + magic <<= SetConcatTypeBits + magic |= t.nftMagic & SetConcatTypeMask + } + return SetDatatype{Name: strings.Join(names, " . "), Bytes: bytes, nftMagic: magic}, nil +} + +// ConcatSetTypeElements uses the ConcatSetType name to calculate and return +// a list of base types which were used to construct the concatenated type +func ConcatSetTypeElements(t SetDatatype) []SetDatatype { + names := strings.Split(t.Name, " . ") + types := make([]SetDatatype, len(names)) + for i, n := range names { + types[i] = nftDatatypesByName[n] + } + return types +} + +// Set represents an nftables set. Anonymous sets are only valid within the +// context of a single batch. +type Set struct { + Table *Table + ID uint32 + Name string + Anonymous bool + Constant bool + Interval bool + AutoMerge bool + IsMap bool + HasTimeout bool + Counter bool + // Can be updated per evaluation path, per `nft list ruleset` + // indicates that set contains "flags dynamic" + // https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=84d12cfacf8ddd857a09435f3d982ab6250d250c#n298 + Dynamic bool + // Indicates that the set contains a concatenation + // https://git.netfilter.org/nftables/tree/include/linux/netfilter/nf_tables.h?id=d1289bff58e1878c3162f574c603da993e29b113#n306 + Concatenation bool + Timeout time.Duration + KeyType SetDatatype + DataType SetDatatype + // Either host (binaryutil.NativeEndian) or big (binaryutil.BigEndian) endian as per + // https://git.netfilter.org/nftables/tree/include/datatype.h?id=d486c9e626405e829221b82d7355558005b26d8a#n109 + KeyByteOrder binaryutil.ByteOrder + Comment string + // Indicates that the set has "size" specifier + Size uint32 +} + +// SetElement represents a data point within a set. +type SetElement struct { + Key []byte + Val []byte + // Field used for definition of ending interval value in concatenated types + // https://git.netfilter.org/libnftnl/tree/include/set_elem.h?id=e2514c0eff4da7e8e0aabd410f7b7d0b7564c880#n11 + KeyEnd []byte + IntervalEnd bool + // To support vmap, a caller must be able to pass Verdict type of data. + // If IsMap is true and VerdictData is not nil, then Val of SetElement will be ignored + // and VerdictData will be wrapped into Attribute data. + VerdictData *expr.Verdict + // To support aging of set elements + Timeout time.Duration + + // Life left of the "timeout" elements + Expires time.Duration + + Counter *expr.Counter + Comment string +} + +func (s *SetElement) decode(fam byte) func(b []byte) error { + return func(b []byte) error { + ad, err := netlink.NewAttributeDecoder(b) + if err != nil { + return fmt.Errorf("failed to create nested attribute decoder: %v", err) + } + ad.ByteOrder = binary.BigEndian + + for ad.Next() { + switch ad.Type() { + case unix.NFTA_SET_ELEM_KEY: + s.Key, err = decodeElement(ad.Bytes()) + if err != nil { + return err + } + case NFTA_SET_ELEM_KEY_END: + s.KeyEnd, err = decodeElement(ad.Bytes()) + if err != nil { + return err + } + case unix.NFTA_SET_ELEM_DATA: + s.Val, err = decodeElement(ad.Bytes()) + if err != nil { + return err + } + case unix.NFTA_SET_ELEM_FLAGS: + flags := ad.Uint32() + s.IntervalEnd = (flags & unix.NFT_SET_ELEM_INTERVAL_END) != 0 + case unix.NFTA_SET_ELEM_TIMEOUT: + s.Timeout = time.Millisecond * time.Duration(ad.Uint64()) + case unix.NFTA_SET_ELEM_EXPIRATION: + s.Expires = time.Millisecond * time.Duration(ad.Uint64()) + case unix.NFTA_SET_ELEM_USERDATA: + userData := ad.Bytes() + // Try to extract comment from userdata if present + if comment, ok := userdata.GetString(userData, userdata.NFTNL_UDATA_SET_ELEM_COMMENT); ok { + s.Comment = comment + } + case unix.NFTA_SET_ELEM_EXPR: + elems, err := parseexprfunc.ParseExprBytesFunc(fam, ad) + if err != nil { + return err + } + + for _, elem := range elems { + switch item := elem.(type) { + case *expr.Counter: + s.Counter = item + } + } + } + } + return ad.Err() + } +} + +func decodeElement(d []byte) ([]byte, error) { + ad, err := netlink.NewAttributeDecoder(d) + if err != nil { + return nil, fmt.Errorf("failed to create nested attribute decoder: %v", err) + } + ad.ByteOrder = binary.BigEndian + var b []byte + for ad.Next() { + switch ad.Type() { + case unix.NFTA_SET_ELEM_KEY: + fallthrough + case unix.NFTA_SET_ELEM_DATA: + b = ad.Bytes() + } + } + if err := ad.Err(); err != nil { + return nil, err + } + return b, nil +} + +// SetAddElements applies data points to an nftables set. +func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { + cc.mu.Lock() + defer cc.mu.Unlock() + if s.Anonymous { + return errors.New("anonymous sets cannot be updated") + } + + elements, err := s.makeElemList(vals, s.ID) + if err != nil { + return err + } + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), + }) + + return nil +} + +func (s *Set) makeElemList(vals []SetElement, id uint32) ([]netlink.Attribute, error) { + var elements []netlink.Attribute + + for i, v := range vals { + item := make([]netlink.Attribute, 0) + var flags uint32 + if v.IntervalEnd { + flags |= unix.NFT_SET_ELEM_INTERVAL_END + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_FLAGS | unix.NLA_F_NESTED, Data: binaryutil.BigEndian.PutUint32(flags)}) + } + + encodedKey, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Key}}) + if err != nil { + return nil, fmt.Errorf("marshal key %d: %v", i, err) + } + + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_KEY | unix.NLA_F_NESTED, Data: encodedKey}) + if len(v.KeyEnd) > 0 { + encodedKeyEnd, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.KeyEnd}}) + if err != nil { + return nil, fmt.Errorf("marshal key end %d: %v", i, err) + } + item = append(item, netlink.Attribute{Type: NFTA_SET_ELEM_KEY_END | unix.NLA_F_NESTED, Data: encodedKeyEnd}) + } + if s.HasTimeout && v.Timeout != 0 { + // Set has Timeout flag set, which means an individual element can specify its own timeout. + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(v.Timeout.Milliseconds()))}) + } + // The following switch statement deal with 3 different types of elements. + // 1. v is an element of vmap + // 2. v is an element of a regular map + // 3. v is an element of a regular set (default) + switch { + case v.VerdictData != nil: + // Since VerdictData is not nil, v is vmap element, need to add to the attributes + encodedVal := []byte{} + encodedKind, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(v.VerdictData.Kind))}, + }) + if err != nil { + return nil, fmt.Errorf("marshal item %d: %v", i, err) + } + encodedVal = append(encodedVal, encodedKind...) + if len(v.VerdictData.Chain) != 0 { + encodedChain, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_ELEM_DATA, Data: []byte(v.VerdictData.Chain + "\x00")}, + }) + if err != nil { + return nil, fmt.Errorf("marshal item %d: %v", i, err) + } + encodedVal = append(encodedVal, encodedChain...) + } + encodedVerdict, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}}) + if err != nil { + return nil, fmt.Errorf("marshal item %d: %v", i, err) + } + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVerdict}) + case len(v.Val) > 0: + // Since v.Val's length is not 0 then, v is a regular map element, need to add to the attributes + encodedVal, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NFTA_DATA_VALUE, Data: v.Val}}) + if err != nil { + return nil, fmt.Errorf("marshal item %d: %v", i, err) + } + + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_DATA | unix.NLA_F_NESTED, Data: encodedVal}) + default: + // If niether of previous cases matche, it means 'e' is an element of a regular Set, no need to add to the attributes + } + + // Add comment to userdata if present + if len(v.Comment) > 0 { + userData := userdata.AppendString(nil, userdata.NFTNL_UDATA_SET_ELEM_COMMENT, v.Comment) + item = append(item, netlink.Attribute{Type: unix.NFTA_SET_ELEM_USERDATA, Data: userData}) + } + + encodedItem, err := netlink.MarshalAttributes(item) + if err != nil { + return nil, fmt.Errorf("marshal item %d: %v", i, err) + } + elements = append(elements, netlink.Attribute{Type: uint16(i+1) | unix.NLA_F_NESTED, Data: encodedItem}) + } + + encodedElem, err := netlink.MarshalAttributes(elements) + if err != nil { + return nil, fmt.Errorf("marshal elements: %v", err) + } + + return []netlink.Attribute{ + {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, + {Type: unix.NFTA_LOOKUP_SET_ID, Data: binaryutil.BigEndian.PutUint32(id)}, + {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, + {Type: unix.NFTA_SET_ELEM_LIST_ELEMENTS | unix.NLA_F_NESTED, Data: encodedElem}, + }, nil +} + +// AddSet adds the specified Set. +func (cc *Conn) AddSet(s *Set, vals []SetElement) error { + cc.mu.Lock() + defer cc.mu.Unlock() + // Based on nft implementation & linux source. + // Link: https://github.com/torvalds/linux/blob/49a57857aeea06ca831043acbb0fa5e0f50602fd/net/netfilter/nf_tables_api.c#L3395 + // Another reference: https://git.netfilter.org/nftables/tree/src + + if s.Anonymous && !s.Constant { + return errors.New("anonymous structs must be constant") + } + + if s.ID == 0 { + allocSetID++ + s.ID = allocSetID + if s.Anonymous { + s.Name = "__set%d" + if s.IsMap { + s.Name = "__map%d" + } + } + } + + var flags uint32 + if s.Anonymous { + flags |= unix.NFT_SET_ANONYMOUS + } + if s.Constant { + flags |= unix.NFT_SET_CONSTANT + } + if s.Interval { + flags |= unix.NFT_SET_INTERVAL + } + if s.IsMap { + flags |= unix.NFT_SET_MAP + } + if s.HasTimeout { + flags |= unix.NFT_SET_TIMEOUT + } + if s.Dynamic { + flags |= unix.NFT_SET_EVAL + } + if s.Concatenation { + flags |= NFT_SET_CONCAT + } + tableInfo := []netlink.Attribute{ + {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, + {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, + {Type: unix.NFTA_SET_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}, + {Type: unix.NFTA_SET_KEY_TYPE, Data: binaryutil.BigEndian.PutUint32(s.KeyType.nftMagic)}, + {Type: unix.NFTA_SET_KEY_LEN, Data: binaryutil.BigEndian.PutUint32(s.KeyType.Bytes)}, + {Type: unix.NFTA_SET_ID, Data: binaryutil.BigEndian.PutUint32(s.ID)}, + } + if s.IsMap { + // Check if it is vmap case + if s.DataType.nftMagic == 1 { + // For Verdict data type, the expected magic is 0xfffff0 + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_DATA_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(unix.NFT_DATA_VERDICT))}, + netlink.Attribute{Type: unix.NFTA_SET_DATA_LEN, Data: binaryutil.BigEndian.PutUint32(s.DataType.Bytes)}) + } else { + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_DATA_TYPE, Data: binaryutil.BigEndian.PutUint32(s.DataType.nftMagic)}, + netlink.Attribute{Type: unix.NFTA_SET_DATA_LEN, Data: binaryutil.BigEndian.PutUint32(s.DataType.Bytes)}) + } + } + if s.HasTimeout && s.Timeout != 0 { + // If Set's global timeout is specified, add it to set's attributes + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(s.Timeout.Milliseconds()))}) + } + if s.Constant { + // nft cli tool adds the number of elements to set/map's descriptor + // It make sense to do only if a set or map are constant, otherwise skip NFTA_SET_DESC attribute + numberOfElements, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(uint32(len(vals)))}, + }) + if err != nil { + return fmt.Errorf("fail to marshal number of elements %d: %v", len(vals), err) + } + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements}) + } + + var descBytes []byte + + if s.Size > 0 { + // Marshal set size description + descSizeBytes, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(s.Size)}, + }) + if err != nil { + return fmt.Errorf("fail to marshal set size description: %w", err) + } + + descBytes = append(descBytes, descSizeBytes...) + } + + if s.Concatenation { + // Length of concatenated types is a must, otherwise segfaults when executing nft list ruleset + var concatDefinition []byte + elements := ConcatSetTypeElements(s.KeyType) + for i, v := range elements { + // Marshal base type size value + valData, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_DATA_VALUE, Data: binaryutil.BigEndian.PutUint32(v.Bytes)}, + }) + if err != nil { + return fmt.Errorf("fail to marshal element key size %d: %v", i, err) + } + // Marshal base type size description + descSize, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_DESC_SIZE, Data: valData}, + }) + if err != nil { + return fmt.Errorf("fail to marshal base type size description: %w", err) + } + concatDefinition = append(concatDefinition, descSize...) + } + // Marshal all base type descriptions into concatenation size description + concatBytes, err := netlink.MarshalAttributes([]netlink.Attribute{{Type: unix.NLA_F_NESTED | NFTA_SET_DESC_CONCAT, Data: concatDefinition}}) + if err != nil { + return fmt.Errorf("fail to marshal concat definition %v", err) + } + + descBytes = append(descBytes, concatBytes...) + } + + if len(descBytes) > 0 { + // Marshal set description + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: descBytes}) + } + + // https://git.netfilter.org/libnftnl/tree/include/udata.h#n17 + var userData []byte + + if s.Anonymous || s.Constant || s.Interval || s.KeyByteOrder == binaryutil.BigEndian { + // Semantically useless - kept for binary compatability with nft + userData = userdata.AppendUint32(userData, userdata.NFTNL_UDATA_SET_KEYBYTEORDER, 2) + } else if s.KeyByteOrder == binaryutil.NativeEndian { + // Per https://git.netfilter.org/nftables/tree/src/mnl.c?id=187c6d01d35722618c2711bbc49262c286472c8f#n1165 + userData = userdata.AppendUint32(userData, userdata.NFTNL_UDATA_SET_KEYBYTEORDER, 1) + } + + if s.Interval && s.AutoMerge { + // https://git.netfilter.org/nftables/tree/src/mnl.c?id=187c6d01d35722618c2711bbc49262c286472c8f#n1174 + userData = userdata.AppendUint32(userData, userdata.NFTNL_UDATA_SET_MERGE_ELEMENTS, 1) + } + + if len(s.Comment) != 0 { + userData = userdata.AppendString(userData, userdata.NFTNL_UDATA_SET_COMMENT, s.Comment) + } + + if len(userData) > 0 { + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: userData}) + } + + if s.Counter { + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_LIST_ELEM, Data: []byte("counter\x00")}, + {Type: unix.NFTA_SET_ELEM_PAD | unix.NFTA_SET_ELEM_DATA, Data: []byte{}}, + }) + if err != nil { + return err + } + tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | NFTA_SET_ELEM_EXPRESSIONS, Data: data}) + } + + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(tableInfo)...), + }) + + // Set the values of the set if initial values were provided. + if len(vals) > 0 { + hdrType := unix.NFT_MSG_NEWSETELEM + elements, err := s.makeElemList(vals, s.ID) + if err != nil { + return err + } + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), + }) + } + + return nil +} + +// DelSet deletes a specific set, along with all elements it contains. +func (cc *Conn) DelSet(s *Set) { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, + {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, + }) + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), data...), + }) +} + +// SetDeleteElements deletes data points from an nftables set. +func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { + cc.mu.Lock() + defer cc.mu.Unlock() + if s.Anonymous { + return errors.New("anonymous sets cannot be updated") + } + + elements, err := s.makeElemList(vals, s.ID) + if err != nil { + return err + } + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), + Flags: netlink.Request | netlink.Acknowledge | netlink.Create, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), cc.marshalAttr(elements)...), + }) + + return nil +} + +// FlushSet deletes all data points from an nftables set. +func (cc *Conn) FlushSet(s *Set) { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, + {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, + }) + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), data...), + }) +} + +var ( + newSetHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET) + delSetHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET) +) + +func setsFromMsg(msg netlink.Message) (*Set, error) { + if got, want1, want2 := msg.Header.Type, newSetHeaderType, delSetHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) + } + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + + var set Set + for ad.Next() { + switch ad.Type() { + case unix.NFTA_SET_NAME: + set.Name = ad.String() + case unix.NFTA_SET_ID: + set.ID = binary.BigEndian.Uint32(ad.Bytes()) + case unix.NFTA_SET_TIMEOUT: + set.Timeout = time.Duration(time.Millisecond * time.Duration(binary.BigEndian.Uint64(ad.Bytes()))) + set.HasTimeout = true + case unix.NFTA_SET_FLAGS: + flags := ad.Uint32() + set.Constant = (flags & unix.NFT_SET_CONSTANT) != 0 + set.Anonymous = (flags & unix.NFT_SET_ANONYMOUS) != 0 + set.Interval = (flags & unix.NFT_SET_INTERVAL) != 0 + set.IsMap = (flags & unix.NFT_SET_MAP) != 0 + set.HasTimeout = (flags & unix.NFT_SET_TIMEOUT) != 0 + set.Dynamic = (flags & unix.NFT_SET_EVAL) != 0 + set.Concatenation = (flags & NFT_SET_CONCAT) != 0 + case unix.NFTA_SET_KEY_TYPE: + nftMagic := ad.Uint32() + dt, err := parseSetDatatype(nftMagic) + if err != nil { + return nil, fmt.Errorf("could not determine data type: %w", err) + } + set.KeyType = dt + case unix.NFTA_SET_KEY_LEN: + set.KeyType.Bytes = binary.BigEndian.Uint32(ad.Bytes()) + case unix.NFTA_SET_DATA_TYPE: + nftMagic := ad.Uint32() + // Special case for the data type verdict, in the message it is stored as 0xffffff00 but it is defined as 1 + if nftMagic == 0xffffff00 { + set.KeyType = TypeVerdict + break + } + dt, err := parseSetDatatype(nftMagic) + if err != nil { + return nil, fmt.Errorf("could not determine data type: %w", err) + } + set.DataType = dt + case unix.NFTA_SET_DATA_LEN: + set.DataType.Bytes = binary.BigEndian.Uint32(ad.Bytes()) + case unix.NFTA_SET_USERDATA: + data := ad.Bytes() + value, ok := userdata.GetUint32(data, userdata.NFTNL_UDATA_SET_MERGE_ELEMENTS) + set.AutoMerge = ok && value == 1 + case unix.NFTA_SET_DESC: + nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes()) + if err != nil { + return nil, fmt.Errorf("nested NewAttributeDecoder() failed: %w", err) + } + for nestedAD.Next() { + switch nestedAD.Type() { + case unix.NFTA_SET_DESC_SIZE: + set.Size = binary.BigEndian.Uint32(nestedAD.Bytes()) + } + } + if nestedAD.Err() != nil { + return nil, fmt.Errorf("decoding set description: %w", nestedAD.Err()) + } + } + } + return &set, nil +} + +func parseSetDatatype(magic uint32) (SetDatatype, error) { + types := make([]SetDatatype, 0, 32/SetConcatTypeBits) + for magic != 0 { + t := magic & SetConcatTypeMask + magic = magic >> SetConcatTypeBits + dt, ok := nftDatatypesByMagic[t] + if !ok { + return TypeInvalid, fmt.Errorf("could not determine data type %+v", dt) + } + // Because we start with the last type, we insert the later types at the front. + types = append([]SetDatatype{dt}, types...) + } + + dt, err := ConcatSetType(types...) + if err != nil { + return TypeInvalid, fmt.Errorf("could not create data type: %w", err) + } + return dt, nil +} + +var ( + newElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM) + delElemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM) +) + +func elementsFromMsg(fam byte, msg netlink.Message) ([]SetElement, error) { + if got, want1, want2 := msg.Header.Type, newElemHeaderType, delElemHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) + } + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + + var elements []SetElement + for ad.Next() { + b := ad.Bytes() + if ad.Type() == unix.NFTA_SET_ELEM_LIST_ELEMENTS { + ad, err := netlink.NewAttributeDecoder(b) + + if err != nil { + return nil, err + } + ad.ByteOrder = binary.BigEndian + + for ad.Next() { + var elem SetElement + switch ad.Type() { + case unix.NFTA_LIST_ELEM: + ad.Do(elem.decode(fam)) + } + + elements = append(elements, elem) + } + } + } + return elements, nil +} + +// GetSets returns the sets in the specified table. +func (cc *Conn) GetSets(t *Table) ([]*Set, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_TABLE, Data: []byte(t.Name + "\x00")}, + }) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETSET), + Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("receiveAckAware: %v", err) + } + var sets []*Set + for _, msg := range reply { + s, err := setsFromMsg(msg) + if err != nil { + return nil, err + } + s.Table = &Table{Name: t.Name, Use: t.Use, Flags: t.Flags, Family: t.Family} + sets = append(sets, s) + } + + return sets, nil +} + +// GetSetByName returns the set in the specified table if matching name is found. +func (cc *Conn) GetSetByName(t *Table, name string) (*Set, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_TABLE, Data: []byte(t.Name + "\x00")}, + {Type: unix.NFTA_SET_NAME, Data: []byte(name + "\x00")}, + }) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETSET), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %w", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("receiveAckAware: %w", err) + } + + if len(reply) != 1 { + return nil, fmt.Errorf("receiveAckAware: expected to receive 1 message but got %d", len(reply)) + } + rs, err := setsFromMsg(reply[0]) + if err != nil { + return nil, err + } + rs.Table = &Table{Name: t.Name, Use: t.Use, Flags: t.Flags, Family: t.Family} + + return rs, nil +} + +// GetSetElements returns the elements in the specified set. +func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + data, err := netlink.MarshalAttributes([]netlink.Attribute{ + {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, + {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, + }) + if err != nil { + return nil, err + } + + message := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETSETELEM), + Flags: netlink.Request | netlink.Acknowledge | netlink.Dump, + }, + Data: append(extraHeader(uint8(s.Table.Family), 0), data...), + } + + if _, err := conn.SendMessages([]netlink.Message{message}); err != nil { + return nil, fmt.Errorf("SendMessages: %v", err) + } + + reply, err := receiveAckAware(conn, message.Header.Flags) + if err != nil { + return nil, fmt.Errorf("receiveAckAware: %v", err) + } + var elems []SetElement + for _, msg := range reply { + s, err := elementsFromMsg(uint8(s.Table.Family), msg) + if err != nil { + return nil, err + } + elems = append(elems, s...) + } + + return elems, nil +} diff --git a/vendor/github.com/google/nftables/table.go b/vendor/github.com/google/nftables/table.go new file mode 100644 index 000000000..c391b7be2 --- /dev/null +++ b/vendor/github.com/google/nftables/table.go @@ -0,0 +1,212 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "fmt" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +var ( + newTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE) + delTableHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE) +) + +// TableFamily specifies the address family for this table. +type TableFamily byte + +// Possible TableFamily values. +const ( + TableFamilyUnspecified TableFamily = unix.NFPROTO_UNSPEC + TableFamilyINet TableFamily = unix.NFPROTO_INET + TableFamilyIPv4 TableFamily = unix.NFPROTO_IPV4 + TableFamilyIPv6 TableFamily = unix.NFPROTO_IPV6 + TableFamilyARP TableFamily = unix.NFPROTO_ARP + TableFamilyNetdev TableFamily = unix.NFPROTO_NETDEV + TableFamilyBridge TableFamily = unix.NFPROTO_BRIDGE +) + +// A Table contains Chains. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables +type Table struct { + Name string // NFTA_TABLE_NAME + Use uint32 // NFTA_TABLE_USE (Number of chains in table) + Flags uint32 // NFTA_TABLE_FLAGS + Family TableFamily +} + +// DelTable deletes a specific table, along with all chains/rules it contains. +func (cc *Conn) DelTable(t *Table) { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, + {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, + }) + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + }) +} + +func (cc *Conn) addTable(t *Table, flag netlink.HeaderFlags) *Table { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, + {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, + }) + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE), + Flags: netlink.Request | netlink.Acknowledge | flag, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + }) + return t +} + +// AddTable adds the specified Table, just like `nft add table ...`. +// See also https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables +func (cc *Conn) AddTable(t *Table) *Table { + return cc.addTable(t, netlink.Create) +} + +// CreateTable create the specified Table if it do not existed. +// just like `nft create table ...`. +func (cc *Conn) CreateTable(t *Table) *Table { + return cc.addTable(t, netlink.Excl) +} + +// FlushTable removes all rules in all chains within the specified Table. See also +// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables#Flushing_tables +func (cc *Conn) FlushTable(t *Table) { + cc.mu.Lock() + defer cc.mu.Unlock() + data := cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, + }) + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), + Flags: netlink.Request | netlink.Acknowledge, + }, + Data: append(extraHeader(uint8(t.Family), 0), data...), + }) +} + +// ListTable returns table found for the specified name. Searches for +// the table under IPv4 family. As per nft man page: "When no address +// family is specified, ip is used by default." +func (cc *Conn) ListTable(name string) (*Table, error) { + return cc.ListTableOfFamily(name, TableFamilyIPv4) +} + +// ListTableOfFamily returns table found for the specified name and table family +func (cc *Conn) ListTableOfFamily(name string, family TableFamily) (*Table, error) { + t, err := cc.listTablesOfNameAndFamily(name, family) + if err != nil { + return nil, err + } + if got, want := len(t), 1; got != want { + return nil, fmt.Errorf("expected table count %d, got %d", want, got) + } + return t[0], nil +} + +// ListTables returns currently configured tables in the kernel +func (cc *Conn) ListTables() ([]*Table, error) { + return cc.ListTablesOfFamily(TableFamilyUnspecified) +} + +// ListTablesOfFamily returns currently configured tables for the specified table family +// in the kernel. It lists all tables if family is TableFamilyUnspecified. +func (cc *Conn) ListTablesOfFamily(family TableFamily) ([]*Table, error) { + return cc.listTablesOfNameAndFamily("", family) +} + +func (cc *Conn) listTablesOfNameAndFamily(name string, family TableFamily) ([]*Table, error) { + conn, closer, err := cc.netlinkConn() + if err != nil { + return nil, err + } + defer func() { _ = closer() }() + + data := extraHeader(uint8(family), 0) + flags := netlink.Request | netlink.Dump + if name != "" { + data = append(data, cc.marshalAttr([]netlink.Attribute{ + {Type: unix.NFTA_TABLE_NAME, Data: []byte(name + "\x00")}, + })...) + flags = netlink.Request + } + + msg := netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETTABLE), + Flags: flags, + }, + Data: data, + } + + response, err := conn.Execute(msg) + if err != nil { + return nil, err + } + + var tables []*Table + for _, m := range response { + t, err := tableFromMsg(m) + if err != nil { + return nil, err + } + + tables = append(tables, t) + } + + return tables, nil +} + +func tableFromMsg(msg netlink.Message) (*Table, error) { + if got, want1, want2 := msg.Header.Type, newTableHeaderType, delTableHeaderType; got != want1 && got != want2 { + return nil, fmt.Errorf("unexpected header type: got %v, want %v or %v", got, want1, want2) + } + + var t Table + t.Family = TableFamily(msg.Data[0]) + + ad, err := netlink.NewAttributeDecoder(msg.Data[4:]) + if err != nil { + return nil, err + } + + for ad.Next() { + switch ad.Type() { + case unix.NFTA_TABLE_NAME: + t.Name = ad.String() + case unix.NFTA_TABLE_USE: + t.Use = ad.Uint32() + case unix.NFTA_TABLE_FLAGS: + t.Flags = ad.Uint32() + } + } + + return &t, nil +} diff --git a/vendor/github.com/google/nftables/userdata/userdata.go b/vendor/github.com/google/nftables/userdata/userdata.go new file mode 100644 index 000000000..984d84c8d --- /dev/null +++ b/vendor/github.com/google/nftables/userdata/userdata.go @@ -0,0 +1,115 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package userdata implements a TLV parser/serializer for libnftables-compatible comments +package userdata + +import ( + "bytes" + "encoding/binary" +) + +type Type byte + +// TLV type values are defined in: +// https://git.netfilter.org/iptables/tree/iptables/nft.c?id=73611d5582e72367a698faf1b5301c836e981465#n1659 +const ( + TypeComment Type = iota + TypeEbtablesPolicy + + TypesCount +) + +// TLV type values are defined in: +// https://git.netfilter.org/libnftnl/tree/include/libnftnl/udata.h#n39 +const ( + NFTNL_UDATA_SET_KEYBYTEORDER Type = iota + NFTNL_UDATA_SET_DATABYTEORDER + NFTNL_UDATA_SET_MERGE_ELEMENTS + NFTNL_UDATA_SET_KEY_TYPEOF + NFTNL_UDATA_SET_DATA_TYPEOF + NFTNL_UDATA_SET_EXPR + NFTNL_UDATA_SET_DATA_INTERVAL + NFTNL_UDATA_SET_COMMENT + + NFTNL_UDATA_SET_MAX +) + +// Set element userdata types +const ( + NFTNL_UDATA_SET_ELEM_COMMENT Type = iota + NFTNL_UDATA_SET_ELEM_FLAGS +) + +func Append(udata []byte, typ Type, data []byte) []byte { + udata = append(udata, byte(typ), byte(len(data))) + udata = append(udata, data...) + + return udata +} + +func Get(udata []byte, styp Type) []byte { + for { + if len(udata) < 2 { + break + } + + typ := Type(udata[0]) + length := int(udata[1]) + data := udata[2 : 2+length] + + if styp == typ { + return data + } + + if len(udata) < 2+length { + break + } else { + udata = udata[2+length:] + } + } + + return nil +} + +func AppendUint32(udata []byte, typ Type, num uint32) []byte { + data := binary.LittleEndian.AppendUint32(nil, num) + + return Append(udata, typ, data) +} + +func GetUint32(udata []byte, typ Type) (uint32, bool) { + data := Get(udata, typ) + if data == nil { + return 0, false + } + + return binary.LittleEndian.Uint32(data), true +} + +func AppendString(udata []byte, typ Type, str string) []byte { + data := append([]byte(str), 0) + return Append(udata, typ, data) +} + +func GetString(udata []byte, typ Type) (string, bool) { + data := Get(udata, typ) + if data == nil { + return "", false + } + + data, _ = bytes.CutSuffix(data, []byte{0}) + + return string(data), true +} diff --git a/vendor/github.com/google/nftables/util.go b/vendor/github.com/google/nftables/util.go new file mode 100644 index 000000000..b040ae4c2 --- /dev/null +++ b/vendor/github.com/google/nftables/util.go @@ -0,0 +1,89 @@ +// Copyright 2018 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nftables + +import ( + "encoding/binary" + "net" + + "github.com/google/nftables/binaryutil" + "golang.org/x/sys/unix" +) + +func extraHeader(family uint8, resID uint16) []byte { + return append([]byte{ + family, + unix.NFNETLINK_V0, + }, binaryutil.BigEndian.PutUint16(resID)...) +} + +// General form of address family dependent message, see +// https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nfnetlink.h#29 +type NFGenMsg struct { + NFGenFamily uint8 + Version uint8 + ResourceID uint16 +} + +func (genmsg *NFGenMsg) Decode(b []byte) { + if len(b) < 4 { + return + } + genmsg.NFGenFamily = b[0] + genmsg.Version = b[1] + genmsg.ResourceID = binary.BigEndian.Uint16(b[2:]) +} + +// NetFirstAndLastIP takes the beginning address of an entire network in CIDR +// notation (e.g. 192.168.1.0/24) and returns the first and last IP addresses +// within the network (e.g. first 192.168.1.0, last 192.168.1.255). +// +// Note that these are the first and last IP addresses, not the first and last +// *usable* IP addresses (which would be 192.168.1.1 and 192.168.1.254, +// respectively, for 192.168.1.0/24). +func NetFirstAndLastIP(networkCIDR string) (first, last net.IP, err error) { + _, subnet, err := net.ParseCIDR(networkCIDR) + if err != nil { + return nil, nil, err + } + + first = make(net.IP, len(subnet.IP)) + last = make(net.IP, len(subnet.IP)) + + switch len(subnet.IP) { + case net.IPv4len: + mask := binary.BigEndian.Uint32(subnet.Mask) + ip := binary.BigEndian.Uint32(subnet.IP) + // To achieve the first IP address, we need to AND the IP with the mask. + // The AND operation will set all bits in the host part to 0. + binary.BigEndian.PutUint32(first, ip&mask) + // To achieve the last IP address, we need to OR the IP network with the inverted mask. + // The AND between the IP and the mask will set all bits in the host part to 0, keeping the network part. + // The XOR between the mask and 0xffffffff will set all bits in the host part to 1, and the network part to 0. + // The OR operation will keep the host part unchanged, and sets the host part to all 1. + binary.BigEndian.PutUint32(last, (ip&mask)|(mask^0xffffffff)) + case net.IPv6len: + mask1 := binary.BigEndian.Uint64(subnet.Mask[:8]) + mask2 := binary.BigEndian.Uint64(subnet.Mask[8:]) + ip1 := binary.BigEndian.Uint64(subnet.IP[:8]) + ip2 := binary.BigEndian.Uint64(subnet.IP[8:]) + binary.BigEndian.PutUint64(first[:8], ip1&mask1) + binary.BigEndian.PutUint64(first[8:], ip2&mask2) + binary.BigEndian.PutUint64(last[:8], (ip1&mask1)|(mask1^0xffffffffffffffff)) + binary.BigEndian.PutUint64(last[8:], (ip2&mask2)|(mask2^0xffffffffffffffff)) + } + + return first, last, nil +} diff --git a/vendor/github.com/google/nftables/xt/comment.go b/vendor/github.com/google/nftables/xt/comment.go new file mode 100644 index 000000000..830773a95 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/comment.go @@ -0,0 +1,34 @@ +package xt + +import ( + "bytes" + "fmt" +) + +// CommentSize is the fixed size of a comment info xt blob, see: +// https://elixir.bootlin.com/linux/v6.8.7/source/include/uapi/linux/netfilter/xt_comment.h#L5 +const CommentSize = 256 + +// Comment gets marshalled and unmarshalled as a fixed-sized char array, filled +// with zeros as necessary, see: +// https://elixir.bootlin.com/linux/v6.8.7/source/include/uapi/linux/netfilter/xt_comment.h#L7 +type Comment string + +func (c *Comment) marshal(fam TableFamily, rev uint32) ([]byte, error) { + if len(*c) >= CommentSize { + return nil, fmt.Errorf("comment must be less than %d bytes, got %d bytes", + CommentSize, len(*c)) + } + data := make([]byte, CommentSize) + copy(data, []byte(*c)) + return data, nil +} + +func (c *Comment) unmarshal(fam TableFamily, rev uint32, data []byte) error { + if len(data) != CommentSize { + return fmt.Errorf("malformed comment: got %d bytes, expected exactly %d bytes", + len(data), CommentSize) + } + *c = Comment(bytes.TrimRight(data, "\x00")) + return nil +} diff --git a/vendor/github.com/google/nftables/xt/info.go b/vendor/github.com/google/nftables/xt/info.go new file mode 100644 index 000000000..4706ba5fc --- /dev/null +++ b/vendor/github.com/google/nftables/xt/info.go @@ -0,0 +1,97 @@ +package xt + +import ( + "golang.org/x/sys/unix" +) + +// TableFamily specifies the address family of the table Match or Target Info +// data is contained in. On purpose, we don't import the expr package here in +// order to keep the option open to import this package instead into expr. +type TableFamily byte + +// InfoAny is a (un)marshaling implemented by any info type. +type InfoAny interface { + marshal(fam TableFamily, rev uint32) ([]byte, error) + unmarshal(fam TableFamily, rev uint32, data []byte) error +} + +// Marshal a Match or Target Info type into its binary representation. +func Marshal(fam TableFamily, rev uint32, info InfoAny) ([]byte, error) { + return info.marshal(fam, rev) +} + +// Unmarshal Info binary payload into its corresponding dedicated type as +// indicated by the name argument. In several cases, unmarshalling depends on +// the specific table family the Target or Match expression with the info +// payload belongs to, as well as the specific info structure revision. +func Unmarshal(name string, fam TableFamily, rev uint32, data []byte) (InfoAny, error) { + var i InfoAny + switch name { + case "addrtype": + switch rev { + case 0: + i = &AddrType{} + case 1: + i = &AddrTypeV1{} + } + case "comment": + var c Comment + i = &c + case "conntrack": + switch rev { + case 1: + i = &ConntrackMtinfo1{} + case 2: + i = &ConntrackMtinfo2{} + case 3: + i = &ConntrackMtinfo3{} + } + case "tcp": + i = &Tcp{} + case "udp": + i = &Udp{} + case "SNAT": + if fam == unix.NFPROTO_IPV4 { + i = &NatIPv4MultiRangeCompat{} + } + case "DNAT": + switch fam { + case unix.NFPROTO_IPV4: + if rev == 0 { + i = &NatIPv4MultiRangeCompat{} + break + } + fallthrough + case unix.NFPROTO_IPV6: + switch rev { + case 1: + i = &NatRange{} + case 2: + i = &NatRange2{} + } + } + case "MASQUERADE": + switch fam { + case unix.NFPROTO_IPV4: + i = &NatIPv4MultiRangeCompat{} + } + case "REDIRECT": + switch fam { + case unix.NFPROTO_IPV4: + if rev == 0 { + i = &NatIPv4MultiRangeCompat{} + break + } + fallthrough + case unix.NFPROTO_IPV6: + i = &NatRange{} + } + } + if i == nil { + i = &Unknown{} + } + if err := i.unmarshal(fam, rev, data); err != nil { + return nil, err + } + return i, nil +} diff --git a/vendor/github.com/google/nftables/xt/match_addrtype.go b/vendor/github.com/google/nftables/xt/match_addrtype.go new file mode 100644 index 000000000..3e21057a1 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/match_addrtype.go @@ -0,0 +1,89 @@ +package xt + +import ( + "github.com/google/nftables/alignedbuff" +) + +// Rev. 0, see https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_addrtype.h#L38 +type AddrType struct { + Source uint16 + Dest uint16 + InvertSource bool + InvertDest bool +} + +type AddrTypeFlags uint32 + +const ( + AddrTypeUnspec AddrTypeFlags = 1 << iota + AddrTypeUnicast + AddrTypeLocal + AddrTypeBroadcast + AddrTypeAnycast + AddrTypeMulticast + AddrTypeBlackhole + AddrTypeUnreachable + AddrTypeProhibit + AddrTypeThrow + AddrTypeNat + AddrTypeXresolve +) + +// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_addrtype.h#L31 +type AddrTypeV1 struct { + Source uint16 + Dest uint16 + Flags AddrTypeFlags +} + +func (x *AddrType) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + ab.PutUint16(x.Source) + ab.PutUint16(x.Dest) + putBool32(&ab, x.InvertSource) + putBool32(&ab, x.InvertDest) + return ab.Data(), nil +} + +func (x *AddrType) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if x.Source, err = ab.Uint16(); err != nil { + return nil + } + if x.Dest, err = ab.Uint16(); err != nil { + return nil + } + if x.InvertSource, err = bool32(&ab); err != nil { + return nil + } + if x.InvertDest, err = bool32(&ab); err != nil { + return nil + } + return nil +} + +func (x *AddrTypeV1) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + ab.PutUint16(x.Source) + ab.PutUint16(x.Dest) + ab.PutUint32(uint32(x.Flags)) + return ab.Data(), nil +} + +func (x *AddrTypeV1) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if x.Source, err = ab.Uint16(); err != nil { + return nil + } + if x.Dest, err = ab.Uint16(); err != nil { + return nil + } + var flags uint32 + if flags, err = ab.Uint32(); err != nil { + return nil + } + x.Flags = AddrTypeFlags(flags) + return nil +} diff --git a/vendor/github.com/google/nftables/xt/match_conntrack.go b/vendor/github.com/google/nftables/xt/match_conntrack.go new file mode 100644 index 000000000..69c51bd80 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/match_conntrack.go @@ -0,0 +1,260 @@ +package xt + +import ( + "net" + + "github.com/google/nftables/alignedbuff" +) + +type ConntrackFlags uint16 + +const ( + ConntrackState ConntrackFlags = 1 << iota + ConntrackProto + ConntrackOrigSrc + ConntrackOrigDst + ConntrackReplSrc + ConntrackReplDst + ConntrackStatus + ConntrackExpires + ConntrackOrigSrcPort + ConntrackOrigDstPort + ConntrackReplSrcPort + ConntrackReplDstPrt + ConntrackDirection + ConntrackStateAlias +) + +type ConntrackMtinfoBase struct { + OrigSrcAddr net.IP + OrigSrcMask net.IPMask + OrigDstAddr net.IP + OrigDstMask net.IPMask + ReplSrcAddr net.IP + ReplSrcMask net.IPMask + ReplDstAddr net.IP + ReplDstMask net.IPMask + ExpiresMin uint32 + ExpiresMax uint32 + L4Proto uint16 + OrigSrcPort uint16 + OrigDstPort uint16 + ReplSrcPort uint16 + ReplDstPort uint16 + MatchFlags uint16 + InvertFlags uint16 +} + +// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_conntrack.h#L38 +type ConntrackMtinfo1 struct { + ConntrackMtinfoBase + StateMask uint8 + StatusMask uint8 +} + +// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_conntrack.h#L51 +type ConntrackMtinfo2 struct { + ConntrackMtinfoBase + StateMask uint16 + StatusMask uint16 +} + +// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_conntrack.h#L64 +type ConntrackMtinfo3 struct { + ConntrackMtinfo2 + OrigSrcPortHigh uint16 + OrigDstPortHigh uint16 + ReplSrcPortHigh uint16 + ReplDstPortHigh uint16 +} + +func (x *ConntrackMtinfoBase) marshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + if err := putIPv46(ab, fam, x.OrigSrcAddr); err != nil { + return err + } + if err := putIPv46Mask(ab, fam, x.OrigSrcMask); err != nil { + return err + } + if err := putIPv46(ab, fam, x.OrigDstAddr); err != nil { + return err + } + if err := putIPv46Mask(ab, fam, x.OrigDstMask); err != nil { + return err + } + if err := putIPv46(ab, fam, x.ReplSrcAddr); err != nil { + return err + } + if err := putIPv46Mask(ab, fam, x.ReplSrcMask); err != nil { + return err + } + if err := putIPv46(ab, fam, x.ReplDstAddr); err != nil { + return err + } + if err := putIPv46Mask(ab, fam, x.ReplDstMask); err != nil { + return err + } + ab.PutUint32(x.ExpiresMin) + ab.PutUint32(x.ExpiresMax) + ab.PutUint16(x.L4Proto) + ab.PutUint16(x.OrigSrcPort) + ab.PutUint16(x.OrigDstPort) + ab.PutUint16(x.ReplSrcPort) + ab.PutUint16(x.ReplDstPort) + ab.PutUint16(x.MatchFlags) + ab.PutUint16(x.InvertFlags) + return nil +} + +func (x *ConntrackMtinfoBase) unmarshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + var err error + if x.OrigSrcAddr, err = iPv46(ab, fam); err != nil { + return err + } + if x.OrigSrcMask, err = iPv46Mask(ab, fam); err != nil { + return err + } + if x.OrigDstAddr, err = iPv46(ab, fam); err != nil { + return err + } + if x.OrigDstMask, err = iPv46Mask(ab, fam); err != nil { + return err + } + if x.ReplSrcAddr, err = iPv46(ab, fam); err != nil { + return err + } + if x.ReplSrcMask, err = iPv46Mask(ab, fam); err != nil { + return err + } + if x.ReplDstAddr, err = iPv46(ab, fam); err != nil { + return err + } + if x.ReplDstMask, err = iPv46Mask(ab, fam); err != nil { + return err + } + if x.ExpiresMin, err = ab.Uint32(); err != nil { + return err + } + if x.ExpiresMax, err = ab.Uint32(); err != nil { + return err + } + if x.L4Proto, err = ab.Uint16(); err != nil { + return err + } + if x.OrigSrcPort, err = ab.Uint16(); err != nil { + return err + } + if x.OrigDstPort, err = ab.Uint16(); err != nil { + return err + } + if x.ReplSrcPort, err = ab.Uint16(); err != nil { + return err + } + if x.ReplDstPort, err = ab.Uint16(); err != nil { + return err + } + if x.MatchFlags, err = ab.Uint16(); err != nil { + return err + } + if x.InvertFlags, err = ab.Uint16(); err != nil { + return err + } + return nil +} + +func (x *ConntrackMtinfo1) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + if err := x.ConntrackMtinfoBase.marshalAB(fam, rev, &ab); err != nil { + return nil, err + } + ab.PutUint8(x.StateMask) + ab.PutUint8(x.StatusMask) + return ab.Data(), nil +} + +func (x *ConntrackMtinfo1) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if err = x.ConntrackMtinfoBase.unmarshalAB(fam, rev, &ab); err != nil { + return err + } + if x.StateMask, err = ab.Uint8(); err != nil { + return err + } + if x.StatusMask, err = ab.Uint8(); err != nil { + return err + } + return nil +} + +func (x *ConntrackMtinfo2) marshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + if err := x.ConntrackMtinfoBase.marshalAB(fam, rev, ab); err != nil { + return err + } + ab.PutUint16(x.StateMask) + ab.PutUint16(x.StatusMask) + return nil +} + +func (x *ConntrackMtinfo2) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + if err := x.marshalAB(fam, rev, &ab); err != nil { + return nil, err + } + return ab.Data(), nil +} + +func (x *ConntrackMtinfo2) unmarshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + var err error + if err = x.ConntrackMtinfoBase.unmarshalAB(fam, rev, ab); err != nil { + return err + } + if x.StateMask, err = ab.Uint16(); err != nil { + return err + } + if x.StatusMask, err = ab.Uint16(); err != nil { + return err + } + return nil +} + +func (x *ConntrackMtinfo2) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if err = x.unmarshalAB(fam, rev, &ab); err != nil { + return err + } + return nil +} + +func (x *ConntrackMtinfo3) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + if err := x.ConntrackMtinfo2.marshalAB(fam, rev, &ab); err != nil { + return nil, err + } + ab.PutUint16(x.OrigSrcPortHigh) + ab.PutUint16(x.OrigDstPortHigh) + ab.PutUint16(x.ReplSrcPortHigh) + ab.PutUint16(x.ReplDstPortHigh) + return ab.Data(), nil +} + +func (x *ConntrackMtinfo3) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if err = x.ConntrackMtinfo2.unmarshalAB(fam, rev, &ab); err != nil { + return err + } + if x.OrigSrcPortHigh, err = ab.Uint16(); err != nil { + return err + } + if x.OrigDstPortHigh, err = ab.Uint16(); err != nil { + return err + } + if x.ReplSrcPortHigh, err = ab.Uint16(); err != nil { + return err + } + if x.ReplDstPortHigh, err = ab.Uint16(); err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/google/nftables/xt/match_tcp.go b/vendor/github.com/google/nftables/xt/match_tcp.go new file mode 100644 index 000000000..d991f1276 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/match_tcp.go @@ -0,0 +1,74 @@ +package xt + +import ( + "github.com/google/nftables/alignedbuff" +) + +// Tcp is the Match.Info payload for the tcp xtables extension +// (https://wiki.nftables.org/wiki-nftables/index.php/Supported_features_compared_to_xtables#tcp). +// +// See +// https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_tcpudp.h#L8 +type Tcp struct { + SrcPorts [2]uint16 // min, max source port range + DstPorts [2]uint16 // min, max destination port range + Option uint8 // TCP option if non-zero + FlagsMask uint8 // TCP flags mask + FlagsCmp uint8 // TCP flags compare + InvFlags TcpInvFlagset // Inverse flags +} + +type TcpInvFlagset uint8 + +const ( + TcpInvSrcPorts TcpInvFlagset = 1 << iota + TcpInvDestPorts + TcpInvFlags + TcpInvOption + TcpInvMask TcpInvFlagset = (1 << iota) - 1 +) + +func (x *Tcp) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + ab.PutUint16(x.SrcPorts[0]) + ab.PutUint16(x.SrcPorts[1]) + ab.PutUint16(x.DstPorts[0]) + ab.PutUint16(x.DstPorts[1]) + ab.PutUint8(x.Option) + ab.PutUint8(x.FlagsMask) + ab.PutUint8(x.FlagsCmp) + ab.PutUint8(byte(x.InvFlags)) + return ab.Data(), nil +} + +func (x *Tcp) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if x.SrcPorts[0], err = ab.Uint16(); err != nil { + return err + } + if x.SrcPorts[1], err = ab.Uint16(); err != nil { + return err + } + if x.DstPorts[0], err = ab.Uint16(); err != nil { + return err + } + if x.DstPorts[1], err = ab.Uint16(); err != nil { + return err + } + if x.Option, err = ab.Uint8(); err != nil { + return err + } + if x.FlagsMask, err = ab.Uint8(); err != nil { + return err + } + if x.FlagsCmp, err = ab.Uint8(); err != nil { + return err + } + var invFlags uint8 + if invFlags, err = ab.Uint8(); err != nil { + return err + } + x.InvFlags = TcpInvFlagset(invFlags) + return nil +} diff --git a/vendor/github.com/google/nftables/xt/match_udp.go b/vendor/github.com/google/nftables/xt/match_udp.go new file mode 100644 index 000000000..68ce12a06 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/match_udp.go @@ -0,0 +1,57 @@ +package xt + +import ( + "github.com/google/nftables/alignedbuff" +) + +// Tcp is the Match.Info payload for the tcp xtables extension +// (https://wiki.nftables.org/wiki-nftables/index.php/Supported_features_compared_to_xtables#tcp). +// +// See +// https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/xt_tcpudp.h#L25 +type Udp struct { + SrcPorts [2]uint16 // min, max source port range + DstPorts [2]uint16 // min, max destination port range + InvFlags UdpInvFlagset // Inverse flags +} + +type UdpInvFlagset uint8 + +const ( + UdpInvSrcPorts UdpInvFlagset = 1 << iota + UdpInvDestPorts + UdpInvMask UdpInvFlagset = (1 << iota) - 1 +) + +func (x *Udp) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + ab.PutUint16(x.SrcPorts[0]) + ab.PutUint16(x.SrcPorts[1]) + ab.PutUint16(x.DstPorts[0]) + ab.PutUint16(x.DstPorts[1]) + ab.PutUint8(byte(x.InvFlags)) + return ab.Data(), nil +} + +func (x *Udp) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if x.SrcPorts[0], err = ab.Uint16(); err != nil { + return err + } + if x.SrcPorts[1], err = ab.Uint16(); err != nil { + return err + } + if x.DstPorts[0], err = ab.Uint16(); err != nil { + return err + } + if x.DstPorts[1], err = ab.Uint16(); err != nil { + return err + } + var invFlags uint8 + if invFlags, err = ab.Uint8(); err != nil { + return err + } + x.InvFlags = UdpInvFlagset(invFlags) + return nil +} diff --git a/vendor/github.com/google/nftables/xt/target_dnat.go b/vendor/github.com/google/nftables/xt/target_dnat.go new file mode 100644 index 000000000..b54e8fbef --- /dev/null +++ b/vendor/github.com/google/nftables/xt/target_dnat.go @@ -0,0 +1,106 @@ +package xt + +import ( + "net" + + "github.com/google/nftables/alignedbuff" +) + +type NatRangeFlags uint + +// See: https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L8 +const ( + NatRangeMapIPs NatRangeFlags = (1 << iota) + NatRangeProtoSpecified + NatRangeProtoRandom + NatRangePersistent + NatRangeProtoRandomFully + NatRangeProtoOffset + NatRangeNetmap + + NatRangeMask NatRangeFlags = (1 << iota) - 1 + + NatRangeProtoRandomAll = NatRangeProtoRandom | NatRangeProtoRandomFully +) + +// see: https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L38 +type NatRange struct { + Flags uint // sic! platform/arch/compiler-dependent uint size + MinIP net.IP // always taking up space for an IPv6 address + MaxIP net.IP // dito + MinPort uint16 + MaxPort uint16 +} + +// see: https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L46 +type NatRange2 struct { + NatRange + BasePort uint16 +} + +func (x *NatRange) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + if err := x.marshalAB(fam, rev, &ab); err != nil { + return nil, err + } + return ab.Data(), nil +} + +func (x *NatRange) marshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + ab.PutUint(x.Flags) + if err := putIPv46(ab, fam, x.MinIP); err != nil { + return err + } + if err := putIPv46(ab, fam, x.MaxIP); err != nil { + return err + } + ab.PutUint16BE(x.MinPort) + ab.PutUint16BE(x.MaxPort) + return nil +} + +func (x *NatRange) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + return x.unmarshalAB(fam, rev, &ab) +} + +func (x *NatRange) unmarshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + var err error + if x.Flags, err = ab.Uint(); err != nil { + return err + } + if x.MinIP, err = iPv46(ab, fam); err != nil { + return err + } + if x.MaxIP, err = iPv46(ab, fam); err != nil { + return err + } + if x.MinPort, err = ab.Uint16BE(); err != nil { + return err + } + if x.MaxPort, err = ab.Uint16BE(); err != nil { + return err + } + return nil +} + +func (x *NatRange2) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + if err := x.NatRange.marshalAB(fam, rev, &ab); err != nil { + return nil, err + } + ab.PutUint16BE(x.BasePort) + return ab.Data(), nil +} + +func (x *NatRange2) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + var err error + if err = x.NatRange.unmarshalAB(fam, rev, &ab); err != nil { + return err + } + if x.BasePort, err = ab.Uint16BE(); err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/google/nftables/xt/target_masquerade_ip.go b/vendor/github.com/google/nftables/xt/target_masquerade_ip.go new file mode 100644 index 000000000..411d3beaa --- /dev/null +++ b/vendor/github.com/google/nftables/xt/target_masquerade_ip.go @@ -0,0 +1,86 @@ +package xt + +import ( + "errors" + "net" + + "github.com/google/nftables/alignedbuff" +) + +// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L25 +type NatIPv4Range struct { + Flags uint // sic! + MinIP net.IP + MaxIP net.IP + MinPort uint16 + MaxPort uint16 +} + +// NatIPv4MultiRangeCompat despite being a slice of NAT IPv4 ranges is currently allowed to +// only hold exactly one element. +// +// See https://elixir.bootlin.com/linux/v5.17.7/source/include/uapi/linux/netfilter/nf_nat.h#L33 +type NatIPv4MultiRangeCompat []NatIPv4Range + +func (x *NatIPv4MultiRangeCompat) marshal(fam TableFamily, rev uint32) ([]byte, error) { + ab := alignedbuff.New() + if len(*x) != 1 { + return nil, errors.New("MasqueradeIp must contain exactly one NatIPv4Range") + } + ab.PutUint(uint(len(*x))) + for _, nat := range *x { + if err := nat.marshalAB(fam, rev, &ab); err != nil { + return nil, err + } + } + return ab.Data(), nil +} + +func (x *NatIPv4MultiRangeCompat) unmarshal(fam TableFamily, rev uint32, data []byte) error { + ab := alignedbuff.NewWithData(data) + l, err := ab.Uint() + if err != nil { + return err + } + nats := make(NatIPv4MultiRangeCompat, l) + for l > 0 { + l-- + if err := nats[l].unmarshalAB(fam, rev, &ab); err != nil { + return err + } + } + *x = nats + return nil +} + +func (x *NatIPv4Range) marshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + ab.PutUint(x.Flags) + ab.PutBytesAligned32(x.MinIP.To4(), 4) + ab.PutBytesAligned32(x.MaxIP.To4(), 4) + ab.PutUint16BE(x.MinPort) + ab.PutUint16BE(x.MaxPort) + return nil +} + +func (x *NatIPv4Range) unmarshalAB(fam TableFamily, rev uint32, ab *alignedbuff.AlignedBuff) error { + var err error + if x.Flags, err = ab.Uint(); err != nil { + return err + } + var ip []byte + if ip, err = ab.BytesAligned32(4); err != nil { + return err + } + x.MinIP = net.IP(ip) + if ip, err = ab.BytesAligned32(4); err != nil { + return err + } + x.MaxIP = net.IP(ip) + if x.MinPort, err = ab.Uint16BE(); err != nil { + return err + } + if x.MaxPort, err = ab.Uint16BE(); err != nil { + return err + } + return nil +} diff --git a/vendor/github.com/google/nftables/xt/unknown.go b/vendor/github.com/google/nftables/xt/unknown.go new file mode 100644 index 000000000..c648307c5 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/unknown.go @@ -0,0 +1,17 @@ +package xt + +// Unknown represents the bytes Info payload for unknown Info types where no +// dedicated match/target info type has (yet) been defined. +type Unknown []byte + +func (x *Unknown) marshal(fam TableFamily, rev uint32) ([]byte, error) { + // In case of unknown payload we assume its creator knows what she/he does + // and thus we don't do any alignment padding. Just take the payload "as + // is". + return *x, nil +} + +func (x *Unknown) unmarshal(fam TableFamily, rev uint32, data []byte) error { + *x = data + return nil +} diff --git a/vendor/github.com/google/nftables/xt/util.go b/vendor/github.com/google/nftables/xt/util.go new file mode 100644 index 000000000..673ac54f7 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/util.go @@ -0,0 +1,64 @@ +package xt + +import ( + "fmt" + "net" + + "github.com/google/nftables/alignedbuff" + "golang.org/x/sys/unix" +) + +func bool32(ab *alignedbuff.AlignedBuff) (bool, error) { + v, err := ab.Uint32() + if err != nil { + return false, err + } + if v != 0 { + return true, nil + } + return false, nil +} + +func putBool32(ab *alignedbuff.AlignedBuff, b bool) { + if b { + ab.PutUint32(1) + return + } + ab.PutUint32(0) +} + +func iPv46(ab *alignedbuff.AlignedBuff, fam TableFamily) (net.IP, error) { + ip, err := ab.BytesAligned32(16) + if err != nil { + return nil, err + } + switch fam { + case unix.NFPROTO_IPV4: + return net.IP(ip[:4]), nil + case unix.NFPROTO_IPV6: + return net.IP(ip), nil + default: + return nil, fmt.Errorf("unmarshal IP: unsupported table family %d", fam) + } +} + +func iPv46Mask(ab *alignedbuff.AlignedBuff, fam TableFamily) (net.IPMask, error) { + v, err := iPv46(ab, fam) + return net.IPMask(v), err +} + +func putIPv46(ab *alignedbuff.AlignedBuff, fam TableFamily, ip net.IP) error { + switch fam { + case unix.NFPROTO_IPV4: + ab.PutBytesAligned32(ip.To4(), 16) + case unix.NFPROTO_IPV6: + ab.PutBytesAligned32(ip.To16(), 16) + default: + return fmt.Errorf("marshal IP: unsupported table family %d", fam) + } + return nil +} + +func putIPv46Mask(ab *alignedbuff.AlignedBuff, fam TableFamily, mask net.IPMask) error { + return putIPv46(ab, fam, net.IP(mask)) +} diff --git a/vendor/github.com/google/nftables/xt/xt.go b/vendor/github.com/google/nftables/xt/xt.go new file mode 100644 index 000000000..d8977c1d0 --- /dev/null +++ b/vendor/github.com/google/nftables/xt/xt.go @@ -0,0 +1,48 @@ +/* +Package xt implements dedicated types for (some) of the "Info" payload in Match +and Target expressions that bridge between the nftables and xtables worlds. + +Bridging between the more unified world of nftables and the slightly +heterogenous world of xtables comes with some caveats. Unmarshalling the +extension/translation information in Match and Target expressions requires +information about the table family the information belongs to, as well as type +and type revision information. In consequence, unmarshalling the Match and +Target Info field payloads often (but not necessarily always) require the table +family and revision information, so it gets passed to the type-specific +unmarshallers. + +To complicate things more, even marshalling requires knowledge about the +enclosing table family. The NatRange/NatRange2 types are an example, where it is +necessary to differentiate between IPv4 and IPv6 address marshalling. Due to +Go's net.IP habit to normally store IPv4 addresses as IPv4-compatible IPv6 +addresses (see also RFC 4291, section 2.5.5.1) marshalling must be handled +differently in the context of an IPv6 table compared to an IPv4 table. In an +IPv4 table, an IPv4-compatible IPv6 address must be marshalled as a 32bit +address, whereas in an IPv6 table the IPv4 address must be marshalled as an +128bit IPv4-compatible IPv6 address. Not relying on heuristics here we avoid +behavior unexpected and most probably unknown to our API users. The net.IP habit +of storing IPv4 addresses in two different storage formats is already a source +for trouble, especially when comparing net.IPs from different Go module sources. +We won't add to this confusion. (...or maybe we can, because of it?) + +An important property of all types of Info extension/translation payloads is +that their marshalling and unmarshalling doesn't follow netlink's TLV +(tag-length-value) architecture. Instead, Info payloads a basically plain binary +blobs of their respective type-specific data structures, so host +platform/architecture alignment and data type sizes apply. The alignedbuff +package implements the different required data types alignments. + +Please note that Info payloads are always padded at their end to the next uint64 +alignment. Kernel code is checking for the padded payload size and will reject +payloads not correctly padded at their ends. + +Most of the time, we find explifcitly sized (unsigned integer) data types. +However, there are notable exceptions where "unsigned int" is used: on 64bit +platforms this mostly translates into 32bit(!). This differs from Go mapping +uint to uint64 instead. This package currently clamps its mapping of C's +"unsigned int" to Go's uint32 for marshalling and unmarshalling. If in the +future 128bit platforms with a differently sized C unsigned int should come into +production, then the alignedbuff package will need to be adapted accordingly, as +it abstracts away this data type handling. +*/ +package xt diff --git a/vendor/github.com/mdlayher/netlink/.gitignore b/vendor/github.com/mdlayher/netlink/.gitignore new file mode 100644 index 000000000..efc8a0a9c --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/.gitignore @@ -0,0 +1,4 @@ +internal/integration/integration.test +netlink.test +netlink-fuzz.zip +testdata/ diff --git a/vendor/github.com/mdlayher/netlink/CHANGELOG.md b/vendor/github.com/mdlayher/netlink/CHANGELOG.md new file mode 100644 index 000000000..eac8e924c --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/CHANGELOG.md @@ -0,0 +1,174 @@ +# CHANGELOG + +## v1.7.2 + +- [Improvement]: updated dependencies, test with Go 1.20. + +## v1.7.1 + +- [Bug Fix]: test only changes to avoid failures on big endian machines. + +## v1.7.0 + +**This is the first release of package netlink that only supports Go 1.18+. +Users on older versions of Go must use v1.6.2.** + +- [Improvement]: drop support for older versions of Go so we can begin using + modern versions of `x/sys` and other dependencies. + +## v1.6.2 + +**This is the last release of package netlink that supports Go 1.17 and below.** + +- [Bug Fix] [commit](https://github.com/mdlayher/netlink/commit/9f7f860d9865069cd1a6b4dee32a3095f0b841fc): + undo update to `golang.org/x/sys` which would force the minimum Go version of + this package to Go 1.17 due to use of `unsafe.Slice`. We encourage users to + use the latest stable version of Go where possible, but continue to maintain + some compatibility with older versions of Go as long as it is reasonable to do + so. + +## v1.6.1 + +- [Deprecation] [commit](https://github.com/mdlayher/netlink/commit/d1b69ea8697d721415c259ef8513ab699c6d3e96): + the `netlink.Socket` interface has been marked as deprecated. The abstraction + is awkward to use properly and disables much of the functionality of the Conn + type when the basic interface is implemented. Do not use. + +## v1.6.0 + +**This is the first release of package netlink that only supports Go 1.13+. +Users on older versions of Go must use v1.5.0.** + +- [New API] [commit](https://github.com/mdlayher/netlink/commit/ad9e2c41caa993e3f4b68831d6cb2cb05818275d): + the `netlink.Config.Strict` field can be used to apply a more strict default + set of options to a `netlink.Conn`. This is recommended for applications + running on modern Linux kernels, but cannot be enabled by default because the + options may require a more recent kernel than the minimum kernel version that + Go supports. See the documentation for details. +- [Improvement]: broke some integration tests into a separate Go module so the + default `go.mod` for package `netlink` has fewer dependencies. + +## v1.5.0 + +**This is the last release of package netlink that supports Go 1.12.** + +- [New API] [commit](https://github.com/mdlayher/netlink/commit/53a1c10065e51077659ceedf921c8f0807abe8c0): + the `netlink.Config.PID` field can be used to specify an explicit port ID when + binding the netlink socket. This is intended for advanced use cases and most + callers should leave this field set to 0. +- [Improvement]: more low-level functionality ported to + `github.com/mdlayher/socket`, reducing package complexity. + +## v1.4.2 + +- [Documentation] [commit](https://github.com/mdlayher/netlink/commit/177e6364fb170d465d681c7c8a6283417a6d3e49): + the `netlink.Config.DisableNSLockThread` now properly uses Go's deprecated + identifier convention. This option has been a noop for a long time and should + not be used. +- [Improvement] [#189](https://github.com/mdlayher/netlink/pull/189): the + package now uses Go 1.17's `//go:build` identifiers. Thanks @tklauser. +- [Bug Fix] + [commit](https://github.com/mdlayher/netlink/commit/fe6002e030928bd1f2a446c0b6c65e8f2df4ed5e): + the `netlink.AttributeEncoder`'s `Bytes`, `String`, and `Do` methods now + properly reject byte slices and strings which are too large to fit in the + value of a netlink attribute. Thanks @ubiquitousbyte for the report. + +## v1.4.1 + +- [Improvement]: significant runtime network poller integration cleanup through + the use of `github.com/mdlayher/socket`. + +## v1.4.0 + +- [New API] [#185](https://github.com/mdlayher/netlink/pull/185): the + `netlink.AttributeDecoder` and `netlink.AttributeEncoder` types now have + methods for dealing with signed integers: `Int8`, `Int16`, `Int32`, and + `Int64`. These are necessary for working with rtnetlink's XDP APIs. Thanks + @fbegyn. + +## v1.3.2 + +- [Improvement] + [commit](https://github.com/mdlayher/netlink/commit/ebc6e2e28bcf1a0671411288423d8116ff924d6d): + `github.com/google/go-cmp` is no longer a (non-test) dependency of this module. + +## v1.3.1 + +- [Improvement]: many internal cleanups and simplifications. The library is now + slimmer and features less internal indirection. There are no user-facing + changes in this release. + +## v1.3.0 + +- [New API] [#176](https://github.com/mdlayher/netlink/pull/176): + `netlink.OpError` now has `Message` and `Offset` fields which are populated + when the kernel returns netlink extended acknowledgement data along with an + error code. The caller can turn on this option by using + `netlink.Conn.SetOption(netlink.ExtendedAcknowledge, true)`. +- [New API] + [commit](https://github.com/mdlayher/netlink/commit/beba85e0372133b6d57221191d2c557727cd1499): + the `netlink.GetStrictCheck` option can be used to tell the kernel to be more + strict when parsing requests. This enables more safety checks and can allow + the kernel to perform more advanced request filtering in subsystems such as + route netlink. + +## v1.2.1 + +- [Bug Fix] + [commit](https://github.com/mdlayher/netlink/commit/d81418f81b0bfa2465f33790a85624c63d6afe3d): + `netlink.SetBPF` will no longer panic if an empty BPF filter is set. +- [Improvement] + [commit](https://github.com/mdlayher/netlink/commit/8014f9a7dbf4fd7b84a1783dd7b470db9113ff36): + the library now uses https://github.com/josharian/native to provide the + system's native endianness at compile time, rather than re-computing it many + times at runtime. + +## v1.2.0 + +**This is the first release of package netlink that only supports Go 1.12+. +Users on older versions of Go must use v1.1.1.** + +- [Improvement] [#173](https://github.com/mdlayher/netlink/pull/173): support + for Go 1.11 and below has been dropped. All users are highly recommended to + use a stable and supported release of Go for their applications. +- [Performance] [#171](https://github.com/mdlayher/netlink/pull/171): + `netlink.Conn` no longer requires a locked OS thread for the vast majority of + operations, which should result in a significant speedup for highly concurrent + callers. Thanks @ti-mo. +- [Bug Fix] [#169](https://github.com/mdlayher/netlink/pull/169): calls to + `netlink.Conn.Close` are now able to unblock concurrent calls to + `netlink.Conn.Receive` and other blocking operations. + +## v1.1.1 + +**This is the last release of package netlink that supports Go 1.11.** + +- [Improvement] [#165](https://github.com/mdlayher/netlink/pull/165): + `netlink.Conn` `SetReadBuffer` and `SetWriteBuffer` methods now attempt the + `SO_*BUFFORCE` socket options to possibly ignore system limits given elevated + caller permissions. Thanks @MarkusBauer. +- [Note] + [commit](https://github.com/mdlayher/netlink/commit/c5f8ab79aa345dcfcf7f14d746659ca1b80a0ecc): + `netlink.Conn.Close` has had a long-standing bug + [#162](https://github.com/mdlayher/netlink/pull/162) related to internal + concurrency handling where a call to `Close` is not sufficient to unblock + pending reads. To effectively fix this issue, it is necessary to drop support + for Go 1.11 and below. This will be fixed in a future release, but a + workaround is noted in the method documentation as of now. + +## v1.1.0 + +- [New API] [#157](https://github.com/mdlayher/netlink/pull/157): the + `netlink.AttributeDecoder.TypeFlags` method enables retrieval of the type bits + stored in a netlink attribute's type field, because the existing `Type` method + masks away these bits. Thanks @ti-mo! +- [Performance] [#157](https://github.com/mdlayher/netlink/pull/157): `netlink.AttributeDecoder` + now decodes netlink attributes on demand, enabling callers who only need a + limited number of attributes to exit early from decoding loops. Thanks @ti-mo! +- [Improvement] [#161](https://github.com/mdlayher/netlink/pull/161): `netlink.Conn` + system calls are now ready for Go 1.14+'s changes to goroutine preemption. + See the PR for details. + +## v1.0.0 + +- Initial stable commit. diff --git a/vendor/github.com/mdlayher/netlink/LICENSE.md b/vendor/github.com/mdlayher/netlink/LICENSE.md new file mode 100644 index 000000000..12f710585 --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/LICENSE.md @@ -0,0 +1,9 @@ +# MIT License + +Copyright (C) 2016-2022 Matt Layher + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/mdlayher/netlink/README.md b/vendor/github.com/mdlayher/netlink/README.md new file mode 100644 index 000000000..c41de0d35 --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/README.md @@ -0,0 +1,183 @@ +# netlink [![Test Status](https://github.com/mdlayher/netlink/workflows/Linux%20Test/badge.svg)](https://github.com/mdlayher/netlink/actions) [![Go Reference](https://pkg.go.dev/badge/github.com/mdlayher/netlink.svg)](https://pkg.go.dev/github.com/mdlayher/netlink) [![Go Report Card](https://goreportcard.com/badge/github.com/mdlayher/netlink)](https://goreportcard.com/report/github.com/mdlayher/netlink) + +Package `netlink` provides low-level access to Linux netlink sockets +(`AF_NETLINK`). MIT Licensed. + +For more information about how netlink works, check out my blog series +on [Linux, Netlink, and Go](https://mdlayher.com/blog/linux-netlink-and-go-part-1-netlink/). + +If you have any questions or you'd like some guidance, please join us on +[Gophers Slack](https://invite.slack.golangbridge.org) in the `#networking` +channel! + +## Stability + +See the [CHANGELOG](./CHANGELOG.md) file for a description of changes between +releases. + +This package has a stable v1 API and any future breaking changes will prompt +the release of a new major version. Features and bug fixes will continue to +occur in the v1.x.x series. + +This package only supports the two most recent major versions of Go, mirroring +Go's own release policy. Older versions of Go may lack critical features and bug +fixes which are necessary for this package to function correctly. + +## Design + +A number of netlink packages are already available for Go, but I wasn't able to +find one that aligned with what I wanted in a netlink package: + +- Straightforward, idiomatic API +- Well tested +- Well documented +- Doesn't use package/global variables or state +- Doesn't necessarily need root to work + +My goal for this package is to use it as a building block for the creation +of other netlink family packages. + +## Ecosystem + +Over time, an ecosystem of Go packages has developed around package `netlink`. +Many of these packages provide building blocks for further interactions with +various netlink families, such as `NETLINK_GENERIC` or `NETLINK_ROUTE`. + +To have your package included in this diagram, please send a pull request! + +```mermaid +flowchart LR + netlink["github.com/mdlayher/netlink"] + click netlink "https://github.com/mdlayher/netlink" + + subgraph "NETLINK_CONNECTOR" + direction LR + + garlic["github.com/fearful-symmetry/garlic"] + click garlic "https://github.com/fearful-symmetry/garlic" + end + + subgraph "NETLINK_CRYPTO" + direction LR + + cryptonl["github.com/mdlayher/cryptonl"] + click cryptonl "https://github.com/mdlayher/cryptonl" + end + + subgraph "NETLINK_GENERIC" + direction LR + + genetlink["github.com/mdlayher/genetlink"] + click genetlink "https://github.com/mdlayher/genetlink" + + devlink["github.com/mdlayher/devlink"] + click devlink "https://github.com/mdlayher/devlink" + + ethtool["github.com/mdlayher/ethtool"] + click ethtool "https://github.com/mdlayher/ethtool" + + go-openvswitch["github.com/digitalocean/go-openvswitch"] + click go-openvswitch "https://github.com/digitalocean/go-openvswitch" + + ipvs["github.com/cloudflare/ipvs"] + click ipvs "https://github.com/cloudflare/ipvs" + + l2tp["github.com/axatrax/l2tp"] + click l2tp "https://github.com/axatrax/l2tp" + + nbd["github.com/Merovius/nbd"] + click nbd "https://github.com/Merovius/nbd" + + quota["github.com/mdlayher/quota"] + click quota "https://github.com/mdlayher/quota" + + router7["github.com/rtr7/router7"] + click router7 "https://github.com/rtr7/router7" + + taskstats["github.com/mdlayher/taskstats"] + click taskstats "https://github.com/mdlayher/taskstats" + + u-bmc["github.com/u-root/u-bmc"] + click u-bmc "https://github.com/u-root/u-bmc" + + wgctrl["golang.zx2c4.com/wireguard/wgctrl"] + click wgctrl "https://golang.zx2c4.com/wireguard/wgctrl" + + wifi["github.com/mdlayher/wifi"] + click wifi "https://github.com/mdlayher/wifi" + + devlink & ethtool & go-openvswitch & ipvs --> genetlink + l2tp & nbd & quota & router7 & taskstats --> genetlink + u-bmc & wgctrl & wifi --> genetlink + end + + subgraph "NETLINK_KOBJECT_UEVENT" + direction LR + + kobject["github.com/mdlayher/kobject"] + click kobject "https://github.com/mdlayher/kobject" + end + + subgraph "NETLINK_NETFILTER" + direction LR + + go-conntrack["github.com/florianl/go-conntrack"] + click go-conntrack "https://github.com/florianl/go-conntrack" + + go-nflog["github.com/florianl/go-nflog"] + click go-nflog "https://github.com/florianl/go-nflog" + + go-nfqueue["github.com/florianl/go-nfqueue"] + click go-nfqueue "https://github.com/florianl/go-nfqueue" + + netfilter["github.com/ti-mo/netfilter"] + click netfilter "https://github.com/ti-mo/netfilter" + + nftables["github.com/google/nftables"] + click nftables "https://github.com/google/nftables" + + conntrack["github.com/ti-mo/conntrack"] + click conntrack "https://github.com/ti-mo/conntrack" + + conntrack --> netfilter + end + + subgraph "NETLINK_ROUTE" + direction LR + + go-tc["github.com/florianl/go-tc"] + click go-tc "https://github.com/florianl/go-tc" + + qdisc["github.com/ema/qdisc"] + click qdisc "https://github.com/ema/qdisc" + + rtnetlink["github.com/jsimonetti/rtnetlink"] + click rtnetlink "https://github.com/jsimonetti/rtnetlink" + + rtnl["gitlab.com/mergetb/tech/rtnl"] + click rtnl "https://gitlab.com/mergetb/tech/rtnl" + end + + subgraph "NETLINK_W1" + direction LR + + go-onewire["github.com/SpComb/go-onewire"] + click go-onewire "https://github.com/SpComb/go-onewire" + end + + subgraph "NETLINK_SOCK_DIAG" + direction LR + + go-diag["github.com/florianl/go-diag"] + click go-diag "https://github.com/florianl/go-diag" + end + + NETLINK_CONNECTOR --> netlink + NETLINK_CRYPTO --> netlink + NETLINK_GENERIC --> netlink + NETLINK_KOBJECT_UEVENT --> netlink + NETLINK_NETFILTER --> netlink + NETLINK_ROUTE --> netlink + NETLINK_SOCK_DIAG --> netlink + NETLINK_W1 --> netlink +``` diff --git a/vendor/github.com/mdlayher/netlink/align.go b/vendor/github.com/mdlayher/netlink/align.go new file mode 100644 index 000000000..20892c701 --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/align.go @@ -0,0 +1,37 @@ +package netlink + +import "unsafe" + +// Functions and values used to properly align netlink messages, headers, +// and attributes. Definitions taken from Linux kernel source. + +// #define NLMSG_ALIGNTO 4U +const nlmsgAlignTo = 4 + +// #define NLMSG_ALIGN(len) ( ((len)+NLMSG_ALIGNTO-1) & ~(NLMSG_ALIGNTO-1) ) +func nlmsgAlign(len int) int { + return ((len) + nlmsgAlignTo - 1) & ^(nlmsgAlignTo - 1) +} + +// #define NLMSG_LENGTH(len) ((len) + NLMSG_HDRLEN) +func nlmsgLength(len int) int { + return len + nlmsgHeaderLen +} + +// #define NLMSG_HDRLEN ((int) NLMSG_ALIGN(sizeof(struct nlmsghdr))) +var nlmsgHeaderLen = nlmsgAlign(int(unsafe.Sizeof(Header{}))) + +// #define NLA_ALIGNTO 4 +const nlaAlignTo = 4 + +// #define NLA_ALIGN(len) (((len) + NLA_ALIGNTO - 1) & ~(NLA_ALIGNTO - 1)) +func nlaAlign(len int) int { + return ((len) + nlaAlignTo - 1) & ^(nlaAlignTo - 1) +} + +// Because this package's Attribute type contains a byte slice, unsafe.Sizeof +// can't be used to determine the correct length. +const sizeofAttribute = 4 + +// #define NLA_HDRLEN ((int) NLA_ALIGN(sizeof(struct nlattr))) +var nlaHeaderLen = nlaAlign(sizeofAttribute) diff --git a/vendor/github.com/mdlayher/netlink/attribute.go b/vendor/github.com/mdlayher/netlink/attribute.go new file mode 100644 index 000000000..1c81c323c --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/attribute.go @@ -0,0 +1,706 @@ +package netlink + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + + "github.com/mdlayher/netlink/nlenc" +) + +// errInvalidAttribute specifies if an Attribute's length is incorrect. +var errInvalidAttribute = errors.New("invalid attribute; length too short or too large") + +// An Attribute is a netlink attribute. Attributes are packed and unpacked +// to and from the Data field of Message for some netlink families. +type Attribute struct { + // Length of an Attribute, including this field and Type. + Length uint16 + + // The type of this Attribute, typically matched to a constant. Note that + // flags such as Nested and NetByteOrder must be handled manually when + // working with Attribute structures directly. + Type uint16 + + // An arbitrary payload which is specified by Type. + Data []byte +} + +// marshal marshals the contents of a into b and returns the number of bytes +// written to b, including attribute alignment padding. +func (a *Attribute) marshal(b []byte) (int, error) { + if int(a.Length) < nlaHeaderLen { + return 0, errInvalidAttribute + } + + nlenc.PutUint16(b[0:2], a.Length) + nlenc.PutUint16(b[2:4], a.Type) + n := copy(b[nlaHeaderLen:], a.Data) + + return nlaHeaderLen + nlaAlign(n), nil +} + +// unmarshal unmarshals the contents of a byte slice into an Attribute. +func (a *Attribute) unmarshal(b []byte) error { + if len(b) < nlaHeaderLen { + return errInvalidAttribute + } + + a.Length = nlenc.Uint16(b[0:2]) + a.Type = nlenc.Uint16(b[2:4]) + + if int(a.Length) > len(b) { + return errInvalidAttribute + } + + switch { + // No length, no data + case a.Length == 0: + a.Data = make([]byte, 0) + // Not enough length for any data + case int(a.Length) < nlaHeaderLen: + return errInvalidAttribute + // Data present + case int(a.Length) >= nlaHeaderLen: + a.Data = make([]byte, len(b[nlaHeaderLen:a.Length])) + copy(a.Data, b[nlaHeaderLen:a.Length]) + } + + return nil +} + +// MarshalAttributes packs a slice of Attributes into a single byte slice. +// In most cases, the Length field of each Attribute should be set to 0, so it +// can be calculated and populated automatically for each Attribute. +// +// It is recommend to use the AttributeEncoder type where possible instead of +// calling MarshalAttributes and using package nlenc functions directly. +func MarshalAttributes(attrs []Attribute) ([]byte, error) { + // Count how many bytes we should allocate to store each attribute's contents. + var c int + for _, a := range attrs { + c += nlaHeaderLen + nlaAlign(len(a.Data)) + } + + // Advance through b with idx to place attribute data at the correct offset. + var idx int + b := make([]byte, c) + for _, a := range attrs { + // Infer the length of attribute if zero. + if a.Length == 0 { + a.Length = uint16(nlaHeaderLen + len(a.Data)) + } + + // Marshal a into b and advance idx to show many bytes are occupied. + n, err := a.marshal(b[idx:]) + if err != nil { + return nil, err + } + idx += n + } + + return b, nil +} + +// UnmarshalAttributes unpacks a slice of Attributes from a single byte slice. +// +// It is recommend to use the AttributeDecoder type where possible instead of calling +// UnmarshalAttributes and using package nlenc functions directly. +func UnmarshalAttributes(b []byte) ([]Attribute, error) { + ad, err := NewAttributeDecoder(b) + if err != nil { + return nil, err + } + + // Return a nil slice when there are no attributes to decode. + if ad.Len() == 0 { + return nil, nil + } + + attrs := make([]Attribute, 0, ad.Len()) + + for ad.Next() { + if ad.a.Length != 0 { + attrs = append(attrs, ad.a) + } + } + + if err := ad.Err(); err != nil { + return nil, err + } + + return attrs, nil +} + +// An AttributeDecoder provides a safe, iterator-like, API around attribute +// decoding. +// +// It is recommend to use an AttributeDecoder where possible instead of calling +// UnmarshalAttributes and using package nlenc functions directly. +// +// The Err method must be called after the Next method returns false to determine +// if any errors occurred during iteration. +type AttributeDecoder struct { + // ByteOrder defines a specific byte order to use when processing integer + // attributes. ByteOrder should be set immediately after creating the + // AttributeDecoder: before any attributes are parsed. + // + // If not set, the native byte order will be used. + ByteOrder binary.ByteOrder + + // The current attribute being worked on. + a Attribute + + // The slice of input bytes and its iterator index. + b []byte + i int + + length int + + // Any error encountered while decoding attributes. + err error +} + +// NewAttributeDecoder creates an AttributeDecoder that unpacks Attributes +// from b and prepares the decoder for iteration. +func NewAttributeDecoder(b []byte) (*AttributeDecoder, error) { + ad := &AttributeDecoder{ + // By default, use native byte order. + ByteOrder: binary.NativeEndian, + + b: b, + } + + var err error + ad.length, err = ad.available() + if err != nil { + return nil, err + } + + return ad, nil +} + +// Next advances the decoder to the next netlink attribute. It returns false +// when no more attributes are present, or an error was encountered. +func (ad *AttributeDecoder) Next() bool { + if ad.err != nil { + // Hit an error, stop iteration. + return false + } + + // Exit if array pointer is at or beyond the end of the slice. + if ad.i >= len(ad.b) { + return false + } + + if err := ad.a.unmarshal(ad.b[ad.i:]); err != nil { + ad.err = err + return false + } + + // Advance the pointer by at least one header's length. + if int(ad.a.Length) < nlaHeaderLen { + ad.i += nlaHeaderLen + } else { + ad.i += nlaAlign(int(ad.a.Length)) + } + + return true +} + +// Type returns the Attribute.Type field of the current netlink attribute +// pointed to by the decoder. +// +// Type masks off the high bits of the netlink attribute type which may contain +// the Nested and NetByteOrder flags. These can be obtained by calling TypeFlags. +func (ad *AttributeDecoder) Type() uint16 { + // Mask off any flags stored in the high bits. + return ad.a.Type & attrTypeMask +} + +// TypeFlags returns the two high bits of the Attribute.Type field of the current +// netlink attribute pointed to by the decoder. +// +// These bits of the netlink attribute type are used for the Nested and NetByteOrder +// flags, available as the Nested and NetByteOrder constants in this package. +func (ad *AttributeDecoder) TypeFlags() uint16 { + return ad.a.Type & ^attrTypeMask +} + +// Len returns the number of netlink attributes pointed to by the decoder. +func (ad *AttributeDecoder) Len() int { return ad.length } + +// count scans the input slice to count the number of netlink attributes +// that could be decoded by Next(). +func (ad *AttributeDecoder) available() (int, error) { + var count int + for i := 0; i < len(ad.b); { + // Make sure there's at least a header's worth + // of data to read on each iteration. + if len(ad.b[i:]) < nlaHeaderLen { + return 0, errInvalidAttribute + } + + // Extract the length of the attribute. + l := int(nlenc.Uint16(ad.b[i : i+2])) + + // Ignore zero-length attributes. + if l != 0 { + count++ + } + + // Advance by at least a header's worth of bytes. + if l < nlaHeaderLen { + l = nlaHeaderLen + } + + i += nlaAlign(l) + } + + return count, nil +} + +// data returns the Data field of the current Attribute pointed to by the decoder. +func (ad *AttributeDecoder) data() []byte { return ad.a.Data } + +// Err returns the first error encountered by the decoder. +func (ad *AttributeDecoder) Err() error { return ad.err } + +// Bytes returns the raw bytes of the current Attribute's data. +func (ad *AttributeDecoder) Bytes() []byte { + src := ad.data() + dest := make([]byte, len(src)) + copy(dest, src) + return dest +} + +// String returns the string representation of the current Attribute's data. +func (ad *AttributeDecoder) String() string { + if ad.err != nil { + return "" + } + + return nlenc.String(ad.data()) +} + +// Uint8 returns the uint8 representation of the current Attribute's data. +func (ad *AttributeDecoder) Uint8() uint8 { + if ad.err != nil { + return 0 + } + + b := ad.data() + if len(b) != 1 { + ad.err = fmt.Errorf("netlink: attribute %d is not a uint8; length: %d", ad.Type(), len(b)) + return 0 + } + + return uint8(b[0]) +} + +// Uint16 returns the uint16 representation of the current Attribute's data. +func (ad *AttributeDecoder) Uint16() uint16 { + if ad.err != nil { + return 0 + } + + b := ad.data() + if len(b) != 2 { + ad.err = fmt.Errorf("netlink: attribute %d is not a uint16; length: %d", ad.Type(), len(b)) + return 0 + } + + return ad.ByteOrder.Uint16(b) +} + +// Uint32 returns the uint32 representation of the current Attribute's data. +func (ad *AttributeDecoder) Uint32() uint32 { + if ad.err != nil { + return 0 + } + + b := ad.data() + if len(b) != 4 { + ad.err = fmt.Errorf("netlink: attribute %d is not a uint32; length: %d", ad.Type(), len(b)) + return 0 + } + + return ad.ByteOrder.Uint32(b) +} + +// Uint64 returns the uint64 representation of the current Attribute's data. +func (ad *AttributeDecoder) Uint64() uint64 { + if ad.err != nil { + return 0 + } + + b := ad.data() + if len(b) != 8 { + ad.err = fmt.Errorf("netlink: attribute %d is not a uint64; length: %d", ad.Type(), len(b)) + return 0 + } + + return ad.ByteOrder.Uint64(b) +} + +// Int8 returns the Int8 representation of the current Attribute's data. +func (ad *AttributeDecoder) Int8() int8 { + if ad.err != nil { + return 0 + } + + b := ad.data() + if len(b) != 1 { + ad.err = fmt.Errorf("netlink: attribute %d is not a int8; length: %d", ad.Type(), len(b)) + return 0 + } + + return int8(b[0]) +} + +// Int16 returns the Int16 representation of the current Attribute's data. +func (ad *AttributeDecoder) Int16() int16 { + if ad.err != nil { + return 0 + } + + b := ad.data() + if len(b) != 2 { + ad.err = fmt.Errorf("netlink: attribute %d is not a int16; length: %d", ad.Type(), len(b)) + return 0 + } + + return int16(ad.ByteOrder.Uint16(b)) +} + +// Int32 returns the Int32 representation of the current Attribute's data. +func (ad *AttributeDecoder) Int32() int32 { + if ad.err != nil { + return 0 + } + + b := ad.data() + if len(b) != 4 { + ad.err = fmt.Errorf("netlink: attribute %d is not a int32; length: %d", ad.Type(), len(b)) + return 0 + } + + return int32(ad.ByteOrder.Uint32(b)) +} + +// Int64 returns the Int64 representation of the current Attribute's data. +func (ad *AttributeDecoder) Int64() int64 { + if ad.err != nil { + return 0 + } + + b := ad.data() + if len(b) != 8 { + ad.err = fmt.Errorf("netlink: attribute %d is not a int64; length: %d", ad.Type(), len(b)) + return 0 + } + + return int64(ad.ByteOrder.Uint64(b)) +} + +// Flag returns a boolean representing the Attribute. +func (ad *AttributeDecoder) Flag() bool { + if ad.err != nil { + return false + } + + b := ad.data() + if len(b) != 0 { + ad.err = fmt.Errorf("netlink: attribute %d is not a flag; length: %d", ad.Type(), len(b)) + return false + } + + return true +} + +// Do is a general purpose function which allows access to the current data +// pointed to by the AttributeDecoder. +// +// Do can be used to allow parsing arbitrary data within the context of the +// decoder. Do is most useful when dealing with nested attributes, attribute +// arrays, or decoding arbitrary types (such as C structures) which don't fit +// cleanly into a typical unsigned integer value. +// +// The function fn should not retain any reference to the data b outside of the +// scope of the function. +func (ad *AttributeDecoder) Do(fn func(b []byte) error) { + if ad.err != nil { + return + } + + b := ad.data() + if err := fn(b); err != nil { + ad.err = err + } +} + +// Nested decodes data into a nested AttributeDecoder to handle nested netlink +// attributes. When calling Nested, the Err method does not need to be called on +// the nested AttributeDecoder. +// +// The nested AttributeDecoder nad inherits the same ByteOrder setting as the +// top-level AttributeDecoder ad. +func (ad *AttributeDecoder) Nested(fn func(nad *AttributeDecoder) error) { + // Because we are wrapping Do, there is no need to check ad.err immediately. + ad.Do(func(b []byte) error { + nad, err := NewAttributeDecoder(b) + if err != nil { + return err + } + nad.ByteOrder = ad.ByteOrder + + if err := fn(nad); err != nil { + return err + } + + return nad.Err() + }) +} + +// An AttributeEncoder provides a safe way to encode attributes. +// +// It is recommended to use an AttributeEncoder where possible instead of +// calling MarshalAttributes or using package nlenc directly. +// +// Errors from intermediate encoding steps are returned in the call to +// Encode. +type AttributeEncoder struct { + // ByteOrder defines a specific byte order to use when processing integer + // attributes. ByteOrder should be set immediately after creating the + // AttributeEncoder: before any attributes are encoded. + // + // If not set, the native byte order will be used. + ByteOrder binary.ByteOrder + + attrs []Attribute + err error +} + +// NewAttributeEncoder creates an AttributeEncoder that encodes Attributes. +func NewAttributeEncoder() *AttributeEncoder { + return &AttributeEncoder{ByteOrder: binary.NativeEndian} +} + +// Uint8 encodes uint8 data into an Attribute specified by typ. +func (ae *AttributeEncoder) Uint8(typ uint16, v uint8) { + if ae.err != nil { + return + } + + ae.attrs = append(ae.attrs, Attribute{ + Type: typ, + Data: []byte{v}, + }) +} + +// Uint16 encodes uint16 data into an Attribute specified by typ. +func (ae *AttributeEncoder) Uint16(typ uint16, v uint16) { + if ae.err != nil { + return + } + + b := make([]byte, 2) + ae.ByteOrder.PutUint16(b, v) + + ae.attrs = append(ae.attrs, Attribute{ + Type: typ, + Data: b, + }) +} + +// Uint32 encodes uint32 data into an Attribute specified by typ. +func (ae *AttributeEncoder) Uint32(typ uint16, v uint32) { + if ae.err != nil { + return + } + + b := make([]byte, 4) + ae.ByteOrder.PutUint32(b, v) + + ae.attrs = append(ae.attrs, Attribute{ + Type: typ, + Data: b, + }) +} + +// Uint64 encodes uint64 data into an Attribute specified by typ. +func (ae *AttributeEncoder) Uint64(typ uint16, v uint64) { + if ae.err != nil { + return + } + + b := make([]byte, 8) + ae.ByteOrder.PutUint64(b, v) + + ae.attrs = append(ae.attrs, Attribute{ + Type: typ, + Data: b, + }) +} + +// Int8 encodes int8 data into an Attribute specified by typ. +func (ae *AttributeEncoder) Int8(typ uint16, v int8) { + if ae.err != nil { + return + } + + ae.attrs = append(ae.attrs, Attribute{ + Type: typ, + Data: []byte{uint8(v)}, + }) +} + +// Int16 encodes int16 data into an Attribute specified by typ. +func (ae *AttributeEncoder) Int16(typ uint16, v int16) { + if ae.err != nil { + return + } + + b := make([]byte, 2) + ae.ByteOrder.PutUint16(b, uint16(v)) + + ae.attrs = append(ae.attrs, Attribute{ + Type: typ, + Data: b, + }) +} + +// Int32 encodes int32 data into an Attribute specified by typ. +func (ae *AttributeEncoder) Int32(typ uint16, v int32) { + if ae.err != nil { + return + } + + b := make([]byte, 4) + ae.ByteOrder.PutUint32(b, uint32(v)) + + ae.attrs = append(ae.attrs, Attribute{ + Type: typ, + Data: b, + }) +} + +// Int64 encodes int64 data into an Attribute specified by typ. +func (ae *AttributeEncoder) Int64(typ uint16, v int64) { + if ae.err != nil { + return + } + + b := make([]byte, 8) + ae.ByteOrder.PutUint64(b, uint64(v)) + + ae.attrs = append(ae.attrs, Attribute{ + Type: typ, + Data: b, + }) +} + +// Flag encodes a flag into an Attribute specified by typ. +func (ae *AttributeEncoder) Flag(typ uint16, v bool) { + // Only set flag on no previous error or v == true. + if ae.err != nil || !v { + return + } + + // Flags have no length or data fields. + ae.attrs = append(ae.attrs, Attribute{Type: typ}) +} + +// String encodes string s as a null-terminated string into an Attribute +// specified by typ. +func (ae *AttributeEncoder) String(typ uint16, s string) { + if ae.err != nil { + return + } + + // Length checking, thanks ubiquitousbyte on GitHub. + if len(s) > math.MaxUint16-nlaHeaderLen { + ae.err = errors.New("string is too large to fit in a netlink attribute") + return + } + + ae.attrs = append(ae.attrs, Attribute{ + Type: typ, + Data: nlenc.Bytes(s), + }) +} + +// Bytes embeds raw byte data into an Attribute specified by typ. +func (ae *AttributeEncoder) Bytes(typ uint16, b []byte) { + if ae.err != nil { + return + } + + if len(b) > math.MaxUint16-nlaHeaderLen { + ae.err = errors.New("byte slice is too large to fit in a netlink attribute") + return + } + + ae.attrs = append(ae.attrs, Attribute{ + Type: typ, + Data: b, + }) +} + +// Do is a general purpose function to encode arbitrary data into an attribute +// specified by typ. +// +// Do is especially helpful in encoding nested attributes, attribute arrays, +// or encoding arbitrary types (such as C structures) which don't fit cleanly +// into an unsigned integer value. +func (ae *AttributeEncoder) Do(typ uint16, fn func() ([]byte, error)) { + if ae.err != nil { + return + } + + b, err := fn() + if err != nil { + ae.err = err + return + } + + if len(b) > math.MaxUint16-nlaHeaderLen { + ae.err = errors.New("byte slice produced by Do is too large to fit in a netlink attribute") + return + } + + ae.attrs = append(ae.attrs, Attribute{ + Type: typ, + Data: b, + }) +} + +// Nested embeds data produced by a nested AttributeEncoder and flags that data +// with the Nested flag. When calling Nested, the Encode method should not be +// called on the nested AttributeEncoder. +// +// The nested AttributeEncoder nae inherits the same ByteOrder setting as the +// top-level AttributeEncoder ae. +func (ae *AttributeEncoder) Nested(typ uint16, fn func(nae *AttributeEncoder) error) { + // Because we are wrapping Do, there is no need to check ae.err immediately. + ae.Do(Nested|typ, func() ([]byte, error) { + nae := NewAttributeEncoder() + nae.ByteOrder = ae.ByteOrder + + if err := fn(nae); err != nil { + return nil, err + } + + return nae.Encode() + }) +} + +// Encode returns the encoded bytes representing the attributes. +func (ae *AttributeEncoder) Encode() ([]byte, error) { + if ae.err != nil { + return nil, ae.err + } + + return MarshalAttributes(ae.attrs) +} diff --git a/vendor/github.com/mdlayher/netlink/conn.go b/vendor/github.com/mdlayher/netlink/conn.go new file mode 100644 index 000000000..7138665bc --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/conn.go @@ -0,0 +1,593 @@ +package netlink + +import ( + "math/rand" + "sync" + "sync/atomic" + "syscall" + "time" + + "golang.org/x/net/bpf" +) + +// A Conn is a connection to netlink. A Conn can be used to send and +// receives messages to and from netlink. +// +// A Conn is safe for concurrent use, but to avoid contention in +// high-throughput applications, the caller should almost certainly create a +// pool of Conns and distribute them among workers. +// +// A Conn is capable of manipulating netlink subsystems from within a specific +// Linux network namespace, but special care must be taken when doing so. See +// the documentation of Config for details. +type Conn struct { + // Atomics must come first. + // + // seq is an atomically incremented integer used to provide sequence + // numbers when Conn.Send is called. + seq uint32 + + // mu serializes access to the netlink socket for the request/response + // transaction within Execute. + mu sync.RWMutex + + // sock is the operating system-specific implementation of + // a netlink sockets connection. + sock Socket + + // pid is the PID assigned by netlink. + pid uint32 + + // d provides debugging capabilities for a Conn if not nil. + d *debugger +} + +// A Socket is an operating-system specific implementation of netlink +// sockets used by Conn. +// +// Deprecated: the intent of Socket was to provide an abstraction layer for +// testing, but this abstraction is awkward to use properly and disables much of +// the functionality of the Conn type. Do not use. +type Socket interface { + Close() error + Send(m Message) error + SendMessages(m []Message) error + Receive() ([]Message, error) +} + +// Dial dials a connection to netlink, using the specified netlink family. +// Config specifies optional configuration for Conn. If config is nil, a default +// configuration will be used. +func Dial(family int, config *Config) (*Conn, error) { + // TODO(mdlayher): plumb in netlink.OpError wrapping? + + // Use OS-specific dial() to create Socket. + c, pid, err := dial(family, config) + if err != nil { + return nil, err + } + + return NewConn(c, pid), nil +} + +// NewConn creates a Conn using the specified Socket and PID for netlink +// communications. +// +// NewConn is primarily useful for tests. Most applications should use +// Dial instead. +func NewConn(sock Socket, pid uint32) *Conn { + // Seed the sequence number using a random number generator. + r := rand.New(rand.NewSource(time.Now().UnixNano())) + seq := r.Uint32() + + // Configure a debugger if arguments are set. + var d *debugger + if len(debugArgs) > 0 { + d = newDebugger(debugArgs) + } + + return &Conn{ + seq: seq, + sock: sock, + pid: pid, + d: d, + } +} + +// debug executes fn with the debugger if the debugger is not nil. +func (c *Conn) debug(fn func(d *debugger)) { + if c.d == nil { + return + } + + fn(c.d) +} + +// Close closes the connection and unblocks any pending read operations. +func (c *Conn) Close() error { + // Close does not acquire a lock because it must be able to interrupt any + // blocked system calls, such as when Receive is waiting on a multicast + // group message. + // + // We rely on the kernel to deal with concurrent operations to the netlink + // socket itself. + return newOpError("close", c.sock.Close()) +} + +// Execute sends a single Message to netlink using Send, receives one or more +// replies using Receive, and then checks the validity of the replies against +// the request using Validate. +// +// Execute acquires a lock for the duration of the function call which blocks +// concurrent calls to Send, SendMessages, and Receive, in order to ensure +// consistency between netlink request/reply messages. +// +// See the documentation of Send, Receive, and Validate for details about +// each function. +func (c *Conn) Execute(m Message) ([]Message, error) { + // Acquire the write lock and invoke the internal implementations of Send + // and Receive which require the lock already be held. + c.mu.Lock() + defer c.mu.Unlock() + + req, err := c.lockedSend(m) + if err != nil { + return nil, err + } + + res, err := c.lockedReceive() + if err != nil { + return nil, err + } + + if err := Validate(req, res); err != nil { + return nil, err + } + + return res, nil +} + +// SendMessages sends multiple Messages to netlink. The handling of +// a Header's Length, Sequence and PID fields is the same as when +// calling Send. +func (c *Conn) SendMessages(msgs []Message) ([]Message, error) { + // Wait for any concurrent calls to Execute to finish before proceeding. + c.mu.RLock() + defer c.mu.RUnlock() + + for i := range msgs { + c.fixMsg(&msgs[i], nlmsgLength(len(msgs[i].Data))) + } + + c.debug(func(d *debugger) { + for _, m := range msgs { + d.debugf(1, "send msgs: %+v", m) + } + }) + + if err := c.sock.SendMessages(msgs); err != nil { + c.debug(func(d *debugger) { + d.debugf(1, "send msgs: err: %v", err) + }) + + return nil, newOpError("send-messages", err) + } + + return msgs, nil +} + +// Send sends a single Message to netlink. In most cases, a Header's Length, +// Sequence, and PID fields should be set to 0, so they can be populated +// automatically before the Message is sent. On success, Send returns a copy +// of the Message with all parameters populated, for later validation. +// +// If Header.Length is 0, it will be automatically populated using the +// correct length for the Message, including its payload. +// +// If Header.Sequence is 0, it will be automatically populated using the +// next sequence number for this connection. +// +// If Header.PID is 0, it will be automatically populated using a PID +// assigned by netlink. +func (c *Conn) Send(m Message) (Message, error) { + // Wait for any concurrent calls to Execute to finish before proceeding. + c.mu.RLock() + defer c.mu.RUnlock() + + return c.lockedSend(m) +} + +// lockedSend implements Send, but must be called with c.mu acquired for reading. +// We rely on the kernel to deal with concurrent reads and writes to the netlink +// socket itself. +func (c *Conn) lockedSend(m Message) (Message, error) { + c.fixMsg(&m, nlmsgLength(len(m.Data))) + + c.debug(func(d *debugger) { + d.debugf(1, "send: %+v", m) + }) + + if err := c.sock.Send(m); err != nil { + c.debug(func(d *debugger) { + d.debugf(1, "send: err: %v", err) + }) + + return Message{}, newOpError("send", err) + } + + return m, nil +} + +// Receive receives one or more messages from netlink. Multi-part messages are +// handled transparently and returned as a single slice of Messages, with the +// final empty "multi-part done" message removed. +// +// If any of the messages indicate a netlink error, that error will be returned. +func (c *Conn) Receive() ([]Message, error) { + // Wait for any concurrent calls to Execute to finish before proceeding. + c.mu.RLock() + defer c.mu.RUnlock() + + return c.lockedReceive() +} + +// lockedReceive implements Receive, but must be called with c.mu acquired for reading. +// We rely on the kernel to deal with concurrent reads and writes to the netlink +// socket itself. +func (c *Conn) lockedReceive() ([]Message, error) { + msgs, err := c.receive() + if err != nil { + c.debug(func(d *debugger) { + d.debugf(1, "recv: err: %v", err) + }) + + return nil, err + } + + c.debug(func(d *debugger) { + for _, m := range msgs { + d.debugf(1, "recv: %+v", m) + } + }) + + // When using nltest, it's possible for zero messages to be returned by receive. + if len(msgs) == 0 { + return msgs, nil + } + + // Trim the final message with multi-part done indicator if + // present. + if m := msgs[len(msgs)-1]; m.Header.Flags&Multi != 0 && m.Header.Type == Done { + return msgs[:len(msgs)-1], nil + } + + return msgs, nil +} + +// receive is the internal implementation of Conn.Receive, which can be called +// recursively to handle multi-part messages. +func (c *Conn) receive() ([]Message, error) { + // NB: All non-nil errors returned from this function *must* be of type + // OpError in order to maintain the appropriate contract with callers of + // this package. + // + // This contract also applies to functions called within this function, + // such as checkMessage. + + var res []Message + for { + msgs, err := c.sock.Receive() + if err != nil { + return nil, newOpError("receive", err) + } + + // If this message is multi-part, we will need to continue looping to + // drain all the messages from the socket. + var multi bool + + for _, m := range msgs { + if err := checkMessage(m); err != nil { + return nil, err + } + + // Does this message indicate a multi-part message? + if m.Header.Flags&Multi == 0 { + // No, check the next messages. + continue + } + + // Does this message indicate the last message in a series of + // multi-part messages from a single read? + multi = m.Header.Type != Done + } + + res = append(res, msgs...) + + if !multi { + // No more messages coming. + return res, nil + } + } +} + +// A groupJoinLeaver is a Socket that supports joining and leaving +// netlink multicast groups. +type groupJoinLeaver interface { + Socket + JoinGroup(group uint32) error + LeaveGroup(group uint32) error +} + +// JoinGroup joins a netlink multicast group by its ID. +func (c *Conn) JoinGroup(group uint32) error { + conn, ok := c.sock.(groupJoinLeaver) + if !ok { + return notSupported("join-group") + } + + return newOpError("join-group", conn.JoinGroup(group)) +} + +// LeaveGroup leaves a netlink multicast group by its ID. +func (c *Conn) LeaveGroup(group uint32) error { + conn, ok := c.sock.(groupJoinLeaver) + if !ok { + return notSupported("leave-group") + } + + return newOpError("leave-group", conn.LeaveGroup(group)) +} + +// A bpfSetter is a Socket that supports setting and removing BPF filters. +type bpfSetter interface { + Socket + bpf.Setter + RemoveBPF() error +} + +// SetBPF attaches an assembled BPF program to a Conn. +func (c *Conn) SetBPF(filter []bpf.RawInstruction) error { + conn, ok := c.sock.(bpfSetter) + if !ok { + return notSupported("set-bpf") + } + + return newOpError("set-bpf", conn.SetBPF(filter)) +} + +// RemoveBPF removes a BPF filter from a Conn. +func (c *Conn) RemoveBPF() error { + conn, ok := c.sock.(bpfSetter) + if !ok { + return notSupported("remove-bpf") + } + + return newOpError("remove-bpf", conn.RemoveBPF()) +} + +// A deadlineSetter is a Socket that supports setting deadlines. +type deadlineSetter interface { + Socket + SetDeadline(time.Time) error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error +} + +// SetDeadline sets the read and write deadlines associated with the connection. +func (c *Conn) SetDeadline(t time.Time) error { + conn, ok := c.sock.(deadlineSetter) + if !ok { + return notSupported("set-deadline") + } + + return newOpError("set-deadline", conn.SetDeadline(t)) +} + +// SetReadDeadline sets the read deadline associated with the connection. +func (c *Conn) SetReadDeadline(t time.Time) error { + conn, ok := c.sock.(deadlineSetter) + if !ok { + return notSupported("set-read-deadline") + } + + return newOpError("set-read-deadline", conn.SetReadDeadline(t)) +} + +// SetWriteDeadline sets the write deadline associated with the connection. +func (c *Conn) SetWriteDeadline(t time.Time) error { + conn, ok := c.sock.(deadlineSetter) + if !ok { + return notSupported("set-write-deadline") + } + + return newOpError("set-write-deadline", conn.SetWriteDeadline(t)) +} + +// A ConnOption is a boolean option that may be set for a Conn. +type ConnOption int + +// Possible ConnOption values. These constants are equivalent to the Linux +// setsockopt boolean options for netlink sockets. +const ( + PacketInfo ConnOption = iota + BroadcastError + NoENOBUFS + ListenAllNSID + CapAcknowledge + ExtendedAcknowledge + GetStrictCheck +) + +// An optionSetter is a Socket that supports setting netlink options. +type optionSetter interface { + Socket + SetOption(option ConnOption, enable bool) error +} + +// SetOption enables or disables a netlink socket option for the Conn. +func (c *Conn) SetOption(option ConnOption, enable bool) error { + conn, ok := c.sock.(optionSetter) + if !ok { + return notSupported("set-option") + } + + return newOpError("set-option", conn.SetOption(option, enable)) +} + +// A bufferSetter is a Socket that supports setting connection buffer sizes. +type bufferSetter interface { + Socket + SetReadBuffer(bytes int) error + SetWriteBuffer(bytes int) error +} + +// SetReadBuffer sets the size of the operating system's receive buffer +// associated with the Conn. +func (c *Conn) SetReadBuffer(bytes int) error { + conn, ok := c.sock.(bufferSetter) + if !ok { + return notSupported("set-read-buffer") + } + + return newOpError("set-read-buffer", conn.SetReadBuffer(bytes)) +} + +// SetWriteBuffer sets the size of the operating system's transmit buffer +// associated with the Conn. +func (c *Conn) SetWriteBuffer(bytes int) error { + conn, ok := c.sock.(bufferSetter) + if !ok { + return notSupported("set-write-buffer") + } + + return newOpError("set-write-buffer", conn.SetWriteBuffer(bytes)) +} + +// A syscallConner is a Socket that supports syscall.Conn. +type syscallConner interface { + Socket + syscall.Conn +} + +var _ syscall.Conn = &Conn{} + +// SyscallConn returns a raw network connection. This implements the +// syscall.Conn interface. +// +// SyscallConn is intended for advanced use cases, such as getting and setting +// arbitrary socket options using the netlink socket's file descriptor. +// +// Once invoked, it is the caller's responsibility to ensure that operations +// performed using Conn and the syscall.RawConn do not conflict with +// each other. +func (c *Conn) SyscallConn() (syscall.RawConn, error) { + sc, ok := c.sock.(syscallConner) + if !ok { + return nil, notSupported("syscall-conn") + } + + // TODO(mdlayher): mutex or similar to enforce syscall.RawConn contract of + // FD remaining valid for duration of calls? + + return sc.SyscallConn() +} + +// fixMsg updates the fields of m using the logic specified in Send. +func (c *Conn) fixMsg(m *Message, ml int) { + if m.Header.Length == 0 { + m.Header.Length = uint32(nlmsgAlign(ml)) + } + + if m.Header.Sequence == 0 { + m.Header.Sequence = c.nextSequence() + } + + if m.Header.PID == 0 { + m.Header.PID = c.pid + } +} + +// nextSequence atomically increments Conn's sequence number and returns +// the incremented value. +func (c *Conn) nextSequence() uint32 { + return atomic.AddUint32(&c.seq, 1) +} + +// Validate validates one or more reply Messages against a request Message, +// ensuring that they contain matching sequence numbers and PIDs. +func Validate(request Message, replies []Message) error { + for _, m := range replies { + // Check for mismatched sequence, unless: + // - request had no sequence, meaning we are probably validating + // a multicast reply + if m.Header.Sequence != request.Header.Sequence && request.Header.Sequence != 0 { + return newOpError("validate", errMismatchedSequence) + } + + // Check for mismatched PID, unless: + // - request had no PID, meaning we are either: + // - validating a multicast reply + // - netlink has not yet assigned us a PID + // - response had no PID, meaning it's from the kernel as a multicast reply + if m.Header.PID != request.Header.PID && request.Header.PID != 0 && m.Header.PID != 0 { + return newOpError("validate", errMismatchedPID) + } + } + + return nil +} + +// Config contains options for a Conn. +type Config struct { + // Groups is a bitmask which specifies multicast groups. If set to 0, + // no multicast group subscriptions will be made. + Groups uint32 + + // NetNS specifies the network namespace the Conn will operate in. + // + // If set (non-zero), Conn will enter the specified network namespace and + // an error will occur in Dial if the operation fails. + // + // If not set (zero), a best-effort attempt will be made to enter the + // network namespace of the calling thread: this means that any changes made + // to the calling thread's network namespace will also be reflected in Conn. + // If this operation fails (due to lack of permissions or because network + // namespaces are disabled by kernel configuration), Dial will not return + // an error, and the Conn will operate in the default network namespace of + // the process. This enables non-privileged use of Conn in applications + // which do not require elevated privileges. + // + // Entering a network namespace is a privileged operation (root or + // CAP_SYS_ADMIN are required), and most applications should leave this set + // to 0. + NetNS int + + // DisableNSLockThread is a no-op. + // + // Deprecated: internal changes have made this option obsolete and it has no + // effect. Do not use. + DisableNSLockThread bool + + // PID specifies the port ID used to bind the netlink socket. If set to 0, + // the kernel will assign a port ID on the caller's behalf. + // + // Most callers should leave this field set to 0. This option is intended + // for advanced use cases where the kernel expects a fixed unicast address + // destination for netlink messages. + PID uint32 + + // Strict applies a more strict default set of options to the Conn, + // including: + // - ExtendedAcknowledge: true + // - provides more useful error messages when supported by the kernel + // - GetStrictCheck: true + // - more strictly enforces request validation for some families such + // as rtnetlink which were historically misused + // + // If any of the options specified by Strict cannot be configured due to an + // outdated kernel or similar, an error will be returned. + // + // When possible, setting Strict to true is recommended for applications + // running on modern Linux kernels. + Strict bool +} diff --git a/vendor/github.com/mdlayher/netlink/conn_linux.go b/vendor/github.com/mdlayher/netlink/conn_linux.go new file mode 100644 index 000000000..4af18c99a --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/conn_linux.go @@ -0,0 +1,251 @@ +//go:build linux +// +build linux + +package netlink + +import ( + "context" + "os" + "syscall" + "time" + "unsafe" + + "github.com/mdlayher/socket" + "golang.org/x/net/bpf" + "golang.org/x/sys/unix" +) + +var _ Socket = &conn{} + +// A conn is the Linux implementation of a netlink sockets connection. +type conn struct { + s *socket.Conn +} + +// dial is the entry point for Dial. dial opens a netlink socket using +// system calls, and returns its PID. +func dial(family int, config *Config) (*conn, uint32, error) { + if config == nil { + config = &Config{} + } + + // Prepare the netlink socket. + s, err := socket.Socket( + unix.AF_NETLINK, + unix.SOCK_RAW, + family, + "netlink", + &socket.Config{NetNS: config.NetNS}, + ) + if err != nil { + return nil, 0, err + } + + return newConn(s, config) +} + +// newConn binds a connection to netlink using the input *socket.Conn. +func newConn(s *socket.Conn, config *Config) (*conn, uint32, error) { + if config == nil { + config = &Config{} + } + + addr := &unix.SockaddrNetlink{ + Family: unix.AF_NETLINK, + Groups: config.Groups, + Pid: config.PID, + } + + // Socket must be closed in the event of any system call errors, to avoid + // leaking file descriptors. + + if err := s.Bind(addr); err != nil { + _ = s.Close() + return nil, 0, err + } + + sa, err := s.Getsockname() + if err != nil { + _ = s.Close() + return nil, 0, err + } + + c := &conn{s: s} + if config.Strict { + // The caller has requested the strict option set. Historically we have + // recommended checking for ENOPROTOOPT if the kernel does not support + // the option in question, but that may result in a silent failure and + // unexpected behavior for the user. + // + // Treat any error here as a fatal error, and require the caller to deal + // with it. + for _, o := range []ConnOption{ExtendedAcknowledge, GetStrictCheck} { + if err := c.SetOption(o, true); err != nil { + _ = c.Close() + return nil, 0, err + } + } + } + + return c, sa.(*unix.SockaddrNetlink).Pid, nil +} + +// SendMessages serializes multiple Messages and sends them to netlink. +func (c *conn) SendMessages(messages []Message) error { + var buf []byte + for _, m := range messages { + b, err := m.MarshalBinary() + if err != nil { + return err + } + + buf = append(buf, b...) + } + + sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK} + _, err := c.s.Sendmsg(context.Background(), buf, nil, sa, 0) + return err +} + +// Send sends a single Message to netlink. +func (c *conn) Send(m Message) error { + b, err := m.MarshalBinary() + if err != nil { + return err + } + + sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK} + _, err = c.s.Sendmsg(context.Background(), b, nil, sa, 0) + return err +} + +// Receive receives one or more Messages from netlink. +func (c *conn) Receive() ([]Message, error) { + b := make([]byte, os.Getpagesize()) + for { + // Peek at the buffer to see how many bytes are available. + // + // TODO(mdlayher): deal with OOB message data if available, such as + // when PacketInfo ConnOption is true. + n, _, _, _, err := c.s.Recvmsg(context.Background(), b, nil, unix.MSG_PEEK) + if err != nil { + return nil, err + } + + // Break when we can read all messages + if n < len(b) { + break + } + + // Double in size if not enough bytes + b = make([]byte, len(b)*2) + } + + // Read out all available messages + n, _, _, _, err := c.s.Recvmsg(context.Background(), b, nil, 0) + if err != nil { + return nil, err + } + + raw, err := syscall.ParseNetlinkMessage(b[:nlmsgAlign(n)]) + if err != nil { + return nil, err + } + + msgs := make([]Message, 0, len(raw)) + for _, r := range raw { + m := Message{ + Header: sysToHeader(r.Header), + Data: r.Data, + } + + msgs = append(msgs, m) + } + + return msgs, nil +} + +// Close closes the connection. +func (c *conn) Close() error { return c.s.Close() } + +// JoinGroup joins a multicast group by ID. +func (c *conn) JoinGroup(group uint32) error { + return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_ADD_MEMBERSHIP, int(group)) +} + +// LeaveGroup leaves a multicast group by ID. +func (c *conn) LeaveGroup(group uint32) error { + return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_DROP_MEMBERSHIP, int(group)) +} + +// SetBPF attaches an assembled BPF program to a conn. +func (c *conn) SetBPF(filter []bpf.RawInstruction) error { return c.s.SetBPF(filter) } + +// RemoveBPF removes a BPF filter from a conn. +func (c *conn) RemoveBPF() error { return c.s.RemoveBPF() } + +// SetOption enables or disables a netlink socket option for the Conn. +func (c *conn) SetOption(option ConnOption, enable bool) error { + o, ok := linuxOption(option) + if !ok { + // Return the typical Linux error for an unknown ConnOption. + return os.NewSyscallError("setsockopt", unix.ENOPROTOOPT) + } + + var v int + if enable { + v = 1 + } + + return c.s.SetsockoptInt(unix.SOL_NETLINK, o, v) +} + +func (c *conn) SetDeadline(t time.Time) error { return c.s.SetDeadline(t) } +func (c *conn) SetReadDeadline(t time.Time) error { return c.s.SetReadDeadline(t) } +func (c *conn) SetWriteDeadline(t time.Time) error { return c.s.SetWriteDeadline(t) } + +// SetReadBuffer sets the size of the operating system's receive buffer +// associated with the Conn. +func (c *conn) SetReadBuffer(bytes int) error { return c.s.SetReadBuffer(bytes) } + +// SetReadBuffer sets the size of the operating system's transmit buffer +// associated with the Conn. +func (c *conn) SetWriteBuffer(bytes int) error { return c.s.SetWriteBuffer(bytes) } + +// SyscallConn returns a raw network connection. +func (c *conn) SyscallConn() (syscall.RawConn, error) { return c.s.SyscallConn() } + +// linuxOption converts a ConnOption to its Linux value. +func linuxOption(o ConnOption) (int, bool) { + switch o { + case PacketInfo: + return unix.NETLINK_PKTINFO, true + case BroadcastError: + return unix.NETLINK_BROADCAST_ERROR, true + case NoENOBUFS: + return unix.NETLINK_NO_ENOBUFS, true + case ListenAllNSID: + return unix.NETLINK_LISTEN_ALL_NSID, true + case CapAcknowledge: + return unix.NETLINK_CAP_ACK, true + case ExtendedAcknowledge: + return unix.NETLINK_EXT_ACK, true + case GetStrictCheck: + return unix.NETLINK_GET_STRICT_CHK, true + default: + return 0, false + } +} + +// sysToHeader converts a syscall.NlMsghdr to a Header. +func sysToHeader(r syscall.NlMsghdr) Header { + // NB: the memory layout of Header and syscall.NlMsgHdr must be + // exactly the same for this unsafe cast to work + return *(*Header)(unsafe.Pointer(&r)) +} + +// newError converts an error number from netlink into the appropriate +// system call error for Linux. +func newError(errno int) error { + return syscall.Errno(errno) +} diff --git a/vendor/github.com/mdlayher/netlink/conn_others.go b/vendor/github.com/mdlayher/netlink/conn_others.go new file mode 100644 index 000000000..4c5e739b9 --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/conn_others.go @@ -0,0 +1,30 @@ +//go:build !linux +// +build !linux + +package netlink + +import ( + "fmt" + "runtime" +) + +// errUnimplemented is returned by all functions on platforms that +// cannot make use of netlink sockets. +var errUnimplemented = fmt.Errorf("netlink: not implemented on %s/%s", + runtime.GOOS, runtime.GOARCH) + +var _ Socket = &conn{} + +// A conn is the no-op implementation of a netlink sockets connection. +type conn struct{} + +// All cross-platform functions and Socket methods are unimplemented outside +// of Linux. + +func dial(_ int, _ *Config) (*conn, uint32, error) { return nil, 0, errUnimplemented } +func newError(_ int) error { return errUnimplemented } + +func (c *conn) Send(_ Message) error { return errUnimplemented } +func (c *conn) SendMessages(_ []Message) error { return errUnimplemented } +func (c *conn) Receive() ([]Message, error) { return nil, errUnimplemented } +func (c *conn) Close() error { return errUnimplemented } diff --git a/vendor/github.com/mdlayher/netlink/debug.go b/vendor/github.com/mdlayher/netlink/debug.go new file mode 100644 index 000000000..d39d66c58 --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/debug.go @@ -0,0 +1,69 @@ +package netlink + +import ( + "fmt" + "log" + "os" + "strconv" + "strings" +) + +// Arguments used to create a debugger. +var debugArgs []string + +func init() { + // Is netlink debugging enabled? + s := os.Getenv("NLDEBUG") + if s == "" { + return + } + + debugArgs = strings.Split(s, ",") +} + +// A debugger is used to provide debugging information about a netlink connection. +type debugger struct { + Log *log.Logger + Level int +} + +// newDebugger creates a debugger by parsing key=value arguments. +func newDebugger(args []string) *debugger { + d := &debugger{ + Log: log.New(os.Stderr, "nl: ", 0), + Level: 1, + } + + for _, a := range args { + kv := strings.Split(a, "=") + if len(kv) != 2 { + // Ignore malformed pairs and assume callers wants defaults. + continue + } + + switch kv[0] { + // Select the log level for the debugger. + case "level": + level, err := strconv.Atoi(kv[1]) + if err != nil { + panicf("netlink: invalid NLDEBUG level: %q", a) + } + + d.Level = level + } + } + + return d +} + +// debugf prints debugging information at the specified level, if d.Level is +// high enough to print the message. +func (d *debugger) debugf(level int, format string, v ...interface{}) { + if d.Level >= level { + d.Log.Printf(format, v...) + } +} + +func panicf(format string, a ...interface{}) { + panic(fmt.Sprintf(format, a...)) +} diff --git a/vendor/github.com/mdlayher/netlink/doc.go b/vendor/github.com/mdlayher/netlink/doc.go new file mode 100644 index 000000000..98c744a5d --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/doc.go @@ -0,0 +1,33 @@ +// Package netlink provides low-level access to Linux netlink sockets +// (AF_NETLINK). +// +// If you have any questions or you'd like some guidance, please join us on +// Gophers Slack (https://invite.slack.golangbridge.org) in the #networking +// channel! +// +// # Network namespaces +// +// This package is aware of Linux network namespaces, and can enter different +// network namespaces either implicitly or explicitly, depending on +// configuration. The Config structure passed to Dial to create a Conn controls +// these behaviors. See the documentation of Config.NetNS for details. +// +// # Debugging +// +// This package supports rudimentary netlink connection debugging support. To +// enable this, run your binary with the NLDEBUG environment variable set. +// Debugging information will be output to stderr with a prefix of "nl:". +// +// To use the debugging defaults, use: +// +// $ NLDEBUG=1 ./nlctl +// +// To configure individual aspects of the debugger, pass key/value options such +// as: +// +// $ NLDEBUG=level=1 ./nlctl +// +// Available key/value debugger options include: +// +// level=N: specify the debugging level (only "1" is currently supported) +package netlink diff --git a/vendor/github.com/mdlayher/netlink/errors.go b/vendor/github.com/mdlayher/netlink/errors.go new file mode 100644 index 000000000..8c0fce7e4 --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/errors.go @@ -0,0 +1,138 @@ +package netlink + +import ( + "errors" + "fmt" + "net" + "os" + "strings" +) + +// Error messages which can be returned by Validate. +var ( + errMismatchedSequence = errors.New("mismatched sequence in netlink reply") + errMismatchedPID = errors.New("mismatched PID in netlink reply") + errShortErrorMessage = errors.New("not enough data for netlink error code") +) + +// Errors which can be returned by a Socket that does not implement +// all exposed methods of Conn. + +var errNotSupported = errors.New("operation not supported") + +// notSupported provides a concise constructor for "not supported" errors. +func notSupported(op string) error { + return newOpError(op, errNotSupported) +} + +// IsNotExist determines if an error is produced as the result of querying some +// file, object, resource, etc. which does not exist. +// +// Deprecated: use errors.Unwrap and/or `errors.Is(err, os.Permission)` in Go +// 1.13+. +func IsNotExist(err error) bool { + switch err := err.(type) { + case *OpError: + // Unwrap the inner error and use the stdlib's logic. + return os.IsNotExist(err.Err) + default: + return os.IsNotExist(err) + } +} + +var ( + _ error = &OpError{} + _ net.Error = &OpError{} + // Ensure compatibility with Go 1.13+ errors package. + _ interface{ Unwrap() error } = &OpError{} +) + +// An OpError is an error produced as the result of a failed netlink operation. +type OpError struct { + // Op is the operation which caused this OpError, such as "send" + // or "receive". + Op string + + // Err is the underlying error which caused this OpError. + // + // If Err was produced by a system call error, Err will be of type + // *os.SyscallError. If Err was produced by an error code in a netlink + // message, Err will contain a raw error value type such as a unix.Errno. + // + // Most callers should inspect Err using errors.Is from the standard + // library. + Err error + + // Message and Offset contain additional error information provided by the + // kernel when the ExtendedAcknowledge option is set on a Conn and the + // kernel indicates the AcknowledgeTLVs flag in a response. If this option + // is not set, both of these fields will be empty. + Message string + Offset int +} + +// newOpError is a small wrapper for creating an OpError. As a convenience, it +// returns nil if the input err is nil: akin to os.NewSyscallError. +func newOpError(op string, err error) error { + if err == nil { + return nil + } + + return &OpError{ + Op: op, + Err: err, + } +} + +func (e *OpError) Error() string { + if e == nil { + return "" + } + + var sb strings.Builder + _, _ = sb.WriteString(fmt.Sprintf("netlink %s: %v", e.Op, e.Err)) + + if e.Message != "" || e.Offset != 0 { + _, _ = sb.WriteString(fmt.Sprintf(", offset: %d, message: %q", + e.Offset, e.Message)) + } + + return sb.String() +} + +// Unwrap unwraps the internal Err field for use with errors.Unwrap. +func (e *OpError) Unwrap() error { return e.Err } + +// Portions of this code taken from the Go standard library: +// +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +type timeout interface { + Timeout() bool +} + +// Timeout reports whether the error was caused by an I/O timeout. +func (e *OpError) Timeout() bool { + if ne, ok := e.Err.(*os.SyscallError); ok { + t, ok := ne.Err.(timeout) + return ok && t.Timeout() + } + t, ok := e.Err.(timeout) + return ok && t.Timeout() +} + +type temporary interface { + Temporary() bool +} + +// Temporary reports whether an operation may succeed if retried. +func (e *OpError) Temporary() bool { + if ne, ok := e.Err.(*os.SyscallError); ok { + t, ok := ne.Err.(temporary) + return ok && t.Temporary() + } + t, ok := e.Err.(temporary) + return ok && t.Temporary() +} diff --git a/vendor/github.com/mdlayher/netlink/fuzz.go b/vendor/github.com/mdlayher/netlink/fuzz.go new file mode 100644 index 000000000..fdd6b6498 --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/fuzz.go @@ -0,0 +1,82 @@ +//go:build gofuzz +// +build gofuzz + +package netlink + +import "github.com/google/go-cmp/cmp" + +func fuzz(b1 []byte) int { + // 1. unmarshal, marshal, unmarshal again to check m1 and m2 for equality + // after a round trip. checkMessage is also used because there is a fair + // amount of tricky logic around testing for presence of error headers and + // extended acknowledgement attributes. + var m1 Message + if err := m1.UnmarshalBinary(b1); err != nil { + return 0 + } + + if err := checkMessage(m1); err != nil { + return 0 + } + + b2, err := m1.MarshalBinary() + if err != nil { + panicf("failed to marshal m1: %v", err) + } + + var m2 Message + if err := m2.UnmarshalBinary(b2); err != nil { + panicf("failed to unmarshal m2: %v", err) + } + + if err := checkMessage(m2); err != nil { + panicf("failed to check m2: %v", err) + } + + if diff := cmp.Diff(m1, m2); diff != "" { + panicf("unexpected Message (-want +got):\n%s", diff) + } + + // 2. marshal again and compare b2 and b3 (b1 may have reserved bytes set + // which we ignore and fill with zeros when marshaling) for equality. + b3, err := m2.MarshalBinary() + if err != nil { + panicf("failed to marshal m2: %v", err) + } + + if diff := cmp.Diff(b2, b3); diff != "" { + panicf("unexpected message bytes (-want +got):\n%s", diff) + } + + // 3. unmarshal any possible attributes from m1's data and marshal them + // again for comparison. + a1, err := UnmarshalAttributes(m1.Data) + if err != nil { + return 0 + } + + ab1, err := MarshalAttributes(a1) + if err != nil { + panicf("failed to marshal a1: %v", err) + } + + a2, err := UnmarshalAttributes(ab1) + if err != nil { + panicf("failed to unmarshal a2: %v", err) + } + + if diff := cmp.Diff(a1, a2); diff != "" { + panicf("unexpected Attributes (-want +got):\n%s", diff) + } + + ab2, err := MarshalAttributes(a2) + if err != nil { + panicf("failed to marshal a2: %v", err) + } + + if diff := cmp.Diff(ab1, ab2); diff != "" { + panicf("unexpected attribute bytes (-want +got):\n%s", diff) + } + + return 1 +} diff --git a/vendor/github.com/mdlayher/netlink/message.go b/vendor/github.com/mdlayher/netlink/message.go new file mode 100644 index 000000000..57277165a --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/message.go @@ -0,0 +1,347 @@ +package netlink + +import ( + "errors" + "fmt" + "unsafe" + + "github.com/mdlayher/netlink/nlenc" +) + +// Flags which may apply to netlink attribute types when communicating with +// certain netlink families. +const ( + Nested uint16 = 0x8000 + NetByteOrder uint16 = 0x4000 + + // attrTypeMask masks off Type bits used for the above flags. + attrTypeMask uint16 = 0x3fff +) + +// Various errors which may occur when attempting to marshal or unmarshal +// a Message to and from its binary form. +var ( + errIncorrectMessageLength = errors.New("netlink message header length incorrect") + errShortMessage = errors.New("not enough data to create a netlink message") + errUnalignedMessage = errors.New("input data is not properly aligned for netlink message") +) + +// HeaderFlags specify flags which may be present in a Header. +type HeaderFlags uint16 + +const ( + // General netlink communication flags. + + // Request indicates a request to netlink. + Request HeaderFlags = 1 + + // Multi indicates a multi-part message, terminated by Done on the + // last message. + Multi HeaderFlags = 2 + + // Acknowledge requests that netlink reply with an acknowledgement + // using Error and, if needed, an error code. + Acknowledge HeaderFlags = 4 + + // Echo requests that netlink echo this request back to the sender. + Echo HeaderFlags = 8 + + // DumpInterrupted indicates that a dump was inconsistent due to a + // sequence change. + DumpInterrupted HeaderFlags = 16 + + // DumpFiltered indicates that a dump was filtered as requested. + DumpFiltered HeaderFlags = 32 + + // Flags used to retrieve data from netlink. + + // Root requests that netlink return a complete table instead of a + // single entry. + Root HeaderFlags = 0x100 + + // Match requests that netlink return a list of all matching entries. + Match HeaderFlags = 0x200 + + // Atomic requests that netlink send an atomic snapshot of its entries. + // Requires CAP_NET_ADMIN or an effective UID of 0. + Atomic HeaderFlags = 0x400 + + // Dump requests that netlink return a complete list of all entries. + Dump HeaderFlags = Root | Match + + // Flags used to create objects. + + // Replace indicates request replaces an existing matching object. + Replace HeaderFlags = 0x100 + + // Excl indicates request does not replace the object if it already exists. + Excl HeaderFlags = 0x200 + + // Create indicates request creates an object if it doesn't already exist. + Create HeaderFlags = 0x400 + + // Append indicates request adds to the end of the object list. + Append HeaderFlags = 0x800 + + // Flags for extended acknowledgements. + + // Capped indicates the size of a request was capped in an extended + // acknowledgement. + Capped HeaderFlags = 0x100 + + // AcknowledgeTLVs indicates the presence of netlink extended + // acknowledgement TLVs in a response. + AcknowledgeTLVs HeaderFlags = 0x200 +) + +// String returns the string representation of a HeaderFlags. +func (f HeaderFlags) String() string { + names := []string{ + "request", + "multi", + "acknowledge", + "echo", + "dumpinterrupted", + "dumpfiltered", + } + + var s string + + left := uint(f) + + for i, name := range names { + if f&(1< 0 { + if s != "" { + s += "|" + } + s += fmt.Sprintf("%#x", left) + } + + return s +} + +// HeaderType specifies the type of a Header. +type HeaderType uint16 + +const ( + // Noop indicates that no action was taken. + Noop HeaderType = 0x1 + + // Error indicates an error code is present, which is also used to indicate + // success when the code is 0. + Error HeaderType = 0x2 + + // Done indicates the end of a multi-part message. + Done HeaderType = 0x3 + + // Overrun indicates that data was lost from this message. + Overrun HeaderType = 0x4 +) + +// String returns the string representation of a HeaderType. +func (t HeaderType) String() string { + switch t { + case Noop: + return "noop" + case Error: + return "error" + case Done: + return "done" + case Overrun: + return "overrun" + default: + return fmt.Sprintf("unknown(%d)", t) + } +} + +// NB: the memory layout of Header and Linux's syscall.NlMsgHdr must be +// exactly the same. Cannot reorder, change data type, add, or remove fields. +// Named types of the same size (e.g. HeaderFlags is a uint16) are okay. + +// A Header is a netlink header. A Header is sent and received with each +// Message to indicate metadata regarding a Message. +type Header struct { + // Length of a Message, including this Header. + Length uint32 + + // Contents of a Message. + Type HeaderType + + // Flags which may be used to modify a request or response. + Flags HeaderFlags + + // The sequence number of a Message. + Sequence uint32 + + // The port ID of the sending process. + PID uint32 +} + +// A Message is a netlink message. It contains a Header and an arbitrary +// byte payload, which may be decoded using information from the Header. +// +// Data is often populated with netlink attributes. For easy encoding and +// decoding of attributes, see the AttributeDecoder and AttributeEncoder types. +type Message struct { + Header Header + Data []byte +} + +// MarshalBinary marshals a Message into a byte slice. +func (m Message) MarshalBinary() ([]byte, error) { + ml := nlmsgAlign(int(m.Header.Length)) + if ml < nlmsgHeaderLen || ml != int(m.Header.Length) { + return nil, errIncorrectMessageLength + } + + b := make([]byte, ml) + + nlenc.PutUint32(b[0:4], m.Header.Length) + nlenc.PutUint16(b[4:6], uint16(m.Header.Type)) + nlenc.PutUint16(b[6:8], uint16(m.Header.Flags)) + nlenc.PutUint32(b[8:12], m.Header.Sequence) + nlenc.PutUint32(b[12:16], m.Header.PID) + copy(b[16:], m.Data) + + return b, nil +} + +// UnmarshalBinary unmarshals the contents of a byte slice into a Message. +func (m *Message) UnmarshalBinary(b []byte) error { + if len(b) < nlmsgHeaderLen { + return errShortMessage + } + if len(b) != nlmsgAlign(len(b)) { + return errUnalignedMessage + } + + // Don't allow misleading length + m.Header.Length = nlenc.Uint32(b[0:4]) + if int(m.Header.Length) != len(b) { + return errShortMessage + } + + m.Header.Type = HeaderType(nlenc.Uint16(b[4:6])) + m.Header.Flags = HeaderFlags(nlenc.Uint16(b[6:8])) + m.Header.Sequence = nlenc.Uint32(b[8:12]) + m.Header.PID = nlenc.Uint32(b[12:16]) + m.Data = b[16:] + + return nil +} + +// checkMessage checks a single Message for netlink errors. +func checkMessage(m Message) error { + // NB: All non-nil errors returned from this function *must* be of type + // OpError in order to maintain the appropriate contract with callers of + // this package. + + // The libnl documentation indicates that type error can + // contain error codes: + // https://www.infradead.org/~tgr/libnl/doc/core.html#core_errmsg. + // + // However, rtnetlink at least seems to also allow errors to occur at the + // end of a multipart message with done/multi and an error number. + var hasHeader bool + switch { + case m.Header.Type == Error: + // Error code followed by nlmsghdr/ext ack attributes. + hasHeader = true + case m.Header.Type == Done && m.Header.Flags&Multi != 0: + // If no data, there must be no error number so just exit early. Some + // of the unit tests hard-coded this but I don't actually know if this + // case occurs in the wild. + if len(m.Data) == 0 { + return nil + } + + // Done|Multi potentially followed by ext ack attributes. + default: + // Neither, nothing to do. + return nil + } + + // Errno occupies 4 bytes. + const endErrno = 4 + if len(m.Data) < endErrno { + return newOpError("receive", errShortErrorMessage) + } + + c := nlenc.Int32(m.Data[:endErrno]) + if c == 0 { + // 0 indicates no error. + return nil + } + + oerr := &OpError{ + Op: "receive", + // Error code is a negative integer, convert it into an OS-specific raw + // system call error, but do not wrap with os.NewSyscallError to signify + // that this error was produced by a netlink message; not a system call. + Err: newError(-1 * int(c)), + } + + // TODO(mdlayher): investigate the Capped flag. + + if m.Header.Flags&AcknowledgeTLVs == 0 { + // No extended acknowledgement. + return oerr + } + + // Flags indicate an extended acknowledgement. The type/flags combination + // checked above determines the offset where the TLVs occur. + var off int + if hasHeader { + // There is an nlmsghdr preceding the TLVs. + if len(m.Data) < endErrno+nlmsgHeaderLen { + return newOpError("receive", errShortErrorMessage) + } + + // The TLVs should be at the offset indicated by the nlmsghdr.length, + // plus the offset where the header began. But make sure the calculated + // offset is still in-bounds. + h := *(*Header)(unsafe.Pointer(&m.Data[endErrno : endErrno+nlmsgHeaderLen][0])) + off = endErrno + int(h.Length) + + if len(m.Data) < off { + return newOpError("receive", errShortErrorMessage) + } + } else { + // There is no nlmsghdr preceding the TLVs, parse them directly. + off = endErrno + } + + ad, err := NewAttributeDecoder(m.Data[off:]) + if err != nil { + // Malformed TLVs, just return the OpError with the info we have. + return oerr + } + + for ad.Next() { + switch ad.Type() { + case 1: // unix.NLMSGERR_ATTR_MSG + oerr.Message = ad.String() + case 2: // unix.NLMSGERR_ATTR_OFFS + oerr.Offset = int(ad.Uint32()) + } + } + + // Explicitly ignore ad.Err: malformed TLVs, just return the OpError with + // the info we have. + return oerr +} diff --git a/vendor/github.com/mdlayher/netlink/nlenc/doc.go b/vendor/github.com/mdlayher/netlink/nlenc/doc.go new file mode 100644 index 000000000..990d12e64 --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/nlenc/doc.go @@ -0,0 +1,13 @@ +// Package nlenc implements encoding and decoding functions for netlink +// messages and attributes. +package nlenc + +import ( + "encoding/binary" +) + +// NativeEndian returns the native byte order of this system. +func NativeEndian() binary.ByteOrder { + // TODO(mdlayher): consider deprecating and removing this function for v2. + return binary.NativeEndian +} diff --git a/vendor/github.com/mdlayher/netlink/nlenc/int.go b/vendor/github.com/mdlayher/netlink/nlenc/int.go new file mode 100644 index 000000000..d56b018de --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/nlenc/int.go @@ -0,0 +1,150 @@ +package nlenc + +import ( + "fmt" + "unsafe" +) + +// PutUint8 encodes a uint8 into b. +// If b is not exactly 1 byte in length, PutUint8 will panic. +func PutUint8(b []byte, v uint8) { + if l := len(b); l != 1 { + panic(fmt.Sprintf("PutUint8: unexpected byte slice length: %d", l)) + } + + b[0] = v +} + +// PutUint16 encodes a uint16 into b using the host machine's native endianness. +// If b is not exactly 2 bytes in length, PutUint16 will panic. +func PutUint16(b []byte, v uint16) { + if l := len(b); l != 2 { + panic(fmt.Sprintf("PutUint16: unexpected byte slice length: %d", l)) + } + + *(*uint16)(unsafe.Pointer(&b[0])) = v +} + +// PutUint32 encodes a uint32 into b using the host machine's native endianness. +// If b is not exactly 4 bytes in length, PutUint32 will panic. +func PutUint32(b []byte, v uint32) { + if l := len(b); l != 4 { + panic(fmt.Sprintf("PutUint32: unexpected byte slice length: %d", l)) + } + + *(*uint32)(unsafe.Pointer(&b[0])) = v +} + +// PutUint64 encodes a uint64 into b using the host machine's native endianness. +// If b is not exactly 8 bytes in length, PutUint64 will panic. +func PutUint64(b []byte, v uint64) { + if l := len(b); l != 8 { + panic(fmt.Sprintf("PutUint64: unexpected byte slice length: %d", l)) + } + + *(*uint64)(unsafe.Pointer(&b[0])) = v +} + +// PutInt32 encodes a int32 into b using the host machine's native endianness. +// If b is not exactly 4 bytes in length, PutInt32 will panic. +func PutInt32(b []byte, v int32) { + if l := len(b); l != 4 { + panic(fmt.Sprintf("PutInt32: unexpected byte slice length: %d", l)) + } + + *(*int32)(unsafe.Pointer(&b[0])) = v +} + +// Uint8 decodes a uint8 from b. +// If b is not exactly 1 byte in length, Uint8 will panic. +func Uint8(b []byte) uint8 { + if l := len(b); l != 1 { + panic(fmt.Sprintf("Uint8: unexpected byte slice length: %d", l)) + } + + return b[0] +} + +// Uint16 decodes a uint16 from b using the host machine's native endianness. +// If b is not exactly 2 bytes in length, Uint16 will panic. +func Uint16(b []byte) uint16 { + if l := len(b); l != 2 { + panic(fmt.Sprintf("Uint16: unexpected byte slice length: %d", l)) + } + + return *(*uint16)(unsafe.Pointer(&b[0])) +} + +// Uint32 decodes a uint32 from b using the host machine's native endianness. +// If b is not exactly 4 bytes in length, Uint32 will panic. +func Uint32(b []byte) uint32 { + if l := len(b); l != 4 { + panic(fmt.Sprintf("Uint32: unexpected byte slice length: %d", l)) + } + + return *(*uint32)(unsafe.Pointer(&b[0])) +} + +// Uint64 decodes a uint64 from b using the host machine's native endianness. +// If b is not exactly 8 bytes in length, Uint64 will panic. +func Uint64(b []byte) uint64 { + if l := len(b); l != 8 { + panic(fmt.Sprintf("Uint64: unexpected byte slice length: %d", l)) + } + + return *(*uint64)(unsafe.Pointer(&b[0])) +} + +// Int32 decodes an int32 from b using the host machine's native endianness. +// If b is not exactly 4 bytes in length, Int32 will panic. +func Int32(b []byte) int32 { + if l := len(b); l != 4 { + panic(fmt.Sprintf("Int32: unexpected byte slice length: %d", l)) + } + + return *(*int32)(unsafe.Pointer(&b[0])) +} + +// Uint8Bytes encodes a uint8 into a newly-allocated byte slice. It is a +// shortcut for allocating a new byte slice and filling it using PutUint8. +func Uint8Bytes(v uint8) []byte { + b := make([]byte, 1) + PutUint8(b, v) + return b +} + +// Uint16Bytes encodes a uint16 into a newly-allocated byte slice using the +// host machine's native endianness. It is a shortcut for allocating a new +// byte slice and filling it using PutUint16. +func Uint16Bytes(v uint16) []byte { + b := make([]byte, 2) + PutUint16(b, v) + return b +} + +// Uint32Bytes encodes a uint32 into a newly-allocated byte slice using the +// host machine's native endianness. It is a shortcut for allocating a new +// byte slice and filling it using PutUint32. +func Uint32Bytes(v uint32) []byte { + b := make([]byte, 4) + PutUint32(b, v) + return b +} + +// Uint64Bytes encodes a uint64 into a newly-allocated byte slice using the +// host machine's native endianness. It is a shortcut for allocating a new +// byte slice and filling it using PutUint64. +func Uint64Bytes(v uint64) []byte { + b := make([]byte, 8) + PutUint64(b, v) + return b +} + +// Int32Bytes encodes a int32 into a newly-allocated byte slice using the +// host machine's native endianness. It is a shortcut for allocating a new +// byte slice and filling it using PutInt32. +func Int32Bytes(v int32) []byte { + b := make([]byte, 4) + PutInt32(b, v) + return b +} diff --git a/vendor/github.com/mdlayher/netlink/nlenc/string.go b/vendor/github.com/mdlayher/netlink/nlenc/string.go new file mode 100644 index 000000000..c0b166ddf --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/nlenc/string.go @@ -0,0 +1,18 @@ +package nlenc + +import "bytes" + +// Bytes returns a null-terminated byte slice with the contents of s. +func Bytes(s string) []byte { + return append([]byte(s), 0x00) +} + +// String returns a string with the contents of b from a null-terminated +// byte slice. +func String(b []byte) string { + // If the string has more than one NULL terminator byte, we want to remove + // all of them before returning the string to the caller; hence the use of + // strings.TrimRight instead of strings.TrimSuffix (which previously only + // removed a single NULL). + return string(bytes.TrimRight(b, "\x00")) +} diff --git a/vendor/github.com/mdlayher/netlink/nltest/errors_others.go b/vendor/github.com/mdlayher/netlink/nltest/errors_others.go new file mode 100644 index 000000000..3a29c9b1a --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/nltest/errors_others.go @@ -0,0 +1,8 @@ +//go:build plan9 || windows +// +build plan9 windows + +package nltest + +func isSyscallError(_ error) bool { + return false +} diff --git a/vendor/github.com/mdlayher/netlink/nltest/errors_unix.go b/vendor/github.com/mdlayher/netlink/nltest/errors_unix.go new file mode 100644 index 000000000..f54403bb0 --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/nltest/errors_unix.go @@ -0,0 +1,11 @@ +//go:build !plan9 && !windows +// +build !plan9,!windows + +package nltest + +import "golang.org/x/sys/unix" + +func isSyscallError(err error) bool { + _, ok := err.(unix.Errno) + return ok +} diff --git a/vendor/github.com/mdlayher/netlink/nltest/nltest.go b/vendor/github.com/mdlayher/netlink/nltest/nltest.go new file mode 100644 index 000000000..2065bab02 --- /dev/null +++ b/vendor/github.com/mdlayher/netlink/nltest/nltest.go @@ -0,0 +1,207 @@ +// Package nltest provides utilities for netlink testing. +package nltest + +import ( + "fmt" + "io" + "os" + + "github.com/mdlayher/netlink" + "github.com/mdlayher/netlink/nlenc" +) + +// PID is the netlink header PID value assigned by nltest. +const PID = 1 + +// MustMarshalAttributes marshals a slice of netlink.Attributes to their binary +// format, but panics if any errors occur. +func MustMarshalAttributes(attrs []netlink.Attribute) []byte { + b, err := netlink.MarshalAttributes(attrs) + if err != nil { + panic(fmt.Sprintf("failed to marshal attributes to binary: %v", err)) + } + + return b +} + +// Multipart sends a slice of netlink.Messages to the caller as a +// netlink multi-part message. If less than two messages are present, +// the messages are not altered. +func Multipart(msgs []netlink.Message) ([]netlink.Message, error) { + if len(msgs) < 2 { + return msgs, nil + } + + for i := range msgs { + // Last message has header type "done" in addition to multi-part flag. + if i == len(msgs)-1 { + msgs[i].Header.Type = netlink.Done + } + + msgs[i].Header.Flags |= netlink.Multi + } + + return msgs, nil +} + +// Error returns a netlink error to the caller with the specified error +// number, in the body of the specified request message. +func Error(number int, reqs []netlink.Message) ([]netlink.Message, error) { + req := reqs[0] + req.Header.Length += 4 + req.Header.Type = netlink.Error + + errno := -1 * int32(number) + req.Data = append(nlenc.Int32Bytes(errno), req.Data...) + + return []netlink.Message{req}, nil +} + +// A Func is a function that can be used to test netlink.Conn interactions. +// The function can choose to return zero or more netlink messages, or an +// error if needed. +// +// For a netlink request/response interaction, a request req is populated by +// netlink.Conn.Send and passed to the function. +// +// For multicast interactions, an empty request req is passed to the function +// when netlink.Conn.Receive is called. +// +// If a Func returns an error, the error will be returned as-is to the caller. +// If no messages and io.EOF are returned, no messages and no error will be +// returned to the caller, simulating a multi-part message with no data. +type Func func(req []netlink.Message) ([]netlink.Message, error) + +// Dial sets up a netlink.Conn for testing using the specified Func. All requests +// sent from the connection will be passed to the Func. The connection should be +// closed as usual when it is no longer needed. +func Dial(fn Func) *netlink.Conn { + sock := &socket{ + fn: fn, + } + + return netlink.NewConn(sock, PID) +} + +// CheckRequest returns a Func that verifies that each message in an incoming +// request has the specified netlink header type and flags in the same slice +// position index, and then passes the request through to fn. +// +// The length of the types and flags slices must match the number of requests +// passed to the returned Func, or CheckRequest will panic. +// +// As an example: +// - types[0] and flags[0] will be checked against reqs[0] +// - types[1] and flags[1] will be checked against reqs[1] +// - ... and so on +// +// If an element of types or flags is set to the zero value, that check will +// be skipped for the request message that occurs at the same index. +// +// As an example, if types[0] is 0 and reqs[0].Header.Type is 1, the check will +// succeed because types[0] was not specified. +func CheckRequest(types []netlink.HeaderType, flags []netlink.HeaderFlags, fn Func) Func { + if len(types) != len(flags) { + panicf("nltest: CheckRequest called with mismatched types and flags slice lengths: %d != %d", + len(types), len(flags)) + } + + return func(req []netlink.Message) ([]netlink.Message, error) { + if len(types) != len(req) { + panicf("nltest: CheckRequest function invoked types/flags and request message slice lengths: %d != %d", + len(types), len(req)) + } + + for i := range req { + if want, got := types[i], req[i].Header.Type; types[i] != 0 && want != got { + return nil, fmt.Errorf("nltest: unexpected netlink header type: %s, want: %s", got, want) + } + + if want, got := flags[i], req[i].Header.Flags; flags[i] != 0 && want != got { + return nil, fmt.Errorf("nltest: unexpected netlink header flags: %s, want: %s", got, want) + } + } + + return fn(req) + } +} + +// A socket is a netlink.Socket used for testing. +type socket struct { + fn Func + + msgs []netlink.Message + err error +} + +func (c *socket) Close() error { return nil } + +func (c *socket) SendMessages(messages []netlink.Message) error { + msgs, err := c.fn(messages) + c.msgs = append(c.msgs, msgs...) + c.err = err + return nil +} + +func (c *socket) Send(m netlink.Message) error { + c.msgs, c.err = c.fn([]netlink.Message{m}) + return nil +} + +func (c *socket) Receive() ([]netlink.Message, error) { + // No messages set by Send means that we are emulating a + // multicast response or an error occurred. + if len(c.msgs) == 0 { + switch c.err { + case nil: + // No error, simulate multicast, but also return EOF to simulate + // no replies if needed. + msgs, err := c.fn(nil) + if err == io.EOF { + err = nil + } + + return msgs, err + case io.EOF: + // EOF, simulate no replies in multi-part message. + return nil, nil + } + + // If the error is a system call error, wrap it in os.NewSyscallError + // to simulate what the Linux netlink.Conn does. + if isSyscallError(c.err) { + return nil, os.NewSyscallError("recvmsg", c.err) + } + + // Some generic error occurred and should be passed to the caller. + return nil, c.err + } + + // Detect multi-part messages. + var multi bool + for _, m := range c.msgs { + if m.Header.Flags&netlink.Multi != 0 && m.Header.Type != netlink.Done { + multi = true + } + } + + // When a multi-part message is detected, return all messages except for the + // final "multi-part done", so that a second call to Receive from netlink.Conn + // will drain that message. + if multi { + last := c.msgs[len(c.msgs)-1] + ret := c.msgs[:len(c.msgs)-1] + c.msgs = []netlink.Message{last} + + return ret, c.err + } + + msgs, err := c.msgs, c.err + c.msgs, c.err = nil, nil + + return msgs, err +} + +func panicf(format string, a ...interface{}) { + panic(fmt.Sprintf(format, a...)) +} diff --git a/vendor/github.com/mdlayher/socket/CHANGELOG.md b/vendor/github.com/mdlayher/socket/CHANGELOG.md new file mode 100644 index 000000000..e1a77c411 --- /dev/null +++ b/vendor/github.com/mdlayher/socket/CHANGELOG.md @@ -0,0 +1,89 @@ +# CHANGELOG + +## v0.5.0 + +**This is the first release of package socket that only supports Go 1.21+. +Users on older versions of Go must use v0.4.1.** + +- [Improvement]: drop support for older versions of Go. +- [New API]: add `socket.Conn` wrappers for various `Getsockopt` and + `Setsockopt` system calls. + +## v0.4.1 + +- [Bug Fix] [commit](https://github.com/mdlayher/socket/commit/2a14ceef4da279de1f957c5761fffcc6c87bbd3b): + ensure `socket.Conn` can be used with non-socket file descriptors by handling + `ENOTSOCK` in the constructor. + +## v0.4.0 + +**This is the first release of package socket that only supports Go 1.18+. +Users on older versions of Go must use v0.3.0.** + +- [Improvement]: drop support for older versions of Go so we can begin using + modern versions of `x/sys` and other dependencies. + +## v0.3.0 + +**This is the last release of package socket that supports Go 1.17 and below.** + +- [New API/API change] [PR](https://github.com/mdlayher/socket/pull/8): + numerous `socket.Conn` methods now support context cancelation. Future + releases will continue adding support as needed. + - New `ReadContext` and `WriteContext` methods. + - `Connect`, `Recvfrom`, `Recvmsg`, `Sendmsg`, and `Sendto` methods now accept + a context. + - `Sendto` parameter order was also fixed to match the underlying syscall. + +## v0.2.3 + +- [New API] [commit](https://github.com/mdlayher/socket/commit/a425d96e0f772c053164f8ce4c9c825380a98086): + `socket.Conn` has new `Pidfd*` methods for wrapping the `pidfd_*(2)` family of + system calls. + +## v0.2.2 + +- [New API] [commit](https://github.com/mdlayher/socket/commit/a2429f1dfe8ec2586df5a09f50ead865276cd027): + `socket.Conn` has new `IoctlKCM*` methods for wrapping `ioctl(2)` for `AF_KCM` + operations. + +## v0.2.1 + +- [New API] [commit](https://github.com/mdlayher/socket/commit/b18ddbe9caa0e34552b4409a3aa311cb460d2f99): + `socket.Conn` has a new `SetsockoptPacketMreq` method for wrapping + `setsockopt(2)` for `AF_PACKET` socket options. + +## v0.2.0 + +- [New API] [commit](https://github.com/mdlayher/socket/commit/6e912a68523c45e5fd899239f4b46c402dd856da): + `socket.FileConn` can be used to create a `socket.Conn` from an existing + `os.File`, which may be provided by systemd socket activation or another + external mechanism. +- [API change] [commit](https://github.com/mdlayher/socket/commit/66d61f565188c23fe02b24099ddc856d538bf1a7): + `socket.Conn.Connect` now returns the `unix.Sockaddr` value provided by + `getpeername(2)`, since we have to invoke that system call anyway to verify + that a connection to a remote peer was successfully established. +- [Bug Fix] [commit](https://github.com/mdlayher/socket/commit/b60b2dbe0ac3caff2338446a150083bde8c5c19c): + check the correct error from `unix.GetsockoptInt` in the `socket.Conn.Connect` + method. Thanks @vcabbage! + +## v0.1.2 + +- [Bug Fix]: `socket.Conn.Connect` now properly checks the `SO_ERROR` socket + option value after calling `connect(2)` to verify whether or not a connection + could successfully be established. This means that `Connect` should now report + an error for an `AF_INET` TCP connection refused or `AF_VSOCK` connection + reset by peer. +- [New API]: add `socket.Conn.Getpeername` for use in `Connect`, but also for + use by external callers. + +## v0.1.1 + +- [New API]: `socket.Conn` now has `CloseRead`, `CloseWrite`, and `Shutdown` + methods. +- [Improvement]: internal rework to more robustly handle various errors. + +## v0.1.0 + +- Initial unstable release. Most functionality has been developed and ported +from package [`netlink`](https://github.com/mdlayher/netlink). diff --git a/vendor/github.com/mdlayher/socket/LICENSE.md b/vendor/github.com/mdlayher/socket/LICENSE.md new file mode 100644 index 000000000..3ccdb75b2 --- /dev/null +++ b/vendor/github.com/mdlayher/socket/LICENSE.md @@ -0,0 +1,9 @@ +# MIT License + +Copyright (C) 2021 Matt Layher + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/mdlayher/socket/README.md b/vendor/github.com/mdlayher/socket/README.md new file mode 100644 index 000000000..2aa065cbb --- /dev/null +++ b/vendor/github.com/mdlayher/socket/README.md @@ -0,0 +1,23 @@ +# socket [![Test Status](https://github.com/mdlayher/socket/workflows/Test/badge.svg)](https://github.com/mdlayher/socket/actions) [![Go Reference](https://pkg.go.dev/badge/github.com/mdlayher/socket.svg)](https://pkg.go.dev/github.com/mdlayher/socket) [![Go Report Card](https://goreportcard.com/badge/github.com/mdlayher/socket)](https://goreportcard.com/report/github.com/mdlayher/socket) + +Package `socket` provides a low-level network connection type which integrates +with Go's runtime network poller to provide asynchronous I/O and deadline +support. MIT Licensed. + +This package focuses on UNIX-like operating systems which make use of BSD +sockets system call APIs. It is meant to be used as a foundation for the +creation of operating system-specific socket packages, for socket families such +as Linux's `AF_NETLINK`, `AF_PACKET`, or `AF_VSOCK`. This package should not be +used directly in end user applications. + +Any use of package socket should be guarded by build tags, as one would also +use when importing the `syscall` or `golang.org/x/sys` packages. + +## Stability + +See the [CHANGELOG](./CHANGELOG.md) file for a description of changes between +releases. + +This package only supports the two most recent major versions of Go, mirroring +Go's own release policy. Older versions of Go may lack critical features and bug +fixes which are necessary for this package to function correctly. diff --git a/vendor/github.com/mdlayher/socket/accept.go b/vendor/github.com/mdlayher/socket/accept.go new file mode 100644 index 000000000..47e9d897e --- /dev/null +++ b/vendor/github.com/mdlayher/socket/accept.go @@ -0,0 +1,23 @@ +//go:build !dragonfly && !freebsd && !illumos && !linux +// +build !dragonfly,!freebsd,!illumos,!linux + +package socket + +import ( + "fmt" + "runtime" + + "golang.org/x/sys/unix" +) + +const sysAccept = "accept" + +// accept wraps accept(2). +func accept(fd, flags int) (int, unix.Sockaddr, error) { + if flags != 0 { + // These operating systems have no support for flags to accept(2). + return 0, nil, fmt.Errorf("socket: Conn.Accept flags are ineffective on %s", runtime.GOOS) + } + + return unix.Accept(fd) +} diff --git a/vendor/github.com/mdlayher/socket/accept4.go b/vendor/github.com/mdlayher/socket/accept4.go new file mode 100644 index 000000000..e1016b206 --- /dev/null +++ b/vendor/github.com/mdlayher/socket/accept4.go @@ -0,0 +1,15 @@ +//go:build dragonfly || freebsd || illumos || linux +// +build dragonfly freebsd illumos linux + +package socket + +import ( + "golang.org/x/sys/unix" +) + +const sysAccept = "accept4" + +// accept wraps accept4(2). +func accept(fd, flags int) (int, unix.Sockaddr, error) { + return unix.Accept4(fd, flags) +} diff --git a/vendor/github.com/mdlayher/socket/conn.go b/vendor/github.com/mdlayher/socket/conn.go new file mode 100644 index 000000000..5be502f5a --- /dev/null +++ b/vendor/github.com/mdlayher/socket/conn.go @@ -0,0 +1,894 @@ +package socket + +import ( + "context" + "errors" + "io" + "os" + "sync" + "sync/atomic" + "syscall" + "time" + + "golang.org/x/sys/unix" +) + +// Lock in an expected public interface for convenience. +var _ interface { + io.ReadWriteCloser + syscall.Conn + SetDeadline(t time.Time) error + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error +} = &Conn{} + +// A Conn is a low-level network connection which integrates with Go's runtime +// network poller to provide asynchronous I/O and deadline support. +// +// Many of a Conn's blocking methods support net.Conn deadlines as well as +// cancelation via context. Note that passing a context with a deadline set will +// override any of the previous deadlines set by calls to the SetDeadline family +// of methods. +type Conn struct { + // Indicates whether or not Conn.Close has been called. Must be accessed + // atomically. Atomics definitions must come first in the Conn struct. + closed uint32 + + // A unique name for the Conn which is also associated with derived file + // descriptors such as those created by accept(2). + name string + + // facts contains information we have determined about Conn to trigger + // alternate behavior in certain functions. + facts facts + + // Provides access to the underlying file registered with the runtime + // network poller, and arbitrary raw I/O calls. + fd *os.File + rc syscall.RawConn +} + +// facts contains facts about a Conn. +type facts struct { + // isStream reports whether this is a streaming descriptor, as opposed to a + // packet-based descriptor like a UDP socket. + isStream bool + + // zeroReadIsEOF reports Whether a zero byte read indicates EOF. This is + // false for a message based socket connection. + zeroReadIsEOF bool +} + +// A Config contains options for a Conn. +type Config struct { + // NetNS specifies the Linux network namespace the Conn will operate in. + // This option is unsupported on other operating systems. + // + // If set (non-zero), Conn will enter the specified network namespace and an + // error will occur in Socket if the operation fails. + // + // If not set (zero), a best-effort attempt will be made to enter the + // network namespace of the calling thread: this means that any changes made + // to the calling thread's network namespace will also be reflected in Conn. + // If this operation fails (due to lack of permissions or because network + // namespaces are disabled by kernel configuration), Socket will not return + // an error, and the Conn will operate in the default network namespace of + // the process. This enables non-privileged use of Conn in applications + // which do not require elevated privileges. + // + // Entering a network namespace is a privileged operation (root or + // CAP_SYS_ADMIN are required), and most applications should leave this set + // to 0. + NetNS int +} + +// High-level methods which provide convenience over raw system calls. + +// Close closes the underlying file descriptor for the Conn, which also causes +// all in-flight I/O operations to immediately unblock and return errors. Any +// subsequent uses of Conn will result in EBADF. +func (c *Conn) Close() error { + // The caller has expressed an intent to close the socket, so immediately + // increment s.closed to force further calls to result in EBADF before also + // closing the file descriptor to unblock any outstanding operations. + // + // Because other operations simply check for s.closed != 0, we will permit + // double Close, which would increment s.closed beyond 1. + if atomic.AddUint32(&c.closed, 1) != 1 { + // Multiple Close calls. + return nil + } + + return os.NewSyscallError("close", c.fd.Close()) +} + +// CloseRead shuts down the reading side of the Conn. Most callers should just +// use Close. +func (c *Conn) CloseRead() error { return c.Shutdown(unix.SHUT_RD) } + +// CloseWrite shuts down the writing side of the Conn. Most callers should just +// use Close. +func (c *Conn) CloseWrite() error { return c.Shutdown(unix.SHUT_WR) } + +// Read reads directly from the underlying file descriptor. +func (c *Conn) Read(b []byte) (int, error) { return c.fd.Read(b) } + +// ReadContext reads from the underlying file descriptor with added support for +// context cancelation. +func (c *Conn) ReadContext(ctx context.Context, b []byte) (int, error) { + if c.facts.isStream && len(b) > maxRW { + b = b[:maxRW] + } + + n, err := readT(c, ctx, "read", func(fd int) (int, error) { + return unix.Read(fd, b) + }) + if n == 0 && err == nil && c.facts.zeroReadIsEOF { + return 0, io.EOF + } + + return n, os.NewSyscallError("read", err) +} + +// Write writes directly to the underlying file descriptor. +func (c *Conn) Write(b []byte) (int, error) { return c.fd.Write(b) } + +// WriteContext writes to the underlying file descriptor with added support for +// context cancelation. +func (c *Conn) WriteContext(ctx context.Context, b []byte) (int, error) { + var ( + n, nn int + err error + ) + + doErr := c.write(ctx, "write", func(fd int) error { + max := len(b) + if c.facts.isStream && max-nn > maxRW { + max = nn + maxRW + } + + n, err = unix.Write(fd, b[nn:max]) + if n > 0 { + nn += n + } + if nn == len(b) { + return err + } + if n == 0 && err == nil { + err = io.ErrUnexpectedEOF + return nil + } + + return err + }) + if doErr != nil { + return 0, doErr + } + + return nn, os.NewSyscallError("write", err) +} + +// SetDeadline sets both the read and write deadlines associated with the Conn. +func (c *Conn) SetDeadline(t time.Time) error { return c.fd.SetDeadline(t) } + +// SetReadDeadline sets the read deadline associated with the Conn. +func (c *Conn) SetReadDeadline(t time.Time) error { return c.fd.SetReadDeadline(t) } + +// SetWriteDeadline sets the write deadline associated with the Conn. +func (c *Conn) SetWriteDeadline(t time.Time) error { return c.fd.SetWriteDeadline(t) } + +// ReadBuffer gets the size of the operating system's receive buffer associated +// with the Conn. +func (c *Conn) ReadBuffer() (int, error) { + return c.GetsockoptInt(unix.SOL_SOCKET, unix.SO_RCVBUF) +} + +// WriteBuffer gets the size of the operating system's transmit buffer +// associated with the Conn. +func (c *Conn) WriteBuffer() (int, error) { + return c.GetsockoptInt(unix.SOL_SOCKET, unix.SO_SNDBUF) +} + +// SetReadBuffer sets the size of the operating system's receive buffer +// associated with the Conn. +// +// When called with elevated privileges on Linux, the SO_RCVBUFFORCE option will +// be used to override operating system limits. Otherwise SO_RCVBUF is used +// (which obeys operating system limits). +func (c *Conn) SetReadBuffer(bytes int) error { return c.setReadBuffer(bytes) } + +// SetWriteBuffer sets the size of the operating system's transmit buffer +// associated with the Conn. +// +// When called with elevated privileges on Linux, the SO_SNDBUFFORCE option will +// be used to override operating system limits. Otherwise SO_SNDBUF is used +// (which obeys operating system limits). +func (c *Conn) SetWriteBuffer(bytes int) error { return c.setWriteBuffer(bytes) } + +// SyscallConn returns a raw network connection. This implements the +// syscall.Conn interface. +// +// SyscallConn is intended for advanced use cases, such as getting and setting +// arbitrary socket options using the socket's file descriptor. If possible, +// those operations should be performed using methods on Conn instead. +// +// Once invoked, it is the caller's responsibility to ensure that operations +// performed using Conn and the syscall.RawConn do not conflict with each other. +func (c *Conn) SyscallConn() (syscall.RawConn, error) { + if atomic.LoadUint32(&c.closed) != 0 { + return nil, os.NewSyscallError("syscallconn", unix.EBADF) + } + + // TODO(mdlayher): mutex or similar to enforce syscall.RawConn contract of + // FD remaining valid for duration of calls? + return c.rc, nil +} + +// Socket wraps the socket(2) system call to produce a Conn. domain, typ, and +// proto are passed directly to socket(2), and name should be a unique name for +// the socket type such as "netlink" or "vsock". +// +// The cfg parameter specifies optional configuration for the Conn. If nil, no +// additional configuration will be applied. +// +// If the operating system supports SOCK_CLOEXEC and SOCK_NONBLOCK, they are +// automatically applied to typ to mirror the standard library's socket flag +// behaviors. +func Socket(domain, typ, proto int, name string, cfg *Config) (*Conn, error) { + if cfg == nil { + cfg = &Config{} + } + + if cfg.NetNS == 0 { + // Non-Linux or no network namespace. + return socket(domain, typ, proto, name) + } + + // Linux only: create Conn in the specified network namespace. + return withNetNS(cfg.NetNS, func() (*Conn, error) { + return socket(domain, typ, proto, name) + }) +} + +// socket is the internal, cross-platform entry point for socket(2). +func socket(domain, typ, proto int, name string) (*Conn, error) { + var ( + fd int + err error + ) + + for { + fd, err = unix.Socket(domain, typ|socketFlags, proto) + switch { + case err == nil: + // Some OSes already set CLOEXEC with typ. + if !flagCLOEXEC { + unix.CloseOnExec(fd) + } + + // No error, prepare the Conn. + return New(fd, name) + case !ready(err): + // System call interrupted or not ready, try again. + continue + case err == unix.EINVAL, err == unix.EPROTONOSUPPORT: + // On Linux, SOCK_NONBLOCK and SOCK_CLOEXEC were introduced in + // 2.6.27. On FreeBSD, both flags were introduced in FreeBSD 10. + // EINVAL and EPROTONOSUPPORT check for earlier versions of these + // OSes respectively. + // + // Mirror what the standard library does when creating file + // descriptors: avoid racing a fork/exec with the creation of new + // file descriptors, so that child processes do not inherit socket + // file descriptors unexpectedly. + // + // For a more thorough explanation, see similar work in the Go tree: + // func sysSocket in net/sock_cloexec.go, as well as the detailed + // comment in syscall/exec_unix.go. + syscall.ForkLock.RLock() + fd, err = unix.Socket(domain, typ, proto) + if err != nil { + syscall.ForkLock.RUnlock() + return nil, os.NewSyscallError("socket", err) + } + unix.CloseOnExec(fd) + syscall.ForkLock.RUnlock() + + return New(fd, name) + default: + // Unhandled error. + return nil, os.NewSyscallError("socket", err) + } + } +} + +// FileConn returns a copy of the network connection corresponding to the open +// file. It is the caller's responsibility to close the file when finished. +// Closing the Conn does not affect the File, and closing the File does not +// affect the Conn. +func FileConn(f *os.File, name string) (*Conn, error) { + // First we'll try to do fctnl(2) with F_DUPFD_CLOEXEC because we can dup + // the file descriptor and set the flag in one syscall. + fd, err := unix.FcntlInt(f.Fd(), unix.F_DUPFD_CLOEXEC, 0) + switch err { + case nil: + // OK, ready to set up non-blocking I/O. + return New(fd, name) + case unix.EINVAL: + // The kernel rejected our fcntl(2), fall back to separate dup(2) and + // setting close on exec. + // + // Mirror what the standard library does when creating file descriptors: + // avoid racing a fork/exec with the creation of new file descriptors, + // so that child processes do not inherit socket file descriptors + // unexpectedly. + syscall.ForkLock.RLock() + fd, err := unix.Dup(fd) + if err != nil { + syscall.ForkLock.RUnlock() + return nil, os.NewSyscallError("dup", err) + } + unix.CloseOnExec(fd) + syscall.ForkLock.RUnlock() + + return New(fd, name) + default: + // Any other errors. + return nil, os.NewSyscallError("fcntl", err) + } +} + +// New wraps an existing file descriptor to create a Conn. name should be a +// unique name for the socket type such as "netlink" or "vsock". +// +// Most callers should use Socket or FileConn to construct a Conn. New is +// intended for integrating with specific system calls which provide a file +// descriptor that supports asynchronous I/O. The file descriptor is immediately +// set to nonblocking mode and registered with Go's runtime network poller for +// future I/O operations. +// +// Unlike FileConn, New does not duplicate the existing file descriptor in any +// way. The returned Conn takes ownership of the underlying file descriptor. +func New(fd int, name string) (*Conn, error) { + // All Conn I/O is nonblocking for integration with Go's runtime network + // poller. Depending on the OS this might already be set but it can't hurt + // to set it again. + if err := unix.SetNonblock(fd, true); err != nil { + return nil, os.NewSyscallError("setnonblock", err) + } + + // os.NewFile registers the non-blocking file descriptor with the runtime + // poller, which is then used for most subsequent operations except those + // that require raw I/O via SyscallConn. + // + // See also: https://golang.org/pkg/os/#NewFile + f := os.NewFile(uintptr(fd), name) + rc, err := f.SyscallConn() + if err != nil { + return nil, err + } + + c := &Conn{ + name: name, + fd: f, + rc: rc, + } + + // Probe the file descriptor for socket settings. + sotype, err := c.GetsockoptInt(unix.SOL_SOCKET, unix.SO_TYPE) + switch { + case err == nil: + // File is a socket, check its properties. + c.facts = facts{ + isStream: sotype == unix.SOCK_STREAM, + zeroReadIsEOF: sotype != unix.SOCK_DGRAM && sotype != unix.SOCK_RAW, + } + case errors.Is(err, unix.ENOTSOCK): + // File is not a socket, treat it as a regular file. + c.facts = facts{ + isStream: true, + zeroReadIsEOF: true, + } + default: + return nil, err + } + + return c, nil +} + +// Low-level methods which provide raw system call access. + +// Accept wraps accept(2) or accept4(2) depending on the operating system, but +// returns a Conn for the accepted connection rather than a raw file descriptor. +// +// If the operating system supports accept4(2) (which allows flags), +// SOCK_CLOEXEC and SOCK_NONBLOCK are automatically applied to flags to mirror +// the standard library's socket flag behaviors. +// +// If the operating system only supports accept(2) (which does not allow flags) +// and flags is not zero, an error will be returned. +// +// Accept obeys context cancelation and uses the deadline set on the context to +// cancel accepting the next connection. If a deadline is set on ctx, this +// deadline will override any previous deadlines set using SetDeadline or +// SetReadDeadline. Upon return, the read deadline is cleared. +func (c *Conn) Accept(ctx context.Context, flags int) (*Conn, unix.Sockaddr, error) { + type ret struct { + nfd int + sa unix.Sockaddr + } + + r, err := readT(c, ctx, sysAccept, func(fd int) (ret, error) { + // Either accept(2) or accept4(2) depending on the OS. + nfd, sa, err := accept(fd, flags|socketFlags) + return ret{nfd, sa}, err + }) + if err != nil { + // internal/poll, context error, or user function error. + return nil, nil, err + } + + // Successfully accepted a connection, wrap it in a Conn for use by the + // caller. + ac, err := New(r.nfd, c.name) + if err != nil { + return nil, nil, err + } + + return ac, r.sa, nil +} + +// Bind wraps bind(2). +func (c *Conn) Bind(sa unix.Sockaddr) error { + return c.control("bind", func(fd int) error { return unix.Bind(fd, sa) }) +} + +// Connect wraps connect(2). In order to verify that the underlying socket is +// connected to a remote peer, Connect calls getpeername(2) and returns the +// unix.Sockaddr from that call. +// +// Connect obeys context cancelation and uses the deadline set on the context to +// cancel connecting to a remote peer. If a deadline is set on ctx, this +// deadline will override any previous deadlines set using SetDeadline or +// SetWriteDeadline. Upon return, the write deadline is cleared. +func (c *Conn) Connect(ctx context.Context, sa unix.Sockaddr) (unix.Sockaddr, error) { + const op = "connect" + + // TODO(mdlayher): it would seem that trying to connect to unbound vsock + // listeners by calling Connect multiple times results in ECONNRESET for the + // first and nil error for subsequent calls. Do we need to memoize the + // error? Check what the stdlib behavior is. + + var ( + // Track progress between invocations of the write closure. We don't + // have an explicit WaitWrite call like internal/poll does, so we have + // to wait until the runtime calls the closure again to indicate we can + // write. + progress uint32 + + // Capture closure sockaddr and error. + rsa unix.Sockaddr + err error + ) + + doErr := c.write(ctx, op, func(fd int) error { + if atomic.AddUint32(&progress, 1) == 1 { + // First call: initiate connect. + return unix.Connect(fd, sa) + } + + // Subsequent calls: the runtime network poller indicates fd is + // writable. Check for errno. + errno, gerr := c.GetsockoptInt(unix.SOL_SOCKET, unix.SO_ERROR) + if gerr != nil { + return gerr + } + if errno != 0 { + // Connection is still not ready or failed. If errno indicates + // the socket is not ready, we will wait for the next write + // event. Otherwise we propagate this errno back to the as a + // permanent error. + uerr := unix.Errno(errno) + err = uerr + return uerr + } + + // According to internal/poll, it's possible for the runtime network + // poller to spuriously wake us and return errno 0 for SO_ERROR. + // Make sure we are actually connected to a peer. + peer, err := c.Getpeername() + if err != nil { + // internal/poll unconditionally goes back to WaitWrite. + // Synthesize an error that will do the same for us. + return unix.EAGAIN + } + + // Connection complete. + rsa = peer + return nil + }) + if doErr != nil { + // internal/poll or context error. + return nil, doErr + } + + if err == unix.EISCONN { + // TODO(mdlayher): is this block obsolete with the addition of the + // getsockopt SO_ERROR check above? + // + // EISCONN is reported if the socket is already established and should + // not be treated as an error. + // - Darwin reports this for at least TCP sockets + // - Linux reports this for at least AF_VSOCK sockets + return rsa, nil + } + + return rsa, os.NewSyscallError(op, err) +} + +// Getsockname wraps getsockname(2). +func (c *Conn) Getsockname() (unix.Sockaddr, error) { + return controlT(c, "getsockname", unix.Getsockname) +} + +// Getpeername wraps getpeername(2). +func (c *Conn) Getpeername() (unix.Sockaddr, error) { + return controlT(c, "getpeername", unix.Getpeername) +} + +// GetsockoptICMPv6Filter wraps getsockopt(2) for *unix.ICMPv6Filter values. +func (c *Conn) GetsockoptICMPv6Filter(level, opt int) (*unix.ICMPv6Filter, error) { + return controlT(c, "getsockopt", func(fd int) (*unix.ICMPv6Filter, error) { + return unix.GetsockoptICMPv6Filter(fd, level, opt) + }) +} + +// GetsockoptInt wraps getsockopt(2) for integer values. +func (c *Conn) GetsockoptInt(level, opt int) (int, error) { + return controlT(c, "getsockopt", func(fd int) (int, error) { + return unix.GetsockoptInt(fd, level, opt) + }) +} + +// GetsockoptString wraps getsockopt(2) for string values. +func (c *Conn) GetsockoptString(level, opt int) (string, error) { + return controlT(c, "getsockopt", func(fd int) (string, error) { + return unix.GetsockoptString(fd, level, opt) + }) +} + +// Listen wraps listen(2). +func (c *Conn) Listen(n int) error { + return c.control("listen", func(fd int) error { return unix.Listen(fd, n) }) +} + +// Recvmsg wraps recvmsg(2). +func (c *Conn) Recvmsg(ctx context.Context, p, oob []byte, flags int) (int, int, int, unix.Sockaddr, error) { + type ret struct { + n, oobn, recvflags int + from unix.Sockaddr + } + + r, err := readT(c, ctx, "recvmsg", func(fd int) (ret, error) { + n, oobn, recvflags, from, err := unix.Recvmsg(fd, p, oob, flags) + return ret{n, oobn, recvflags, from}, err + }) + if r.n == 0 && err == nil && c.facts.zeroReadIsEOF { + return 0, 0, 0, nil, io.EOF + } + + return r.n, r.oobn, r.recvflags, r.from, err +} + +// Recvfrom wraps recvfrom(2). +func (c *Conn) Recvfrom(ctx context.Context, p []byte, flags int) (int, unix.Sockaddr, error) { + type ret struct { + n int + addr unix.Sockaddr + } + + out, err := readT(c, ctx, "recvfrom", func(fd int) (ret, error) { + n, addr, err := unix.Recvfrom(fd, p, flags) + return ret{n, addr}, err + }) + if out.n == 0 && err == nil && c.facts.zeroReadIsEOF { + return 0, nil, io.EOF + } + + return out.n, out.addr, err +} + +// Sendmsg wraps sendmsg(2). +func (c *Conn) Sendmsg(ctx context.Context, p, oob []byte, to unix.Sockaddr, flags int) (int, error) { + return writeT(c, ctx, "sendmsg", func(fd int) (int, error) { + return unix.SendmsgN(fd, p, oob, to, flags) + }) +} + +// Sendto wraps sendto(2). +func (c *Conn) Sendto(ctx context.Context, p []byte, flags int, to unix.Sockaddr) error { + return c.write(ctx, "sendto", func(fd int) error { + return unix.Sendto(fd, p, flags, to) + }) +} + +// SetsockoptICMPv6Filter wraps setsockopt(2) for *unix.ICMPv6Filter values. +func (c *Conn) SetsockoptICMPv6Filter(level, opt int, filter *unix.ICMPv6Filter) error { + return c.control("setsockopt", func(fd int) error { + return unix.SetsockoptICMPv6Filter(fd, level, opt, filter) + }) +} + +// SetsockoptInt wraps setsockopt(2) for integer values. +func (c *Conn) SetsockoptInt(level, opt, value int) error { + return c.control("setsockopt", func(fd int) error { + return unix.SetsockoptInt(fd, level, opt, value) + }) +} + +// SetsockoptString wraps setsockopt(2) for string values. +func (c *Conn) SetsockoptString(level, opt int, value string) error { + return c.control("setsockopt", func(fd int) error { + return unix.SetsockoptString(fd, level, opt, value) + }) +} + +// Shutdown wraps shutdown(2). +func (c *Conn) Shutdown(how int) error { + return c.control("shutdown", func(fd int) error { return unix.Shutdown(fd, how) }) +} + +// Conn low-level read/write/control functions. These functions mirror the +// syscall.RawConn APIs but the input closures return errors rather than +// booleans. + +// read wraps readT to execute a function and capture its error result. This is +// a convenience wrapper for functions which don't return any extra values. +func (c *Conn) read(ctx context.Context, op string, f func(fd int) error) error { + _, err := readT(c, ctx, op, func(fd int) (struct{}, error) { + return struct{}{}, f(fd) + }) + return err +} + +// write executes f, a write function, against the associated file descriptor. +// op is used to create an *os.SyscallError if the file descriptor is closed. +func (c *Conn) write(ctx context.Context, op string, f func(fd int) error) error { + _, err := writeT(c, ctx, op, func(fd int) (struct{}, error) { + return struct{}{}, f(fd) + }) + return err +} + +// readT executes c.rc.Read for op using the input function, returning a newly +// allocated result T. +func readT[T any](c *Conn, ctx context.Context, op string, f func(fd int) (T, error)) (T, error) { + return rwT(c, rwContext[T]{ + Context: ctx, + Type: read, + Op: op, + Do: f, + }) +} + +// writeT executes c.rc.Write for op using the input function, returning a newly +// allocated result T. +func writeT[T any](c *Conn, ctx context.Context, op string, f func(fd int) (T, error)) (T, error) { + return rwT(c, rwContext[T]{ + Context: ctx, + Type: write, + Op: op, + Do: f, + }) +} + +// readWrite indicates if an operation intends to read or write. +type readWrite bool + +// Possible readWrite values. +const ( + read readWrite = false + write readWrite = true +) + +// An rwContext provides arguments to rwT. +type rwContext[T any] struct { + // The caller's context passed for cancelation. + Context context.Context + + // The type of an operation: read or write. + Type readWrite + + // The name of the operation used in errors. + Op string + + // The actual function to perform. + Do func(fd int) (T, error) +} + +// rwT executes c.rc.Read or c.rc.Write (depending on the value of rw.Type) for +// rw.Op using the input function, returning a newly allocated result T. +// +// It obeys context cancelation and the rw.Context must not be nil. +func rwT[T any](c *Conn, rw rwContext[T]) (T, error) { + if atomic.LoadUint32(&c.closed) != 0 { + // If the file descriptor is already closed, do nothing. + return *new(T), os.NewSyscallError(rw.Op, unix.EBADF) + } + + if err := rw.Context.Err(); err != nil { + // Early exit due to context cancel. + return *new(T), os.NewSyscallError(rw.Op, err) + } + + var ( + // The read or write function used to access the runtime network poller. + poll func(func(uintptr) bool) error + + // The read or write function used to set the matching deadline. + deadline func(time.Time) error + ) + + if rw.Type == write { + poll = c.rc.Write + deadline = c.SetWriteDeadline + } else { + poll = c.rc.Read + deadline = c.SetReadDeadline + } + + var ( + // Whether or not the context carried a deadline we are actively using + // for cancelation. + setDeadline bool + + // Signals for the cancelation watcher goroutine. + wg sync.WaitGroup + doneC = make(chan struct{}) + + // Atomic: reports whether we have to disarm the deadline. + needDisarm atomic.Bool + ) + + // On cancel, clean up the watcher. + defer func() { + close(doneC) + wg.Wait() + }() + + if d, ok := rw.Context.Deadline(); ok { + // The context has an explicit deadline. We will use it for cancelation + // but disarm it after poll for the next call. + if err := deadline(d); err != nil { + return *new(T), err + } + setDeadline = true + needDisarm.Store(true) + } else { + // The context does not have an explicit deadline. We have to watch for + // cancelation so we can propagate that signal to immediately unblock + // the runtime network poller. + // + // TODO(mdlayher): is it possible to detect a background context vs a + // context with possible future cancel? + wg.Add(1) + go func() { + defer wg.Done() + + select { + case <-rw.Context.Done(): + // Cancel the operation. Make the caller disarm after poll + // returns. + needDisarm.Store(true) + _ = deadline(time.Unix(0, 1)) + case <-doneC: + // Nothing to do. + } + }() + } + + var ( + t T + err error + ) + + pollErr := poll(func(fd uintptr) bool { + t, err = rw.Do(int(fd)) + return ready(err) + }) + + if needDisarm.Load() { + _ = deadline(time.Time{}) + } + + if pollErr != nil { + if rw.Context.Err() != nil || (setDeadline && errors.Is(pollErr, os.ErrDeadlineExceeded)) { + // The caller canceled the operation or we set a deadline internally + // and it was reached. + // + // Unpack a plain context error. We wait for the context to be done + // to synchronize state externally. Otherwise we have noticed I/O + // timeout wakeups when we set a deadline but the context was not + // yet marked done. + <-rw.Context.Done() + return *new(T), os.NewSyscallError(rw.Op, rw.Context.Err()) + } + + // Error from syscall.RawConn methods. Conventionally the standard + // library does not wrap internal/poll errors in os.NewSyscallError. + return *new(T), pollErr + } + + // Result from user function. + return t, os.NewSyscallError(rw.Op, err) +} + +// control executes Conn.control for op using the input function. +func (c *Conn) control(op string, f func(fd int) error) error { + _, err := controlT(c, op, func(fd int) (struct{}, error) { + return struct{}{}, f(fd) + }) + return err +} + +// controlT executes c.rc.Control for op using the input function, returning a +// newly allocated result T. +func controlT[T any](c *Conn, op string, f func(fd int) (T, error)) (T, error) { + if atomic.LoadUint32(&c.closed) != 0 { + // If the file descriptor is already closed, do nothing. + return *new(T), os.NewSyscallError(op, unix.EBADF) + } + + var ( + t T + err error + ) + + doErr := c.rc.Control(func(fd uintptr) { + // Repeatedly attempt the syscall(s) invoked by f until completion is + // indicated by the return value of ready or the context is canceled. + // + // The last values for t and err are captured outside of the closure for + // use when the loop breaks. + for { + t, err = f(int(fd)) + if ready(err) { + return + } + } + }) + if doErr != nil { + // Error from syscall.RawConn methods. Conventionally the standard + // library does not wrap internal/poll errors in os.NewSyscallError. + return *new(T), doErr + } + + // Result from user function. + return t, os.NewSyscallError(op, err) +} + +// ready indicates readiness based on the value of err. +func ready(err error) bool { + switch err { + case unix.EAGAIN, unix.EINPROGRESS, unix.EINTR: + // When a socket is in non-blocking mode, we might see a variety of errors: + // - EAGAIN: most common case for a socket read not being ready + // - EINPROGRESS: reported by some sockets when first calling connect + // - EINTR: system call interrupted, more frequently occurs in Go 1.14+ + // because goroutines can be asynchronously preempted + // + // Return false to let the poller wait for readiness. See the source code + // for internal/poll.FD.RawRead for more details. + return false + default: + // Ready regardless of whether there was an error or no error. + return true + } +} + +// Darwin and FreeBSD can't read or write 2GB+ files at a time, +// even on 64-bit systems. +// The same is true of socket implementations on many systems. +// See golang.org/issue/7812 and golang.org/issue/16266. +// Use 1GB instead of, say, 2GB-1, to keep subsequent reads aligned. +const maxRW = 1 << 30 diff --git a/vendor/github.com/mdlayher/socket/conn_linux.go b/vendor/github.com/mdlayher/socket/conn_linux.go new file mode 100644 index 000000000..081194f32 --- /dev/null +++ b/vendor/github.com/mdlayher/socket/conn_linux.go @@ -0,0 +1,118 @@ +//go:build linux +// +build linux + +package socket + +import ( + "context" + "os" + "unsafe" + + "golang.org/x/net/bpf" + "golang.org/x/sys/unix" +) + +// IoctlKCMClone wraps ioctl(2) for unix.KCMClone values, but returns a Conn +// rather than a raw file descriptor. +func (c *Conn) IoctlKCMClone() (*Conn, error) { + info, err := controlT(c, "ioctl", unix.IoctlKCMClone) + if err != nil { + return nil, err + } + + // Successful clone, wrap in a Conn for use by the caller. + return New(int(info.Fd), c.name) +} + +// IoctlKCMAttach wraps ioctl(2) for unix.KCMAttach values. +func (c *Conn) IoctlKCMAttach(info unix.KCMAttach) error { + return c.control("ioctl", func(fd int) error { + return unix.IoctlKCMAttach(fd, info) + }) +} + +// IoctlKCMUnattach wraps ioctl(2) for unix.KCMUnattach values. +func (c *Conn) IoctlKCMUnattach(info unix.KCMUnattach) error { + return c.control("ioctl", func(fd int) error { + return unix.IoctlKCMUnattach(fd, info) + }) +} + +// PidfdGetfd wraps pidfd_getfd(2) for a Conn which wraps a pidfd, but returns a +// Conn rather than a raw file descriptor. +func (c *Conn) PidfdGetfd(targetFD, flags int) (*Conn, error) { + outFD, err := controlT(c, "pidfd_getfd", func(fd int) (int, error) { + return unix.PidfdGetfd(fd, targetFD, flags) + }) + if err != nil { + return nil, err + } + + // Successful getfd, wrap in a Conn for use by the caller. + return New(outFD, c.name) +} + +// PidfdSendSignal wraps pidfd_send_signal(2) for a Conn which wraps a Linux +// pidfd. +func (c *Conn) PidfdSendSignal(sig unix.Signal, info *unix.Siginfo, flags int) error { + return c.control("pidfd_send_signal", func(fd int) error { + return unix.PidfdSendSignal(fd, sig, info, flags) + }) +} + +// SetBPF attaches an assembled BPF program to a Conn. +func (c *Conn) SetBPF(filter []bpf.RawInstruction) error { + // We can't point to the first instruction in the array if no instructions + // are present. + if len(filter) == 0 { + return os.NewSyscallError("setsockopt", unix.EINVAL) + } + + prog := unix.SockFprog{ + Len: uint16(len(filter)), + Filter: (*unix.SockFilter)(unsafe.Pointer(&filter[0])), + } + + return c.SetsockoptSockFprog(unix.SOL_SOCKET, unix.SO_ATTACH_FILTER, &prog) +} + +// RemoveBPF removes a BPF filter from a Conn. +func (c *Conn) RemoveBPF() error { + // 0 argument is ignored. + return c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_DETACH_FILTER, 0) +} + +// SetsockoptPacketMreq wraps setsockopt(2) for unix.PacketMreq values. +func (c *Conn) SetsockoptPacketMreq(level, opt int, mreq *unix.PacketMreq) error { + return c.control("setsockopt", func(fd int) error { + return unix.SetsockoptPacketMreq(fd, level, opt, mreq) + }) +} + +// SetsockoptSockFprog wraps setsockopt(2) for unix.SockFprog values. +func (c *Conn) SetsockoptSockFprog(level, opt int, fprog *unix.SockFprog) error { + return c.control("setsockopt", func(fd int) error { + return unix.SetsockoptSockFprog(fd, level, opt, fprog) + }) +} + +// GetsockoptTpacketStats wraps getsockopt(2) for unix.TpacketStats values. +func (c *Conn) GetsockoptTpacketStats(level, name int) (*unix.TpacketStats, error) { + return controlT(c, "getsockopt", func(fd int) (*unix.TpacketStats, error) { + return unix.GetsockoptTpacketStats(fd, level, name) + }) +} + +// GetsockoptTpacketStatsV3 wraps getsockopt(2) for unix.TpacketStatsV3 values. +func (c *Conn) GetsockoptTpacketStatsV3(level, name int) (*unix.TpacketStatsV3, error) { + return controlT(c, "getsockopt", func(fd int) (*unix.TpacketStatsV3, error) { + return unix.GetsockoptTpacketStatsV3(fd, level, name) + }) +} + +// Waitid wraps waitid(2). +func (c *Conn) Waitid(idType int, info *unix.Siginfo, options int, rusage *unix.Rusage) error { + return c.read(context.Background(), "waitid", func(fd int) error { + return unix.Waitid(idType, fd, info, options, rusage) + }) +} diff --git a/vendor/github.com/mdlayher/socket/doc.go b/vendor/github.com/mdlayher/socket/doc.go new file mode 100644 index 000000000..7d4566c90 --- /dev/null +++ b/vendor/github.com/mdlayher/socket/doc.go @@ -0,0 +1,13 @@ +// Package socket provides a low-level network connection type which integrates +// with Go's runtime network poller to provide asynchronous I/O and deadline +// support. +// +// This package focuses on UNIX-like operating systems which make use of BSD +// sockets system call APIs. It is meant to be used as a foundation for the +// creation of operating system-specific socket packages, for socket families +// such as Linux's AF_NETLINK, AF_PACKET, or AF_VSOCK. This package should not +// be used directly in end user applications. +// +// Any use of package socket should be guarded by build tags, as one would also +// use when importing the syscall or golang.org/x/sys packages. +package socket diff --git a/vendor/github.com/mdlayher/socket/netns_linux.go b/vendor/github.com/mdlayher/socket/netns_linux.go new file mode 100644 index 000000000..b29115ad1 --- /dev/null +++ b/vendor/github.com/mdlayher/socket/netns_linux.go @@ -0,0 +1,150 @@ +//go:build linux +// +build linux + +package socket + +import ( + "errors" + "fmt" + "os" + "runtime" + + "golang.org/x/sync/errgroup" + "golang.org/x/sys/unix" +) + +// errNetNSDisabled is returned when network namespaces are unavailable on +// a given system. +var errNetNSDisabled = errors.New("socket: Linux network namespaces are not enabled on this system") + +// withNetNS invokes fn within the context of the network namespace specified by +// fd, while also managing the logic required to safely do so by manipulating +// thread-local state. +func withNetNS(fd int, fn func() (*Conn, error)) (*Conn, error) { + var ( + eg errgroup.Group + conn *Conn + ) + + eg.Go(func() error { + // Retrieve and store the calling OS thread's network namespace so the + // thread can be reassigned to it after creating a socket in another network + // namespace. + runtime.LockOSThread() + + ns, err := threadNetNS() + if err != nil { + // No thread-local manipulation, unlock. + runtime.UnlockOSThread() + return err + } + defer ns.Close() + + // Beyond this point, the thread's network namespace is poisoned. Do not + // unlock the OS thread until all network namespace manipulation completes + // to avoid returning to the caller with altered thread-local state. + + // Assign the current OS thread the goroutine is locked to to the given + // network namespace. + if err := ns.Set(fd); err != nil { + return err + } + + // Attempt Conn creation and unconditionally restore the original namespace. + c, err := fn() + if nerr := ns.Restore(); nerr != nil { + // Failed to restore original namespace. Return an error and allow the + // runtime to terminate the thread. + if err == nil { + _ = c.Close() + } + + return nerr + } + + // No more thread-local state manipulation; return the new Conn. + runtime.UnlockOSThread() + conn = c + return nil + }) + + if err := eg.Wait(); err != nil { + return nil, err + } + + return conn, nil +} + +// A netNS is a handle that can manipulate network namespaces. +// +// Operations performed on a netNS must use runtime.LockOSThread before +// manipulating any network namespaces. +type netNS struct { + // The handle to a network namespace. + f *os.File + + // Indicates if network namespaces are disabled on this system, and thus + // operations should become a no-op or return errors. + disabled bool +} + +// threadNetNS constructs a netNS using the network namespace of the calling +// thread. If the namespace is not the default namespace, runtime.LockOSThread +// should be invoked first. +func threadNetNS() (*netNS, error) { + return fileNetNS(fmt.Sprintf("/proc/self/task/%d/ns/net", unix.Gettid())) +} + +// fileNetNS opens file and creates a netNS. fileNetNS should only be called +// directly in tests. +func fileNetNS(file string) (*netNS, error) { + f, err := os.Open(file) + switch { + case err == nil: + return &netNS{f: f}, nil + case os.IsNotExist(err): + // Network namespaces are not enabled on this system. Use this signal + // to return errors elsewhere if the caller explicitly asks for a + // network namespace to be set. + return &netNS{disabled: true}, nil + default: + return nil, err + } +} + +// Close releases the handle to a network namespace. +func (n *netNS) Close() error { + return n.do(func() error { return n.f.Close() }) +} + +// FD returns a file descriptor which represents the network namespace. +func (n *netNS) FD() int { + if n.disabled { + // No reasonable file descriptor value in this case, so specify a + // non-existent one. + return -1 + } + + return int(n.f.Fd()) +} + +// Restore restores the original network namespace for the calling thread. +func (n *netNS) Restore() error { + return n.do(func() error { return n.Set(n.FD()) }) +} + +// Set sets a new network namespace for the current thread using fd. +func (n *netNS) Set(fd int) error { + return n.do(func() error { + return os.NewSyscallError("setns", unix.Setns(fd, unix.CLONE_NEWNET)) + }) +} + +// do runs fn if network namespaces are enabled on this system. +func (n *netNS) do(fn func() error) error { + if n.disabled { + return errNetNSDisabled + } + + return fn() +} diff --git a/vendor/github.com/mdlayher/socket/netns_others.go b/vendor/github.com/mdlayher/socket/netns_others.go new file mode 100644 index 000000000..4cceb3d04 --- /dev/null +++ b/vendor/github.com/mdlayher/socket/netns_others.go @@ -0,0 +1,14 @@ +//go:build !linux +// +build !linux + +package socket + +import ( + "fmt" + "runtime" +) + +// withNetNS returns an error on non-Linux systems. +func withNetNS(_ int, _ func() (*Conn, error)) (*Conn, error) { + return nil, fmt.Errorf("socket: Linux network namespace support is not available on %s", runtime.GOOS) +} diff --git a/vendor/github.com/mdlayher/socket/setbuffer_linux.go b/vendor/github.com/mdlayher/socket/setbuffer_linux.go new file mode 100644 index 000000000..0d4aa4417 --- /dev/null +++ b/vendor/github.com/mdlayher/socket/setbuffer_linux.go @@ -0,0 +1,24 @@ +//go:build linux +// +build linux + +package socket + +import "golang.org/x/sys/unix" + +// setReadBuffer wraps the SO_RCVBUF{,FORCE} setsockopt(2) options. +func (c *Conn) setReadBuffer(bytes int) error { + err := c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, bytes) + if err != nil { + err = c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_RCVBUF, bytes) + } + return err +} + +// setWriteBuffer wraps the SO_SNDBUF{,FORCE} setsockopt(2) options. +func (c *Conn) setWriteBuffer(bytes int) error { + err := c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, bytes) + if err != nil { + err = c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_SNDBUF, bytes) + } + return err +} diff --git a/vendor/github.com/mdlayher/socket/setbuffer_others.go b/vendor/github.com/mdlayher/socket/setbuffer_others.go new file mode 100644 index 000000000..72b36dbe3 --- /dev/null +++ b/vendor/github.com/mdlayher/socket/setbuffer_others.go @@ -0,0 +1,16 @@ +//go:build !linux +// +build !linux + +package socket + +import "golang.org/x/sys/unix" + +// setReadBuffer wraps the SO_RCVBUF setsockopt(2) option. +func (c *Conn) setReadBuffer(bytes int) error { + return c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_RCVBUF, bytes) +} + +// setWriteBuffer wraps the SO_SNDBUF setsockopt(2) option. +func (c *Conn) setWriteBuffer(bytes int) error { + return c.SetsockoptInt(unix.SOL_SOCKET, unix.SO_SNDBUF, bytes) +} diff --git a/vendor/github.com/mdlayher/socket/typ_cloexec_nonblock.go b/vendor/github.com/mdlayher/socket/typ_cloexec_nonblock.go new file mode 100644 index 000000000..40e834310 --- /dev/null +++ b/vendor/github.com/mdlayher/socket/typ_cloexec_nonblock.go @@ -0,0 +1,12 @@ +//go:build !darwin +// +build !darwin + +package socket + +import "golang.org/x/sys/unix" + +const ( + // These operating systems support CLOEXEC and NONBLOCK socket options. + flagCLOEXEC = true + socketFlags = unix.SOCK_CLOEXEC | unix.SOCK_NONBLOCK +) diff --git a/vendor/github.com/mdlayher/socket/typ_none.go b/vendor/github.com/mdlayher/socket/typ_none.go new file mode 100644 index 000000000..9bbb1aab5 --- /dev/null +++ b/vendor/github.com/mdlayher/socket/typ_none.go @@ -0,0 +1,11 @@ +//go:build darwin +// +build darwin + +package socket + +const ( + // These operating systems do not support CLOEXEC and NONBLOCK socket + // options. + flagCLOEXEC = false + socketFlags = 0 +) diff --git a/vendor/golang.org/x/net/bpf/asm.go b/vendor/golang.org/x/net/bpf/asm.go new file mode 100644 index 000000000..15e21b181 --- /dev/null +++ b/vendor/golang.org/x/net/bpf/asm.go @@ -0,0 +1,41 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bpf + +import "fmt" + +// Assemble converts insts into raw instructions suitable for loading +// into a BPF virtual machine. +// +// Currently, no optimization is attempted, the assembled program flow +// is exactly as provided. +func Assemble(insts []Instruction) ([]RawInstruction, error) { + ret := make([]RawInstruction, len(insts)) + var err error + for i, inst := range insts { + ret[i], err = inst.Assemble() + if err != nil { + return nil, fmt.Errorf("assembling instruction %d: %s", i+1, err) + } + } + return ret, nil +} + +// Disassemble attempts to parse raw back into +// Instructions. Unrecognized RawInstructions are assumed to be an +// extension not implemented by this package, and are passed through +// unchanged to the output. The allDecoded value reports whether insts +// contains no RawInstructions. +func Disassemble(raw []RawInstruction) (insts []Instruction, allDecoded bool) { + insts = make([]Instruction, len(raw)) + allDecoded = true + for i, r := range raw { + insts[i] = r.Disassemble() + if _, ok := insts[i].(RawInstruction); ok { + allDecoded = false + } + } + return insts, allDecoded +} diff --git a/vendor/golang.org/x/net/bpf/constants.go b/vendor/golang.org/x/net/bpf/constants.go new file mode 100644 index 000000000..12f3ee835 --- /dev/null +++ b/vendor/golang.org/x/net/bpf/constants.go @@ -0,0 +1,222 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bpf + +// A Register is a register of the BPF virtual machine. +type Register uint16 + +const ( + // RegA is the accumulator register. RegA is always the + // destination register of ALU operations. + RegA Register = iota + // RegX is the indirection register, used by LoadIndirect + // operations. + RegX +) + +// An ALUOp is an arithmetic or logic operation. +type ALUOp uint16 + +// ALU binary operation types. +const ( + ALUOpAdd ALUOp = iota << 4 + ALUOpSub + ALUOpMul + ALUOpDiv + ALUOpOr + ALUOpAnd + ALUOpShiftLeft + ALUOpShiftRight + aluOpNeg // Not exported because it's the only unary ALU operation, and gets its own instruction type. + ALUOpMod + ALUOpXor +) + +// A JumpTest is a comparison operator used in conditional jumps. +type JumpTest uint16 + +// Supported operators for conditional jumps. +// K can be RegX for JumpIfX +const ( + // K == A + JumpEqual JumpTest = iota + // K != A + JumpNotEqual + // K > A + JumpGreaterThan + // K < A + JumpLessThan + // K >= A + JumpGreaterOrEqual + // K <= A + JumpLessOrEqual + // K & A != 0 + JumpBitsSet + // K & A == 0 + JumpBitsNotSet +) + +// An Extension is a function call provided by the kernel that +// performs advanced operations that are expensive or impossible +// within the BPF virtual machine. +// +// Extensions are only implemented by the Linux kernel. +// +// TODO: should we prune this list? Some of these extensions seem +// either broken or near-impossible to use correctly, whereas other +// (len, random, ifindex) are quite useful. +type Extension int + +// Extension functions available in the Linux kernel. +const ( + // extOffset is the negative maximum number of instructions used + // to load instructions by overloading the K argument. + extOffset = -0x1000 + // ExtLen returns the length of the packet. + ExtLen Extension = 1 + // ExtProto returns the packet's L3 protocol type. + ExtProto Extension = 0 + // ExtType returns the packet's type (skb->pkt_type in the kernel) + // + // TODO: better documentation. How nice an API do we want to + // provide for these esoteric extensions? + ExtType Extension = 4 + // ExtPayloadOffset returns the offset of the packet payload, or + // the first protocol header that the kernel does not know how to + // parse. + ExtPayloadOffset Extension = 52 + // ExtInterfaceIndex returns the index of the interface on which + // the packet was received. + ExtInterfaceIndex Extension = 8 + // ExtNetlinkAttr returns the netlink attribute of type X at + // offset A. + ExtNetlinkAttr Extension = 12 + // ExtNetlinkAttrNested returns the nested netlink attribute of + // type X at offset A. + ExtNetlinkAttrNested Extension = 16 + // ExtMark returns the packet's mark value. + ExtMark Extension = 20 + // ExtQueue returns the packet's assigned hardware queue. + ExtQueue Extension = 24 + // ExtLinkLayerType returns the packet's hardware address type + // (e.g. Ethernet, Infiniband). + ExtLinkLayerType Extension = 28 + // ExtRXHash returns the packets receive hash. + // + // TODO: figure out what this rxhash actually is. + ExtRXHash Extension = 32 + // ExtCPUID returns the ID of the CPU processing the current + // packet. + ExtCPUID Extension = 36 + // ExtVLANTag returns the packet's VLAN tag. + ExtVLANTag Extension = 44 + // ExtVLANTagPresent returns non-zero if the packet has a VLAN + // tag. + // + // TODO: I think this might be a lie: it reads bit 0x1000 of the + // VLAN header, which changed meaning in recent revisions of the + // spec - this extension may now return meaningless information. + ExtVLANTagPresent Extension = 48 + // ExtVLANProto returns 0x8100 if the frame has a VLAN header, + // 0x88a8 if the frame has a "Q-in-Q" double VLAN header, or some + // other value if no VLAN information is present. + ExtVLANProto Extension = 60 + // ExtRand returns a uniformly random uint32. + ExtRand Extension = 56 +) + +// The following gives names to various bit patterns used in opcode construction. + +const ( + opMaskCls uint16 = 0x7 + // opClsLoad masks + opMaskLoadDest = 0x01 + opMaskLoadWidth = 0x18 + opMaskLoadMode = 0xe0 + // opClsALU & opClsJump + opMaskOperand = 0x08 + opMaskOperator = 0xf0 +) + +const ( + // +---------------+-----------------+---+---+---+ + // | AddrMode (3b) | LoadWidth (2b) | 0 | 0 | 0 | + // +---------------+-----------------+---+---+---+ + opClsLoadA uint16 = iota + // +---------------+-----------------+---+---+---+ + // | AddrMode (3b) | LoadWidth (2b) | 0 | 0 | 1 | + // +---------------+-----------------+---+---+---+ + opClsLoadX + // +---+---+---+---+---+---+---+---+ + // | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | + // +---+---+---+---+---+---+---+---+ + opClsStoreA + // +---+---+---+---+---+---+---+---+ + // | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | + // +---+---+---+---+---+---+---+---+ + opClsStoreX + // +---------------+-----------------+---+---+---+ + // | Operator (4b) | OperandSrc (1b) | 1 | 0 | 0 | + // +---------------+-----------------+---+---+---+ + opClsALU + // +-----------------------------+---+---+---+---+ + // | TestOperator (4b) | 0 | 1 | 0 | 1 | + // +-----------------------------+---+---+---+---+ + opClsJump + // +---+-------------------------+---+---+---+---+ + // | 0 | 0 | 0 | RetSrc (1b) | 0 | 1 | 1 | 0 | + // +---+-------------------------+---+---+---+---+ + opClsReturn + // +---+-------------------------+---+---+---+---+ + // | 0 | 0 | 0 | TXAorTAX (1b) | 0 | 1 | 1 | 1 | + // +---+-------------------------+---+---+---+---+ + opClsMisc +) + +const ( + opAddrModeImmediate uint16 = iota << 5 + opAddrModeAbsolute + opAddrModeIndirect + opAddrModeScratch + opAddrModePacketLen // actually an extension, not an addressing mode. + opAddrModeMemShift +) + +const ( + opLoadWidth4 uint16 = iota << 3 + opLoadWidth2 + opLoadWidth1 +) + +// Operand for ALU and Jump instructions +type opOperand uint16 + +// Supported operand sources. +const ( + opOperandConstant opOperand = iota << 3 + opOperandX +) + +// An jumpOp is a conditional jump condition. +type jumpOp uint16 + +// Supported jump conditions. +const ( + opJumpAlways jumpOp = iota << 4 + opJumpEqual + opJumpGT + opJumpGE + opJumpSet +) + +const ( + opRetSrcConstant uint16 = iota << 4 + opRetSrcA +) + +const ( + opMiscTAX = 0x00 + opMiscTXA = 0x80 +) diff --git a/vendor/golang.org/x/net/bpf/doc.go b/vendor/golang.org/x/net/bpf/doc.go new file mode 100644 index 000000000..04ec1c8ab --- /dev/null +++ b/vendor/golang.org/x/net/bpf/doc.go @@ -0,0 +1,80 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package bpf implements marshaling and unmarshaling of programs for the +Berkeley Packet Filter virtual machine, and provides a Go implementation +of the virtual machine. + +BPF's main use is to specify a packet filter for network taps, so that +the kernel doesn't have to expensively copy every packet it sees to +userspace. However, it's been repurposed to other areas where running +user code in-kernel is needed. For example, Linux's seccomp uses BPF +to apply security policies to system calls. For simplicity, this +documentation refers only to packets, but other uses of BPF have their +own data payloads. + +BPF programs run in a restricted virtual machine. It has almost no +access to kernel functions, and while conditional branches are +allowed, they can only jump forwards, to guarantee that there are no +infinite loops. + +# The virtual machine + +The BPF VM is an accumulator machine. Its main register, called +register A, is an implicit source and destination in all arithmetic +and logic operations. The machine also has 16 scratch registers for +temporary storage, and an indirection register (register X) for +indirect memory access. All registers are 32 bits wide. + +Each run of a BPF program is given one packet, which is placed in the +VM's read-only "main memory". LoadAbsolute and LoadIndirect +instructions can fetch up to 32 bits at a time into register A for +examination. + +The goal of a BPF program is to produce and return a verdict (uint32), +which tells the kernel what to do with the packet. In the context of +packet filtering, the returned value is the number of bytes of the +packet to forward to userspace, or 0 to ignore the packet. Other +contexts like seccomp define their own return values. + +In order to simplify programs, attempts to read past the end of the +packet terminate the program execution with a verdict of 0 (ignore +packet). This means that the vast majority of BPF programs don't need +to do any explicit bounds checking. + +In addition to the bytes of the packet, some BPF programs have access +to extensions, which are essentially calls to kernel utility +functions. Currently, the only extensions supported by this package +are the Linux packet filter extensions. + +# Examples + +This packet filter selects all ARP packets. + + bpf.Assemble([]bpf.Instruction{ + // Load "EtherType" field from the ethernet header. + bpf.LoadAbsolute{Off: 12, Size: 2}, + // Skip over the next instruction if EtherType is not ARP. + bpf.JumpIf{Cond: bpf.JumpNotEqual, Val: 0x0806, SkipTrue: 1}, + // Verdict is "send up to 4k of the packet to userspace." + bpf.RetConstant{Val: 4096}, + // Verdict is "ignore packet." + bpf.RetConstant{Val: 0}, + }) + +This packet filter captures a random 1% sample of traffic. + + bpf.Assemble([]bpf.Instruction{ + // Get a 32-bit random number from the Linux kernel. + bpf.LoadExtension{Num: bpf.ExtRand}, + // 1% dice roll? + bpf.JumpIf{Cond: bpf.JumpLessThan, Val: 2^32/100, SkipFalse: 1}, + // Capture. + bpf.RetConstant{Val: 4096}, + // Ignore. + bpf.RetConstant{Val: 0}, + }) +*/ +package bpf // import "golang.org/x/net/bpf" diff --git a/vendor/golang.org/x/net/bpf/instructions.go b/vendor/golang.org/x/net/bpf/instructions.go new file mode 100644 index 000000000..3cffcaa01 --- /dev/null +++ b/vendor/golang.org/x/net/bpf/instructions.go @@ -0,0 +1,726 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bpf + +import "fmt" + +// An Instruction is one instruction executed by the BPF virtual +// machine. +type Instruction interface { + // Assemble assembles the Instruction into a RawInstruction. + Assemble() (RawInstruction, error) +} + +// A RawInstruction is a raw BPF virtual machine instruction. +type RawInstruction struct { + // Operation to execute. + Op uint16 + // For conditional jump instructions, the number of instructions + // to skip if the condition is true/false. + Jt uint8 + Jf uint8 + // Constant parameter. The meaning depends on the Op. + K uint32 +} + +// Assemble implements the Instruction Assemble method. +func (ri RawInstruction) Assemble() (RawInstruction, error) { return ri, nil } + +// Disassemble parses ri into an Instruction and returns it. If ri is +// not recognized by this package, ri itself is returned. +func (ri RawInstruction) Disassemble() Instruction { + switch ri.Op & opMaskCls { + case opClsLoadA, opClsLoadX: + reg := Register(ri.Op & opMaskLoadDest) + sz := 0 + switch ri.Op & opMaskLoadWidth { + case opLoadWidth4: + sz = 4 + case opLoadWidth2: + sz = 2 + case opLoadWidth1: + sz = 1 + default: + return ri + } + switch ri.Op & opMaskLoadMode { + case opAddrModeImmediate: + if sz != 4 { + return ri + } + return LoadConstant{Dst: reg, Val: ri.K} + case opAddrModeScratch: + if sz != 4 || ri.K > 15 { + return ri + } + return LoadScratch{Dst: reg, N: int(ri.K)} + case opAddrModeAbsolute: + if ri.K > extOffset+0xffffffff { + return LoadExtension{Num: Extension(-extOffset + ri.K)} + } + return LoadAbsolute{Size: sz, Off: ri.K} + case opAddrModeIndirect: + return LoadIndirect{Size: sz, Off: ri.K} + case opAddrModePacketLen: + if sz != 4 { + return ri + } + return LoadExtension{Num: ExtLen} + case opAddrModeMemShift: + return LoadMemShift{Off: ri.K} + default: + return ri + } + + case opClsStoreA: + if ri.Op != opClsStoreA || ri.K > 15 { + return ri + } + return StoreScratch{Src: RegA, N: int(ri.K)} + + case opClsStoreX: + if ri.Op != opClsStoreX || ri.K > 15 { + return ri + } + return StoreScratch{Src: RegX, N: int(ri.K)} + + case opClsALU: + switch op := ALUOp(ri.Op & opMaskOperator); op { + case ALUOpAdd, ALUOpSub, ALUOpMul, ALUOpDiv, ALUOpOr, ALUOpAnd, ALUOpShiftLeft, ALUOpShiftRight, ALUOpMod, ALUOpXor: + switch operand := opOperand(ri.Op & opMaskOperand); operand { + case opOperandX: + return ALUOpX{Op: op} + case opOperandConstant: + return ALUOpConstant{Op: op, Val: ri.K} + default: + return ri + } + case aluOpNeg: + return NegateA{} + default: + return ri + } + + case opClsJump: + switch op := jumpOp(ri.Op & opMaskOperator); op { + case opJumpAlways: + return Jump{Skip: ri.K} + case opJumpEqual, opJumpGT, opJumpGE, opJumpSet: + cond, skipTrue, skipFalse := jumpOpToTest(op, ri.Jt, ri.Jf) + switch operand := opOperand(ri.Op & opMaskOperand); operand { + case opOperandX: + return JumpIfX{Cond: cond, SkipTrue: skipTrue, SkipFalse: skipFalse} + case opOperandConstant: + return JumpIf{Cond: cond, Val: ri.K, SkipTrue: skipTrue, SkipFalse: skipFalse} + default: + return ri + } + default: + return ri + } + + case opClsReturn: + switch ri.Op { + case opClsReturn | opRetSrcA: + return RetA{} + case opClsReturn | opRetSrcConstant: + return RetConstant{Val: ri.K} + default: + return ri + } + + case opClsMisc: + switch ri.Op { + case opClsMisc | opMiscTAX: + return TAX{} + case opClsMisc | opMiscTXA: + return TXA{} + default: + return ri + } + + default: + panic("unreachable") // switch is exhaustive on the bit pattern + } +} + +func jumpOpToTest(op jumpOp, skipTrue uint8, skipFalse uint8) (JumpTest, uint8, uint8) { + var test JumpTest + + // Decode "fake" jump conditions that don't appear in machine code + // Ensures the Assemble -> Disassemble stage recreates the same instructions + // See https://github.com/golang/go/issues/18470 + if skipTrue == 0 { + switch op { + case opJumpEqual: + test = JumpNotEqual + case opJumpGT: + test = JumpLessOrEqual + case opJumpGE: + test = JumpLessThan + case opJumpSet: + test = JumpBitsNotSet + } + + return test, skipFalse, 0 + } + + switch op { + case opJumpEqual: + test = JumpEqual + case opJumpGT: + test = JumpGreaterThan + case opJumpGE: + test = JumpGreaterOrEqual + case opJumpSet: + test = JumpBitsSet + } + + return test, skipTrue, skipFalse +} + +// LoadConstant loads Val into register Dst. +type LoadConstant struct { + Dst Register + Val uint32 +} + +// Assemble implements the Instruction Assemble method. +func (a LoadConstant) Assemble() (RawInstruction, error) { + return assembleLoad(a.Dst, 4, opAddrModeImmediate, a.Val) +} + +// String returns the instruction in assembler notation. +func (a LoadConstant) String() string { + switch a.Dst { + case RegA: + return fmt.Sprintf("ld #%d", a.Val) + case RegX: + return fmt.Sprintf("ldx #%d", a.Val) + default: + return fmt.Sprintf("unknown instruction: %#v", a) + } +} + +// LoadScratch loads scratch[N] into register Dst. +type LoadScratch struct { + Dst Register + N int // 0-15 +} + +// Assemble implements the Instruction Assemble method. +func (a LoadScratch) Assemble() (RawInstruction, error) { + if a.N < 0 || a.N > 15 { + return RawInstruction{}, fmt.Errorf("invalid scratch slot %d", a.N) + } + return assembleLoad(a.Dst, 4, opAddrModeScratch, uint32(a.N)) +} + +// String returns the instruction in assembler notation. +func (a LoadScratch) String() string { + switch a.Dst { + case RegA: + return fmt.Sprintf("ld M[%d]", a.N) + case RegX: + return fmt.Sprintf("ldx M[%d]", a.N) + default: + return fmt.Sprintf("unknown instruction: %#v", a) + } +} + +// LoadAbsolute loads packet[Off:Off+Size] as an integer value into +// register A. +type LoadAbsolute struct { + Off uint32 + Size int // 1, 2 or 4 +} + +// Assemble implements the Instruction Assemble method. +func (a LoadAbsolute) Assemble() (RawInstruction, error) { + return assembleLoad(RegA, a.Size, opAddrModeAbsolute, a.Off) +} + +// String returns the instruction in assembler notation. +func (a LoadAbsolute) String() string { + switch a.Size { + case 1: // byte + return fmt.Sprintf("ldb [%d]", a.Off) + case 2: // half word + return fmt.Sprintf("ldh [%d]", a.Off) + case 4: // word + if a.Off > extOffset+0xffffffff { + return LoadExtension{Num: Extension(a.Off + 0x1000)}.String() + } + return fmt.Sprintf("ld [%d]", a.Off) + default: + return fmt.Sprintf("unknown instruction: %#v", a) + } +} + +// LoadIndirect loads packet[X+Off:X+Off+Size] as an integer value +// into register A. +type LoadIndirect struct { + Off uint32 + Size int // 1, 2 or 4 +} + +// Assemble implements the Instruction Assemble method. +func (a LoadIndirect) Assemble() (RawInstruction, error) { + return assembleLoad(RegA, a.Size, opAddrModeIndirect, a.Off) +} + +// String returns the instruction in assembler notation. +func (a LoadIndirect) String() string { + switch a.Size { + case 1: // byte + return fmt.Sprintf("ldb [x + %d]", a.Off) + case 2: // half word + return fmt.Sprintf("ldh [x + %d]", a.Off) + case 4: // word + return fmt.Sprintf("ld [x + %d]", a.Off) + default: + return fmt.Sprintf("unknown instruction: %#v", a) + } +} + +// LoadMemShift multiplies the first 4 bits of the byte at packet[Off] +// by 4 and stores the result in register X. +// +// This instruction is mainly useful to load into X the length of an +// IPv4 packet header in a single instruction, rather than have to do +// the arithmetic on the header's first byte by hand. +type LoadMemShift struct { + Off uint32 +} + +// Assemble implements the Instruction Assemble method. +func (a LoadMemShift) Assemble() (RawInstruction, error) { + return assembleLoad(RegX, 1, opAddrModeMemShift, a.Off) +} + +// String returns the instruction in assembler notation. +func (a LoadMemShift) String() string { + return fmt.Sprintf("ldx 4*([%d]&0xf)", a.Off) +} + +// LoadExtension invokes a linux-specific extension and stores the +// result in register A. +type LoadExtension struct { + Num Extension +} + +// Assemble implements the Instruction Assemble method. +func (a LoadExtension) Assemble() (RawInstruction, error) { + if a.Num == ExtLen { + return assembleLoad(RegA, 4, opAddrModePacketLen, 0) + } + return assembleLoad(RegA, 4, opAddrModeAbsolute, uint32(extOffset+a.Num)) +} + +// String returns the instruction in assembler notation. +func (a LoadExtension) String() string { + switch a.Num { + case ExtLen: + return "ld #len" + case ExtProto: + return "ld #proto" + case ExtType: + return "ld #type" + case ExtPayloadOffset: + return "ld #poff" + case ExtInterfaceIndex: + return "ld #ifidx" + case ExtNetlinkAttr: + return "ld #nla" + case ExtNetlinkAttrNested: + return "ld #nlan" + case ExtMark: + return "ld #mark" + case ExtQueue: + return "ld #queue" + case ExtLinkLayerType: + return "ld #hatype" + case ExtRXHash: + return "ld #rxhash" + case ExtCPUID: + return "ld #cpu" + case ExtVLANTag: + return "ld #vlan_tci" + case ExtVLANTagPresent: + return "ld #vlan_avail" + case ExtVLANProto: + return "ld #vlan_tpid" + case ExtRand: + return "ld #rand" + default: + return fmt.Sprintf("unknown instruction: %#v", a) + } +} + +// StoreScratch stores register Src into scratch[N]. +type StoreScratch struct { + Src Register + N int // 0-15 +} + +// Assemble implements the Instruction Assemble method. +func (a StoreScratch) Assemble() (RawInstruction, error) { + if a.N < 0 || a.N > 15 { + return RawInstruction{}, fmt.Errorf("invalid scratch slot %d", a.N) + } + var op uint16 + switch a.Src { + case RegA: + op = opClsStoreA + case RegX: + op = opClsStoreX + default: + return RawInstruction{}, fmt.Errorf("invalid source register %v", a.Src) + } + + return RawInstruction{ + Op: op, + K: uint32(a.N), + }, nil +} + +// String returns the instruction in assembler notation. +func (a StoreScratch) String() string { + switch a.Src { + case RegA: + return fmt.Sprintf("st M[%d]", a.N) + case RegX: + return fmt.Sprintf("stx M[%d]", a.N) + default: + return fmt.Sprintf("unknown instruction: %#v", a) + } +} + +// ALUOpConstant executes A = A Val. +type ALUOpConstant struct { + Op ALUOp + Val uint32 +} + +// Assemble implements the Instruction Assemble method. +func (a ALUOpConstant) Assemble() (RawInstruction, error) { + return RawInstruction{ + Op: opClsALU | uint16(opOperandConstant) | uint16(a.Op), + K: a.Val, + }, nil +} + +// String returns the instruction in assembler notation. +func (a ALUOpConstant) String() string { + switch a.Op { + case ALUOpAdd: + return fmt.Sprintf("add #%d", a.Val) + case ALUOpSub: + return fmt.Sprintf("sub #%d", a.Val) + case ALUOpMul: + return fmt.Sprintf("mul #%d", a.Val) + case ALUOpDiv: + return fmt.Sprintf("div #%d", a.Val) + case ALUOpMod: + return fmt.Sprintf("mod #%d", a.Val) + case ALUOpAnd: + return fmt.Sprintf("and #%d", a.Val) + case ALUOpOr: + return fmt.Sprintf("or #%d", a.Val) + case ALUOpXor: + return fmt.Sprintf("xor #%d", a.Val) + case ALUOpShiftLeft: + return fmt.Sprintf("lsh #%d", a.Val) + case ALUOpShiftRight: + return fmt.Sprintf("rsh #%d", a.Val) + default: + return fmt.Sprintf("unknown instruction: %#v", a) + } +} + +// ALUOpX executes A = A X +type ALUOpX struct { + Op ALUOp +} + +// Assemble implements the Instruction Assemble method. +func (a ALUOpX) Assemble() (RawInstruction, error) { + return RawInstruction{ + Op: opClsALU | uint16(opOperandX) | uint16(a.Op), + }, nil +} + +// String returns the instruction in assembler notation. +func (a ALUOpX) String() string { + switch a.Op { + case ALUOpAdd: + return "add x" + case ALUOpSub: + return "sub x" + case ALUOpMul: + return "mul x" + case ALUOpDiv: + return "div x" + case ALUOpMod: + return "mod x" + case ALUOpAnd: + return "and x" + case ALUOpOr: + return "or x" + case ALUOpXor: + return "xor x" + case ALUOpShiftLeft: + return "lsh x" + case ALUOpShiftRight: + return "rsh x" + default: + return fmt.Sprintf("unknown instruction: %#v", a) + } +} + +// NegateA executes A = -A. +type NegateA struct{} + +// Assemble implements the Instruction Assemble method. +func (a NegateA) Assemble() (RawInstruction, error) { + return RawInstruction{ + Op: opClsALU | uint16(aluOpNeg), + }, nil +} + +// String returns the instruction in assembler notation. +func (a NegateA) String() string { + return fmt.Sprintf("neg") +} + +// Jump skips the following Skip instructions in the program. +type Jump struct { + Skip uint32 +} + +// Assemble implements the Instruction Assemble method. +func (a Jump) Assemble() (RawInstruction, error) { + return RawInstruction{ + Op: opClsJump | uint16(opJumpAlways), + K: a.Skip, + }, nil +} + +// String returns the instruction in assembler notation. +func (a Jump) String() string { + return fmt.Sprintf("ja %d", a.Skip) +} + +// JumpIf skips the following Skip instructions in the program if A +// Val is true. +type JumpIf struct { + Cond JumpTest + Val uint32 + SkipTrue uint8 + SkipFalse uint8 +} + +// Assemble implements the Instruction Assemble method. +func (a JumpIf) Assemble() (RawInstruction, error) { + return jumpToRaw(a.Cond, opOperandConstant, a.Val, a.SkipTrue, a.SkipFalse) +} + +// String returns the instruction in assembler notation. +func (a JumpIf) String() string { + return jumpToString(a.Cond, fmt.Sprintf("#%d", a.Val), a.SkipTrue, a.SkipFalse) +} + +// JumpIfX skips the following Skip instructions in the program if A +// X is true. +type JumpIfX struct { + Cond JumpTest + SkipTrue uint8 + SkipFalse uint8 +} + +// Assemble implements the Instruction Assemble method. +func (a JumpIfX) Assemble() (RawInstruction, error) { + return jumpToRaw(a.Cond, opOperandX, 0, a.SkipTrue, a.SkipFalse) +} + +// String returns the instruction in assembler notation. +func (a JumpIfX) String() string { + return jumpToString(a.Cond, "x", a.SkipTrue, a.SkipFalse) +} + +// jumpToRaw assembles a jump instruction into a RawInstruction +func jumpToRaw(test JumpTest, operand opOperand, k uint32, skipTrue, skipFalse uint8) (RawInstruction, error) { + var ( + cond jumpOp + flip bool + ) + switch test { + case JumpEqual: + cond = opJumpEqual + case JumpNotEqual: + cond, flip = opJumpEqual, true + case JumpGreaterThan: + cond = opJumpGT + case JumpLessThan: + cond, flip = opJumpGE, true + case JumpGreaterOrEqual: + cond = opJumpGE + case JumpLessOrEqual: + cond, flip = opJumpGT, true + case JumpBitsSet: + cond = opJumpSet + case JumpBitsNotSet: + cond, flip = opJumpSet, true + default: + return RawInstruction{}, fmt.Errorf("unknown JumpTest %v", test) + } + jt, jf := skipTrue, skipFalse + if flip { + jt, jf = jf, jt + } + return RawInstruction{ + Op: opClsJump | uint16(cond) | uint16(operand), + Jt: jt, + Jf: jf, + K: k, + }, nil +} + +// jumpToString converts a jump instruction to assembler notation +func jumpToString(cond JumpTest, operand string, skipTrue, skipFalse uint8) string { + switch cond { + // K == A + case JumpEqual: + return conditionalJump(operand, skipTrue, skipFalse, "jeq", "jneq") + // K != A + case JumpNotEqual: + return fmt.Sprintf("jneq %s,%d", operand, skipTrue) + // K > A + case JumpGreaterThan: + return conditionalJump(operand, skipTrue, skipFalse, "jgt", "jle") + // K < A + case JumpLessThan: + return fmt.Sprintf("jlt %s,%d", operand, skipTrue) + // K >= A + case JumpGreaterOrEqual: + return conditionalJump(operand, skipTrue, skipFalse, "jge", "jlt") + // K <= A + case JumpLessOrEqual: + return fmt.Sprintf("jle %s,%d", operand, skipTrue) + // K & A != 0 + case JumpBitsSet: + if skipFalse > 0 { + return fmt.Sprintf("jset %s,%d,%d", operand, skipTrue, skipFalse) + } + return fmt.Sprintf("jset %s,%d", operand, skipTrue) + // K & A == 0, there is no assembler instruction for JumpBitNotSet, use JumpBitSet and invert skips + case JumpBitsNotSet: + return jumpToString(JumpBitsSet, operand, skipFalse, skipTrue) + default: + return fmt.Sprintf("unknown JumpTest %#v", cond) + } +} + +func conditionalJump(operand string, skipTrue, skipFalse uint8, positiveJump, negativeJump string) string { + if skipTrue > 0 { + if skipFalse > 0 { + return fmt.Sprintf("%s %s,%d,%d", positiveJump, operand, skipTrue, skipFalse) + } + return fmt.Sprintf("%s %s,%d", positiveJump, operand, skipTrue) + } + return fmt.Sprintf("%s %s,%d", negativeJump, operand, skipFalse) +} + +// RetA exits the BPF program, returning the value of register A. +type RetA struct{} + +// Assemble implements the Instruction Assemble method. +func (a RetA) Assemble() (RawInstruction, error) { + return RawInstruction{ + Op: opClsReturn | opRetSrcA, + }, nil +} + +// String returns the instruction in assembler notation. +func (a RetA) String() string { + return fmt.Sprintf("ret a") +} + +// RetConstant exits the BPF program, returning a constant value. +type RetConstant struct { + Val uint32 +} + +// Assemble implements the Instruction Assemble method. +func (a RetConstant) Assemble() (RawInstruction, error) { + return RawInstruction{ + Op: opClsReturn | opRetSrcConstant, + K: a.Val, + }, nil +} + +// String returns the instruction in assembler notation. +func (a RetConstant) String() string { + return fmt.Sprintf("ret #%d", a.Val) +} + +// TXA copies the value of register X to register A. +type TXA struct{} + +// Assemble implements the Instruction Assemble method. +func (a TXA) Assemble() (RawInstruction, error) { + return RawInstruction{ + Op: opClsMisc | opMiscTXA, + }, nil +} + +// String returns the instruction in assembler notation. +func (a TXA) String() string { + return fmt.Sprintf("txa") +} + +// TAX copies the value of register A to register X. +type TAX struct{} + +// Assemble implements the Instruction Assemble method. +func (a TAX) Assemble() (RawInstruction, error) { + return RawInstruction{ + Op: opClsMisc | opMiscTAX, + }, nil +} + +// String returns the instruction in assembler notation. +func (a TAX) String() string { + return fmt.Sprintf("tax") +} + +func assembleLoad(dst Register, loadSize int, mode uint16, k uint32) (RawInstruction, error) { + var ( + cls uint16 + sz uint16 + ) + switch dst { + case RegA: + cls = opClsLoadA + case RegX: + cls = opClsLoadX + default: + return RawInstruction{}, fmt.Errorf("invalid target register %v", dst) + } + switch loadSize { + case 1: + sz = opLoadWidth1 + case 2: + sz = opLoadWidth2 + case 4: + sz = opLoadWidth4 + default: + return RawInstruction{}, fmt.Errorf("invalid load byte length %d", sz) + } + return RawInstruction{ + Op: cls | sz | mode, + K: k, + }, nil +} diff --git a/vendor/golang.org/x/net/bpf/setter.go b/vendor/golang.org/x/net/bpf/setter.go new file mode 100644 index 000000000..43e35f0ac --- /dev/null +++ b/vendor/golang.org/x/net/bpf/setter.go @@ -0,0 +1,10 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bpf + +// A Setter is a type which can attach a compiled BPF filter to itself. +type Setter interface { + SetBPF(filter []RawInstruction) error +} diff --git a/vendor/golang.org/x/net/bpf/vm.go b/vendor/golang.org/x/net/bpf/vm.go new file mode 100644 index 000000000..73f57f1f7 --- /dev/null +++ b/vendor/golang.org/x/net/bpf/vm.go @@ -0,0 +1,150 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bpf + +import ( + "errors" + "fmt" +) + +// A VM is an emulated BPF virtual machine. +type VM struct { + filter []Instruction +} + +// NewVM returns a new VM using the input BPF program. +func NewVM(filter []Instruction) (*VM, error) { + if len(filter) == 0 { + return nil, errors.New("one or more Instructions must be specified") + } + + for i, ins := range filter { + check := len(filter) - (i + 1) + switch ins := ins.(type) { + // Check for out-of-bounds jumps in instructions + case Jump: + if check <= int(ins.Skip) { + return nil, fmt.Errorf("cannot jump %d instructions; jumping past program bounds", ins.Skip) + } + case JumpIf: + if check <= int(ins.SkipTrue) { + return nil, fmt.Errorf("cannot jump %d instructions in true case; jumping past program bounds", ins.SkipTrue) + } + if check <= int(ins.SkipFalse) { + return nil, fmt.Errorf("cannot jump %d instructions in false case; jumping past program bounds", ins.SkipFalse) + } + case JumpIfX: + if check <= int(ins.SkipTrue) { + return nil, fmt.Errorf("cannot jump %d instructions in true case; jumping past program bounds", ins.SkipTrue) + } + if check <= int(ins.SkipFalse) { + return nil, fmt.Errorf("cannot jump %d instructions in false case; jumping past program bounds", ins.SkipFalse) + } + // Check for division or modulus by zero + case ALUOpConstant: + if ins.Val != 0 { + break + } + + switch ins.Op { + case ALUOpDiv, ALUOpMod: + return nil, errors.New("cannot divide by zero using ALUOpConstant") + } + // Check for unknown extensions + case LoadExtension: + switch ins.Num { + case ExtLen: + default: + return nil, fmt.Errorf("extension %d not implemented", ins.Num) + } + } + } + + // Make sure last instruction is a return instruction + switch filter[len(filter)-1].(type) { + case RetA, RetConstant: + default: + return nil, errors.New("BPF program must end with RetA or RetConstant") + } + + // Though our VM works using disassembled instructions, we + // attempt to assemble the input filter anyway to ensure it is compatible + // with an operating system VM. + _, err := Assemble(filter) + + return &VM{ + filter: filter, + }, err +} + +// Run runs the VM's BPF program against the input bytes. +// Run returns the number of bytes accepted by the BPF program, and any errors +// which occurred while processing the program. +func (v *VM) Run(in []byte) (int, error) { + var ( + // Registers of the virtual machine + regA uint32 + regX uint32 + regScratch [16]uint32 + + // OK is true if the program should continue processing the next + // instruction, or false if not, causing the loop to break + ok = true + ) + + // TODO(mdlayher): implement: + // - NegateA: + // - would require a change from uint32 registers to int32 + // registers + + // TODO(mdlayher): add interop tests that check signedness of ALU + // operations against kernel implementation, and make sure Go + // implementation matches behavior + + for i := 0; i < len(v.filter) && ok; i++ { + ins := v.filter[i] + + switch ins := ins.(type) { + case ALUOpConstant: + regA = aluOpConstant(ins, regA) + case ALUOpX: + regA, ok = aluOpX(ins, regA, regX) + case Jump: + i += int(ins.Skip) + case JumpIf: + jump := jumpIf(ins, regA) + i += jump + case JumpIfX: + jump := jumpIfX(ins, regA, regX) + i += jump + case LoadAbsolute: + regA, ok = loadAbsolute(ins, in) + case LoadConstant: + regA, regX = loadConstant(ins, regA, regX) + case LoadExtension: + regA = loadExtension(ins, in) + case LoadIndirect: + regA, ok = loadIndirect(ins, in, regX) + case LoadMemShift: + regX, ok = loadMemShift(ins, in) + case LoadScratch: + regA, regX = loadScratch(ins, regScratch, regA, regX) + case RetA: + return int(regA), nil + case RetConstant: + return int(ins.Val), nil + case StoreScratch: + regScratch = storeScratch(ins, regScratch, regA, regX) + case TAX: + regX = regA + case TXA: + regA = regX + default: + return 0, fmt.Errorf("unknown Instruction at index %d: %T", i, ins) + } + } + + return 0, nil +} diff --git a/vendor/golang.org/x/net/bpf/vm_instructions.go b/vendor/golang.org/x/net/bpf/vm_instructions.go new file mode 100644 index 000000000..0aa307c06 --- /dev/null +++ b/vendor/golang.org/x/net/bpf/vm_instructions.go @@ -0,0 +1,182 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package bpf + +import ( + "encoding/binary" + "fmt" +) + +func aluOpConstant(ins ALUOpConstant, regA uint32) uint32 { + return aluOpCommon(ins.Op, regA, ins.Val) +} + +func aluOpX(ins ALUOpX, regA uint32, regX uint32) (uint32, bool) { + // Guard against division or modulus by zero by terminating + // the program, as the OS BPF VM does + if regX == 0 { + switch ins.Op { + case ALUOpDiv, ALUOpMod: + return 0, false + } + } + + return aluOpCommon(ins.Op, regA, regX), true +} + +func aluOpCommon(op ALUOp, regA uint32, value uint32) uint32 { + switch op { + case ALUOpAdd: + return regA + value + case ALUOpSub: + return regA - value + case ALUOpMul: + return regA * value + case ALUOpDiv: + // Division by zero not permitted by NewVM and aluOpX checks + return regA / value + case ALUOpOr: + return regA | value + case ALUOpAnd: + return regA & value + case ALUOpShiftLeft: + return regA << value + case ALUOpShiftRight: + return regA >> value + case ALUOpMod: + // Modulus by zero not permitted by NewVM and aluOpX checks + return regA % value + case ALUOpXor: + return regA ^ value + default: + return regA + } +} + +func jumpIf(ins JumpIf, regA uint32) int { + return jumpIfCommon(ins.Cond, ins.SkipTrue, ins.SkipFalse, regA, ins.Val) +} + +func jumpIfX(ins JumpIfX, regA uint32, regX uint32) int { + return jumpIfCommon(ins.Cond, ins.SkipTrue, ins.SkipFalse, regA, regX) +} + +func jumpIfCommon(cond JumpTest, skipTrue, skipFalse uint8, regA uint32, value uint32) int { + var ok bool + + switch cond { + case JumpEqual: + ok = regA == value + case JumpNotEqual: + ok = regA != value + case JumpGreaterThan: + ok = regA > value + case JumpLessThan: + ok = regA < value + case JumpGreaterOrEqual: + ok = regA >= value + case JumpLessOrEqual: + ok = regA <= value + case JumpBitsSet: + ok = (regA & value) != 0 + case JumpBitsNotSet: + ok = (regA & value) == 0 + } + + if ok { + return int(skipTrue) + } + + return int(skipFalse) +} + +func loadAbsolute(ins LoadAbsolute, in []byte) (uint32, bool) { + offset := int(ins.Off) + size := ins.Size + + return loadCommon(in, offset, size) +} + +func loadConstant(ins LoadConstant, regA uint32, regX uint32) (uint32, uint32) { + switch ins.Dst { + case RegA: + regA = ins.Val + case RegX: + regX = ins.Val + } + + return regA, regX +} + +func loadExtension(ins LoadExtension, in []byte) uint32 { + switch ins.Num { + case ExtLen: + return uint32(len(in)) + default: + panic(fmt.Sprintf("unimplemented extension: %d", ins.Num)) + } +} + +func loadIndirect(ins LoadIndirect, in []byte, regX uint32) (uint32, bool) { + offset := int(ins.Off) + int(regX) + size := ins.Size + + return loadCommon(in, offset, size) +} + +func loadMemShift(ins LoadMemShift, in []byte) (uint32, bool) { + offset := int(ins.Off) + + // Size of LoadMemShift is always 1 byte + if !inBounds(len(in), offset, 1) { + return 0, false + } + + // Mask off high 4 bits and multiply low 4 bits by 4 + return uint32(in[offset]&0x0f) * 4, true +} + +func inBounds(inLen int, offset int, size int) bool { + return offset+size <= inLen +} + +func loadCommon(in []byte, offset int, size int) (uint32, bool) { + if !inBounds(len(in), offset, size) { + return 0, false + } + + switch size { + case 1: + return uint32(in[offset]), true + case 2: + return uint32(binary.BigEndian.Uint16(in[offset : offset+size])), true + case 4: + return uint32(binary.BigEndian.Uint32(in[offset : offset+size])), true + default: + panic(fmt.Sprintf("invalid load size: %d", size)) + } +} + +func loadScratch(ins LoadScratch, regScratch [16]uint32, regA uint32, regX uint32) (uint32, uint32) { + switch ins.Dst { + case RegA: + regA = regScratch[ins.N] + case RegX: + regX = regScratch[ins.N] + } + + return regA, regX +} + +func storeScratch(ins StoreScratch, regScratch [16]uint32, regA uint32, regX uint32) [16]uint32 { + switch ins.Src { + case RegA: + regScratch[ins.N] = regA + case RegX: + regScratch[ins.N] = regX + } + + return regScratch +} diff --git a/vendor/modules.txt b/vendor/modules.txt index ebea165f7..cf4415731 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -175,6 +175,15 @@ github.com/google/go-cmp/cmp/internal/diff github.com/google/go-cmp/cmp/internal/flags github.com/google/go-cmp/cmp/internal/function github.com/google/go-cmp/cmp/internal/value +# github.com/google/nftables v0.3.0 +## explicit; go 1.21 +github.com/google/nftables +github.com/google/nftables/alignedbuff +github.com/google/nftables/binaryutil +github.com/google/nftables/expr +github.com/google/nftables/internal/parseexprfunc +github.com/google/nftables/userdata +github.com/google/nftables/xt # github.com/google/pprof v0.0.0-20260202012954-cb029daf43ef ## explicit; go 1.24.0 github.com/google/pprof/profile @@ -261,6 +270,14 @@ github.com/maxbrunsfeld/counterfeiter/v6 github.com/maxbrunsfeld/counterfeiter/v6/arguments github.com/maxbrunsfeld/counterfeiter/v6/command github.com/maxbrunsfeld/counterfeiter/v6/generator +# github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 +## explicit; go 1.21 +github.com/mdlayher/netlink +github.com/mdlayher/netlink/nlenc +github.com/mdlayher/netlink/nltest +# github.com/mdlayher/socket v0.5.0 +## explicit; go 1.21 +github.com/mdlayher/socket # github.com/mitchellh/mapstructure v1.5.0 ## explicit; go 1.14 github.com/mitchellh/mapstructure @@ -385,6 +402,7 @@ golang.org/x/mod/module golang.org/x/mod/semver # golang.org/x/net v0.49.0 ## explicit; go 1.24.0 +golang.org/x/net/bpf golang.org/x/net/context golang.org/x/net/html golang.org/x/net/html/atom