From ae588d82e3b4e4006f758c5880e903b2d21f2125 Mon Sep 17 00:00:00 2001 From: Alex Valiushko Date: Wed, 25 Mar 2026 21:09:00 -0700 Subject: [PATCH 1/4] bump go to 1.26 Signed-off-by: Alex Valiushko Change-Id: I3852ae79c0a5f6bed0fad16e1860eed06a6a6964 --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 37b0b6daf..56db8f3c8 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/tailscale/wireguard-go -go 1.25 +go 1.26 require ( golang.org/x/crypto v0.13.0 From fcac8646321b076cc3cfe7c8d9a4c66340715ece Mon Sep 17 00:00:00 2001 From: Alex Valiushko Date: Wed, 11 Mar 2026 11:07:45 -0700 Subject: [PATCH 2/4] device: fix slow test and go vet Signed-off-by: Alex Valiushko Change-Id: I97a2d22561468de14b17e09d557f212b6a6a6964 --- device/pools_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/device/pools_test.go b/device/pools_test.go index 2b16f3984..0d1009b8e 100644 --- a/device/pools_test.go +++ b/device/pools_test.go @@ -17,7 +17,8 @@ import ( func TestWaitPool(t *testing.T) { var wg sync.WaitGroup var trials atomic.Int32 - startTrials := int32(100000) + n := int32(runtime.NumCPU()) + startTrials := 125 * n * n if raceEnabled { // This test can be very slow with -race. startTrials /= 10 @@ -63,7 +64,7 @@ func TestWaitPool(t *testing.T) { } wg.Wait() if max.Load() != p.max { - t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) + t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max.Load(), p.max) } } From f34bee92efd328c90883586877bc0f70d51248ba Mon Sep 17 00:00:00 2001 From: Alex Valiushko Date: Tue, 3 Mar 2026 13:15:34 -0800 Subject: [PATCH 3/4] tun, device: allocate buffers in the edge devices Signed-off-by: Alex Valiushko Change-Id: I58908d9d3fd09441e9378a74b0ee19136a6a6964 --- buffer/buffer.go | 111 +++++++++++++++++++++++++++++++++++ buffer/pool.go | 59 +++++++++++++++++++ conn/bind_std.go | 68 +++++++++++++++------- conn/bind_std_test.go | 32 ++++++++--- conn/bind_windows.go | 9 +-- conn/bindtest/bindtest.go | 8 ++- conn/conn.go | 5 +- conn/conn_test.go | 4 +- device/channels.go | 4 +- device/device.go | 3 +- device/device_test.go | 3 +- device/pools.go | 18 +++--- device/receive.go | 65 +++++++++++---------- device/send.go | 118 ++++++++++++++++++++------------------ tun/netstack/tun.go | 8 ++- tun/offload.go | 18 ++++-- tun/offload_linux.go | 35 +++++++---- tun/offload_linux_test.go | 19 +++--- tun/offload_test.go | 19 +++--- tun/tun.go | 5 +- tun/tun_darwin.go | 8 ++- tun/tun_freebsd.go | 8 ++- tun/tun_linux.go | 49 ++++++++++++---- tun/tun_openbsd.go | 8 ++- tun/tun_plan9.go | 12 +++- tun/tun_windows.go | 8 ++- tun/tuntest/tuntest.go | 8 ++- 27 files changed, 512 insertions(+), 200 deletions(-) create mode 100644 buffer/buffer.go create mode 100644 buffer/pool.go diff --git a/buffer/buffer.go b/buffer/buffer.go new file mode 100644 index 000000000..83bda2fe2 --- /dev/null +++ b/buffer/buffer.go @@ -0,0 +1,111 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +// Package buffer implements a reusable buffer abstraction. +// +// Wireguard-go's data processing is constrained by both the hosts API, +// and the transformations performed during encapsulation: +// +// 1. Encryption requires tail- and headroom for extra headers and padding. +// Available via winrio, and pread(2). +// 2. Systems are moving towards coalesced reads for both TCP and UDP. +// The read data has no gaps for individual slices. +// 3. crypto.AEAD interface requires a contiguous dst []byte for Sealing. +// So we can't use scatter-gather to inject the gaps. +// +// Until one of these three conditions is changed, the encryption strategy +// is to copy on read into buffers with the required gaps. +// The buffers are right-sized for the packet to avoid memory inflation. +// To recycle said allocations, each buffer carries a recycle function +// that routes it back to its originating pool. +// +// Decryption shrinks each fragment instead of growing, so buffers can pass +// through the pipeline without copying till the egress coalescion. +// Depending on the chosen head of the coalescion, there may or may be no room +// and reallocation is a necessary fallback until we start passing +// buffers in batches. + +package buffer + +const ( + MaxMessageSize = (1 << 16) - 1 // largest possible UDP datagram +) + +// Source produces new Buffers. +type Source interface { + + // Get returns a Buffer of at least the requested size. + // Implementations may chose to return error if the request can not be fulfilled. + Get(size int) (*Buffer, error) +} + +// Recycler holds state necessary for a correct Buffer return to its originating Source +type Recycler interface { + Recycle(*Buffer) +} + +// Buffer is a reusable slice of bytes of fixed length. +// The returned Data slice must not be retained past Release. +type Buffer struct { + data []byte + recycler Recycler +} + +// New creates a standalone Buffer. +func New(b []byte, recycler Recycler) *Buffer { + return &Buffer{data: b, recycler: recycler} +} + +// Make creates a standalone Buffer with a new byte slice of the requested size. +func Make(size int) *Buffer { + return &Buffer{data: make([]byte, size), recycler: nil} +} + +// Data returns the full underlying byte slice of the Buffer. +func (b *Buffer) Data() []byte { + return b.data +} + +// Release returns the Buffer to its originating Source for reuse. +// The Buffer must not be used after calling Release. +func (b *Buffer) Release() { + if b.recycler != nil { + clear(b.data) + b.recycler.Recycle(b) + } +} + +// ReleaseAll calls Release on each non-nil Buffer in the slice, and sets the slice elements to nil. +func ReleaseAll(bs []*Buffer) { + for i := range bs { + if bs[i] != nil { + bs[i].Release() + bs[i] = nil + } + } +} + +// Arena is a Buffer with an internal watermark for sequential allocations. +// FIXME Arena needs a graceful fallback on overflow. +type Arena struct { + *Buffer + watermark int +} + +// Get returns a slice of the Arena's Buffer of the requested size, and advances the watermark. +func (a *Arena) Get(size int) []byte { + if a.watermark+size > len(a.Buffer.Data()) { + panic("arena overflow") // or return a heap-allocated fallback + } + b := a.Buffer.Data()[a.watermark : a.watermark+size] + a.watermark += size + return b +} + +// Flush resets the Arena's watermark to zero, and clears the valid data in the Buffer. +func (a *Arena) Flush() { + clear(a.Buffer.Data()[:a.watermark]) + a.watermark = 0 +} diff --git a/buffer/pool.go b/buffer/pool.go new file mode 100644 index 000000000..acc38b0ea --- /dev/null +++ b/buffer/pool.go @@ -0,0 +1,59 @@ +package buffer + +import "sync" + +const ( + min = 2 << 10 // 2KB, enough for a typical MTU-sized packet + mid = 10 << 10 // 10KB, enough for a jumbo frame + max = 65 << 10 // 65KB, enough for the maximum possible UDP datagram size +) + +var _ Source = (*FragmentPool)(nil) + +// FragmentPool is a tiered source of buffers. Tiers are balanced +// to accomodate regular MTU sizes, jumbo frames, and the maximum possible UDP datagram size. +type FragmentPool struct { + minPool sync.Pool + midPool sync.Pool + maxPool sync.Pool +} + +type poolRecycler struct { + *sync.Pool +} + +func (p *poolRecycler) Recycle(b *Buffer) { + p.Put(b) +} + +func NewFragmentPool() *FragmentPool { + p := new(FragmentPool) + p.minPool.New = func() any { + return &Buffer{data: make([]byte, min), recycler: &poolRecycler{&p.minPool}} + } + p.midPool.New = func() any { + return &Buffer{data: make([]byte, mid), recycler: &poolRecycler{&p.midPool}} + } + p.maxPool.New = func() any { + return &Buffer{data: make([]byte, max), recycler: &poolRecycler{&p.maxPool}} + } + return p +} + +// Get returns a Buffer of at least the requested size. Implementations may chose to return error if the request can not be fulfilled. +// FIXME current code ignores err return +func (p *FragmentPool) Get(size int) (*Buffer, error) { + var buf *Buffer + switch { + case size <= min: + buf = p.minPool.Get().(*Buffer) + case size <= mid: + buf = p.midPool.Get().(*Buffer) + case size <= max: + buf = p.maxPool.Get().(*Buffer) + default: + return &Buffer{data: make([]byte, size)}, nil + } + buf.data = buf.data[:size] + return buf, nil +} diff --git a/conn/bind_std.go b/conn/bind_std.go index fc0563456..8f553a6c5 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -16,6 +16,7 @@ import ( "sync" "syscall" + "github.com/tailscale/wireguard-go/buffer" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) @@ -44,6 +45,7 @@ type StdNetBind struct { // these two fields are not guarded by mu udpAddrPool sync.Pool msgsPool sync.Pool + bufPool buffer.Source blackhole4 bool blackhole6 bool @@ -63,12 +65,14 @@ func NewStdNetBind() Bind { New: func() any { msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { - msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].Buffers = make(net.Buffers, 1, udpSegmentMaxDatagrams) msgs[i].OOB = make([]byte, controlSize) } return &msgs }, }, + + bufPool: buffer.NewFragmentPool(), } } @@ -204,7 +208,7 @@ again: func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { for i := range *msgs { - (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers[:1], OOB: (*msgs)[i].OOB} } s.msgsPool.Put(msgs) } @@ -230,36 +234,52 @@ func (s *StdNetBind) receiveIP( br batchReader, conn *net.UDPConn, rxOffload bool, - bufs [][]byte, + bufs []*buffer.Buffer, sizes []int, eps []Endpoint, ) (n int, err error) { msgs := s.getMessages() - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] - } defer s.putMessages(msgs) var numMsgs int if runtime.GOOS == "linux" { if rxOffload { - readAt := len(*msgs) - 2 + const readBatch = 2 + readAt := len(*msgs) - readBatch + for i := readAt; i < readAt+readBatch; i++ { + if bufs[i] == nil { + bufs[i], _ = s.bufPool.Get(buffer.MaxMessageSize) + } + (*msgs)[i].Buffers[0] = bufs[i].Data() + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) if err != nil { return 0, err } - numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize, bufs, s.bufPool) if err != nil { return 0, err } } else { + for i := range bufs { + if bufs[i] == nil { + bufs[i], _ = s.bufPool.Get(buffer.MaxMessageSize) + } + (*msgs)[i].Buffers[0] = bufs[i].Data() + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } numMsgs, err = br.ReadBatch(*msgs, 0) if err != nil { return 0, err } } } else { + if bufs[0] == nil { + bufs[0], _ = s.bufPool.Get(buffer.MaxMessageSize) + } msg := &(*msgs)[0] + msg.Buffers[0] = bufs[0].Data() + msg.OOB = msg.OOB[:cap(msg.OOB)] msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) if err != nil { return 0, err @@ -281,13 +301,13 @@ func (s *StdNetBind) receiveIP( } func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return func(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (n int, err error) { return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) } } func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return func(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (n int, err error) { return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) } } @@ -452,10 +472,11 @@ type setGSOFunc func(control *[]byte, gsoSize uint16) func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offset int, msgs []ipv6.Message, setGSO setGSOFunc) int { var ( - base = -1 // index of msg we are currently coalescing into - gsoSize int // segmentation size of msgs[base] - dgramCnt int // number of dgrams coalesced into msgs[base] - endBatch bool // tracking flag to start a new batch on next iteration of bufs + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of bufs + coalescedLen int // bytes coalesced into msgs[base] ) maxPayloadLen := maxIPv4PayloadLen if ep.DstIP().Is6() { @@ -465,18 +486,16 @@ func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offs buf = buf[offset:] if i > 0 { msgLen := len(buf) - baseLenBefore := len(msgs[base].Buffers[0]) - freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore - if msgLen+baseLenBefore <= maxPayloadLen && + if msgLen+coalescedLen <= maxPayloadLen && msgLen <= gsoSize && - msgLen <= freeBaseCap && dgramCnt < udpSegmentMaxDatagrams && !endBatch { - msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) + msgs[base].Buffers = append(msgs[base].Buffers, buf) if i == len(bufs)-1 { setGSO(&msgs[base].OOB, uint16(gsoSize)) } dgramCnt++ + coalescedLen += msgLen if msgLen < gsoSize { // A smaller than gsoSize packet on the tail is legal, but // it must end the batch. @@ -497,13 +516,14 @@ func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offs msgs[base].Buffers[0] = buf msgs[base].Addr = addr dgramCnt = 1 + coalescedLen = gsoSize } return base + 1 } type getGSOFunc func(control []byte) (int, error) -func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc, bufs []*buffer.Buffer, pool buffer.Source) (n int, err error) { for i := firstMsgAt; i < len(msgs); i++ { msg := &msgs[i] if msg.N == 0 { @@ -527,6 +547,12 @@ func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFu if n > i { return n, errors.New("splitting coalesced packet resulted in overflow") } + segLen := end - start + if bufs[n] == nil { + bufs[n], _ = pool.Get(segLen) + } + msgs[n].Buffers[0] = bufs[n].Data() + msgs[n].OOB = msgs[n].OOB[:cap(msgs[n].OOB)] copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) msgs[n].N = copied msgs[n].Addr = msg.Addr diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go index 77af0d925..0b3d01476 100644 --- a/conn/bind_std_test.go +++ b/conn/bind_std_test.go @@ -1,10 +1,12 @@ package conn import ( + "bytes" "encoding/binary" "net" "testing" + "github.com/tailscale/wireguard-go/buffer" "golang.org/x/net/ipv6" ) @@ -15,10 +17,10 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { t.Fatal(err) } bind.Close() - bufs := make([][]byte, 1) - bufs[0] = make([]byte, 1) - sizes := make([]int, 1) - eps := make([]Endpoint, 1) + bufs := make([]*buffer.Buffer, IdealBatchSize) + bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + sizes := make([]int, IdealBatchSize) + eps := make([]Endpoint, IdealBatchSize) for _, fn := range fns { // The ReceiveFuncs must not access conn-related fields on StdNetBind // unguarded. Close() nils the conn-related fields resulting in a panic @@ -82,8 +84,8 @@ func Test_coalesceMessages(t *testing.T) { make([]byte, 2, 2), make([]byte, 2, 2), }, - wantLens: []int{4, 2}, - wantGSO: []int{2, 0}, + wantLens: []int{6}, + wantGSO: []int{2}, }, } @@ -106,9 +108,12 @@ func Test_coalesceMessages(t *testing.T) { if msgs[i].Addr != addr { t.Errorf("msgs[%d].Addr != passed addr", i) } - gotLen := len(msgs[i].Buffers[0]) + var gotLen int + for _, b := range msgs[i].Buffers { + gotLen += len(b) + } if gotLen != tt.wantLens[i] { - t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + t.Errorf("len(msgs[%d].Buffers) %d != %d", i, gotLen, tt.wantLens[i]) } gotGSO, err := mockGetGSOSize(msgs[i].OOB) if err != nil { @@ -233,7 +238,11 @@ func Test_splitCoalescedMessages(t *testing.T) { for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize) + bufs := make([]*buffer.Buffer, len(tt.msgs)) + for i := range tt.msgs { + bufs[i] = buffer.New(tt.msgs[i].Buffers[0], nil) + } + got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize, bufs, buffer.NewFragmentPool()) if err != nil && !tt.wantErr { t.Fatalf("err: %v", err) } @@ -245,6 +254,11 @@ func Test_splitCoalescedMessages(t *testing.T) { t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) } } + for i := range got { + if !bytes.Equal(bufs[i].Data(), tt.msgs[i].Buffers[0]) { + t.Fatalf("bufs[%d].Data() and tt.msgs[%d] unequal", i, i) + } + } }) } } diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 737b475e1..0c99acfeb 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -17,6 +17,7 @@ import ( "golang.org/x/sys/windows" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn/winrio" ) @@ -416,19 +417,19 @@ retry: return n, &ep, nil } -func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { +func (bind *WinRingBind) receiveIPv4(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen) + n, ep, err := bind.v4.Receive(bufs[0].Data(), &bind.isOpen) sizes[0] = n eps[0] = ep return 1, err } -func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { +func (bind *WinRingBind) receiveIPv6(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen) + n, ep, err := bind.v6.Receive(bufs[0].Data(), &bind.isOpen) sizes[0] = n eps[0] = ep return 1, err diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 741b776c4..2fe41b7a7 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -12,6 +12,7 @@ import ( "net/netip" "os" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" ) @@ -94,12 +95,15 @@ func (c *ChannelBind) BatchSize() int { return 1 } func (c *ChannelBind) SetMark(mark uint32) error { return nil } func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { + return func(bufs []*buffer.Buffer, sizes []int, eps []conn.Endpoint) (n int, err error) { select { case <-c.closeSignal: return 0, net.ErrClosed case rx := <-ch: - copied := copy(bufs[0], rx) + if bufs[0] == nil { + bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + } + copied := copy(bufs[0].Data(), rx) sizes[0] = copied eps[0] = c.target6 return 1, nil diff --git a/conn/conn.go b/conn/conn.go index f1781614d..c674cf849 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -13,6 +13,8 @@ import ( "reflect" "runtime" "strings" + + "github.com/tailscale/wireguard-go/buffer" ) const ( @@ -25,7 +27,8 @@ const ( // sizes may be zero, and callers should ignore them. Callers must pass a sizes // and eps slice with a length greater than or equal to the length of packets. // These lengths must not exceed the length of the associated Bind.BatchSize(). -type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error) +// Nil entries in bufs are allocated by the implementation. +type ReceiveFunc func(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (n int, err error) // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // diff --git a/conn/conn_test.go b/conn/conn_test.go index c6194ee0c..77d493ea0 100644 --- a/conn/conn_test.go +++ b/conn/conn_test.go @@ -7,11 +7,13 @@ package conn import ( "testing" + + "github.com/tailscale/wireguard-go/buffer" ) func TestPrettyName(t *testing.T) { var ( - recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return } + recvFunc ReceiveFunc = func(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (n int, err error) { return } ) const want = "TestPrettyName" diff --git a/device/channels.go b/device/channels.go index e526f6bb1..7fff0febd 100644 --- a/device/channels.go +++ b/device/channels.go @@ -93,7 +93,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { case elemsContainer := <-q.c: elemsContainer.Lock() for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buf.Release() device.PutInboundElement(elem) } device.PutInboundElementsContainer(elemsContainer) @@ -126,7 +126,7 @@ func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { case elemsContainer := <-q.c: elemsContainer.Lock() for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buf.Release() device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) diff --git a/device/device.go b/device/device.go index 0e720f251..267cb91c8 100644 --- a/device/device.go +++ b/device/device.go @@ -13,6 +13,7 @@ import ( "sync/atomic" "time" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/ratelimiter" "github.com/tailscale/wireguard-go/rwcancel" @@ -73,9 +74,9 @@ type Device struct { pool struct { inboundElementsContainer *WaitPool outboundElementsContainer *WaitPool - messageBuffers *WaitPool inboundElements *WaitPool outboundElements *WaitPool + messageBuffers buffer.Source } queue struct { diff --git a/device/device_test.go b/device/device_test.go index e44342170..3c00c91f8 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/conn/bindtest" "github.com/tailscale/wireguard-go/tun" @@ -437,7 +438,7 @@ type fakeTUNDeviceSized struct { } func (t *fakeTUNDeviceSized) File() *os.File { return nil } -func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { +func (t *fakeTUNDeviceSized) Read(bufs []*buffer.Buffer, sizes []int, offset int) (n int, err error) { return 0, nil } func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } diff --git a/device/pools.go b/device/pools.go index 55d2be7df..ee3cb46d9 100644 --- a/device/pools.go +++ b/device/pools.go @@ -7,6 +7,8 @@ package device import ( "sync" + + "github.com/tailscale/wireguard-go/buffer" ) type WaitPool struct { @@ -55,15 +57,13 @@ func (device *Device) PopulatePools() { s := make([]*QueueOutboundElement, 0, device.BatchSize()) return &QueueOutboundElementsContainer{elems: s} }) - device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { - return new([MaxMessageSize]byte) - }) device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new(QueueInboundElement) }) device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new(QueueOutboundElement) }) + device.pool.messageBuffers = buffer.NewFragmentPool() } func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { @@ -94,12 +94,12 @@ func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsConta device.pool.outboundElementsContainer.Put(c) } -func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) -} - -func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - device.pool.messageBuffers.Put(msg) +func (device *Device) GetMessageBuffer(size int) *buffer.Buffer { + b, err := device.pool.messageBuffers.Get(size) + if err != nil { + panic("failed to get buffer: " + err.Error()) + } + return b } func (device *Device) GetInboundElement() *QueueInboundElement { diff --git a/device/receive.go b/device/receive.go index 56cde1047..9dbea3f38 100644 --- a/device/receive.go +++ b/device/receive.go @@ -13,6 +13,7 @@ import ( "sync" "time" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" @@ -23,11 +24,11 @@ type QueueHandshakeElement struct { msgType uint32 packet []byte endpoint conn.Endpoint - buffer *[MaxMessageSize]byte + buf *buffer.Buffer } type QueueInboundElement struct { - buffer *[MaxMessageSize]byte + buf *buffer.Buffer packet []byte counter uint64 keypair *Keypair @@ -44,7 +45,7 @@ type QueueInboundElementsContainer struct { // avoids accidentally keeping other objects around unnecessarily. // It also reduces the possible collateral damage from use-after-free bugs. func (elem *QueueInboundElement) clearPointers() { - elem.buffer = nil + elem.buf = nil elem.packet = nil elem.keypair = nil elem.endpoint = nil @@ -84,27 +85,17 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive // receive datagrams until conn is closed var ( - bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) - bufs = make([][]byte, maxBatchSize) - err error + bufs = make([]*buffer.Buffer, maxBatchSize) // nil entries; recv allocates sizes = make([]int, maxBatchSize) + err error count int endpoints = make([]conn.Endpoint, maxBatchSize) deathSpiral int elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) - for i := range bufsArrs { - bufsArrs[i] = device.GetMessageBuffer() - bufs[i] = bufsArrs[i][:] - } - defer func() { - for i := 0; i < maxBatchSize; i++ { - if bufsArrs[i] != nil { - device.PutMessageBuffer(bufsArrs[i]) - } - } + buffer.ReleaseAll(bufs) }() for { @@ -134,7 +125,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive // check size of packet - packet := bufsArrs[i][:size] + packet := bufs[i].Data()[:size] msgType := binary.LittleEndian.Uint32(packet[:4]) switch msgType { @@ -170,7 +161,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive peer := value.peer elem := device.GetInboundElement() elem.packet = packet - elem.buffer = bufsArrs[i] + elem.buf = bufs[i] elem.keypair = keypair elem.endpoint = endpoints[i] elem.counter = 0 @@ -182,8 +173,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive elemsByPeer[peer] = elemsForPeer } elemsForPeer.elems = append(elemsForPeer.elems, elem) - bufsArrs[i] = device.GetMessageBuffer() - bufs[i] = bufsArrs[i][:] + bufs[i] = nil // consumed; next recv allocates fresh continue // otherwise it is a fixed size & handshake related packet @@ -211,22 +201,27 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive select { case device.queue.handshake.c <- QueueHandshakeElement{ msgType: msgType, - buffer: bufsArrs[i], + buf: bufs[i], packet: packet, endpoint: endpoints[i], }: - bufsArrs[i] = device.GetMessageBuffer() - bufs[i] = bufsArrs[i][:] + bufs[i] = nil // consumed; next recv allocates fresh default: } } + for i := 0; i < count; i++ { + if bufs[i] != nil { + bufs[i].Release() + bufs[i] = nil + } + } for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer device.queue.decryption.c <- elemsContainer } else { for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buf.Release() device.PutInboundElement(elem) } device.PutInboundElementsContainer(elemsContainer) @@ -423,7 +418,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.PutMessageBuffer(elem.buffer) + elem.buf.Release() } } @@ -435,7 +430,8 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { }() device.log.Verbosef("%v - Routine: sequential receiver - started", peer) - bufs := make([][]byte, 0, maxBatchSize) + writeBufs := make([]*buffer.Buffer, 0, maxBatchSize) + legacyBufs := make([][]byte, 0, maxBatchSize) for elemsContainer := range peer.queue.inbound.c { if elemsContainer == nil { @@ -513,7 +509,9 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { continue } - bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) + legacyBufs = append(legacyBufs, elem.buf.Data()[:MessageTransportOffsetContent+len(elem.packet)]) + writeBufs = append(writeBufs, elem.buf) + elem.buf = nil // ownership transferred to writeBufs } peer.rxBytes.Add(rxBytesLen) @@ -526,17 +524,22 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { if dataPacketReceived { peer.timersDataReceived() } - if len(bufs) > 0 { - _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) + if len(writeBufs) > 0 { + _, err := device.tun.device.Write(legacyBufs, MessageTransportOffsetContent) if err != nil && !device.isClosed() { device.log.Errorf("Failed to write packets to TUN device: %v", err) } + buffer.ReleaseAll(writeBufs) } + // Release buffers for skipped elements (not transferred to writeBufs). for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + if elem.buf != nil { + elem.buf.Release() + } device.PutInboundElement(elem) } - bufs = bufs[:0] + writeBufs = writeBufs[:0] + legacyBufs = legacyBufs[:0] device.PutInboundElementsContainer(elemsContainer) } } diff --git a/device/send.go b/device/send.go index 89269fc07..0a79d662f 100644 --- a/device/send.go +++ b/device/send.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/tun" "golang.org/x/crypto/chacha20poly1305" @@ -46,7 +47,7 @@ import ( */ type QueueOutboundElement struct { - buffer *[MaxMessageSize]byte // slice holding the packet data + buf *buffer.Buffer // packet is always a slice of "buffer". The starting offset in buffer // is either: // a) MessageEncapsulatingTransportSize+MessageTransportHeaderSize (plaintext) @@ -64,7 +65,7 @@ type QueueOutboundElementsContainer struct { func (device *Device) NewOutboundElement() *QueueOutboundElement { elem := device.GetOutboundElement() - elem.buffer = device.GetMessageBuffer() + elem.buf = device.GetMessageBuffer(MaxMessageSize) elem.nonce = 0 // keypair and peer were cleared (if necessary) by clearPointers. return elem @@ -75,7 +76,7 @@ func (device *Device) NewOutboundElement() *QueueOutboundElement { // avoids accidentally keeping other objects around unnecessarily. // It also reduces the possible collateral damage from use-after-free bugs. func (elem *QueueOutboundElement) clearPointers() { - elem.buffer = nil + elem.buf = nil elem.packet = nil elem.keypair = nil elem.peer = nil @@ -92,7 +93,7 @@ func (peer *Peer) SendKeepalive() { case peer.queue.staged <- elemsContainer: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: - peer.device.PutMessageBuffer(elem.buffer) + elem.buf.Release() peer.device.PutOutboundElement(elem) peer.device.PutOutboundElementsContainer(elemsContainer) } @@ -128,15 +129,16 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } - buf := make([]byte, MessageEncapsulatingTransportSize+MessageInitiationSize) - packet := buf[MessageEncapsulatingTransportSize:] + buf := peer.device.GetMessageBuffer(MessageEncapsulatingTransportSize + MessageInitiationSize) + packet := buf.Data()[MessageEncapsulatingTransportSize:] _ = msg.marshal(packet) peer.cookieGenerator.AddMacs(packet) peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffers([][]byte{buf}) + err = peer.SendBuffers([][]byte{buf.Data()}) + buf.Release() if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -158,13 +160,14 @@ func (peer *Peer) SendHandshakeResponse() error { return err } - buf := make([]byte, MessageEncapsulatingTransportSize+MessageResponseSize) - packet := buf[MessageEncapsulatingTransportSize:] + buf := peer.device.GetMessageBuffer(MessageEncapsulatingTransportSize + MessageResponseSize) + packet := buf.Data()[MessageEncapsulatingTransportSize:] _ = response.marshal(packet) peer.cookieGenerator.AddMacs(packet) err = peer.BeginSymmetricSession() if err != nil { + buf.Release() peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) return err } @@ -174,7 +177,8 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffers([][]byte{buf}) + err = peer.SendBuffers([][]byte{buf.Data()}) + buf.Release() if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -191,11 +195,12 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) return err } - buf := make([]byte, MessageEncapsulatingTransportSize+MessageCookieReplySize) - packet := buf[MessageEncapsulatingTransportSize:] + buf := device.GetMessageBuffer(MessageEncapsulatingTransportSize + MessageCookieReplySize) + packet := buf.Data()[MessageEncapsulatingTransportSize:] _ = reply.marshal(packet) - // TODO: allocation could be avoided - device.net.bind.Send([][]byte{buf}, initiatingElem.endpoint, MessageEncapsulatingTransportSize) + + device.net.bind.Send([][]byte{buf.Data()}, initiatingElem.endpoint, MessageEncapsulatingTransportSize) + buf.Release() return nil } @@ -223,57 +228,44 @@ func (device *Device) RoutineReadFromTUN() { var ( batchSize = device.BatchSize() readErr error - elems = make([]*QueueOutboundElement, batchSize) - bufs = make([][]byte, batchSize) + bufs = make([]*buffer.Buffer, batchSize) elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) count = 0 sizes = make([]int, batchSize) offset = MessageEncapsulatingTransportSize + MessageTransportHeaderSize ) - for i := range elems { - elems[i] = device.NewOutboundElement() - bufs[i] = elems[i].buffer[:] - } - defer func() { - for _, elem := range elems { - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - } - } + buffer.ReleaseAll(bufs) }() for { - // read packets count, readErr = device.tun.device.Read(bufs, sizes, offset) + for i := 0; i < count; i++ { - if sizes[i] < 1 { + packet := bufs[i].Data()[offset : offset+sizes[i]] + if len(packet) < 1 { continue } - elem := elems[i] - elem.packet = bufs[i][offset : offset+sizes[i]] - // lookup peer var peer *Peer - switch elem.packet[0] >> 4 { + switch packet[0] >> 4 { case 4: - if len(elem.packet) < ipv4.HeaderLen { + if len(packet) < ipv4.HeaderLen { continue } - src := netip.AddrFrom4([4]byte(elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len])) - dst := netip.AddrFrom4([4]byte(elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len])) - peer = device.allowedips.LookupFromPacket(src, dst, elem.packet) + src := netip.AddrFrom4([4]byte(packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len])) + dst := netip.AddrFrom4([4]byte(packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len])) + peer = device.allowedips.LookupFromPacket(src, dst, packet) case 6: - if len(elem.packet) < ipv6.HeaderLen { + if len(packet) < ipv6.HeaderLen { continue } - src := netip.AddrFrom16([16]byte(elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len])) - dst := netip.AddrFrom16([16]byte(elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len])) - peer = device.allowedips.LookupFromPacket(src, dst, elem.packet) + src := netip.AddrFrom16([16]byte(packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len])) + dst := netip.AddrFrom16([16]byte(packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len])) + peer = device.allowedips.LookupFromPacket(src, dst, packet) default: device.log.Verbosef("Received packet with unknown IP version") @@ -282,14 +274,26 @@ func (device *Device) RoutineReadFromTUN() { if peer == nil { continue } + + elem := device.GetOutboundElement() + elem.buf = bufs[i] + elem.packet = packet + elemsForPeer, ok := elemsByPeer[peer] if !ok { elemsForPeer = device.GetOutboundElementsContainer() elemsByPeer[peer] = elemsForPeer } elemsForPeer.elems = append(elemsForPeer.elems, elem) - elems[i] = device.NewOutboundElement() - bufs[i] = elems[i].buffer[:] + bufs[i] = nil // consumed; next Read allocates fresh + } + // Release unconsumed buffers so stale right-sized entries + // don't persist into the next Read call. + for i := 0; i < count; i++ { + if bufs[i] != nil { + bufs[i].Release() + bufs[i] = nil + } } for peer, elemsForPeer := range elemsByPeer { @@ -298,7 +302,7 @@ func (device *Device) RoutineReadFromTUN() { peer.SendStagedPackets() } else { for _, elem := range elemsForPeer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buf.Release() device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsForPeer) @@ -335,7 +339,7 @@ func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { select { case tooOld := <-peer.queue.staged: for _, elem := range tooOld.elems { - peer.device.PutMessageBuffer(elem.buffer) + elem.buf.Release() peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(tooOld) @@ -396,7 +400,7 @@ top: peer.device.queue.encryption.c <- elemsContainer } else { for _, elem := range elemsContainer.elems { - peer.device.PutMessageBuffer(elem.buffer) + elem.buf.Release() peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(elemsContainer) @@ -416,7 +420,7 @@ func (peer *Peer) FlushStagedPackets() { select { case elemsContainer := <-peer.queue.staged: for _, elem := range elemsContainer.elems { - peer.device.PutMessageBuffer(elem.buffer) + elem.buf.Release() peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(elemsContainer) @@ -456,7 +460,7 @@ func (device *Device) RoutineEncryption(id int) { for elemsContainer := range device.queue.encryption.c { for _, elem := range elemsContainer.elems { // populate header fields - header := elem.buffer[MessageEncapsulatingTransportSize : MessageEncapsulatingTransportSize+MessageTransportHeaderSize] + header := elem.buf.Data()[MessageEncapsulatingTransportSize : MessageEncapsulatingTransportSize+MessageTransportHeaderSize] fieldType := header[0:4] fieldReceiver := header[4:8] @@ -481,7 +485,7 @@ func (device *Device) RoutineEncryption(id int) { ) // re-slice packet to include encapsulating transport space - elem.packet = elem.buffer[:MessageEncapsulatingTransportSize+len(elem.packet)] + elem.packet = elem.buf.Data()[:MessageEncapsulatingTransportSize+len(elem.packet)] } elemsContainer.Unlock() } @@ -495,10 +499,12 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { }() device.log.Verbosef("%v - Routine: sequential sender - started", peer) - bufs := make([][]byte, 0, maxBatchSize) + rawBufs := make([][]byte, 0, maxBatchSize) + releaseBufs := make([]*buffer.Buffer, 0, maxBatchSize) for elemsContainer := range peer.queue.outbound.c { - bufs = bufs[:0] + rawBufs = rawBufs[:0] + releaseBufs = releaseBufs[:0] if elemsContainer == nil { return } @@ -511,7 +517,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { // that we never accidentally keep timers alive longer than necessary. elemsContainer.Lock() for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buf.Release() device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) @@ -523,18 +529,20 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { if len(elem.packet[MessageEncapsulatingTransportSize:]) != MessageKeepaliveSize { dataSent = true } - bufs = append(bufs, elem.packet) + rawBufs = append(rawBufs, elem.packet) + releaseBufs = append(releaseBufs, elem.buf) + elem.buf = nil } peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err := peer.SendBuffers(bufs) + err := peer.SendBuffers(rawBufs) + buffer.ReleaseAll(releaseBufs) if dataSent { peer.timersDataSent() } for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index d8e70bb03..b8824f674 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -22,6 +22,7 @@ import ( "syscall" "time" + wgbuf "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/tun" "golang.org/x/net/dns/dnsmessage" @@ -119,13 +120,16 @@ func (tun *netTun) Events() <-chan tun.Event { return tun.events } -func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { +func (tun *netTun) Read(bufs []*wgbuf.Buffer, sizes []int, offset int) (int, error) { view, ok := <-tun.incomingPacket if !ok { return 0, os.ErrClosed } - n, err := view.Read(buf[0][offset:]) + if bufs[0] == nil { + bufs[0] = wgbuf.New(make([]byte, wgbuf.MaxMessageSize), nil) + } + n, err := view.Read(bufs[0].Data()[offset:]) if err != nil { return 0, err } diff --git a/tun/offload.go b/tun/offload.go index 6db437c34..f464c742c 100644 --- a/tun/offload.go +++ b/tun/offload.go @@ -3,6 +3,8 @@ package tun import ( "encoding/binary" "fmt" + + "github.com/tailscale/wireguard-go/buffer" ) // GSOType represents the type of segmentation offload. @@ -81,7 +83,7 @@ const ( // value of options.NeedsCsum. Length of each outBufs element must be greater // than or equal to the length of 'in', otherwise output may be silently // truncated. -func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outOffset int) (int, error) { +func GSOSplit(in []byte, options GSOOptions, outBufs []*buffer.Buffer, pool buffer.Source, outOffset int) (int, error) { cSumAt := int(options.CsumStart) + int(options.CsumOffset) if cSumAt+1 >= len(in) { return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) @@ -94,8 +96,9 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO // Handle the conditions where we are copying a single element to outBuffs. payloadLen := len(in) - int(options.HdrLen) if options.GSOType == GSONone || payloadLen < int(options.GSOSize) { - if len(in) > len(outBufs[0][outOffset:]) { - return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0][outOffset:])) + needed := outOffset + len(in) + if outBufs[0] == nil { + outBufs[0], _ = pool.Get(needed) } if options.NeedsCsum { // The initial value at the checksum offset should be summed with @@ -104,7 +107,7 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO in[cSumAt], in[cSumAt+1] = 0, 0 binary.BigEndian.PutUint16(in[cSumAt:], ^Checksum(in[options.CsumStart:], initial)) } - sizes[0] = copy(outBufs[0][outOffset:], in) + copy(outBufs[0].Data()[outOffset:], in) return 1, nil } @@ -164,8 +167,11 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO } segmentDataLen := nextSegmentEnd - nextSegmentDataAt totalLen := int(options.HdrLen) + segmentDataLen - sizes[i] = totalLen - out := outBufs[i][outOffset:] + needed := outOffset + totalLen + if outBufs[i] == nil { + outBufs[i], _ = pool.Get(needed) + } + out := outBufs[i].Data()[outOffset:] copy(out, in[:iphLen]) if ipVersion == 4 { diff --git a/tun/offload_linux.go b/tun/offload_linux.go index fb6ac5b94..864d28244 100644 --- a/tun/offload_linux.go +++ b/tun/offload_linux.go @@ -13,6 +13,7 @@ import ( "io" "unsafe" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "golang.org/x/sys/unix" ) @@ -428,7 +429,13 @@ const ( // coalesceUDPPackets attempts to coalesce pkt with the packet described by // item, and returns the outcome. -func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { +func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool, arena *buffer.Arena) coalesceResult { + if head := bufs[item.bufsIndex]; cap(head) < buffer.MaxMessageSize { + new := arena.Get(buffer.MaxMessageSize) + n := copy(new, head) + new = new[:n] + bufs[item.bufsIndex] = new + } pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front headersLen := item.iphLen + udphLen coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) @@ -458,7 +465,13 @@ func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset // item, and returns the outcome. This function may swap bufs elements in the // event of a prepend as item's bufs index is already being tracked for writing // to a Device. -func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool, arena *buffer.Arena) coalesceResult { + if head := bufs[item.bufsIndex]; cap(head) < buffer.MaxMessageSize { + new := arena.Get(buffer.MaxMessageSize) + n := copy(new, head) + new = new[:n] + bufs[item.bufsIndex] = new + } var pktHead []byte // the packet that will end up at the front headersLen := item.iphLen + item.tcphLen coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) @@ -543,7 +556,7 @@ const ( // action was taken, groResultTableInsert when the evaluated packet was // inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool, arena *buffer.Arena) groResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. @@ -615,7 +628,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) item := items[i] can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) if can != coalesceUnavailable { - result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6, arena) switch result { case coalesceSuccess: table.updateAt(item, i) @@ -798,7 +811,7 @@ const ( // action was taken, groResultTableInsert when the evaluated packet was // inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. -func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { +func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool, arena *buffer.Arena) groResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. @@ -851,7 +864,7 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) var pktCSumKnownInvalid bool if can == coalesceAppend { - result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) + result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6, arena) switch result { case coalesceSuccess: table.updateAt(item, len(items)-1) @@ -877,7 +890,7 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) // empty (but non-nil), and are passed in to save allocs as the caller may reset // and recycle them across vectors of packets. gro indicates if TCP and UDP GRO // are supported/enabled. -func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int) error { +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int, arena *buffer.Arena) error { for i := range bufs { if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") @@ -885,13 +898,13 @@ func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGR var result groResult switch packetIsGROCandidate(bufs[i][offset:], gro) { case tcp4GROCandidate: - result = tcpGRO(bufs, offset, i, tcpTable, false) + result = tcpGRO(bufs, offset, i, tcpTable, false, arena) case tcp6GROCandidate: - result = tcpGRO(bufs, offset, i, tcpTable, true) + result = tcpGRO(bufs, offset, i, tcpTable, true, arena) case udp4GROCandidate: - result = udpGRO(bufs, offset, i, udpTable, false) + result = udpGRO(bufs, offset, i, udpTable, false, arena) case udp6GROCandidate: - result = udpGRO(bufs, offset, i, udpTable, true) + result = udpGRO(bufs, offset, i, udpTable, true, arena) } switch result { case groResultNoop: diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go index 407037863..a2db02e19 100644 --- a/tun/offload_linux_test.go +++ b/tun/offload_linux_test.go @@ -9,6 +9,7 @@ import ( "net/netip" "testing" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" @@ -234,13 +235,9 @@ func Test_handleVirtioRead(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - out := make([][]byte, conn.IdealBatchSize) - sizes := make([]int, conn.IdealBatchSize) - for i := range out { - out[i] = make([]byte, 65535) - } + out := make([]*buffer.Buffer, conn.IdealBatchSize) tt.hdr.encode(tt.pktIn) - n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) + n, err := handleVirtioRead(tt.pktIn, out, buffer.NewFragmentPool(), offset) if err != nil { if tt.wantErr { return @@ -251,8 +248,8 @@ func Test_handleVirtioRead(t *testing.T) { t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) } for i := range tt.wantLens { - if tt.wantLens[i] != sizes[i] { - t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) + if tt.wantLens[i]+offset != out[i].Len() { + t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i]+offset, out[i].Len()) } } }) @@ -290,7 +287,8 @@ func Fuzz_handleGRO(f *testing.F) { f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, gro int, offset int) { pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} toWrite := make([]int, 0, len(pkts)) - handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), groDisablementFlags(gro), &toWrite) + arena := &buffer.Arena{Buffer: buffer.New(make([]byte, 1<<20), nil)} + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), groDisablementFlags(gro), &toWrite, arena) if len(toWrite) > len(pkts) { t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) } @@ -507,7 +505,8 @@ func Test_handleGRO(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { toWrite := make([]int, 0, len(tt.pktsIn)) - err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.gro, &toWrite) + arena := &buffer.Arena{Buffer: buffer.New(make([]byte, 1<<20), nil)} + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.gro, &toWrite, arena) if err != nil { if tt.wantErr { return diff --git a/tun/offload_test.go b/tun/offload_test.go index 82a37b9cc..d0e536492 100644 --- a/tun/offload_test.go +++ b/tun/offload_test.go @@ -4,6 +4,7 @@ import ( "net/netip" "testing" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -67,11 +68,7 @@ func Fuzz_GSOSplit(f *testing.F) { }) header.UDP(gsoUDPv6[20:]).Encode(udpFields) - out := make([][]byte, conn.IdealBatchSize) - for i := range out { - out[i] = make([]byte, 65535) - } - sizes := make([]int, conn.IdealBatchSize) + out := make([]*buffer.Buffer, conn.IdealBatchSize) f.Add(gsoTCPv4, int(GSOTCPv4), uint16(40), uint16(20), uint16(16), uint16(100), false) f.Add(gsoUDPv4, int(GSOUDPL4), uint16(28), uint16(20), uint16(6), uint16(100), false) @@ -87,9 +84,15 @@ func Fuzz_GSOSplit(f *testing.F) { GSOSize: gsoSize, NeedsCsum: needsCsum, } - n, _ := GSOSplit(pkt, options, out, sizes, 0) - if n > len(sizes) { - t.Errorf("n (%d) > len(sizes): %d", n, len(sizes)) + n, _ := GSOSplit(pkt, options, out, buffer.NewFragmentPool(), 0) + if n > len(out) { + t.Errorf("n (%d) > len(sizes): %d", n, len(out)) + } + for i := range out { + if out[i] != nil { + out[i].Release() + out[i] = nil + } } }) } diff --git a/tun/tun.go b/tun/tun.go index 719a60631..7a79c0e8f 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -7,6 +7,8 @@ package tun import ( "os" + + "github.com/tailscale/wireguard-go/buffer" ) type Event int @@ -24,9 +26,10 @@ type Device interface { // Read one or more packets from the Device (without any additional headers). // On a successful read it returns the number of packets read, and sets // packet lengths within the sizes slice. len(sizes) must be >= len(bufs). + // Nil entries in bufs are allocated by the implementation. // A nonzero offset can be used to instruct the Device on where to begin // reading into each element of the bufs slice. - Read(bufs [][]byte, sizes []int, offset int) (n int, err error) + Read(bufs []*buffer.Buffer, sizes []int, offset int) (n int, err error) // Write one or more packets to the device (without any additional headers). // On a successful write it returns the number of packets written. A nonzero diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index c9a6c0bc4..f2f75bcaf 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -16,6 +16,7 @@ import ( "time" "unsafe" + "github.com/tailscale/wireguard-go/buffer" "golang.org/x/sys/unix" ) @@ -217,7 +218,7 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { // TODO: the BSDs look very similar in Read() and Write(). They should be // collapsed, with platform-specific files containing the varying parts of // their implementations. @@ -225,7 +226,10 @@ func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) case err := <-tun.errors: return 0, err default: - buf := bufs[0][offset-4:] + if bufs[0] == nil { + bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + } + buf := bufs[0].Data()[offset-4:] n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go index 7c65fd999..6a532e1e6 100644 --- a/tun/tun_freebsd.go +++ b/tun/tun_freebsd.go @@ -15,6 +15,7 @@ import ( "syscall" "unsafe" + "github.com/tailscale/wireguard-go/buffer" "golang.org/x/sys/unix" ) @@ -333,12 +334,15 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { select { case err := <-tun.errors: return 0, err default: - buf := bufs[0][offset-4:] + if bufs[0] == nil { + bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + } + buf := bufs[0].Data()[offset-4:] n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 7cdbf8825..81bb83a1f 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -17,6 +17,7 @@ import ( "time" "unsafe" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/rwcancel" "golang.org/x/sys/unix" @@ -53,6 +54,9 @@ type NativeTun struct { tcpGROTable *tcpGROTable udpGROTable *udpGROTable gro groDisablementFlags + + bufPool buffer.Source + coalescedBufs buffer.Arena } type groDisablementFlags int @@ -354,6 +358,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { defer func() { tun.tcpGROTable.reset() tun.udpGROTable.reset() + tun.coalescedBufs.Flush() tun.writeOpMu.Unlock() }() var ( @@ -362,7 +367,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { ) tun.toWrite = tun.toWrite[:0] if tun.vnetHdr { - err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.gro, &tun.toWrite) + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.gro, &tun.toWrite, &tun.coalescedBufs) if err != nil { return 0, err } @@ -389,7 +394,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { // handleVirtioRead splits in into bufs, leaving offset bytes at the front of // each buffer. It mutates sizes to reflect the size of each element of bufs, // and returns the number of packets read. -func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { +func handleVirtioRead(in []byte, bufs []*buffer.Buffer, pool buffer.Source, offset int) (int, error) { var hdr virtioNetHdr err := hdr.decode(in) if err != nil { @@ -421,19 +426,24 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e options.HdrLen = options.CsumStart + tcpHLen } - return GSOSplit(in, options, bufs, sizes, offset) + return GSOSplit(in, options, bufs, pool, offset) } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { tun.readOpMu.Lock() defer tun.readOpMu.Unlock() select { case err := <-tun.errors: return 0, err default: - readInto := bufs[0][offset:] + var readInto []byte if tun.vnetHdr { readInto = tun.readBuff[:] + } else { + if bufs[0] == nil { + bufs[0], _ = tun.bufPool.Get(buffer.MaxMessageSize) + } + readInto = bufs[0].Data()[offset:] } n, err := tun.tunFile.Read(readInto) if errors.Is(err, syscall.EBADFD) { @@ -443,7 +453,14 @@ func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) return 0, err } if tun.vnetHdr { - return handleVirtioRead(readInto[:n], bufs, sizes, offset) + count, err := handleVirtioRead(readInto[:n], bufs, tun.bufPool, offset) + if err != nil { + return 0, err + } + for i := 0; i < count; i++ { + sizes[i] = len(bufs[i].Data()) - offset + } + return count, nil } else { sizes[0] = n return 1, nil @@ -577,6 +594,8 @@ func CreateTUN(name string, mtu int) (Device, error) { // CreateTUNFromFile creates a Device from an os.File with the provided MTU. func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { + bufPool := buffer.NewFragmentPool() + arenaBuf, _ := bufPool.Get(32 * 65 << 10) // 32 buffers of coalesced size tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), @@ -585,6 +604,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tcpGROTable: newTCPGROTable(), udpGROTable: newUDPGROTable(), toWrite: make([]int, 0, conn.IdealBatchSize), + bufPool: bufPool, + coalescedBufs: buffer.Arena{Buffer: arenaBuf}, } name, err := tun.Name() @@ -634,13 +655,17 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { return nil, "", err } file := os.NewFile(uintptr(fd), "/dev/tun") + bufPool := buffer.NewFragmentPool() + arenaBuf, _ := bufPool.Get(32 * 65 << 10) // 32 buffers of coalesced size tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - tcpGROTable: newTCPGROTable(), - udpGROTable: newUDPGROTable(), - toWrite: make([]int, 0, conn.IdealBatchSize), + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), + bufPool: bufPool, + coalescedBufs: buffer.Arena{Buffer: arenaBuf}, } name, err := tun.Name() if err != nil { diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go index ae571b90c..cdda7feaa 100644 --- a/tun/tun_openbsd.go +++ b/tun/tun_openbsd.go @@ -15,6 +15,7 @@ import ( "syscall" "unsafe" + "github.com/tailscale/wireguard-go/buffer" "golang.org/x/sys/unix" ) @@ -204,12 +205,15 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { select { case err := <-tun.errors: return 0, err default: - buf := bufs[0][offset-4:] + if bufs[0] == nil { + bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + } + buf := bufs[0].Data()[offset-4:] n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err diff --git a/tun/tun_plan9.go b/tun/tun_plan9.go index 7b66eadf6..afe2ad7bb 100644 --- a/tun/tun_plan9.go +++ b/tun/tun_plan9.go @@ -12,6 +12,8 @@ import ( "strconv" "strings" "sync" + + "github.com/tailscale/wireguard-go/buffer" ) type NativeTun struct { @@ -81,13 +83,17 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { select { case err := <-tun.errors: return 0, err default: - n, err := tun.dataFile.Read(bufs[0][offset:]) - if n == 1 && bufs[0][offset] == 0 { + if bufs[0] == nil { + bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + } + data := bufs[0].Data() + n, err := tun.dataFile.Read(data[offset:]) + if n == 1 && data[offset] == 0 { // EOF err = io.EOF n = 0 diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 34f29805d..caf6caa05 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -14,6 +14,7 @@ import ( "time" _ "unsafe" + "github.com/tailscale/wireguard-go/buffer" "golang.org/x/sys/windows" "golang.zx2c4.com/wintun" ) @@ -144,7 +145,7 @@ func (tun *NativeTun) BatchSize() int { // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() retry: @@ -161,7 +162,10 @@ retry: switch err { case nil: packetSize := len(packet) - copy(bufs[0][offset:], packet) + if bufs[0] == nil { + bufs[0] = buffer.New(make([]byte, packetSize), nil) + } + copy(bufs[0].Data()[offset:], packet) sizes[0] = packetSize tun.session.ReleaseReceivePacket(packet) tun.rate.update(uint64(packetSize)) diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index e7507c26c..0873163ef 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -11,6 +11,7 @@ import ( "net/netip" "os" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/tun" ) @@ -110,12 +111,15 @@ type chTun struct { func (t *chTun) File() *os.File { return nil } -func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) { +func (t *chTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { select { case <-t.c.closed: return 0, os.ErrClosed case msg := <-t.c.Outbound: - n := copy(packets[0][offset:], msg) + if bufs[0] == nil { + bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + } + n := copy(bufs[0].Data()[offset:], msg) sizes[0] = n return 1, nil } From f357c5af4648a891abacdca0754905144cd82537 Mon Sep 17 00:00:00 2001 From: Alex Valiushko Date: Tue, 17 Mar 2026 09:01:47 -0700 Subject: [PATCH 4/4] store read size with the buffer; helpers Signed-off-by: Alex Valiushko Change-Id: I48217f8f461b17a01901cee9ab64d45e6a6a6964 --- buffer/buffer.go | 116 ++++++++++++++---------- buffer/constants.go | 10 +++ buffer/constants_android.go | 13 +++ buffer/constants_default.go | 13 +++ buffer/constants_ios.go | 17 ++++ buffer/constants_windows.go | 13 +++ buffer/pool.go | 59 ------------ buffer/source.go | 175 ++++++++++++++++++++++++++++++++++++ buffer/source_test.go | 45 ++++++++++ conn/bind_std.go | 53 +++++------ conn/bind_std_test.go | 9 +- conn/bind_windows.go | 20 +++-- conn/bindtest/bindtest.go | 11 +-- conn/conn.go | 14 ++- conn/conn_test.go | 2 +- device/channels.go | 2 - device/device_test.go | 2 +- device/pools.go | 4 +- device/receive.go | 37 ++++---- device/send.go | 49 +++++----- tun/netstack/tun.go | 12 +-- tun/offload.go | 35 ++++---- tun/offload_linux.go | 48 ++++++---- tun/offload_linux_test.go | 16 ++-- tun/offload_test.go | 11 +-- tun/tun.go | 4 +- tun/tun_darwin.go | 11 +-- tun/tun_freebsd.go | 11 +-- tun/tun_linux.go | 46 ++++------ tun/tun_openbsd.go | 11 +-- tun/tun_plan9.go | 11 +-- tun/tun_windows.go | 11 +-- tun/tuntest/tuntest.go | 12 +-- 33 files changed, 577 insertions(+), 326 deletions(-) create mode 100644 buffer/constants.go create mode 100644 buffer/constants_android.go create mode 100644 buffer/constants_default.go create mode 100644 buffer/constants_ios.go create mode 100644 buffer/constants_windows.go delete mode 100644 buffer/pool.go create mode 100644 buffer/source.go create mode 100644 buffer/source_test.go diff --git a/buffer/buffer.go b/buffer/buffer.go index 83bda2fe2..2106554c7 100644 --- a/buffer/buffer.go +++ b/buffer/buffer.go @@ -29,83 +29,103 @@ package buffer -const ( - MaxMessageSize = (1 << 16) - 1 // largest possible UDP datagram -) - -// Source produces new Buffers. -type Source interface { - - // Get returns a Buffer of at least the requested size. - // Implementations may chose to return error if the request can not be fulfilled. - Get(size int) (*Buffer, error) -} +import "fmt" // Recycler holds state necessary for a correct Buffer return to its originating Source type Recycler interface { Recycle(*Buffer) } +// RecycleFunc adapts arbitrary closures to the Recycler interface. +type RecycleFunc func(*Buffer) + +func (f RecycleFunc) Recycle(b *Buffer) { + f(b) +} + // Buffer is a reusable slice of bytes of fixed length. -// The returned Data slice must not be retained past Release. +// Buffer or its Data must not be retained past Release. type Buffer struct { + // Data len tracks valid payload + offset (offset passed out of band). + // It starts equal to the requested size for new buffers, and can be adjusted. + // Data is read-only, exposed for convenience but should never be assigned to. + // Buffer methods maintain the invariant that Data[offset:len(Data)] is valid payload, + // SetLen and Shift let callers adjust the offset and length of the valid payload as needed. data []byte recycler Recycler } -// New creates a standalone Buffer. +// New creates Buffer referencing the provided Recycler. func New(b []byte, recycler Recycler) *Buffer { return &Buffer{data: b, recycler: recycler} } -// Make creates a standalone Buffer with a new byte slice of the requested size. +// Make creates Buffer with a new byte slice of the requested size. func Make(size int) *Buffer { - return &Buffer{data: make([]byte, size), recycler: nil} + buf, _ := DefaultSource.Get(size) // fragment pool never errors + return buf } -// Data returns the full underlying byte slice of the Buffer. -func (b *Buffer) Data() []byte { +// Bytes returns the valid data in the Buffer. +func (b *Buffer) Bytes() []byte { return b.data } -// Release returns the Buffer to its originating Source for reuse. -// The Buffer must not be used after calling Release. -func (b *Buffer) Release() { - if b.recycler != nil { - clear(b.data) - b.recycler.Recycle(b) - } +// BytesAt returns the valid data in the Buffer starting at offset. +func (b *Buffer) BytesAt(offset int) []byte { + return b.data[offset:] } -// ReleaseAll calls Release on each non-nil Buffer in the slice, and sets the slice elements to nil. -func ReleaseAll(bs []*Buffer) { - for i := range bs { - if bs[i] != nil { - bs[i].Release() - bs[i] = nil - } - } +// SetLen sets the length of the valid data in the Buffer. +// Intended to be used for truncating the valid data post read, +// or extending post encryption. Does not check the capacity. +func (b *Buffer) SetLen(l int) { + b.data = b.data[:l] } -// Arena is a Buffer with an internal watermark for sequential allocations. -// FIXME Arena needs a graceful fallback on overflow. -type Arena struct { - *Buffer - watermark int +// Ensure returns a Buffer of the requested len, with the valid data from the provided Buffer. +// The returned Buffer may be the same as the provided Buffer if it has sufficient capacity, or a new Buffer otherwise. +// Safe to call on a nil Buffer. +func Ensure(b *Buffer, size int, src Source) (*Buffer, error) { + if src == nil { + src = DefaultSource + } + if b == nil { + return src.Get(size) + } + if size > cap(b.data) { + bb, err := src.Get(size) + if err != nil { + return nil, err + } + n := copy(bb.data, b.data) + if n != len(b.data) { + panic(fmt.Sprintf("short copy: %d != %d", n, len(b.data))) + } + Release(b) + return bb, nil + } + b.data = b.data[:size] + return b, nil } -// Get returns a slice of the Arena's Buffer of the requested size, and advances the watermark. -func (a *Arena) Get(size int) []byte { - if a.watermark+size > len(a.Buffer.Data()) { - panic("arena overflow") // or return a heap-allocated fallback +// Release returns Buffer to its Source for reuse. +// Safe to call on a nil Buffer. +func Release(b *Buffer) { + if b == nil { + return + } + b.data = b.data[:cap(b.data)] + clear(b.data) + if b.recycler != nil { + b.recycler.Recycle(b) } - b := a.Buffer.Data()[a.watermark : a.watermark+size] - a.watermark += size - return b } -// Flush resets the Arena's watermark to zero, and clears the valid data in the Buffer. -func (a *Arena) Flush() { - clear(a.Buffer.Data()[:a.watermark]) - a.watermark = 0 +// ReleaseAll calls Release on each non-nil Buffer in the slice, and sets the slice elements to nil. +func ReleaseAll(bs []*Buffer) { + for i := range bs { + Release(bs[i]) + bs[i] = nil + } } diff --git a/buffer/constants.go b/buffer/constants.go new file mode 100644 index 000000000..83bc85707 --- /dev/null +++ b/buffer/constants.go @@ -0,0 +1,10 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ +package buffer + +const ( + // MaxMessageSize is the largest buffer that callers may request from a Source. + MaxMessageSize = MaxSegmentSize +) diff --git a/buffer/constants_android.go b/buffer/constants_android.go new file mode 100644 index 000000000..86cf297e6 --- /dev/null +++ b/buffer/constants_android.go @@ -0,0 +1,13 @@ +//go:build android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package buffer + +const ( + MaxSegmentSize = 2200 // largest possible Android read + MaxBytesPerSource = 4096 * MaxSegmentSize +) diff --git a/buffer/constants_default.go b/buffer/constants_default.go new file mode 100644 index 000000000..ee2c556c8 --- /dev/null +++ b/buffer/constants_default.go @@ -0,0 +1,13 @@ +//go:build !android && !ios && !windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package buffer + +const ( + MaxSegmentSize = (1 << 16) - 1 // largest possible Unix read + MaxBytesPerSource = 0 // Disable and allow for infinite memory growth +) diff --git a/buffer/constants_ios.go b/buffer/constants_ios.go new file mode 100644 index 000000000..4d4e06604 --- /dev/null +++ b/buffer/constants_ios.go @@ -0,0 +1,17 @@ +//go:build ios + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package buffer + +// Fit within memory limits for iOS's Network Extension API, which has stricter requirements. +// These are vars instead of consts, because heavier network extensions might want to reduce +// them further. +var ( + MaxBytesPerSource = 1024 * MaxSegmentSize +) + +const MaxSegmentSize = 1700 // largest possible iOS read diff --git a/buffer/constants_windows.go b/buffer/constants_windows.go new file mode 100644 index 000000000..7f2f4e8b6 --- /dev/null +++ b/buffer/constants_windows.go @@ -0,0 +1,13 @@ +//go:build windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package buffer + +const ( + MaxSegmentSize = 2048 - 32 // largest possible Windows read + MaxBytesPerSource = 0 // Disable and allow for infinite memory growth +) diff --git a/buffer/pool.go b/buffer/pool.go deleted file mode 100644 index acc38b0ea..000000000 --- a/buffer/pool.go +++ /dev/null @@ -1,59 +0,0 @@ -package buffer - -import "sync" - -const ( - min = 2 << 10 // 2KB, enough for a typical MTU-sized packet - mid = 10 << 10 // 10KB, enough for a jumbo frame - max = 65 << 10 // 65KB, enough for the maximum possible UDP datagram size -) - -var _ Source = (*FragmentPool)(nil) - -// FragmentPool is a tiered source of buffers. Tiers are balanced -// to accomodate regular MTU sizes, jumbo frames, and the maximum possible UDP datagram size. -type FragmentPool struct { - minPool sync.Pool - midPool sync.Pool - maxPool sync.Pool -} - -type poolRecycler struct { - *sync.Pool -} - -func (p *poolRecycler) Recycle(b *Buffer) { - p.Put(b) -} - -func NewFragmentPool() *FragmentPool { - p := new(FragmentPool) - p.minPool.New = func() any { - return &Buffer{data: make([]byte, min), recycler: &poolRecycler{&p.minPool}} - } - p.midPool.New = func() any { - return &Buffer{data: make([]byte, mid), recycler: &poolRecycler{&p.midPool}} - } - p.maxPool.New = func() any { - return &Buffer{data: make([]byte, max), recycler: &poolRecycler{&p.maxPool}} - } - return p -} - -// Get returns a Buffer of at least the requested size. Implementations may chose to return error if the request can not be fulfilled. -// FIXME current code ignores err return -func (p *FragmentPool) Get(size int) (*Buffer, error) { - var buf *Buffer - switch { - case size <= min: - buf = p.minPool.Get().(*Buffer) - case size <= mid: - buf = p.midPool.Get().(*Buffer) - case size <= max: - buf = p.maxPool.Get().(*Buffer) - default: - return &Buffer{data: make([]byte, size)}, nil - } - buf.data = buf.data[:size] - return buf, nil -} diff --git a/buffer/source.go b/buffer/source.go new file mode 100644 index 000000000..fa75a2120 --- /dev/null +++ b/buffer/source.go @@ -0,0 +1,175 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ +package buffer + +import ( + "fmt" + "sync" + "sync/atomic" +) + +const ( + min = 2 << 10 // 2KB, typical MTU-sized packet + mid = 10 << 10 // 10KB, jumbo frame + max = 65 << 10 // 65KB, max UDP datagram read +) + +// Source produces new Buffers. +type Source interface { + // Get returns a Buffer of exactly requested len and at least the requested cap. + // Implementations may return error if the request can not be fulfilled. + Get(size int) (*Buffer, error) +} + +// DefaultSource is a package-level Source of Buffers. +// Used by Make and Ensure when no Source is provided. +var DefaultSource Source = &CappedSource{s: NewPoolSource(), cap: int64(MaxBytesPerSource)} + +var ( + _ Source = (*PoolSource)(nil) + _ Source = (*LoggingSource)(nil) + _ Source = (*CappedSource)(nil) +) + +// PoolSource is a tiered [Source] of buffers. Tiers are balanced +// to accommodate regular MTU sizes, jumbo frames, and the maximum possible UDP datagram size. +// PoolSource never errors, instead allocating a GC-managed buffer for requests that exceed the max tier size. +type PoolSource struct { + minPool sync.Pool + midPool sync.Pool + maxPool sync.Pool +} + +type poolRecycler struct { + *sync.Pool +} + +func (p *poolRecycler) Recycle(b *Buffer) { + p.Put(b) +} + +// NewPoolSource returns a PoolSource with the default tier sizes. +func NewPoolSource() *PoolSource { + p := new(PoolSource) + p.minPool.New = func() any { + return &Buffer{data: make([]byte, min), recycler: &poolRecycler{&p.minPool}} + } + p.midPool.New = func() any { + return &Buffer{data: make([]byte, mid), recycler: &poolRecycler{&p.midPool}} + } + p.maxPool.New = func() any { + return &Buffer{data: make([]byte, max), recycler: &poolRecycler{&p.maxPool}} + } + return p +} + +func (p *PoolSource) Get(size int) (*Buffer, error) { + var buf *Buffer + switch { + case size <= min: + buf = p.minPool.Get().(*Buffer) + case size <= mid: + buf = p.midPool.Get().(*Buffer) + case size <= max: + buf = p.maxPool.Get().(*Buffer) + default: + return &Buffer{data: make([]byte, size)}, nil + } + buf.data = buf.data[:size] + return buf, nil +} + +// LoggingSource is a Source that keeps track of all Buffers it has produced. +// Use when the buffers are not retained and can not be released back individually. +// Not safe for concurrent use. +type LoggingSource struct { + Source + log []*Buffer +} + +// Get returns a Buffer and records it for later bulk release. +func (l *LoggingSource) Get(size int) (*Buffer, error) { + buf, err := l.Source.Get(size) + if err != nil { + return nil, err + } + l.log = append(l.log, buf) + return buf, nil +} + +// Log returns all Buffers produced by this source. +func (l *LoggingSource) Log() []*Buffer { + return l.log +} + +// ReleaseAll releases all tracked Buffers and resets the log. +func (l *LoggingSource) ReleaseAll() { + ReleaseAll(l.log) + l.log = l.log[:0] +} + +// CappedSource is a Source that tracks the total capacity of all Buffers +// it has produced and returns an error if a request would cause the total +// to exceed a specified cap. +type CappedSource struct { + Used atomic.Int64 // public to expose metrics + + s Source + cap int64 + recyclers sync.Pool +} + +// NewCappedSource returns a CappedSource wrapping src with the given byte cap. +// A cap of zero or less disables the limit. +func NewCappedSource(src Source, cap int64) *CappedSource { + s := &CappedSource{ + s: src, + cap: cap, + } + s.recyclers = sync.Pool{ + New: func() any { + return &cappedRecycler{s: s} + }} + return s + +} + +// ErrSizeExceedsCap is returned by CappedSource.Get when the cap would be exceeded. +var ErrSizeExceedsCap = fmt.Errorf("buffer: request exceeds cap") + +func (p *CappedSource) Get(size int) (*Buffer, error) { + // This implementation acquires the buffer first to + // account for correct memory footprint via cap(). + // There is little point in optimizing this to just + // spin faster when we're over. + b, err := p.s.Get(size) + if err != nil { + return nil, err + } + if p.cap <= 0 { + return b, nil // uncapped path + } + charge := int64(cap(b.data)) + if new := p.Used.Add(charge); new > p.cap { + p.Used.Add(-charge) + Release(b) + return nil, ErrSizeExceedsCap + } + r := p.recyclers.Get().(*cappedRecycler) + r.next, b.recycler = b.recycler, r // wrap recycler + return b, nil +} + +type cappedRecycler struct { + s *CappedSource + next Recycler +} + +func (r *cappedRecycler) Recycle(b *Buffer) { + b.recycler, r.next = r.next, nil // unwrap recycler + r.s.Used.Add(-int64(cap(b.data))) + Release(b) // safe release, recycler may be nil + r.s.recyclers.Put(r) +} diff --git a/buffer/source_test.go b/buffer/source_test.go new file mode 100644 index 000000000..4ff11fb95 --- /dev/null +++ b/buffer/source_test.go @@ -0,0 +1,45 @@ +package buffer + +import ( + "errors" + "sync" + "testing" +) + +func BenchmarkCappedSource(b *testing.B) { + for _, sc := range []struct { + name string + src Source + }{ + {name: "baseline", src: NewPoolSource()}, + {name: "uncapped", src: NewCappedSource(NewPoolSource(), 0)}, + {name: "capped", src: NewCappedSource(NewPoolSource(), 2*MaxMessageSize)}, + } { + b.Run(sc.name, func(b *testing.B) { + var cappedErrCount int + consumer := make(chan *Buffer, 1) + var wg sync.WaitGroup + wg.Go(func() { + for b := range consumer { + Release(b) + } + }) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + buf, err := sc.src.Get(MaxMessageSize) + if err != nil { + if !errors.Is(err, ErrSizeExceedsCap) { + b.Fatalf("unexpected error: %v", err) + } + cappedErrCount++ + } + consumer <- buf + } + }) + close(consumer) + wg.Wait() + b.ReportMetric(float64(cappedErrCount), "capped") + }) + + } +} diff --git a/conn/bind_std.go b/conn/bind_std.go index 8f553a6c5..507001e02 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -43,9 +43,9 @@ type StdNetBind struct { ipv6RxOffload bool // these two fields are not guarded by mu - udpAddrPool sync.Pool - msgsPool sync.Pool - bufPool buffer.Source + udpAddrPool sync.Pool + msgsPool sync.Pool + bufferSource buffer.Source blackhole4 bool blackhole6 bool @@ -72,7 +72,7 @@ func NewStdNetBind() Bind { }, }, - bufPool: buffer.NewFragmentPool(), + bufferSource: buffer.DefaultSource, } } @@ -235,7 +235,6 @@ func (s *StdNetBind) receiveIP( conn *net.UDPConn, rxOffload bool, bufs []*buffer.Buffer, - sizes []int, eps []Endpoint, ) (n int, err error) { msgs := s.getMessages() @@ -246,26 +245,28 @@ func (s *StdNetBind) receiveIP( const readBatch = 2 readAt := len(*msgs) - readBatch for i := readAt; i < readAt+readBatch; i++ { - if bufs[i] == nil { - bufs[i], _ = s.bufPool.Get(buffer.MaxMessageSize) + bufs[i], err = buffer.Ensure(bufs[i], buffer.MaxMessageSize, s.bufferSource) + if err != nil { + return 0, err } - (*msgs)[i].Buffers[0] = bufs[i].Data() + (*msgs)[i].Buffers[0] = bufs[i].Bytes() (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] } numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) if err != nil { return 0, err } - numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize, bufs, s.bufPool) + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize, bufs, s.bufferSource) if err != nil { return 0, err } } else { for i := range bufs { - if bufs[i] == nil { - bufs[i], _ = s.bufPool.Get(buffer.MaxMessageSize) + bufs[i], err = buffer.Ensure(bufs[i], buffer.MaxMessageSize, s.bufferSource) + if err != nil { + return 0, err } - (*msgs)[i].Buffers[0] = bufs[i].Data() + (*msgs)[i].Buffers[0] = bufs[i].Bytes() (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] } numMsgs, err = br.ReadBatch(*msgs, 0) @@ -274,11 +275,12 @@ func (s *StdNetBind) receiveIP( } } } else { - if bufs[0] == nil { - bufs[0], _ = s.bufPool.Get(buffer.MaxMessageSize) + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, s.bufferSource) + if err != nil { + return 0, err } msg := &(*msgs)[0] - msg.Buffers[0] = bufs[0].Data() + msg.Buffers[0] = bufs[0].Bytes() msg.OOB = msg.OOB[:cap(msg.OOB)] msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) if err != nil { @@ -288,8 +290,8 @@ func (s *StdNetBind) receiveIP( } for i := 0; i < numMsgs; i++ { msg := &(*msgs)[i] - sizes[i] = msg.N - if sizes[i] == 0 { + bufs[i].SetLen(msg.N) + if msg.N == 0 { continue } addrPort := msg.Addr.(*net.UDPAddr).AddrPort() @@ -301,14 +303,14 @@ func (s *StdNetBind) receiveIP( } func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { - return func(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (n int, err error) { - return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + return func(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, eps) } } func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { - return func(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (n int, err error) { - return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + return func(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, eps) } } @@ -512,11 +514,11 @@ func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offs endBatch = false base++ gsoSize = len(buf) + coalescedLen = len(buf) setSrcControl(&msgs[base].OOB, ep) msgs[base].Buffers[0] = buf msgs[base].Addr = addr dgramCnt = 1 - coalescedLen = gsoSize } return base + 1 } @@ -548,10 +550,11 @@ func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFu return n, errors.New("splitting coalesced packet resulted in overflow") } segLen := end - start - if bufs[n] == nil { - bufs[n], _ = pool.Get(segLen) + bufs[n], err = buffer.Ensure(bufs[n], segLen, pool) + if err != nil { + return 0, err } - msgs[n].Buffers[0] = bufs[n].Data() + msgs[n].Buffers[0] = bufs[n].Bytes() msgs[n].OOB = msgs[n].OOB[:cap(msgs[n].OOB)] copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) msgs[n].N = copied diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go index 0b3d01476..3db3d1831 100644 --- a/conn/bind_std_test.go +++ b/conn/bind_std_test.go @@ -19,13 +19,12 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { bind.Close() bufs := make([]*buffer.Buffer, IdealBatchSize) bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) - sizes := make([]int, IdealBatchSize) eps := make([]Endpoint, IdealBatchSize) for _, fn := range fns { // The ReceiveFuncs must not access conn-related fields on StdNetBind // unguarded. Close() nils the conn-related fields resulting in a panic // if they violate the mutex. - fn(bufs, sizes, eps) + fn(bufs, eps) } } @@ -242,7 +241,7 @@ func Test_splitCoalescedMessages(t *testing.T) { for i := range tt.msgs { bufs[i] = buffer.New(tt.msgs[i].Buffers[0], nil) } - got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize, bufs, buffer.NewFragmentPool()) + got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize, bufs, buffer.DefaultSource) if err != nil && !tt.wantErr { t.Fatalf("err: %v", err) } @@ -255,8 +254,8 @@ func Test_splitCoalescedMessages(t *testing.T) { } } for i := range got { - if !bytes.Equal(bufs[i].Data(), tt.msgs[i].Buffers[0]) { - t.Fatalf("bufs[%d].Data() and tt.msgs[%d] unequal", i, i) + if !bytes.Equal(bufs[i].Bytes(), tt.msgs[i].Buffers[0]) { + t.Fatalf("bufs[%d].Data and tt.msgs[%d] unequal", i, i) } } }) diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 0c99acfeb..99a32c23a 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -417,20 +417,28 @@ retry: return n, &ep, nil } -func (bind *WinRingBind) receiveIPv4(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (int, error) { +func (bind *WinRingBind) receiveIPv4(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) { bind.mu.RLock() defer bind.mu.RUnlock() - n, ep, err := bind.v4.Receive(bufs[0].Data(), &bind.isOpen) - sizes[0] = n + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err + } + n, ep, err := bind.v4.Receive(bufs[0].Bytes(), &bind.isOpen) + bufs[0].SetLen(n) eps[0] = ep return 1, err } -func (bind *WinRingBind) receiveIPv6(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (int, error) { +func (bind *WinRingBind) receiveIPv6(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) { bind.mu.RLock() defer bind.mu.RUnlock() - n, ep, err := bind.v6.Receive(bufs[0].Data(), &bind.isOpen) - sizes[0] = n + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err + } + n, ep, err := bind.v6.Receive(bufs[0].Bytes(), &bind.isOpen) + bufs[0].SetLen(n) eps[0] = ep return 1, err } diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 2fe41b7a7..a9cc66f0c 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -95,16 +95,17 @@ func (c *ChannelBind) BatchSize() int { return 1 } func (c *ChannelBind) SetMark(mark uint32) error { return nil } func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { - return func(bufs []*buffer.Buffer, sizes []int, eps []conn.Endpoint) (n int, err error) { + return func(bufs []*buffer.Buffer, eps []conn.Endpoint) (n int, err error) { select { case <-c.closeSignal: return 0, net.ErrClosed case rx := <-ch: - if bufs[0] == nil { - bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err } - copied := copy(bufs[0].Data(), rx) - sizes[0] = copied + copied := copy(bufs[0].Bytes(), rx) + bufs[0].SetLen(copied) eps[0] = c.target6 return 1, nil } diff --git a/conn/conn.go b/conn/conn.go index c674cf849..c46913431 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -21,14 +21,12 @@ const ( IdealBatchSize = 128 // maximum number of packets handled per read and write ) -// A ReceiveFunc receives at least one packet from the network and writes them -// into packets. On a successful read it returns the number of elements of -// sizes, packets, and endpoints that should be evaluated. Some elements of -// sizes may be zero, and callers should ignore them. Callers must pass a sizes -// and eps slice with a length greater than or equal to the length of packets. -// These lengths must not exceed the length of the associated Bind.BatchSize(). -// Nil entries in bufs are allocated by the implementation. -type ReceiveFunc func(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (n int, err error) +// A ReceiveFunc receives at least one packet from the network into bufs. +// On a successful read it returns the number of elements of bufs and eps +// that should be evaluated. Callers must pass an eps slice with a length +// greater than or equal to the length of bufs. These lengths must not +// exceed the length of the associated Bind.BatchSize(). +type ReceiveFunc func(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // diff --git a/conn/conn_test.go b/conn/conn_test.go index 77d493ea0..3b9e92c2f 100644 --- a/conn/conn_test.go +++ b/conn/conn_test.go @@ -13,7 +13,7 @@ import ( func TestPrettyName(t *testing.T) { var ( - recvFunc ReceiveFunc = func(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (n int, err error) { return } + recvFunc ReceiveFunc = func(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) { return } ) const want = "TestPrettyName" diff --git a/device/channels.go b/device/channels.go index 7fff0febd..c1be585b9 100644 --- a/device/channels.go +++ b/device/channels.go @@ -93,7 +93,6 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { case elemsContainer := <-q.c: elemsContainer.Lock() for _, elem := range elemsContainer.elems { - elem.buf.Release() device.PutInboundElement(elem) } device.PutInboundElementsContainer(elemsContainer) @@ -126,7 +125,6 @@ func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { case elemsContainer := <-q.c: elemsContainer.Lock() for _, elem := range elemsContainer.elems { - elem.buf.Release() device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) diff --git a/device/device_test.go b/device/device_test.go index 3c00c91f8..9454db6e5 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -438,7 +438,7 @@ type fakeTUNDeviceSized struct { } func (t *fakeTUNDeviceSized) File() *os.File { return nil } -func (t *fakeTUNDeviceSized) Read(bufs []*buffer.Buffer, sizes []int, offset int) (n int, err error) { +func (t *fakeTUNDeviceSized) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { return 0, nil } func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } diff --git a/device/pools.go b/device/pools.go index ee3cb46d9..b1988a424 100644 --- a/device/pools.go +++ b/device/pools.go @@ -63,7 +63,7 @@ func (device *Device) PopulatePools() { device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new(QueueOutboundElement) }) - device.pool.messageBuffers = buffer.NewFragmentPool() + device.pool.messageBuffers = buffer.DefaultSource } func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { @@ -107,6 +107,7 @@ func (device *Device) GetInboundElement() *QueueInboundElement { } func (device *Device) PutInboundElement(elem *QueueInboundElement) { + buffer.Release(elem.buf) elem.clearPointers() device.pool.inboundElements.Put(elem) } @@ -116,6 +117,7 @@ func (device *Device) GetOutboundElement() *QueueOutboundElement { } func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { + buffer.Release(elem.buf) elem.clearPointers() device.pool.outboundElements.Put(elem) } diff --git a/device/receive.go b/device/receive.go index 9dbea3f38..df4a43870 100644 --- a/device/receive.go +++ b/device/receive.go @@ -86,7 +86,6 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive var ( bufs = make([]*buffer.Buffer, maxBatchSize) // nil entries; recv allocates - sizes = make([]int, maxBatchSize) err error count int endpoints = make([]conn.Endpoint, maxBatchSize) @@ -99,8 +98,12 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive }() for { - count, err = recv(bufs, sizes, endpoints) + count, err = recv(bufs, endpoints) if err != nil { + if errors.Is(err, buffer.ErrSizeExceedsCap) { + // This is expected if TUN uses [buffer.CappedSource]. Spin until some memory is freed. + continue + } if errors.Is(err, net.ErrClosed) { return } @@ -118,14 +121,14 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive deathSpiral = 0 // handle each packet in the batch - for i, size := range sizes[:count] { - if size < MinMessageSize { + for i := 0; i < count; i++ { + if len(bufs[i].Bytes()) < MinMessageSize { continue } // check size of packet - packet := bufs[i].Data()[:size] + packet := bufs[i].Bytes() msgType := binary.LittleEndian.Uint32(packet[:4]) switch msgType { @@ -209,19 +212,13 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive default: } } - for i := 0; i < count; i++ { - if bufs[i] != nil { - bufs[i].Release() - bufs[i] = nil - } - } + buffer.ReleaseAll(bufs[:count]) for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer device.queue.decryption.c <- elemsContainer } else { for _, elem := range elemsContainer.elems { - elem.buf.Release() device.PutInboundElement(elem) } device.PutInboundElementsContainer(elemsContainer) @@ -418,7 +415,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - elem.buf.Release() + buffer.Release(elem.buf) } } @@ -430,8 +427,8 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { }() device.log.Verbosef("%v - Routine: sequential receiver - started", peer) - writeBufs := make([]*buffer.Buffer, 0, maxBatchSize) - legacyBufs := make([][]byte, 0, maxBatchSize) + writeBufs := make([]*buffer.Buffer, 0, maxBatchSize) // reference to transferred buffers; released after batch write + writeSlices := make([][]byte, 0, maxBatchSize) // slices of the above buffers; passed to TUN device; not used after write for elemsContainer := range peer.queue.inbound.c { if elemsContainer == nil { @@ -509,7 +506,8 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { continue } - legacyBufs = append(legacyBufs, elem.buf.Data()[:MessageTransportOffsetContent+len(elem.packet)]) + elem.buf.SetLen(MessageTransportOffsetContent + len(elem.packet)) + writeSlices = append(writeSlices, elem.buf.Bytes()) writeBufs = append(writeBufs, elem.buf) elem.buf = nil // ownership transferred to writeBufs } @@ -525,7 +523,7 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { peer.timersDataReceived() } if len(writeBufs) > 0 { - _, err := device.tun.device.Write(legacyBufs, MessageTransportOffsetContent) + _, err := device.tun.device.Write(writeSlices, MessageTransportOffsetContent) if err != nil && !device.isClosed() { device.log.Errorf("Failed to write packets to TUN device: %v", err) } @@ -533,13 +531,10 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { } // Release buffers for skipped elements (not transferred to writeBufs). for _, elem := range elemsContainer.elems { - if elem.buf != nil { - elem.buf.Release() - } device.PutInboundElement(elem) } writeBufs = writeBufs[:0] - legacyBufs = legacyBufs[:0] + writeSlices = writeSlices[:0] device.PutInboundElementsContainer(elemsContainer) } } diff --git a/device/send.go b/device/send.go index 0a79d662f..365ff015e 100644 --- a/device/send.go +++ b/device/send.go @@ -48,7 +48,7 @@ import ( type QueueOutboundElement struct { buf *buffer.Buffer - // packet is always a slice of "buffer". The starting offset in buffer + // packet is always a slice of buf. The starting offset in buf // is either: // a) MessageEncapsulatingTransportSize+MessageTransportHeaderSize (plaintext) // b) 0 (post-encryption) @@ -93,7 +93,6 @@ func (peer *Peer) SendKeepalive() { case peer.queue.staged <- elemsContainer: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: - elem.buf.Release() peer.device.PutOutboundElement(elem) peer.device.PutOutboundElementsContainer(elemsContainer) } @@ -130,15 +129,15 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { } buf := peer.device.GetMessageBuffer(MessageEncapsulatingTransportSize + MessageInitiationSize) - packet := buf.Data()[MessageEncapsulatingTransportSize:] + packet := buf.Bytes()[MessageEncapsulatingTransportSize:] _ = msg.marshal(packet) peer.cookieGenerator.AddMacs(packet) peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffers([][]byte{buf.Data()}) - buf.Release() + err = peer.SendBuffers([][]byte{buf.Bytes()}) + buffer.Release(buf) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -161,13 +160,13 @@ func (peer *Peer) SendHandshakeResponse() error { } buf := peer.device.GetMessageBuffer(MessageEncapsulatingTransportSize + MessageResponseSize) - packet := buf.Data()[MessageEncapsulatingTransportSize:] + packet := buf.Bytes()[MessageEncapsulatingTransportSize:] _ = response.marshal(packet) peer.cookieGenerator.AddMacs(packet) err = peer.BeginSymmetricSession() if err != nil { - buf.Release() + buffer.Release(buf) peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) return err } @@ -177,8 +176,8 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffers([][]byte{buf.Data()}) - buf.Release() + err = peer.SendBuffers([][]byte{buf.Bytes()}) + buffer.Release(buf) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -196,11 +195,11 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) } buf := device.GetMessageBuffer(MessageEncapsulatingTransportSize + MessageCookieReplySize) - packet := buf.Data()[MessageEncapsulatingTransportSize:] + packet := buf.Bytes()[MessageEncapsulatingTransportSize:] _ = reply.marshal(packet) - device.net.bind.Send([][]byte{buf.Data()}, initiatingElem.endpoint, MessageEncapsulatingTransportSize) - buf.Release() + device.net.bind.Send([][]byte{buf.Bytes()}, initiatingElem.endpoint, MessageEncapsulatingTransportSize) + buffer.Release(buf) return nil } @@ -231,7 +230,6 @@ func (device *Device) RoutineReadFromTUN() { bufs = make([]*buffer.Buffer, batchSize) elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) count = 0 - sizes = make([]int, batchSize) offset = MessageEncapsulatingTransportSize + MessageTransportHeaderSize ) @@ -240,10 +238,10 @@ func (device *Device) RoutineReadFromTUN() { }() for { - count, readErr = device.tun.device.Read(bufs, sizes, offset) + count, readErr = device.tun.device.Read(bufs, offset) for i := 0; i < count; i++ { - packet := bufs[i].Data()[offset : offset+sizes[i]] + packet := bufs[i].Bytes()[offset:] if len(packet) < 1 { continue } @@ -289,12 +287,7 @@ func (device *Device) RoutineReadFromTUN() { } // Release unconsumed buffers so stale right-sized entries // don't persist into the next Read call. - for i := 0; i < count; i++ { - if bufs[i] != nil { - bufs[i].Release() - bufs[i] = nil - } - } + buffer.ReleaseAll(bufs[:count]) for peer, elemsForPeer := range elemsByPeer { if peer.isRunning.Load() { @@ -302,7 +295,6 @@ func (device *Device) RoutineReadFromTUN() { peer.SendStagedPackets() } else { for _, elem := range elemsForPeer.elems { - elem.buf.Release() device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsForPeer) @@ -311,6 +303,10 @@ func (device *Device) RoutineReadFromTUN() { } if readErr != nil { + if errors.Is(readErr, buffer.ErrSizeExceedsCap) { + // This is expected if TUN uses [buffer.CappedSource]. Spin until some memory is freed. + continue + } if errors.Is(readErr, tun.ErrTooManySegments) { // TODO: record stat for this // This will happen if MSS is surprisingly small (< 576) @@ -339,7 +335,6 @@ func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { select { case tooOld := <-peer.queue.staged: for _, elem := range tooOld.elems { - elem.buf.Release() peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(tooOld) @@ -400,7 +395,6 @@ top: peer.device.queue.encryption.c <- elemsContainer } else { for _, elem := range elemsContainer.elems { - elem.buf.Release() peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(elemsContainer) @@ -420,7 +414,6 @@ func (peer *Peer) FlushStagedPackets() { select { case elemsContainer := <-peer.queue.staged: for _, elem := range elemsContainer.elems { - elem.buf.Release() peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(elemsContainer) @@ -460,7 +453,7 @@ func (device *Device) RoutineEncryption(id int) { for elemsContainer := range device.queue.encryption.c { for _, elem := range elemsContainer.elems { // populate header fields - header := elem.buf.Data()[MessageEncapsulatingTransportSize : MessageEncapsulatingTransportSize+MessageTransportHeaderSize] + header := elem.buf.Bytes()[MessageEncapsulatingTransportSize : MessageEncapsulatingTransportSize+MessageTransportHeaderSize] fieldType := header[0:4] fieldReceiver := header[4:8] @@ -485,7 +478,8 @@ func (device *Device) RoutineEncryption(id int) { ) // re-slice packet to include encapsulating transport space - elem.packet = elem.buf.Data()[:MessageEncapsulatingTransportSize+len(elem.packet)] + elem.buf.SetLen(MessageEncapsulatingTransportSize + len(elem.packet)) + elem.packet = elem.buf.Bytes() } elemsContainer.Unlock() } @@ -517,7 +511,6 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { // that we never accidentally keep timers alive longer than necessary. elemsContainer.Lock() for _, elem := range elemsContainer.elems { - elem.buf.Release() device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index b8824f674..9a494fbc5 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -120,20 +120,22 @@ func (tun *netTun) Events() <-chan tun.Event { return tun.events } -func (tun *netTun) Read(bufs []*wgbuf.Buffer, sizes []int, offset int) (int, error) { +func (tun *netTun) Read(bufs []*wgbuf.Buffer, offset int) (int, error) { view, ok := <-tun.incomingPacket if !ok { return 0, os.ErrClosed } - if bufs[0] == nil { - bufs[0] = wgbuf.New(make([]byte, wgbuf.MaxMessageSize), nil) + var err error + bufs[0], err = wgbuf.Ensure(bufs[0], wgbuf.MaxMessageSize, nil) + if err != nil { + return 0, err } - n, err := view.Read(bufs[0].Data()[offset:]) + n, err := view.Read(bufs[0].Bytes()[offset:]) if err != nil { return 0, err } - sizes[0] = n + bufs[0].SetLen(offset + n) return 1, nil } diff --git a/tun/offload.go b/tun/offload.go index f464c742c..454d16e88 100644 --- a/tun/offload.go +++ b/tun/offload.go @@ -75,15 +75,14 @@ const ( ipProtoUDP = 17 ) -// GSOSplit splits packets from 'in' into outBufs[][outOffset:], writing -// the size of each element into sizes. It returns the number of buffers -// populated, and/or an error. Callers may pass an 'in' slice that overlaps with -// the first element of outBuffers, i.e. &in[0] may be equal to -// &outBufs[0][outOffset]. GSONone is a valid options.GSOType regardless of the -// value of options.NeedsCsum. Length of each outBufs element must be greater -// than or equal to the length of 'in', otherwise output may be silently -// truncated. -func GSOSplit(in []byte, options GSOOptions, outBufs []*buffer.Buffer, pool buffer.Source, outOffset int) (int, error) { +// GSOSplit splits packets from 'in' into outBufs, allocating from src as +// needed. Each output buffer is sized to outOffset + segment length. +// It returns the number of buffers populated, and/or an error. +// Callers may pass an 'in' slice that overlaps with the first element +// of outBuffers, i.e. &in[0] may be equal to &outBufs[0].BytesAt(outOffset). +// GSONone is a valid options.GSOType regardless of the value of +// options.NeedsCsum. +func GSOSplit(in []byte, options GSOOptions, outBufs []*buffer.Buffer, src buffer.Source, outOffset int) (int, error) { cSumAt := int(options.CsumStart) + int(options.CsumOffset) if cSumAt+1 >= len(in) { return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) @@ -97,8 +96,10 @@ func GSOSplit(in []byte, options GSOOptions, outBufs []*buffer.Buffer, pool buff payloadLen := len(in) - int(options.HdrLen) if options.GSOType == GSONone || payloadLen < int(options.GSOSize) { needed := outOffset + len(in) - if outBufs[0] == nil { - outBufs[0], _ = pool.Get(needed) + var err error + outBufs[0], err = buffer.Ensure(outBufs[0], needed, src) + if err != nil { + return 0, err } if options.NeedsCsum { // The initial value at the checksum offset should be summed with @@ -107,7 +108,8 @@ func GSOSplit(in []byte, options GSOOptions, outBufs []*buffer.Buffer, pool buff in[cSumAt], in[cSumAt+1] = 0, 0 binary.BigEndian.PutUint16(in[cSumAt:], ^Checksum(in[options.CsumStart:], initial)) } - copy(outBufs[0].Data()[outOffset:], in) + n := copy(outBufs[0].Bytes()[outOffset:], in) + outBufs[0].SetLen(outOffset + n) return 1, nil } @@ -167,11 +169,12 @@ func GSOSplit(in []byte, options GSOOptions, outBufs []*buffer.Buffer, pool buff } segmentDataLen := nextSegmentEnd - nextSegmentDataAt totalLen := int(options.HdrLen) + segmentDataLen - needed := outOffset + totalLen - if outBufs[i] == nil { - outBufs[i], _ = pool.Get(needed) + var err error + outBufs[i], err = buffer.Ensure(outBufs[i], outOffset+totalLen, src) + if err != nil { + return i, err } - out := outBufs[i].Data()[outOffset:] + out := outBufs[i].Bytes()[outOffset:] copy(out, in[:iphLen]) if ipVersion == 4 { diff --git a/tun/offload_linux.go b/tun/offload_linux.go index 864d28244..c7a06d401 100644 --- a/tun/offload_linux.go +++ b/tun/offload_linux.go @@ -429,12 +429,17 @@ const ( // coalesceUDPPackets attempts to coalesce pkt with the packet described by // item, and returns the outcome. -func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool, arena *buffer.Arena) coalesceResult { +func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool, bufSrc *buffer.LoggingSource) coalesceResult { if head := bufs[item.bufsIndex]; cap(head) < buffer.MaxMessageSize { - new := arena.Get(buffer.MaxMessageSize) - n := copy(new, head) - new = new[:n] - bufs[item.bufsIndex] = new + b, err := bufSrc.Get(buffer.MaxMessageSize) + if err == nil { + // Upsize head buffer to max. Abandon buffer struct, bufSrc tracks it. + new := b.Bytes() + n := copy(new, head) + new = new[:n] + bufs[item.bufsIndex] = new + } + // If we fail to get a new buffer, attempt to coalesce whatever fits. } pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front headersLen := item.iphLen + udphLen @@ -465,12 +470,17 @@ func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset // item, and returns the outcome. This function may swap bufs elements in the // event of a prepend as item's bufs index is already being tracked for writing // to a Device. -func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool, arena *buffer.Arena) coalesceResult { +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool, bufSrc *buffer.LoggingSource) coalesceResult { if head := bufs[item.bufsIndex]; cap(head) < buffer.MaxMessageSize { - new := arena.Get(buffer.MaxMessageSize) - n := copy(new, head) - new = new[:n] - bufs[item.bufsIndex] = new + b, err := bufSrc.Get(buffer.MaxMessageSize) + if err == nil { + // Upsize head buffer to max. Abandon buffer struct, bufSrc tracks it. + new := b.Bytes() + n := copy(new, head) + new = new[:n] + bufs[item.bufsIndex] = new + } + // If we fail to get a new buffer, attempt to coalesce whatever fits. } var pktHead []byte // the packet that will end up at the front headersLen := item.iphLen + item.tcphLen @@ -556,7 +566,7 @@ const ( // action was taken, groResultTableInsert when the evaluated packet was // inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool, arena *buffer.Arena) groResult { +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool, bufSrc *buffer.LoggingSource) groResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. @@ -628,7 +638,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool, item := items[i] can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) if can != coalesceUnavailable { - result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6, arena) + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6, bufSrc) switch result { case coalesceSuccess: table.updateAt(item, i) @@ -811,7 +821,7 @@ const ( // action was taken, groResultTableInsert when the evaluated packet was // inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. -func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool, arena *buffer.Arena) groResult { +func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool, bufSrc *buffer.LoggingSource) groResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. @@ -864,7 +874,7 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool, can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) var pktCSumKnownInvalid bool if can == coalesceAppend { - result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6, arena) + result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6, bufSrc) switch result { case coalesceSuccess: table.updateAt(item, len(items)-1) @@ -890,7 +900,7 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool, // empty (but non-nil), and are passed in to save allocs as the caller may reset // and recycle them across vectors of packets. gro indicates if TCP and UDP GRO // are supported/enabled. -func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int, arena *buffer.Arena) error { +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int, bufSrc *buffer.LoggingSource) error { for i := range bufs { if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") @@ -898,13 +908,13 @@ func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGR var result groResult switch packetIsGROCandidate(bufs[i][offset:], gro) { case tcp4GROCandidate: - result = tcpGRO(bufs, offset, i, tcpTable, false, arena) + result = tcpGRO(bufs, offset, i, tcpTable, false, bufSrc) case tcp6GROCandidate: - result = tcpGRO(bufs, offset, i, tcpTable, true, arena) + result = tcpGRO(bufs, offset, i, tcpTable, true, bufSrc) case udp4GROCandidate: - result = udpGRO(bufs, offset, i, udpTable, false, arena) + result = udpGRO(bufs, offset, i, udpTable, false, bufSrc) case udp6GROCandidate: - result = udpGRO(bufs, offset, i, udpTable, true, arena) + result = udpGRO(bufs, offset, i, udpTable, true, bufSrc) } switch result { case groResultNoop: diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go index a2db02e19..9dce3bb45 100644 --- a/tun/offload_linux_test.go +++ b/tun/offload_linux_test.go @@ -237,7 +237,7 @@ func Test_handleVirtioRead(t *testing.T) { t.Run(tt.name, func(t *testing.T) { out := make([]*buffer.Buffer, conn.IdealBatchSize) tt.hdr.encode(tt.pktIn) - n, err := handleVirtioRead(tt.pktIn, out, buffer.NewFragmentPool(), offset) + n, err := handleVirtioRead(tt.pktIn, out, buffer.DefaultSource, offset) if err != nil { if tt.wantErr { return @@ -248,8 +248,8 @@ func Test_handleVirtioRead(t *testing.T) { t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) } for i := range tt.wantLens { - if tt.wantLens[i]+offset != out[i].Len() { - t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i]+offset, out[i].Len()) + if tt.wantLens[i] != len(out[i].Bytes())-offset { + t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], len(out[i].Bytes())-offset) } } }) @@ -287,8 +287,9 @@ func Fuzz_handleGRO(f *testing.F) { f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, gro int, offset int) { pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} toWrite := make([]int, 0, len(pkts)) - arena := &buffer.Arena{Buffer: buffer.New(make([]byte, 1<<20), nil)} - handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), groDisablementFlags(gro), &toWrite, arena) + bufSrc := &buffer.LoggingSource{Source: buffer.DefaultSource} + defer bufSrc.ReleaseAll() + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), groDisablementFlags(gro), &toWrite, bufSrc) if len(toWrite) > len(pkts) { t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) } @@ -505,8 +506,9 @@ func Test_handleGRO(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { toWrite := make([]int, 0, len(tt.pktsIn)) - arena := &buffer.Arena{Buffer: buffer.New(make([]byte, 1<<20), nil)} - err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.gro, &toWrite, arena) + bufSrc := &buffer.LoggingSource{Source: buffer.DefaultSource} + defer bufSrc.ReleaseAll() + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.gro, &toWrite, bufSrc) if err != nil { if tt.wantErr { return diff --git a/tun/offload_test.go b/tun/offload_test.go index d0e536492..13e201668 100644 --- a/tun/offload_test.go +++ b/tun/offload_test.go @@ -84,15 +84,10 @@ func Fuzz_GSOSplit(f *testing.F) { GSOSize: gsoSize, NeedsCsum: needsCsum, } - n, _ := GSOSplit(pkt, options, out, buffer.NewFragmentPool(), 0) + n, _ := GSOSplit(pkt, options, out, buffer.DefaultSource, 0) if n > len(out) { - t.Errorf("n (%d) > len(sizes): %d", n, len(out)) - } - for i := range out { - if out[i] != nil { - out[i].Release() - out[i] = nil - } + t.Errorf("n (%d) > len(out): %d", n, len(out)) } + buffer.ReleaseAll(out) }) } diff --git a/tun/tun.go b/tun/tun.go index 7a79c0e8f..72c7edb72 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -25,11 +25,11 @@ type Device interface { // Read one or more packets from the Device (without any additional headers). // On a successful read it returns the number of packets read, and sets - // packet lengths within the sizes slice. len(sizes) must be >= len(bufs). + // each buf's length to include the read data. // Nil entries in bufs are allocated by the implementation. // A nonzero offset can be used to instruct the Device on where to begin // reading into each element of the bufs slice. - Read(bufs []*buffer.Buffer, sizes []int, offset int) (n int, err error) + Read(bufs []*buffer.Buffer, offset int) (n int, err error) // Write one or more packets to the device (without any additional headers). // On a successful write it returns the number of packets written. A nonzero diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index f2f75bcaf..89133a900 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -218,7 +218,7 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { // TODO: the BSDs look very similar in Read() and Write(). They should be // collapsed, with platform-specific files containing the varying parts of // their implementations. @@ -226,15 +226,16 @@ func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, case err := <-tun.errors: return 0, err default: - if bufs[0] == nil { - bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err } - buf := bufs[0].Data()[offset-4:] + buf := bufs[0].Bytes()[offset-4:] n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - sizes[0] = n - 4 + bufs[0].SetLen(offset - 4 + n) return 1, err } } diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go index 6a532e1e6..7f68743e0 100644 --- a/tun/tun_freebsd.go +++ b/tun/tun_freebsd.go @@ -334,20 +334,21 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { select { case err := <-tun.errors: return 0, err default: - if bufs[0] == nil { - bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err } - buf := bufs[0].Data()[offset-4:] + buf := bufs[0].Bytes()[offset-4:] n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - sizes[0] = n - 4 + bufs[0].SetLen(offset - 4 + n) return 1, err } } diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 81bb83a1f..ff6f947dc 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -55,8 +55,8 @@ type NativeTun struct { udpGROTable *udpGROTable gro groDisablementFlags - bufPool buffer.Source - coalescedBufs buffer.Arena + bufferSource buffer.Source + coalescedBufs buffer.LoggingSource } type groDisablementFlags int @@ -358,7 +358,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { defer func() { tun.tcpGROTable.reset() tun.udpGROTable.reset() - tun.coalescedBufs.Flush() + tun.coalescedBufs.ReleaseAll() tun.writeOpMu.Unlock() }() var ( @@ -392,9 +392,8 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { } // handleVirtioRead splits in into bufs, leaving offset bytes at the front of -// each buffer. It mutates sizes to reflect the size of each element of bufs, -// and returns the number of packets read. -func handleVirtioRead(in []byte, bufs []*buffer.Buffer, pool buffer.Source, offset int) (int, error) { +// each buffer, and returns the number of packets read. +func handleVirtioRead(in []byte, bufs []*buffer.Buffer, src buffer.Source, offset int) (int, error) { var hdr virtioNetHdr err := hdr.decode(in) if err != nil { @@ -426,10 +425,10 @@ func handleVirtioRead(in []byte, bufs []*buffer.Buffer, pool buffer.Source, offs options.HdrLen = options.CsumStart + tcpHLen } - return GSOSplit(in, options, bufs, pool, offset) + return GSOSplit(in, options, bufs, src, offset) } -func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { tun.readOpMu.Lock() defer tun.readOpMu.Unlock() select { @@ -440,10 +439,11 @@ func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, if tun.vnetHdr { readInto = tun.readBuff[:] } else { - if bufs[0] == nil { - bufs[0], _ = tun.bufPool.Get(buffer.MaxMessageSize) + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err } - readInto = bufs[0].Data()[offset:] + readInto = bufs[0].Bytes()[offset:] } n, err := tun.tunFile.Read(readInto) if errors.Is(err, syscall.EBADFD) { @@ -453,16 +453,9 @@ func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, return 0, err } if tun.vnetHdr { - count, err := handleVirtioRead(readInto[:n], bufs, tun.bufPool, offset) - if err != nil { - return 0, err - } - for i := 0; i < count; i++ { - sizes[i] = len(bufs[i].Data()) - offset - } - return count, nil + return handleVirtioRead(readInto[:n], bufs, tun.bufferSource, offset) } else { - sizes[0] = n + bufs[0].SetLen(offset + n) return 1, nil } } @@ -594,8 +587,6 @@ func CreateTUN(name string, mtu int) (Device, error) { // CreateTUNFromFile creates a Device from an os.File with the provided MTU. func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { - bufPool := buffer.NewFragmentPool() - arenaBuf, _ := bufPool.Get(32 * 65 << 10) // 32 buffers of coalesced size tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), @@ -604,9 +595,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tcpGROTable: newTCPGROTable(), udpGROTable: newUDPGROTable(), toWrite: make([]int, 0, conn.IdealBatchSize), - bufPool: bufPool, - coalescedBufs: buffer.Arena{Buffer: arenaBuf}, - } + bufferSource: buffer.DefaultSource, + coalescedBufs: buffer.LoggingSource{Source: buffer.DefaultSource}} name, err := tun.Name() if err != nil { @@ -655,8 +645,6 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { return nil, "", err } file := os.NewFile(uintptr(fd), "/dev/tun") - bufPool := buffer.NewFragmentPool() - arenaBuf, _ := bufPool.Get(32 * 65 << 10) // 32 buffers of coalesced size tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), @@ -664,8 +652,8 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { tcpGROTable: newTCPGROTable(), udpGROTable: newUDPGROTable(), toWrite: make([]int, 0, conn.IdealBatchSize), - bufPool: bufPool, - coalescedBufs: buffer.Arena{Buffer: arenaBuf}, + bufferSource: buffer.DefaultSource, + coalescedBufs: buffer.LoggingSource{Source: buffer.DefaultSource}, } name, err := tun.Name() if err != nil { diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go index cdda7feaa..dfcd5247f 100644 --- a/tun/tun_openbsd.go +++ b/tun/tun_openbsd.go @@ -205,20 +205,21 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { select { case err := <-tun.errors: return 0, err default: - if bufs[0] == nil { - bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err } - buf := bufs[0].Data()[offset-4:] + buf := bufs[0].Bytes()[offset-4:] n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - sizes[0] = n - 4 + bufs[0].SetLen(offset - 4 + n) return 1, err } } diff --git a/tun/tun_plan9.go b/tun/tun_plan9.go index afe2ad7bb..fc821ff1f 100644 --- a/tun/tun_plan9.go +++ b/tun/tun_plan9.go @@ -83,22 +83,23 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { select { case err := <-tun.errors: return 0, err default: - if bufs[0] == nil { - bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err } - data := bufs[0].Data() + data := bufs[0].Bytes() n, err := tun.dataFile.Read(data[offset:]) if n == 1 && data[offset] == 0 { // EOF err = io.EOF n = 0 } - sizes[0] = n + bufs[0].SetLen(offset + n) return 1, err } } diff --git a/tun/tun_windows.go b/tun/tun_windows.go index caf6caa05..5d4a29e34 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -145,7 +145,7 @@ func (tun *NativeTun) BatchSize() int { // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. -func (tun *NativeTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { tun.running.Add(1) defer tun.running.Done() retry: @@ -162,11 +162,12 @@ retry: switch err { case nil: packetSize := len(packet) - if bufs[0] == nil { - bufs[0] = buffer.New(make([]byte, packetSize), nil) + bufs[0], err = buffer.Ensure(bufs[0], offset+packetSize, nil) + if err != nil { + return 0, err } - copy(bufs[0].Data()[offset:], packet) - sizes[0] = packetSize + copy(bufs[0].Bytes()[offset:], packet) + bufs[0].SetLen(offset + packetSize) tun.session.ReleaseReceivePacket(packet) tun.rate.update(uint64(packetSize)) return 1, nil diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index 0873163ef..e56fbd529 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -111,16 +111,18 @@ type chTun struct { func (t *chTun) File() *os.File { return nil } -func (t *chTun) Read(bufs []*buffer.Buffer, sizes []int, offset int) (int, error) { +func (t *chTun) Read(bufs []*buffer.Buffer, offset int) (int, error) { select { case <-t.c.closed: return 0, os.ErrClosed case msg := <-t.c.Outbound: - if bufs[0] == nil { - bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + var err error + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err } - n := copy(bufs[0].Data()[offset:], msg) - sizes[0] = n + n := copy(bufs[0].Bytes()[offset:], msg) + bufs[0].SetLen(offset + n) return 1, nil } }