diff --git a/buffer/buffer.go b/buffer/buffer.go new file mode 100644 index 000000000..2106554c7 --- /dev/null +++ b/buffer/buffer.go @@ -0,0 +1,131 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +// Package buffer implements a reusable buffer abstraction. +// +// Wireguard-go's data processing is constrained by both the hosts API, +// and the transformations performed during encapsulation: +// +// 1. Encryption requires tail- and headroom for extra headers and padding. +// Available via winrio, and pread(2). +// 2. Systems are moving towards coalesced reads for both TCP and UDP. +// The read data has no gaps for individual slices. +// 3. crypto.AEAD interface requires a contiguous dst []byte for Sealing. +// So we can't use scatter-gather to inject the gaps. +// +// Until one of these three conditions is changed, the encryption strategy +// is to copy on read into buffers with the required gaps. +// The buffers are right-sized for the packet to avoid memory inflation. +// To recycle said allocations, each buffer carries a recycle function +// that routes it back to its originating pool. +// +// Decryption shrinks each fragment instead of growing, so buffers can pass +// through the pipeline without copying till the egress coalescion. +// Depending on the chosen head of the coalescion, there may or may be no room +// and reallocation is a necessary fallback until we start passing +// buffers in batches. + +package buffer + +import "fmt" + +// Recycler holds state necessary for a correct Buffer return to its originating Source +type Recycler interface { + Recycle(*Buffer) +} + +// RecycleFunc adapts arbitrary closures to the Recycler interface. +type RecycleFunc func(*Buffer) + +func (f RecycleFunc) Recycle(b *Buffer) { + f(b) +} + +// Buffer is a reusable slice of bytes of fixed length. +// Buffer or its Data must not be retained past Release. +type Buffer struct { + // Data len tracks valid payload + offset (offset passed out of band). + // It starts equal to the requested size for new buffers, and can be adjusted. + // Data is read-only, exposed for convenience but should never be assigned to. + // Buffer methods maintain the invariant that Data[offset:len(Data)] is valid payload, + // SetLen and Shift let callers adjust the offset and length of the valid payload as needed. + data []byte + recycler Recycler +} + +// New creates Buffer referencing the provided Recycler. +func New(b []byte, recycler Recycler) *Buffer { + return &Buffer{data: b, recycler: recycler} +} + +// Make creates Buffer with a new byte slice of the requested size. +func Make(size int) *Buffer { + buf, _ := DefaultSource.Get(size) // fragment pool never errors + return buf +} + +// Bytes returns the valid data in the Buffer. +func (b *Buffer) Bytes() []byte { + return b.data +} + +// BytesAt returns the valid data in the Buffer starting at offset. +func (b *Buffer) BytesAt(offset int) []byte { + return b.data[offset:] +} + +// SetLen sets the length of the valid data in the Buffer. +// Intended to be used for truncating the valid data post read, +// or extending post encryption. Does not check the capacity. +func (b *Buffer) SetLen(l int) { + b.data = b.data[:l] +} + +// Ensure returns a Buffer of the requested len, with the valid data from the provided Buffer. +// The returned Buffer may be the same as the provided Buffer if it has sufficient capacity, or a new Buffer otherwise. +// Safe to call on a nil Buffer. +func Ensure(b *Buffer, size int, src Source) (*Buffer, error) { + if src == nil { + src = DefaultSource + } + if b == nil { + return src.Get(size) + } + if size > cap(b.data) { + bb, err := src.Get(size) + if err != nil { + return nil, err + } + n := copy(bb.data, b.data) + if n != len(b.data) { + panic(fmt.Sprintf("short copy: %d != %d", n, len(b.data))) + } + Release(b) + return bb, nil + } + b.data = b.data[:size] + return b, nil +} + +// Release returns Buffer to its Source for reuse. +// Safe to call on a nil Buffer. +func Release(b *Buffer) { + if b == nil { + return + } + b.data = b.data[:cap(b.data)] + clear(b.data) + if b.recycler != nil { + b.recycler.Recycle(b) + } +} + +// ReleaseAll calls Release on each non-nil Buffer in the slice, and sets the slice elements to nil. +func ReleaseAll(bs []*Buffer) { + for i := range bs { + Release(bs[i]) + bs[i] = nil + } +} diff --git a/buffer/constants.go b/buffer/constants.go new file mode 100644 index 000000000..83bc85707 --- /dev/null +++ b/buffer/constants.go @@ -0,0 +1,10 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ +package buffer + +const ( + // MaxMessageSize is the largest buffer that callers may request from a Source. + MaxMessageSize = MaxSegmentSize +) diff --git a/buffer/constants_android.go b/buffer/constants_android.go new file mode 100644 index 000000000..86cf297e6 --- /dev/null +++ b/buffer/constants_android.go @@ -0,0 +1,13 @@ +//go:build android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package buffer + +const ( + MaxSegmentSize = 2200 // largest possible Android read + MaxBytesPerSource = 4096 * MaxSegmentSize +) diff --git a/buffer/constants_default.go b/buffer/constants_default.go new file mode 100644 index 000000000..ee2c556c8 --- /dev/null +++ b/buffer/constants_default.go @@ -0,0 +1,13 @@ +//go:build !android && !ios && !windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package buffer + +const ( + MaxSegmentSize = (1 << 16) - 1 // largest possible Unix read + MaxBytesPerSource = 0 // Disable and allow for infinite memory growth +) diff --git a/buffer/constants_ios.go b/buffer/constants_ios.go new file mode 100644 index 000000000..4d4e06604 --- /dev/null +++ b/buffer/constants_ios.go @@ -0,0 +1,17 @@ +//go:build ios + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package buffer + +// Fit within memory limits for iOS's Network Extension API, which has stricter requirements. +// These are vars instead of consts, because heavier network extensions might want to reduce +// them further. +var ( + MaxBytesPerSource = 1024 * MaxSegmentSize +) + +const MaxSegmentSize = 1700 // largest possible iOS read diff --git a/buffer/constants_windows.go b/buffer/constants_windows.go new file mode 100644 index 000000000..7f2f4e8b6 --- /dev/null +++ b/buffer/constants_windows.go @@ -0,0 +1,13 @@ +//go:build windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package buffer + +const ( + MaxSegmentSize = 2048 - 32 // largest possible Windows read + MaxBytesPerSource = 0 // Disable and allow for infinite memory growth +) diff --git a/buffer/source.go b/buffer/source.go new file mode 100644 index 000000000..fa75a2120 --- /dev/null +++ b/buffer/source.go @@ -0,0 +1,175 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ +package buffer + +import ( + "fmt" + "sync" + "sync/atomic" +) + +const ( + min = 2 << 10 // 2KB, typical MTU-sized packet + mid = 10 << 10 // 10KB, jumbo frame + max = 65 << 10 // 65KB, max UDP datagram read +) + +// Source produces new Buffers. +type Source interface { + // Get returns a Buffer of exactly requested len and at least the requested cap. + // Implementations may return error if the request can not be fulfilled. + Get(size int) (*Buffer, error) +} + +// DefaultSource is a package-level Source of Buffers. +// Used by Make and Ensure when no Source is provided. +var DefaultSource Source = &CappedSource{s: NewPoolSource(), cap: int64(MaxBytesPerSource)} + +var ( + _ Source = (*PoolSource)(nil) + _ Source = (*LoggingSource)(nil) + _ Source = (*CappedSource)(nil) +) + +// PoolSource is a tiered [Source] of buffers. Tiers are balanced +// to accommodate regular MTU sizes, jumbo frames, and the maximum possible UDP datagram size. +// PoolSource never errors, instead allocating a GC-managed buffer for requests that exceed the max tier size. +type PoolSource struct { + minPool sync.Pool + midPool sync.Pool + maxPool sync.Pool +} + +type poolRecycler struct { + *sync.Pool +} + +func (p *poolRecycler) Recycle(b *Buffer) { + p.Put(b) +} + +// NewPoolSource returns a PoolSource with the default tier sizes. +func NewPoolSource() *PoolSource { + p := new(PoolSource) + p.minPool.New = func() any { + return &Buffer{data: make([]byte, min), recycler: &poolRecycler{&p.minPool}} + } + p.midPool.New = func() any { + return &Buffer{data: make([]byte, mid), recycler: &poolRecycler{&p.midPool}} + } + p.maxPool.New = func() any { + return &Buffer{data: make([]byte, max), recycler: &poolRecycler{&p.maxPool}} + } + return p +} + +func (p *PoolSource) Get(size int) (*Buffer, error) { + var buf *Buffer + switch { + case size <= min: + buf = p.minPool.Get().(*Buffer) + case size <= mid: + buf = p.midPool.Get().(*Buffer) + case size <= max: + buf = p.maxPool.Get().(*Buffer) + default: + return &Buffer{data: make([]byte, size)}, nil + } + buf.data = buf.data[:size] + return buf, nil +} + +// LoggingSource is a Source that keeps track of all Buffers it has produced. +// Use when the buffers are not retained and can not be released back individually. +// Not safe for concurrent use. +type LoggingSource struct { + Source + log []*Buffer +} + +// Get returns a Buffer and records it for later bulk release. +func (l *LoggingSource) Get(size int) (*Buffer, error) { + buf, err := l.Source.Get(size) + if err != nil { + return nil, err + } + l.log = append(l.log, buf) + return buf, nil +} + +// Log returns all Buffers produced by this source. +func (l *LoggingSource) Log() []*Buffer { + return l.log +} + +// ReleaseAll releases all tracked Buffers and resets the log. +func (l *LoggingSource) ReleaseAll() { + ReleaseAll(l.log) + l.log = l.log[:0] +} + +// CappedSource is a Source that tracks the total capacity of all Buffers +// it has produced and returns an error if a request would cause the total +// to exceed a specified cap. +type CappedSource struct { + Used atomic.Int64 // public to expose metrics + + s Source + cap int64 + recyclers sync.Pool +} + +// NewCappedSource returns a CappedSource wrapping src with the given byte cap. +// A cap of zero or less disables the limit. +func NewCappedSource(src Source, cap int64) *CappedSource { + s := &CappedSource{ + s: src, + cap: cap, + } + s.recyclers = sync.Pool{ + New: func() any { + return &cappedRecycler{s: s} + }} + return s + +} + +// ErrSizeExceedsCap is returned by CappedSource.Get when the cap would be exceeded. +var ErrSizeExceedsCap = fmt.Errorf("buffer: request exceeds cap") + +func (p *CappedSource) Get(size int) (*Buffer, error) { + // This implementation acquires the buffer first to + // account for correct memory footprint via cap(). + // There is little point in optimizing this to just + // spin faster when we're over. + b, err := p.s.Get(size) + if err != nil { + return nil, err + } + if p.cap <= 0 { + return b, nil // uncapped path + } + charge := int64(cap(b.data)) + if new := p.Used.Add(charge); new > p.cap { + p.Used.Add(-charge) + Release(b) + return nil, ErrSizeExceedsCap + } + r := p.recyclers.Get().(*cappedRecycler) + r.next, b.recycler = b.recycler, r // wrap recycler + return b, nil +} + +type cappedRecycler struct { + s *CappedSource + next Recycler +} + +func (r *cappedRecycler) Recycle(b *Buffer) { + b.recycler, r.next = r.next, nil // unwrap recycler + r.s.Used.Add(-int64(cap(b.data))) + Release(b) // safe release, recycler may be nil + r.s.recyclers.Put(r) +} diff --git a/buffer/source_test.go b/buffer/source_test.go new file mode 100644 index 000000000..4ff11fb95 --- /dev/null +++ b/buffer/source_test.go @@ -0,0 +1,45 @@ +package buffer + +import ( + "errors" + "sync" + "testing" +) + +func BenchmarkCappedSource(b *testing.B) { + for _, sc := range []struct { + name string + src Source + }{ + {name: "baseline", src: NewPoolSource()}, + {name: "uncapped", src: NewCappedSource(NewPoolSource(), 0)}, + {name: "capped", src: NewCappedSource(NewPoolSource(), 2*MaxMessageSize)}, + } { + b.Run(sc.name, func(b *testing.B) { + var cappedErrCount int + consumer := make(chan *Buffer, 1) + var wg sync.WaitGroup + wg.Go(func() { + for b := range consumer { + Release(b) + } + }) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + buf, err := sc.src.Get(MaxMessageSize) + if err != nil { + if !errors.Is(err, ErrSizeExceedsCap) { + b.Fatalf("unexpected error: %v", err) + } + cappedErrCount++ + } + consumer <- buf + } + }) + close(consumer) + wg.Wait() + b.ReportMetric(float64(cappedErrCount), "capped") + }) + + } +} diff --git a/conn/bind_std.go b/conn/bind_std.go index fc0563456..507001e02 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -16,6 +16,7 @@ import ( "sync" "syscall" + "github.com/tailscale/wireguard-go/buffer" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) @@ -42,8 +43,9 @@ type StdNetBind struct { ipv6RxOffload bool // these two fields are not guarded by mu - udpAddrPool sync.Pool - msgsPool sync.Pool + udpAddrPool sync.Pool + msgsPool sync.Pool + bufferSource buffer.Source blackhole4 bool blackhole6 bool @@ -63,12 +65,14 @@ func NewStdNetBind() Bind { New: func() any { msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { - msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].Buffers = make(net.Buffers, 1, udpSegmentMaxDatagrams) msgs[i].OOB = make([]byte, controlSize) } return &msgs }, }, + + bufferSource: buffer.DefaultSource, } } @@ -204,7 +208,7 @@ again: func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { for i := range *msgs { - (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers[:1], OOB: (*msgs)[i].OOB} } s.msgsPool.Put(msgs) } @@ -230,36 +234,54 @@ func (s *StdNetBind) receiveIP( br batchReader, conn *net.UDPConn, rxOffload bool, - bufs [][]byte, - sizes []int, + bufs []*buffer.Buffer, eps []Endpoint, ) (n int, err error) { msgs := s.getMessages() - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] - } defer s.putMessages(msgs) var numMsgs int if runtime.GOOS == "linux" { if rxOffload { - readAt := len(*msgs) - 2 + const readBatch = 2 + readAt := len(*msgs) - readBatch + for i := readAt; i < readAt+readBatch; i++ { + bufs[i], err = buffer.Ensure(bufs[i], buffer.MaxMessageSize, s.bufferSource) + if err != nil { + return 0, err + } + (*msgs)[i].Buffers[0] = bufs[i].Bytes() + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) if err != nil { return 0, err } - numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize, bufs, s.bufferSource) if err != nil { return 0, err } } else { + for i := range bufs { + bufs[i], err = buffer.Ensure(bufs[i], buffer.MaxMessageSize, s.bufferSource) + if err != nil { + return 0, err + } + (*msgs)[i].Buffers[0] = bufs[i].Bytes() + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } numMsgs, err = br.ReadBatch(*msgs, 0) if err != nil { return 0, err } } } else { + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, s.bufferSource) + if err != nil { + return 0, err + } msg := &(*msgs)[0] + msg.Buffers[0] = bufs[0].Bytes() + msg.OOB = msg.OOB[:cap(msg.OOB)] msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) if err != nil { return 0, err @@ -268,8 +290,8 @@ func (s *StdNetBind) receiveIP( } for i := 0; i < numMsgs; i++ { msg := &(*msgs)[i] - sizes[i] = msg.N - if sizes[i] == 0 { + bufs[i].SetLen(msg.N) + if msg.N == 0 { continue } addrPort := msg.Addr.(*net.UDPAddr).AddrPort() @@ -281,14 +303,14 @@ func (s *StdNetBind) receiveIP( } func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + return func(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, eps) } } func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + return func(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, eps) } } @@ -452,10 +474,11 @@ type setGSOFunc func(control *[]byte, gsoSize uint16) func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offset int, msgs []ipv6.Message, setGSO setGSOFunc) int { var ( - base = -1 // index of msg we are currently coalescing into - gsoSize int // segmentation size of msgs[base] - dgramCnt int // number of dgrams coalesced into msgs[base] - endBatch bool // tracking flag to start a new batch on next iteration of bufs + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of bufs + coalescedLen int // bytes coalesced into msgs[base] ) maxPayloadLen := maxIPv4PayloadLen if ep.DstIP().Is6() { @@ -465,18 +488,16 @@ func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offs buf = buf[offset:] if i > 0 { msgLen := len(buf) - baseLenBefore := len(msgs[base].Buffers[0]) - freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore - if msgLen+baseLenBefore <= maxPayloadLen && + if msgLen+coalescedLen <= maxPayloadLen && msgLen <= gsoSize && - msgLen <= freeBaseCap && dgramCnt < udpSegmentMaxDatagrams && !endBatch { - msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) + msgs[base].Buffers = append(msgs[base].Buffers, buf) if i == len(bufs)-1 { setGSO(&msgs[base].OOB, uint16(gsoSize)) } dgramCnt++ + coalescedLen += msgLen if msgLen < gsoSize { // A smaller than gsoSize packet on the tail is legal, but // it must end the batch. @@ -493,6 +514,7 @@ func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offs endBatch = false base++ gsoSize = len(buf) + coalescedLen = len(buf) setSrcControl(&msgs[base].OOB, ep) msgs[base].Buffers[0] = buf msgs[base].Addr = addr @@ -503,7 +525,7 @@ func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offs type getGSOFunc func(control []byte) (int, error) -func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc, bufs []*buffer.Buffer, pool buffer.Source) (n int, err error) { for i := firstMsgAt; i < len(msgs); i++ { msg := &msgs[i] if msg.N == 0 { @@ -527,6 +549,13 @@ func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFu if n > i { return n, errors.New("splitting coalesced packet resulted in overflow") } + segLen := end - start + bufs[n], err = buffer.Ensure(bufs[n], segLen, pool) + if err != nil { + return 0, err + } + msgs[n].Buffers[0] = bufs[n].Bytes() + msgs[n].OOB = msgs[n].OOB[:cap(msgs[n].OOB)] copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) msgs[n].N = copied msgs[n].Addr = msg.Addr diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go index 77af0d925..3db3d1831 100644 --- a/conn/bind_std_test.go +++ b/conn/bind_std_test.go @@ -1,10 +1,12 @@ package conn import ( + "bytes" "encoding/binary" "net" "testing" + "github.com/tailscale/wireguard-go/buffer" "golang.org/x/net/ipv6" ) @@ -15,15 +17,14 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { t.Fatal(err) } bind.Close() - bufs := make([][]byte, 1) - bufs[0] = make([]byte, 1) - sizes := make([]int, 1) - eps := make([]Endpoint, 1) + bufs := make([]*buffer.Buffer, IdealBatchSize) + bufs[0] = buffer.New(make([]byte, buffer.MaxMessageSize), nil) + eps := make([]Endpoint, IdealBatchSize) for _, fn := range fns { // The ReceiveFuncs must not access conn-related fields on StdNetBind // unguarded. Close() nils the conn-related fields resulting in a panic // if they violate the mutex. - fn(bufs, sizes, eps) + fn(bufs, eps) } } @@ -82,8 +83,8 @@ func Test_coalesceMessages(t *testing.T) { make([]byte, 2, 2), make([]byte, 2, 2), }, - wantLens: []int{4, 2}, - wantGSO: []int{2, 0}, + wantLens: []int{6}, + wantGSO: []int{2}, }, } @@ -106,9 +107,12 @@ func Test_coalesceMessages(t *testing.T) { if msgs[i].Addr != addr { t.Errorf("msgs[%d].Addr != passed addr", i) } - gotLen := len(msgs[i].Buffers[0]) + var gotLen int + for _, b := range msgs[i].Buffers { + gotLen += len(b) + } if gotLen != tt.wantLens[i] { - t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + t.Errorf("len(msgs[%d].Buffers) %d != %d", i, gotLen, tt.wantLens[i]) } gotGSO, err := mockGetGSOSize(msgs[i].OOB) if err != nil { @@ -233,7 +237,11 @@ func Test_splitCoalescedMessages(t *testing.T) { for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize) + bufs := make([]*buffer.Buffer, len(tt.msgs)) + for i := range tt.msgs { + bufs[i] = buffer.New(tt.msgs[i].Buffers[0], nil) + } + got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize, bufs, buffer.DefaultSource) if err != nil && !tt.wantErr { t.Fatalf("err: %v", err) } @@ -245,6 +253,11 @@ func Test_splitCoalescedMessages(t *testing.T) { t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) } } + for i := range got { + if !bytes.Equal(bufs[i].Bytes(), tt.msgs[i].Buffers[0]) { + t.Fatalf("bufs[%d].Data and tt.msgs[%d] unequal", i, i) + } + } }) } } diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 737b475e1..99a32c23a 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -17,6 +17,7 @@ import ( "golang.org/x/sys/windows" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn/winrio" ) @@ -416,20 +417,28 @@ retry: return n, &ep, nil } -func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { +func (bind *WinRingBind) receiveIPv4(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) { bind.mu.RLock() defer bind.mu.RUnlock() - n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen) - sizes[0] = n + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err + } + n, ep, err := bind.v4.Receive(bufs[0].Bytes(), &bind.isOpen) + bufs[0].SetLen(n) eps[0] = ep return 1, err } -func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { +func (bind *WinRingBind) receiveIPv6(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) { bind.mu.RLock() defer bind.mu.RUnlock() - n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen) - sizes[0] = n + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err + } + n, ep, err := bind.v6.Receive(bufs[0].Bytes(), &bind.isOpen) + bufs[0].SetLen(n) eps[0] = ep return 1, err } diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 741b776c4..a9cc66f0c 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -12,6 +12,7 @@ import ( "net/netip" "os" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" ) @@ -94,13 +95,17 @@ func (c *ChannelBind) BatchSize() int { return 1 } func (c *ChannelBind) SetMark(mark uint32) error { return nil } func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { + return func(bufs []*buffer.Buffer, eps []conn.Endpoint) (n int, err error) { select { case <-c.closeSignal: return 0, net.ErrClosed case rx := <-ch: - copied := copy(bufs[0], rx) - sizes[0] = copied + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err + } + copied := copy(bufs[0].Bytes(), rx) + bufs[0].SetLen(copied) eps[0] = c.target6 return 1, nil } diff --git a/conn/conn.go b/conn/conn.go index f1781614d..c46913431 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -13,19 +13,20 @@ import ( "reflect" "runtime" "strings" + + "github.com/tailscale/wireguard-go/buffer" ) const ( IdealBatchSize = 128 // maximum number of packets handled per read and write ) -// A ReceiveFunc receives at least one packet from the network and writes them -// into packets. On a successful read it returns the number of elements of -// sizes, packets, and endpoints that should be evaluated. Some elements of -// sizes may be zero, and callers should ignore them. Callers must pass a sizes -// and eps slice with a length greater than or equal to the length of packets. -// These lengths must not exceed the length of the associated Bind.BatchSize(). -type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error) +// A ReceiveFunc receives at least one packet from the network into bufs. +// On a successful read it returns the number of elements of bufs and eps +// that should be evaluated. Callers must pass an eps slice with a length +// greater than or equal to the length of bufs. These lengths must not +// exceed the length of the associated Bind.BatchSize(). +type ReceiveFunc func(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // diff --git a/conn/conn_test.go b/conn/conn_test.go index c6194ee0c..3b9e92c2f 100644 --- a/conn/conn_test.go +++ b/conn/conn_test.go @@ -7,11 +7,13 @@ package conn import ( "testing" + + "github.com/tailscale/wireguard-go/buffer" ) func TestPrettyName(t *testing.T) { var ( - recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return } + recvFunc ReceiveFunc = func(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) { return } ) const want = "TestPrettyName" diff --git a/device/channels.go b/device/channels.go index e526f6bb1..c1be585b9 100644 --- a/device/channels.go +++ b/device/channels.go @@ -93,7 +93,6 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { case elemsContainer := <-q.c: elemsContainer.Lock() for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } device.PutInboundElementsContainer(elemsContainer) @@ -126,7 +125,6 @@ func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { case elemsContainer := <-q.c: elemsContainer.Lock() for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) diff --git a/device/device.go b/device/device.go index 0e720f251..267cb91c8 100644 --- a/device/device.go +++ b/device/device.go @@ -13,6 +13,7 @@ import ( "sync/atomic" "time" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/ratelimiter" "github.com/tailscale/wireguard-go/rwcancel" @@ -73,9 +74,9 @@ type Device struct { pool struct { inboundElementsContainer *WaitPool outboundElementsContainer *WaitPool - messageBuffers *WaitPool inboundElements *WaitPool outboundElements *WaitPool + messageBuffers buffer.Source } queue struct { diff --git a/device/device_test.go b/device/device_test.go index e44342170..9454db6e5 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/conn/bindtest" "github.com/tailscale/wireguard-go/tun" @@ -437,7 +438,7 @@ type fakeTUNDeviceSized struct { } func (t *fakeTUNDeviceSized) File() *os.File { return nil } -func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { +func (t *fakeTUNDeviceSized) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { return 0, nil } func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } diff --git a/device/pools.go b/device/pools.go index 55d2be7df..b1988a424 100644 --- a/device/pools.go +++ b/device/pools.go @@ -7,6 +7,8 @@ package device import ( "sync" + + "github.com/tailscale/wireguard-go/buffer" ) type WaitPool struct { @@ -55,15 +57,13 @@ func (device *Device) PopulatePools() { s := make([]*QueueOutboundElement, 0, device.BatchSize()) return &QueueOutboundElementsContainer{elems: s} }) - device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { - return new([MaxMessageSize]byte) - }) device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new(QueueInboundElement) }) device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new(QueueOutboundElement) }) + device.pool.messageBuffers = buffer.DefaultSource } func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { @@ -94,12 +94,12 @@ func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsConta device.pool.outboundElementsContainer.Put(c) } -func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) -} - -func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - device.pool.messageBuffers.Put(msg) +func (device *Device) GetMessageBuffer(size int) *buffer.Buffer { + b, err := device.pool.messageBuffers.Get(size) + if err != nil { + panic("failed to get buffer: " + err.Error()) + } + return b } func (device *Device) GetInboundElement() *QueueInboundElement { @@ -107,6 +107,7 @@ func (device *Device) GetInboundElement() *QueueInboundElement { } func (device *Device) PutInboundElement(elem *QueueInboundElement) { + buffer.Release(elem.buf) elem.clearPointers() device.pool.inboundElements.Put(elem) } @@ -116,6 +117,7 @@ func (device *Device) GetOutboundElement() *QueueOutboundElement { } func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { + buffer.Release(elem.buf) elem.clearPointers() device.pool.outboundElements.Put(elem) } diff --git a/device/pools_test.go b/device/pools_test.go index 2b16f3984..0d1009b8e 100644 --- a/device/pools_test.go +++ b/device/pools_test.go @@ -17,7 +17,8 @@ import ( func TestWaitPool(t *testing.T) { var wg sync.WaitGroup var trials atomic.Int32 - startTrials := int32(100000) + n := int32(runtime.NumCPU()) + startTrials := 125 * n * n if raceEnabled { // This test can be very slow with -race. startTrials /= 10 @@ -63,7 +64,7 @@ func TestWaitPool(t *testing.T) { } wg.Wait() if max.Load() != p.max { - t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) + t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max.Load(), p.max) } } diff --git a/device/receive.go b/device/receive.go index 56cde1047..df4a43870 100644 --- a/device/receive.go +++ b/device/receive.go @@ -13,6 +13,7 @@ import ( "sync" "time" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" @@ -23,11 +24,11 @@ type QueueHandshakeElement struct { msgType uint32 packet []byte endpoint conn.Endpoint - buffer *[MaxMessageSize]byte + buf *buffer.Buffer } type QueueInboundElement struct { - buffer *[MaxMessageSize]byte + buf *buffer.Buffer packet []byte counter uint64 keypair *Keypair @@ -44,7 +45,7 @@ type QueueInboundElementsContainer struct { // avoids accidentally keeping other objects around unnecessarily. // It also reduces the possible collateral damage from use-after-free bugs. func (elem *QueueInboundElement) clearPointers() { - elem.buffer = nil + elem.buf = nil elem.packet = nil elem.keypair = nil elem.endpoint = nil @@ -84,32 +85,25 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive // receive datagrams until conn is closed var ( - bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) - bufs = make([][]byte, maxBatchSize) + bufs = make([]*buffer.Buffer, maxBatchSize) // nil entries; recv allocates err error - sizes = make([]int, maxBatchSize) count int endpoints = make([]conn.Endpoint, maxBatchSize) deathSpiral int elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) - for i := range bufsArrs { - bufsArrs[i] = device.GetMessageBuffer() - bufs[i] = bufsArrs[i][:] - } - defer func() { - for i := 0; i < maxBatchSize; i++ { - if bufsArrs[i] != nil { - device.PutMessageBuffer(bufsArrs[i]) - } - } + buffer.ReleaseAll(bufs) }() for { - count, err = recv(bufs, sizes, endpoints) + count, err = recv(bufs, endpoints) if err != nil { + if errors.Is(err, buffer.ErrSizeExceedsCap) { + // This is expected if TUN uses [buffer.CappedSource]. Spin until some memory is freed. + continue + } if errors.Is(err, net.ErrClosed) { return } @@ -127,14 +121,14 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive deathSpiral = 0 // handle each packet in the batch - for i, size := range sizes[:count] { - if size < MinMessageSize { + for i := 0; i < count; i++ { + if len(bufs[i].Bytes()) < MinMessageSize { continue } // check size of packet - packet := bufsArrs[i][:size] + packet := bufs[i].Bytes() msgType := binary.LittleEndian.Uint32(packet[:4]) switch msgType { @@ -170,7 +164,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive peer := value.peer elem := device.GetInboundElement() elem.packet = packet - elem.buffer = bufsArrs[i] + elem.buf = bufs[i] elem.keypair = keypair elem.endpoint = endpoints[i] elem.counter = 0 @@ -182,8 +176,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive elemsByPeer[peer] = elemsForPeer } elemsForPeer.elems = append(elemsForPeer.elems, elem) - bufsArrs[i] = device.GetMessageBuffer() - bufs[i] = bufsArrs[i][:] + bufs[i] = nil // consumed; next recv allocates fresh continue // otherwise it is a fixed size & handshake related packet @@ -211,22 +204,21 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive select { case device.queue.handshake.c <- QueueHandshakeElement{ msgType: msgType, - buffer: bufsArrs[i], + buf: bufs[i], packet: packet, endpoint: endpoints[i], }: - bufsArrs[i] = device.GetMessageBuffer() - bufs[i] = bufsArrs[i][:] + bufs[i] = nil // consumed; next recv allocates fresh default: } } + buffer.ReleaseAll(bufs[:count]) for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer device.queue.decryption.c <- elemsContainer } else { for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } device.PutInboundElementsContainer(elemsContainer) @@ -423,7 +415,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.PutMessageBuffer(elem.buffer) + buffer.Release(elem.buf) } } @@ -435,7 +427,8 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { }() device.log.Verbosef("%v - Routine: sequential receiver - started", peer) - bufs := make([][]byte, 0, maxBatchSize) + writeBufs := make([]*buffer.Buffer, 0, maxBatchSize) // reference to transferred buffers; released after batch write + writeSlices := make([][]byte, 0, maxBatchSize) // slices of the above buffers; passed to TUN device; not used after write for elemsContainer := range peer.queue.inbound.c { if elemsContainer == nil { @@ -513,7 +506,10 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { continue } - bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) + elem.buf.SetLen(MessageTransportOffsetContent + len(elem.packet)) + writeSlices = append(writeSlices, elem.buf.Bytes()) + writeBufs = append(writeBufs, elem.buf) + elem.buf = nil // ownership transferred to writeBufs } peer.rxBytes.Add(rxBytesLen) @@ -526,17 +522,19 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { if dataPacketReceived { peer.timersDataReceived() } - if len(bufs) > 0 { - _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) + if len(writeBufs) > 0 { + _, err := device.tun.device.Write(writeSlices, MessageTransportOffsetContent) if err != nil && !device.isClosed() { device.log.Errorf("Failed to write packets to TUN device: %v", err) } + buffer.ReleaseAll(writeBufs) } + // Release buffers for skipped elements (not transferred to writeBufs). for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } - bufs = bufs[:0] + writeBufs = writeBufs[:0] + writeSlices = writeSlices[:0] device.PutInboundElementsContainer(elemsContainer) } } diff --git a/device/send.go b/device/send.go index 89269fc07..365ff015e 100644 --- a/device/send.go +++ b/device/send.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/tun" "golang.org/x/crypto/chacha20poly1305" @@ -46,8 +47,8 @@ import ( */ type QueueOutboundElement struct { - buffer *[MaxMessageSize]byte // slice holding the packet data - // packet is always a slice of "buffer". The starting offset in buffer + buf *buffer.Buffer + // packet is always a slice of buf. The starting offset in buf // is either: // a) MessageEncapsulatingTransportSize+MessageTransportHeaderSize (plaintext) // b) 0 (post-encryption) @@ -64,7 +65,7 @@ type QueueOutboundElementsContainer struct { func (device *Device) NewOutboundElement() *QueueOutboundElement { elem := device.GetOutboundElement() - elem.buffer = device.GetMessageBuffer() + elem.buf = device.GetMessageBuffer(MaxMessageSize) elem.nonce = 0 // keypair and peer were cleared (if necessary) by clearPointers. return elem @@ -75,7 +76,7 @@ func (device *Device) NewOutboundElement() *QueueOutboundElement { // avoids accidentally keeping other objects around unnecessarily. // It also reduces the possible collateral damage from use-after-free bugs. func (elem *QueueOutboundElement) clearPointers() { - elem.buffer = nil + elem.buf = nil elem.packet = nil elem.keypair = nil elem.peer = nil @@ -92,7 +93,6 @@ func (peer *Peer) SendKeepalive() { case peer.queue.staged <- elemsContainer: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: - peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) peer.device.PutOutboundElementsContainer(elemsContainer) } @@ -128,15 +128,16 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } - buf := make([]byte, MessageEncapsulatingTransportSize+MessageInitiationSize) - packet := buf[MessageEncapsulatingTransportSize:] + buf := peer.device.GetMessageBuffer(MessageEncapsulatingTransportSize + MessageInitiationSize) + packet := buf.Bytes()[MessageEncapsulatingTransportSize:] _ = msg.marshal(packet) peer.cookieGenerator.AddMacs(packet) peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffers([][]byte{buf}) + err = peer.SendBuffers([][]byte{buf.Bytes()}) + buffer.Release(buf) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -158,13 +159,14 @@ func (peer *Peer) SendHandshakeResponse() error { return err } - buf := make([]byte, MessageEncapsulatingTransportSize+MessageResponseSize) - packet := buf[MessageEncapsulatingTransportSize:] + buf := peer.device.GetMessageBuffer(MessageEncapsulatingTransportSize + MessageResponseSize) + packet := buf.Bytes()[MessageEncapsulatingTransportSize:] _ = response.marshal(packet) peer.cookieGenerator.AddMacs(packet) err = peer.BeginSymmetricSession() if err != nil { + buffer.Release(buf) peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) return err } @@ -174,7 +176,8 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffers([][]byte{buf}) + err = peer.SendBuffers([][]byte{buf.Bytes()}) + buffer.Release(buf) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -191,11 +194,12 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) return err } - buf := make([]byte, MessageEncapsulatingTransportSize+MessageCookieReplySize) - packet := buf[MessageEncapsulatingTransportSize:] + buf := device.GetMessageBuffer(MessageEncapsulatingTransportSize + MessageCookieReplySize) + packet := buf.Bytes()[MessageEncapsulatingTransportSize:] _ = reply.marshal(packet) - // TODO: allocation could be avoided - device.net.bind.Send([][]byte{buf}, initiatingElem.endpoint, MessageEncapsulatingTransportSize) + + device.net.bind.Send([][]byte{buf.Bytes()}, initiatingElem.endpoint, MessageEncapsulatingTransportSize) + buffer.Release(buf) return nil } @@ -223,57 +227,43 @@ func (device *Device) RoutineReadFromTUN() { var ( batchSize = device.BatchSize() readErr error - elems = make([]*QueueOutboundElement, batchSize) - bufs = make([][]byte, batchSize) + bufs = make([]*buffer.Buffer, batchSize) elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) count = 0 - sizes = make([]int, batchSize) offset = MessageEncapsulatingTransportSize + MessageTransportHeaderSize ) - for i := range elems { - elems[i] = device.NewOutboundElement() - bufs[i] = elems[i].buffer[:] - } - defer func() { - for _, elem := range elems { - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - } - } + buffer.ReleaseAll(bufs) }() for { - // read packets - count, readErr = device.tun.device.Read(bufs, sizes, offset) + count, readErr = device.tun.device.Read(bufs, offset) + for i := 0; i < count; i++ { - if sizes[i] < 1 { + packet := bufs[i].Bytes()[offset:] + if len(packet) < 1 { continue } - elem := elems[i] - elem.packet = bufs[i][offset : offset+sizes[i]] - // lookup peer var peer *Peer - switch elem.packet[0] >> 4 { + switch packet[0] >> 4 { case 4: - if len(elem.packet) < ipv4.HeaderLen { + if len(packet) < ipv4.HeaderLen { continue } - src := netip.AddrFrom4([4]byte(elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len])) - dst := netip.AddrFrom4([4]byte(elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len])) - peer = device.allowedips.LookupFromPacket(src, dst, elem.packet) + src := netip.AddrFrom4([4]byte(packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len])) + dst := netip.AddrFrom4([4]byte(packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len])) + peer = device.allowedips.LookupFromPacket(src, dst, packet) case 6: - if len(elem.packet) < ipv6.HeaderLen { + if len(packet) < ipv6.HeaderLen { continue } - src := netip.AddrFrom16([16]byte(elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len])) - dst := netip.AddrFrom16([16]byte(elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len])) - peer = device.allowedips.LookupFromPacket(src, dst, elem.packet) + src := netip.AddrFrom16([16]byte(packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len])) + dst := netip.AddrFrom16([16]byte(packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len])) + peer = device.allowedips.LookupFromPacket(src, dst, packet) default: device.log.Verbosef("Received packet with unknown IP version") @@ -282,15 +272,22 @@ func (device *Device) RoutineReadFromTUN() { if peer == nil { continue } + + elem := device.GetOutboundElement() + elem.buf = bufs[i] + elem.packet = packet + elemsForPeer, ok := elemsByPeer[peer] if !ok { elemsForPeer = device.GetOutboundElementsContainer() elemsByPeer[peer] = elemsForPeer } elemsForPeer.elems = append(elemsForPeer.elems, elem) - elems[i] = device.NewOutboundElement() - bufs[i] = elems[i].buffer[:] + bufs[i] = nil // consumed; next Read allocates fresh } + // Release unconsumed buffers so stale right-sized entries + // don't persist into the next Read call. + buffer.ReleaseAll(bufs[:count]) for peer, elemsForPeer := range elemsByPeer { if peer.isRunning.Load() { @@ -298,7 +295,6 @@ func (device *Device) RoutineReadFromTUN() { peer.SendStagedPackets() } else { for _, elem := range elemsForPeer.elems { - device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsForPeer) @@ -307,6 +303,10 @@ func (device *Device) RoutineReadFromTUN() { } if readErr != nil { + if errors.Is(readErr, buffer.ErrSizeExceedsCap) { + // This is expected if TUN uses [buffer.CappedSource]. Spin until some memory is freed. + continue + } if errors.Is(readErr, tun.ErrTooManySegments) { // TODO: record stat for this // This will happen if MSS is surprisingly small (< 576) @@ -335,7 +335,6 @@ func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { select { case tooOld := <-peer.queue.staged: for _, elem := range tooOld.elems { - peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(tooOld) @@ -396,7 +395,6 @@ top: peer.device.queue.encryption.c <- elemsContainer } else { for _, elem := range elemsContainer.elems { - peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(elemsContainer) @@ -416,7 +414,6 @@ func (peer *Peer) FlushStagedPackets() { select { case elemsContainer := <-peer.queue.staged: for _, elem := range elemsContainer.elems { - peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(elemsContainer) @@ -456,7 +453,7 @@ func (device *Device) RoutineEncryption(id int) { for elemsContainer := range device.queue.encryption.c { for _, elem := range elemsContainer.elems { // populate header fields - header := elem.buffer[MessageEncapsulatingTransportSize : MessageEncapsulatingTransportSize+MessageTransportHeaderSize] + header := elem.buf.Bytes()[MessageEncapsulatingTransportSize : MessageEncapsulatingTransportSize+MessageTransportHeaderSize] fieldType := header[0:4] fieldReceiver := header[4:8] @@ -481,7 +478,8 @@ func (device *Device) RoutineEncryption(id int) { ) // re-slice packet to include encapsulating transport space - elem.packet = elem.buffer[:MessageEncapsulatingTransportSize+len(elem.packet)] + elem.buf.SetLen(MessageEncapsulatingTransportSize + len(elem.packet)) + elem.packet = elem.buf.Bytes() } elemsContainer.Unlock() } @@ -495,10 +493,12 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { }() device.log.Verbosef("%v - Routine: sequential sender - started", peer) - bufs := make([][]byte, 0, maxBatchSize) + rawBufs := make([][]byte, 0, maxBatchSize) + releaseBufs := make([]*buffer.Buffer, 0, maxBatchSize) for elemsContainer := range peer.queue.outbound.c { - bufs = bufs[:0] + rawBufs = rawBufs[:0] + releaseBufs = releaseBufs[:0] if elemsContainer == nil { return } @@ -511,7 +511,6 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { // that we never accidentally keep timers alive longer than necessary. elemsContainer.Lock() for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) @@ -523,18 +522,20 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { if len(elem.packet[MessageEncapsulatingTransportSize:]) != MessageKeepaliveSize { dataSent = true } - bufs = append(bufs, elem.packet) + rawBufs = append(rawBufs, elem.packet) + releaseBufs = append(releaseBufs, elem.buf) + elem.buf = nil } peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err := peer.SendBuffers(bufs) + err := peer.SendBuffers(rawBufs) + buffer.ReleaseAll(releaseBufs) if dataSent { peer.timersDataSent() } for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) diff --git a/go.mod b/go.mod index 37b0b6daf..56db8f3c8 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/tailscale/wireguard-go -go 1.25 +go 1.26 require ( golang.org/x/crypto v0.13.0 diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index d8e70bb03..9a494fbc5 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -22,6 +22,7 @@ import ( "syscall" "time" + wgbuf "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/tun" "golang.org/x/net/dns/dnsmessage" @@ -119,17 +120,22 @@ func (tun *netTun) Events() <-chan tun.Event { return tun.events } -func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { +func (tun *netTun) Read(bufs []*wgbuf.Buffer, offset int) (int, error) { view, ok := <-tun.incomingPacket if !ok { return 0, os.ErrClosed } - n, err := view.Read(buf[0][offset:]) + var err error + bufs[0], err = wgbuf.Ensure(bufs[0], wgbuf.MaxMessageSize, nil) if err != nil { return 0, err } - sizes[0] = n + n, err := view.Read(bufs[0].Bytes()[offset:]) + if err != nil { + return 0, err + } + bufs[0].SetLen(offset + n) return 1, nil } diff --git a/tun/offload.go b/tun/offload.go index 6db437c34..454d16e88 100644 --- a/tun/offload.go +++ b/tun/offload.go @@ -3,6 +3,8 @@ package tun import ( "encoding/binary" "fmt" + + "github.com/tailscale/wireguard-go/buffer" ) // GSOType represents the type of segmentation offload. @@ -73,15 +75,14 @@ const ( ipProtoUDP = 17 ) -// GSOSplit splits packets from 'in' into outBufs[][outOffset:], writing -// the size of each element into sizes. It returns the number of buffers -// populated, and/or an error. Callers may pass an 'in' slice that overlaps with -// the first element of outBuffers, i.e. &in[0] may be equal to -// &outBufs[0][outOffset]. GSONone is a valid options.GSOType regardless of the -// value of options.NeedsCsum. Length of each outBufs element must be greater -// than or equal to the length of 'in', otherwise output may be silently -// truncated. -func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outOffset int) (int, error) { +// GSOSplit splits packets from 'in' into outBufs, allocating from src as +// needed. Each output buffer is sized to outOffset + segment length. +// It returns the number of buffers populated, and/or an error. +// Callers may pass an 'in' slice that overlaps with the first element +// of outBuffers, i.e. &in[0] may be equal to &outBufs[0].BytesAt(outOffset). +// GSONone is a valid options.GSOType regardless of the value of +// options.NeedsCsum. +func GSOSplit(in []byte, options GSOOptions, outBufs []*buffer.Buffer, src buffer.Source, outOffset int) (int, error) { cSumAt := int(options.CsumStart) + int(options.CsumOffset) if cSumAt+1 >= len(in) { return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) @@ -94,8 +95,11 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO // Handle the conditions where we are copying a single element to outBuffs. payloadLen := len(in) - int(options.HdrLen) if options.GSOType == GSONone || payloadLen < int(options.GSOSize) { - if len(in) > len(outBufs[0][outOffset:]) { - return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0][outOffset:])) + needed := outOffset + len(in) + var err error + outBufs[0], err = buffer.Ensure(outBufs[0], needed, src) + if err != nil { + return 0, err } if options.NeedsCsum { // The initial value at the checksum offset should be summed with @@ -104,7 +108,8 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO in[cSumAt], in[cSumAt+1] = 0, 0 binary.BigEndian.PutUint16(in[cSumAt:], ^Checksum(in[options.CsumStart:], initial)) } - sizes[0] = copy(outBufs[0][outOffset:], in) + n := copy(outBufs[0].Bytes()[outOffset:], in) + outBufs[0].SetLen(outOffset + n) return 1, nil } @@ -164,8 +169,12 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO } segmentDataLen := nextSegmentEnd - nextSegmentDataAt totalLen := int(options.HdrLen) + segmentDataLen - sizes[i] = totalLen - out := outBufs[i][outOffset:] + var err error + outBufs[i], err = buffer.Ensure(outBufs[i], outOffset+totalLen, src) + if err != nil { + return i, err + } + out := outBufs[i].Bytes()[outOffset:] copy(out, in[:iphLen]) if ipVersion == 4 { diff --git a/tun/offload_linux.go b/tun/offload_linux.go index fb6ac5b94..c7a06d401 100644 --- a/tun/offload_linux.go +++ b/tun/offload_linux.go @@ -13,6 +13,7 @@ import ( "io" "unsafe" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "golang.org/x/sys/unix" ) @@ -428,7 +429,18 @@ const ( // coalesceUDPPackets attempts to coalesce pkt with the packet described by // item, and returns the outcome. -func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { +func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool, bufSrc *buffer.LoggingSource) coalesceResult { + if head := bufs[item.bufsIndex]; cap(head) < buffer.MaxMessageSize { + b, err := bufSrc.Get(buffer.MaxMessageSize) + if err == nil { + // Upsize head buffer to max. Abandon buffer struct, bufSrc tracks it. + new := b.Bytes() + n := copy(new, head) + new = new[:n] + bufs[item.bufsIndex] = new + } + // If we fail to get a new buffer, attempt to coalesce whatever fits. + } pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front headersLen := item.iphLen + udphLen coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) @@ -458,7 +470,18 @@ func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset // item, and returns the outcome. This function may swap bufs elements in the // event of a prepend as item's bufs index is already being tracked for writing // to a Device. -func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool, bufSrc *buffer.LoggingSource) coalesceResult { + if head := bufs[item.bufsIndex]; cap(head) < buffer.MaxMessageSize { + b, err := bufSrc.Get(buffer.MaxMessageSize) + if err == nil { + // Upsize head buffer to max. Abandon buffer struct, bufSrc tracks it. + new := b.Bytes() + n := copy(new, head) + new = new[:n] + bufs[item.bufsIndex] = new + } + // If we fail to get a new buffer, attempt to coalesce whatever fits. + } var pktHead []byte // the packet that will end up at the front headersLen := item.iphLen + item.tcphLen coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) @@ -543,7 +566,7 @@ const ( // action was taken, groResultTableInsert when the evaluated packet was // inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool, bufSrc *buffer.LoggingSource) groResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. @@ -615,7 +638,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) item := items[i] can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) if can != coalesceUnavailable { - result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6, bufSrc) switch result { case coalesceSuccess: table.updateAt(item, i) @@ -798,7 +821,7 @@ const ( // action was taken, groResultTableInsert when the evaluated packet was // inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. -func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { +func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool, bufSrc *buffer.LoggingSource) groResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. @@ -851,7 +874,7 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) var pktCSumKnownInvalid bool if can == coalesceAppend { - result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) + result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6, bufSrc) switch result { case coalesceSuccess: table.updateAt(item, len(items)-1) @@ -877,7 +900,7 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) // empty (but non-nil), and are passed in to save allocs as the caller may reset // and recycle them across vectors of packets. gro indicates if TCP and UDP GRO // are supported/enabled. -func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int) error { +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int, bufSrc *buffer.LoggingSource) error { for i := range bufs { if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") @@ -885,13 +908,13 @@ func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGR var result groResult switch packetIsGROCandidate(bufs[i][offset:], gro) { case tcp4GROCandidate: - result = tcpGRO(bufs, offset, i, tcpTable, false) + result = tcpGRO(bufs, offset, i, tcpTable, false, bufSrc) case tcp6GROCandidate: - result = tcpGRO(bufs, offset, i, tcpTable, true) + result = tcpGRO(bufs, offset, i, tcpTable, true, bufSrc) case udp4GROCandidate: - result = udpGRO(bufs, offset, i, udpTable, false) + result = udpGRO(bufs, offset, i, udpTable, false, bufSrc) case udp6GROCandidate: - result = udpGRO(bufs, offset, i, udpTable, true) + result = udpGRO(bufs, offset, i, udpTable, true, bufSrc) } switch result { case groResultNoop: diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go index 407037863..9dce3bb45 100644 --- a/tun/offload_linux_test.go +++ b/tun/offload_linux_test.go @@ -9,6 +9,7 @@ import ( "net/netip" "testing" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" @@ -234,13 +235,9 @@ func Test_handleVirtioRead(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - out := make([][]byte, conn.IdealBatchSize) - sizes := make([]int, conn.IdealBatchSize) - for i := range out { - out[i] = make([]byte, 65535) - } + out := make([]*buffer.Buffer, conn.IdealBatchSize) tt.hdr.encode(tt.pktIn) - n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) + n, err := handleVirtioRead(tt.pktIn, out, buffer.DefaultSource, offset) if err != nil { if tt.wantErr { return @@ -251,8 +248,8 @@ func Test_handleVirtioRead(t *testing.T) { t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) } for i := range tt.wantLens { - if tt.wantLens[i] != sizes[i] { - t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) + if tt.wantLens[i] != len(out[i].Bytes())-offset { + t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], len(out[i].Bytes())-offset) } } }) @@ -290,7 +287,9 @@ func Fuzz_handleGRO(f *testing.F) { f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, gro int, offset int) { pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} toWrite := make([]int, 0, len(pkts)) - handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), groDisablementFlags(gro), &toWrite) + bufSrc := &buffer.LoggingSource{Source: buffer.DefaultSource} + defer bufSrc.ReleaseAll() + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), groDisablementFlags(gro), &toWrite, bufSrc) if len(toWrite) > len(pkts) { t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) } @@ -507,7 +506,9 @@ func Test_handleGRO(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { toWrite := make([]int, 0, len(tt.pktsIn)) - err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.gro, &toWrite) + bufSrc := &buffer.LoggingSource{Source: buffer.DefaultSource} + defer bufSrc.ReleaseAll() + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.gro, &toWrite, bufSrc) if err != nil { if tt.wantErr { return diff --git a/tun/offload_test.go b/tun/offload_test.go index 82a37b9cc..13e201668 100644 --- a/tun/offload_test.go +++ b/tun/offload_test.go @@ -4,6 +4,7 @@ import ( "net/netip" "testing" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -67,11 +68,7 @@ func Fuzz_GSOSplit(f *testing.F) { }) header.UDP(gsoUDPv6[20:]).Encode(udpFields) - out := make([][]byte, conn.IdealBatchSize) - for i := range out { - out[i] = make([]byte, 65535) - } - sizes := make([]int, conn.IdealBatchSize) + out := make([]*buffer.Buffer, conn.IdealBatchSize) f.Add(gsoTCPv4, int(GSOTCPv4), uint16(40), uint16(20), uint16(16), uint16(100), false) f.Add(gsoUDPv4, int(GSOUDPL4), uint16(28), uint16(20), uint16(6), uint16(100), false) @@ -87,9 +84,10 @@ func Fuzz_GSOSplit(f *testing.F) { GSOSize: gsoSize, NeedsCsum: needsCsum, } - n, _ := GSOSplit(pkt, options, out, sizes, 0) - if n > len(sizes) { - t.Errorf("n (%d) > len(sizes): %d", n, len(sizes)) + n, _ := GSOSplit(pkt, options, out, buffer.DefaultSource, 0) + if n > len(out) { + t.Errorf("n (%d) > len(out): %d", n, len(out)) } + buffer.ReleaseAll(out) }) } diff --git a/tun/tun.go b/tun/tun.go index 719a60631..72c7edb72 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -7,6 +7,8 @@ package tun import ( "os" + + "github.com/tailscale/wireguard-go/buffer" ) type Event int @@ -23,10 +25,11 @@ type Device interface { // Read one or more packets from the Device (without any additional headers). // On a successful read it returns the number of packets read, and sets - // packet lengths within the sizes slice. len(sizes) must be >= len(bufs). + // each buf's length to include the read data. + // Nil entries in bufs are allocated by the implementation. // A nonzero offset can be used to instruct the Device on where to begin // reading into each element of the bufs slice. - Read(bufs [][]byte, sizes []int, offset int) (n int, err error) + Read(bufs []*buffer.Buffer, offset int) (n int, err error) // Write one or more packets to the device (without any additional headers). // On a successful write it returns the number of packets written. A nonzero diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index c9a6c0bc4..89133a900 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -16,6 +16,7 @@ import ( "time" "unsafe" + "github.com/tailscale/wireguard-go/buffer" "golang.org/x/sys/unix" ) @@ -217,7 +218,7 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { // TODO: the BSDs look very similar in Read() and Write(). They should be // collapsed, with platform-specific files containing the varying parts of // their implementations. @@ -225,12 +226,16 @@ func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) case err := <-tun.errors: return 0, err default: - buf := bufs[0][offset-4:] + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err + } + buf := bufs[0].Bytes()[offset-4:] n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - sizes[0] = n - 4 + bufs[0].SetLen(offset - 4 + n) return 1, err } } diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go index 7c65fd999..7f68743e0 100644 --- a/tun/tun_freebsd.go +++ b/tun/tun_freebsd.go @@ -15,6 +15,7 @@ import ( "syscall" "unsafe" + "github.com/tailscale/wireguard-go/buffer" "golang.org/x/sys/unix" ) @@ -333,17 +334,21 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { select { case err := <-tun.errors: return 0, err default: - buf := bufs[0][offset-4:] + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err + } + buf := bufs[0].Bytes()[offset-4:] n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - sizes[0] = n - 4 + bufs[0].SetLen(offset - 4 + n) return 1, err } } diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 7cdbf8825..ff6f947dc 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -17,6 +17,7 @@ import ( "time" "unsafe" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/rwcancel" "golang.org/x/sys/unix" @@ -53,6 +54,9 @@ type NativeTun struct { tcpGROTable *tcpGROTable udpGROTable *udpGROTable gro groDisablementFlags + + bufferSource buffer.Source + coalescedBufs buffer.LoggingSource } type groDisablementFlags int @@ -354,6 +358,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { defer func() { tun.tcpGROTable.reset() tun.udpGROTable.reset() + tun.coalescedBufs.ReleaseAll() tun.writeOpMu.Unlock() }() var ( @@ -362,7 +367,7 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { ) tun.toWrite = tun.toWrite[:0] if tun.vnetHdr { - err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.gro, &tun.toWrite) + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.gro, &tun.toWrite, &tun.coalescedBufs) if err != nil { return 0, err } @@ -387,9 +392,8 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { } // handleVirtioRead splits in into bufs, leaving offset bytes at the front of -// each buffer. It mutates sizes to reflect the size of each element of bufs, -// and returns the number of packets read. -func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { +// each buffer, and returns the number of packets read. +func handleVirtioRead(in []byte, bufs []*buffer.Buffer, src buffer.Source, offset int) (int, error) { var hdr virtioNetHdr err := hdr.decode(in) if err != nil { @@ -421,19 +425,25 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e options.HdrLen = options.CsumStart + tcpHLen } - return GSOSplit(in, options, bufs, sizes, offset) + return GSOSplit(in, options, bufs, src, offset) } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { tun.readOpMu.Lock() defer tun.readOpMu.Unlock() select { case err := <-tun.errors: return 0, err default: - readInto := bufs[0][offset:] + var readInto []byte if tun.vnetHdr { readInto = tun.readBuff[:] + } else { + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err + } + readInto = bufs[0].Bytes()[offset:] } n, err := tun.tunFile.Read(readInto) if errors.Is(err, syscall.EBADFD) { @@ -443,9 +453,9 @@ func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) return 0, err } if tun.vnetHdr { - return handleVirtioRead(readInto[:n], bufs, sizes, offset) + return handleVirtioRead(readInto[:n], bufs, tun.bufferSource, offset) } else { - sizes[0] = n + bufs[0].SetLen(offset + n) return 1, nil } } @@ -585,7 +595,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tcpGROTable: newTCPGROTable(), udpGROTable: newUDPGROTable(), toWrite: make([]int, 0, conn.IdealBatchSize), - } + bufferSource: buffer.DefaultSource, + coalescedBufs: buffer.LoggingSource{Source: buffer.DefaultSource}} name, err := tun.Name() if err != nil { @@ -635,12 +646,14 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - tcpGROTable: newTCPGROTable(), - udpGROTable: newUDPGROTable(), - toWrite: make([]int, 0, conn.IdealBatchSize), + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), + bufferSource: buffer.DefaultSource, + coalescedBufs: buffer.LoggingSource{Source: buffer.DefaultSource}, } name, err := tun.Name() if err != nil { diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go index ae571b90c..dfcd5247f 100644 --- a/tun/tun_openbsd.go +++ b/tun/tun_openbsd.go @@ -15,6 +15,7 @@ import ( "syscall" "unsafe" + "github.com/tailscale/wireguard-go/buffer" "golang.org/x/sys/unix" ) @@ -204,17 +205,21 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { select { case err := <-tun.errors: return 0, err default: - buf := bufs[0][offset-4:] + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err + } + buf := bufs[0].Bytes()[offset-4:] n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - sizes[0] = n - 4 + bufs[0].SetLen(offset - 4 + n) return 1, err } } diff --git a/tun/tun_plan9.go b/tun/tun_plan9.go index 7b66eadf6..fc821ff1f 100644 --- a/tun/tun_plan9.go +++ b/tun/tun_plan9.go @@ -12,6 +12,8 @@ import ( "strconv" "strings" "sync" + + "github.com/tailscale/wireguard-go/buffer" ) type NativeTun struct { @@ -81,18 +83,23 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { select { case err := <-tun.errors: return 0, err default: - n, err := tun.dataFile.Read(bufs[0][offset:]) - if n == 1 && bufs[0][offset] == 0 { + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err + } + data := bufs[0].Bytes() + n, err := tun.dataFile.Read(data[offset:]) + if n == 1 && data[offset] == 0 { // EOF err = io.EOF n = 0 } - sizes[0] = n + bufs[0].SetLen(offset + n) return 1, err } } diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 34f29805d..5d4a29e34 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -14,6 +14,7 @@ import ( "time" _ "unsafe" + "github.com/tailscale/wireguard-go/buffer" "golang.org/x/sys/windows" "golang.zx2c4.com/wintun" ) @@ -144,7 +145,7 @@ func (tun *NativeTun) BatchSize() int { // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []*buffer.Buffer, offset int) (n int, err error) { tun.running.Add(1) defer tun.running.Done() retry: @@ -161,8 +162,12 @@ retry: switch err { case nil: packetSize := len(packet) - copy(bufs[0][offset:], packet) - sizes[0] = packetSize + bufs[0], err = buffer.Ensure(bufs[0], offset+packetSize, nil) + if err != nil { + return 0, err + } + copy(bufs[0].Bytes()[offset:], packet) + bufs[0].SetLen(offset + packetSize) tun.session.ReleaseReceivePacket(packet) tun.rate.update(uint64(packetSize)) return 1, nil diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index e7507c26c..e56fbd529 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -11,6 +11,7 @@ import ( "net/netip" "os" + "github.com/tailscale/wireguard-go/buffer" "github.com/tailscale/wireguard-go/tun" ) @@ -110,13 +111,18 @@ type chTun struct { func (t *chTun) File() *os.File { return nil } -func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) { +func (t *chTun) Read(bufs []*buffer.Buffer, offset int) (int, error) { select { case <-t.c.closed: return 0, os.ErrClosed case msg := <-t.c.Outbound: - n := copy(packets[0][offset:], msg) - sizes[0] = n + var err error + bufs[0], err = buffer.Ensure(bufs[0], buffer.MaxMessageSize, nil) + if err != nil { + return 0, err + } + n := copy(bufs[0].Bytes()[offset:], msg) + bufs[0].SetLen(offset + n) return 1, nil } }