From 7edb4969c040de5d909d5c585c886a514ce9fc3e Mon Sep 17 00:00:00 2001 From: Thomas Vilte Date: Tue, 9 Jun 2026 18:52:15 -0300 Subject: [PATCH 1/2] Add stereo mid/side encoding --- encoder.go | 41 +++- encoder_test.go | 134 ++++++++++++- internal/celt/analysis.go | 75 ++++--- internal/celt/encode_bands.go | 360 ++++++++++++++++++++++++++++++++++ internal/celt/encoder.go | 39 ++-- internal/celt/encoder_test.go | 35 +++- 6 files changed, 615 insertions(+), 69 deletions(-) diff --git a/encoder.go b/encoder.go index 9a2571d..7256f1f 100644 --- a/encoder.go +++ b/encoder.go @@ -52,11 +52,10 @@ func WithSampleRate(rate int) EncoderOption { } } -// WithChannels sets the channel count. The current encoder only supports -// mono (1 channel); stereo is planned in a follow-up PR. +// WithChannels sets the channel count (1 for mono, 2 for stereo). func WithChannels(channels int) EncoderOption { return func(e *Encoder) error { - if channels != 1 { + if channels < 1 || channels > 2 { return errInvalidChannelCount } e.channels = channels @@ -95,8 +94,8 @@ func WithComplexity(complexity int) EncoderOption { // NewEncoder creates a new Opus encoder with the supplied options. // // Defaults: 48 kHz, mono, 24 kbit/s, complexity 0. Pass options to override -// any of these. The current API surface only supports 48 kHz mono 20 ms -// CELT-only packets; stereo, transient detection, and SILK encoding will land +// any of these. The current implementation supports 48 kHz, 1 or 2 channels, +// 20 ms CELT-only packets. Transient detection and SILK encoding will land // in follow-up PRs. func NewEncoder(opts ...EncoderOption) (*Encoder, error) { encoder := &Encoder{ @@ -151,20 +150,19 @@ func (e *Encoder) Encode(in []byte, out []byte) (int, error) { // EncodeFloat32 encodes float PCM into a single Opus packet. // -// The input must contain exactly one 20 ms mono 48 kHz frame. +// The input must contain one 20 ms 48 kHz frame. func (e *Encoder) EncodeFloat32(in []float32, out []byte) (int, error) { if e.sampleRate != celtSampleRate { return 0, errInvalidSampleRate } - if e.channels != 1 { - return 0, errInvalidChannelCount - } frameSamples := e.frameSampleCount() if len(in) != frameSamples*e.channels { return 0, fmt.Errorf("%w: got %d samples, want %d", errInvalidFrameSize, len(in), frameSamples*e.channels) } + channels := splitChannels(in, e.channels, frameSamples) + frameBytes := e.frameBytes() if frameBytes <= 0 || frameBytes > maxOpusFrameSize { return 0, fmt.Errorf("%w: %d", errInvalidFrameByteBudget, frameBytes) @@ -173,7 +171,7 @@ func (e *Encoder) EncodeFloat32(in []float32, out []byte) (int, error) { return 0, errOutBufferTooSmall } - payload, err := e.celtEncoder.EncodeFrame(in, frameBytes, 0, e.celtEncoder.Mode().BandCount()) + payload, err := e.celtEncoder.EncodeFrame(channels, frameBytes, 0, e.celtEncoder.Mode().BandCount()) if err != nil { return 0, err } @@ -193,10 +191,33 @@ func (e *Encoder) EncodeFloat32(in []float32, out []byte) (int, error) { func (e *Encoder) tocHeader() tableOfContentsHeader { header := byte(celtOnlyFullband20msConfig << 3) header |= byte(frameCodeOneFrame) + if e.channels == 2 { + header |= 1 << 2 + } return tableOfContentsHeader(header) } +// splitChannels splits interleaved PCM into per-channel slices. +// For mono, it returns the input directly without allocation. +func splitChannels(in []float32, numChannels, frameSamples int) [][]float32 { + ch := make([][]float32, numChannels) + if numChannels == 1 { + ch[0] = in + + return ch + } + + for c := range numChannels { + ch[c] = make([]float32, frameSamples) + for i := range frameSamples { + ch[c][i] = in[i*numChannels+c] + } + } + + return ch +} + func (e *Encoder) frameBytes() int { return int(int64(e.bitrate) * frame20msNS / 1000000000 / 8) } diff --git a/encoder_test.go b/encoder_test.go index cdf4f00..a08b443 100644 --- a/encoder_test.go +++ b/encoder_test.go @@ -25,8 +25,9 @@ func TestNewEncoder(t *testing.T) { _, err = NewEncoder(WithSampleRate(16000)) assert.ErrorIs(t, err, errInvalidSampleRate) - _, err = NewEncoder(WithChannels(2)) - assert.ErrorIs(t, err, errInvalidChannelCount) + encoder, err = NewEncoder(WithChannels(2)) + require.NoError(t, err) + assert.Equal(t, 2, encoder.channels) } func TestNewEncoderOptions(t *testing.T) { @@ -101,6 +102,98 @@ func TestEncodeS16LERoundTrip(t *testing.T) { assert.Greater(t, vectorEnergyFloat32(out), 1e-6) } +func TestEncodeFloat32StereoRoundTrip(t *testing.T) { + encoder, err := NewEncoder(WithChannels(2)) + require.NoError(t, err) + + decoder, err := NewDecoderWithOutput(48000, 2) + require.NoError(t, err) + + pcm := testEncoderStereoSineFloat32() + packet := make([]byte, 256) + + n, err := encoder.EncodeFloat32(pcm, packet) + require.NoError(t, err) + require.Positive(t, n) + + assert.Equal(t, byte(celtOnlyFullband20msConfig<<3)|byte(frameCodeOneFrame)|(1<<2), packet[0]) + + out := make([]float32, encoderTestFrameSampleCount*2) + bandwidth, isStereo, err := decoder.DecodeFloat32(packet[:n], out) + require.NoError(t, err) + + assert.Equal(t, BandwidthFullband, bandwidth) + assert.True(t, isStereo) + assert.Greater(t, vectorEnergyFloat32(out), 1e-6) + + L := make([]float32, encoderTestFrameSampleCount) + R := make([]float32, encoderTestFrameSampleCount) + for i := range encoderTestFrameSampleCount { + L[i] = out[i*2] + R[i] = out[i*2+1] + } + L440 := freqEnergy(L, 440) + L660 := freqEnergy(L, 660) + R440 := freqEnergy(R, 440) + R660 := freqEnergy(R, 660) + assert.Greater(t, L440, L660*1.5, "L channel: 440 Hz should dominate over 660 Hz") + assert.Greater(t, R660, R440*1.5, "R channel: 660 Hz should dominate over 440 Hz") +} + +func TestEncodeS16LEStereoRoundTrip(t *testing.T) { + encoder, err := NewEncoder(WithChannels(2)) + require.NoError(t, err) + + decoder, err := NewDecoderWithOutput(48000, 2) + require.NoError(t, err) + + pcm := testEncoderStereoSineS16LE() + packet := make([]byte, 256) + + n, err := encoder.Encode(pcm, packet) + require.NoError(t, err) + require.Positive(t, n) + + out := make([]float32, encoderTestFrameSampleCount*2) + _, isStereo, err := decoder.DecodeFloat32(packet[:n], out) + require.NoError(t, err) + + assert.True(t, isStereo) + assert.Greater(t, vectorEnergyFloat32(out), 1e-6) +} + +func TestStereoMultiFramePersistence(t *testing.T) { + encoder, err := NewEncoder(WithChannels(2)) + require.NoError(t, err) + + decoder, err := NewDecoderWithOutput(48000, 2) + require.NoError(t, err) + + pcm := testEncoderStereoSineFloat32() + packet := make([]byte, 256) + out := make([]float32, encoderTestFrameSampleCount*2) + + const frames = 10 + energies := make([]float64, frames) + for i := range frames { + n, encErr := encoder.EncodeFloat32(pcm, packet) + require.NoError(t, encErr, "frame %d encode failed", i) + require.Positive(t, n) + + _, _, decErr := decoder.DecodeFloat32(packet[:n], out) + require.NoError(t, decErr, "frame %d decode failed", i) + + energies[i] = vectorEnergyFloat32(out) + assert.Greater(t, energies[i], 1e-6, "frame %d should have non-zero energy", i) + } + + for i := 1; i < frames; i++ { + ratio := energies[i] / energies[0] + assert.InDelta(t, 1.0, ratio, 0.75, + "frame %d energy ratio %.3f deviates too far from frame 0", i, ratio) + } +} + func TestEncodeRejectsInvalidS16LEInputLength(t *testing.T) { encoder, err := NewEncoder() require.NoError(t, err) @@ -159,6 +252,18 @@ func testEncoderSineFloat32() []float32 { return pcm } +func testEncoderStereoSineFloat32() []float32 { + pcm := make([]float32, encoderTestFrameSampleCount*2) + for i := range encoderTestFrameSampleCount { + left := float32(math.Sin(2 * math.Pi * 440 * float64(i) / 48000)) + right := float32(math.Sin(2 * math.Pi * 660 * float64(i) / 48000)) + pcm[i*2] = left + pcm[i*2+1] = right + } + + return pcm +} + func testEncoderSineS16LE() []byte { pcm := make([]byte, encoderTestFrameSampleCount*2) for i := range encoderTestFrameSampleCount { @@ -171,6 +276,18 @@ func testEncoderSineS16LE() []byte { return pcm } +func testEncoderStereoSineS16LE() []byte { + pcm := make([]byte, encoderTestFrameSampleCount*4) // 2 channels × 2 bytes each + for i := range encoderTestFrameSampleCount { + left := int16(math.Round(math.Sin(2*math.Pi*440*float64(i)/48000) * 16000)) + right := int16(math.Round(math.Sin(2*math.Pi*660*float64(i)/48000) * 16000)) + binary.LittleEndian.PutUint16(pcm[i*4:], uint16(left)) //nolint:gosec // G115 + binary.LittleEndian.PutUint16(pcm[i*4+2:], uint16(right)) //nolint:gosec // G115 + } + + return pcm +} + func vectorEnergyFloat32(x []float32) float64 { var e float64 for _, v := range x { @@ -179,3 +296,16 @@ func vectorEnergyFloat32(x []float32) float64 { return math.Sqrt(e) } + +// freqEnergy returns the DFT magnitude at freq Hz over a 48 kHz signal. +// It is phase-invariant so it survives the CELT analysis/synthesis delay. +func freqEnergy(samples []float32, freq float64) float64 { + var re, im float64 + for i, s := range samples { + angle := 2 * math.Pi * freq * float64(i) / 48000 + re += float64(s) * math.Cos(angle) + im += float64(s) * math.Sin(angle) + } + + return math.Sqrt(re*re+im*im) / float64(len(samples)) +} diff --git a/internal/celt/analysis.go b/internal/celt/analysis.go index e2269d4..5f371b4 100644 --- a/internal/celt/analysis.go +++ b/internal/celt/analysis.go @@ -10,68 +10,63 @@ import ( const preemphasisCoefficient = 0.85000610 type analysisState struct { - prevPCM []float32 - preemphasisMem float32 + prevPCM [2][]float32 + preemphasisMem [2]float32 } type analysisResult struct { - info frameSideInfo - preemphasized []float32 - mdct []float32 - logBandAmp [maxBands]float32 + info frameSideInfo + mdct [2][]float32 + logBandAmp [2][maxBands]float32 } func newAnalysisState() analysisState { return analysisState{ - prevPCM: make([]float32, shortBlockSampleCount), + prevPCM: [2][]float32{ + make([]float32, shortBlockSampleCount), + make([]float32, shortBlockSampleCount), + }, } } -// analyzeFrame prepares the mono CELT encoder input: applies pre-emphasis, -// extends the frame with the previous overlap for the MDCT window, runs the -// forward MDCT, and returns per-band log amplitude for coarse energy coding. -func analyzeFrame( - mode *Mode, - frame []float32, - startBand int, - endBand int, - state *analysisState, -) (analysisResult, error) { - lm, err := mode.LMForFrameSampleCount(len(frame)) +// analyzeFrame applies pre-emphasis, builds the MDCT overlap window, runs the +// forward MDCT, and returns per-band log amplitude for each input channel. +func analyzeFrame(mode *Mode, pcm [][]float32, startBand, endBand int, state *analysisState) (analysisResult, error) { + lm, err := mode.LMForFrameSampleCount(len(pcm[0])) if err != nil { return analysisResult{}, err } - result := analysisResult{ + res := analysisResult{ info: frameSideInfo{ - lm: lm, - startBand: startBand, - endBand: endBand, - channelCount: 1, - transient: false, - shortBlockCount: 0, - intraEnergy: false, - spread: defaultSpreadDecision, - allocationTrim: defaultAllocationTrim, + lm: lm, + startBand: startBand, + endBand: endBand, + channelCount: len(pcm), + transient: false, + spread: defaultSpreadDecision, + allocationTrim: defaultAllocationTrim, }, - preemphasized: make([]float32, len(frame)), } - applyPreemphasis(frame, result.preemphasized, &state.preemphasisMem) + for ch := range pcm { + pre := make([]float32, len(pcm[ch])) + applyPreemphasis(pcm[ch], pre, &state.preemphasisMem[ch]) - mdctInput := make([]float32, shortBlockSampleCount+len(frame)) - copy(mdctInput, state.prevPCM) - copy(mdctInput[shortBlockSampleCount:], result.preemphasized) + mdctInput := make([]float32, shortBlockSampleCount+len(pre)) + copy(mdctInput, state.prevPCM[ch]) + copy(mdctInput[shortBlockSampleCount:], pre) - result.mdct = forwardMDCT(mdctInput) - if result.mdct == nil { - return analysisResult{}, errInvalidFrameSize - } + res.mdct[ch] = forwardMDCT(mdctInput) + if res.mdct[ch] == nil { + return analysisResult{}, errInvalidFrameSize + } - result.logBandAmp = computeBandLogAmp(result.mdct, lm, startBand, endBand) - copy(state.prevPCM, result.preemphasized[len(result.preemphasized)-shortBlockSampleCount:]) + res.logBandAmp[ch] = computeBandLogAmp(res.mdct[ch], lm, startBand, endBand) + copy(state.prevPCM[ch], pre[len(pre)-shortBlockSampleCount:]) + } - return result, nil + return res, nil } func applyPreemphasis(in []float32, out []float32, mem *float32) { diff --git a/internal/celt/encode_bands.go b/internal/celt/encode_bands.go index e191e71..e8c62a2 100644 --- a/internal/celt/encode_bands.go +++ b/internal/celt/encode_bands.go @@ -483,3 +483,363 @@ func encodeBandThetaMono(symbol int, qn int, blocks int, rangeEncoder *rangecodi low := int(total) - (freq * (freq + 1) >> 1) rangeEncoder.EncodeCumulative(uint32(low), uint32(low+freq), total) } + +func quantBandStereo( + band int, + x []float32, + y []float32, + n int, + bandBits int, + spread int, + blocks int, + intensity int, + tfChange int, + lowband []float32, + remainingBits *int, + lm int, + gain float32, + lowbandScratch []float32, + fill uint, + state *bandEncodeState, +) uint { + if n == 1 { + xSign := uint32(0) + if x[0] < 0 { + xSign = 1 + } + if *remainingBits >= 1<= 1<>1)-thetaOffset, pulseCap, true) + if band >= intensity { + qn = 1 + } + + tell := int(state.rangeEncoder.TellFrac()) + thetaSym := 0 + itheta := 0 + invert := false + if qn != 1 { + thetaSym = quantizeStereoBandTheta(x, y, qn) + encodeBandTheta(thetaSym, qn, n, true, blocks, state.rangeEncoder) + itheta = thetaSym * 16384 / qn + } else if bandBits > 2< 2<= sideBits { + collapseMask := quantBandMono( + band, x, n, midBits, spread, blocks, tfChange, + lowband, remainingBits, lm, nil, 0, 1, lowbandScratch, fill, state, + ) + rebalance = midBits - (rebalance - *remainingBits) + if rebalance > 3<>blocks, state, + ) + if n != 2 { + stereoMerge(x, y, mid, n) + } + if invert { + for i := range n { + y[i] = -y[i] + } + } + + return collapseMask + } + + collapseMask := quantBandMono( + band, y, n, sideBits, spread, blocks, tfChange, + nil, remainingBits, lm, nil, 0, gain*side, nil, originalFill>>blocks, state, + ) + rebalance = sideBits - (rebalance - *remainingBits) + if rebalance > 3< 2: + encodeBandThetaStereoLarge(symbol, qn, rangeEncoder) + case blocks > 1 || stereo: + rangeEncoder.EncodeUniform(uint32(qn+1), uint32(symbol)) + default: + encodeBandThetaMono(symbol, qn, blocks, rangeEncoder) + } +} + +func encodeBandThetaStereoLarge(symbol int, qn int, rangeEncoder *rangecoding.Encoder) { + p0 := uint32(3) + x0 := uint32(qn / 2) + total := p0*(x0+1) + x0 + var low, high uint32 + if uint32(symbol) <= x0 { + low = p0 * uint32(symbol) + high = p0 * (uint32(symbol) + 1) + } else { + low = (x0+1)*p0 + uint32(symbol) - x0 - 1 + high = low + 1 + } + rangeEncoder.EncodeCumulative(low, high, total) +} + +func quantizeStereoBandTheta(x []float32, y []float32, qn int) int { + if qn <= 1 { + return 0 + } + + var ex, ey float64 + for i := range x { + ex += float64(x[i] * x[i]) + ey += float64(y[i] * y[i]) + } + if ex+ey <= 1e-30 { + return 0 + } + + theta := math.Atan2(math.Sqrt(ey), math.Sqrt(ex)) + symbol := int(math.Round(theta * float64(qn) / (0.5 * math.Pi))) + + return min(qn, max(0, symbol)) +} + +func quantAllBandsStereo( + info *frameSideInfo, + x []float32, + y []float32, + totalBits int, + state *bandEncodeState, +) []byte { + channelCount := 2 + blocks := 1 + if info.transient { + blocks = 1 << info.lm + } + scale := 1 << info.lm + frameBins := scale * int(bandEdges[maxBands]) + norm := make([]float32, channelCount*frameBins) + norm2 := norm[frameBins:] + lowbandScratch := make([]float32, scale*int(bandEdges[maxBands]-bandEdges[maxBands-1])) + collapseMasks := make([]byte, channelCount*maxBands) + + lowbandOffset := 0 + updateLowband := true + balance := info.allocation.balance + dualStereo := channelCount == 2 && info.allocation.dualStereo != 0 + for band := info.startBand; band < info.endBand; band++ { + tell := int(state.rangeEncoder.TellFrac()) + if band != info.startBand { + balance -= tell + } + remainingBits := totalBits - tell - 1 + bandBits := 0 + if band <= info.allocation.codedBands-1 { + currentBalance := balance / min(3, info.allocation.codedBands-band) + bandBits = max(0, min(16383, min(remainingBits+1, info.allocation.pulses[band]+currentBalance))) + } + + bandStart := scale * int(bandEdges[band]) + bandEnd := scale * int(bandEdges[band+1]) + bandWidth := bandEnd - bandStart + if bandStart-bandWidth >= scale*int(bandEdges[info.startBand]) || band == info.startBand+1 { + if updateLowband || lowbandOffset == 0 { + lowbandOffset = band + } + } + if band == info.startBand+1 && info.startBand+2 <= maxBands { + n1 := scale * int(bandEdges[info.startBand+1]-bandEdges[info.startBand]) + n2 := scale * int(bandEdges[info.startBand+2]-bandEdges[info.startBand+1]) + offset := scale * int(bandEdges[info.startBand]) + if n2 > n1 { + copy(norm[offset+n1:offset+n2], norm[offset+2*n1-n2:offset+n1]) + copy(norm2[offset+n1:offset+n2], norm2[offset+2*n1-n2:offset+n1]) + } + } + + effectiveLowband := -1 + xMask := uint(0) + yMask := uint(0) + if lowbandOffset != 0 && (info.spread != spreadAggressive || blocks > 1 || info.tfChange[band] < 0) { + effectiveLowband = max(scale*int(bandEdges[info.startBand]), scale*int(bandEdges[lowbandOffset])-bandWidth) + foldStart := lowbandOffset + for { + foldStart-- + if scale*int(bandEdges[foldStart]) <= effectiveLowband { + break + } + } + foldEnd := lowbandOffset - 1 + for { + foldEnd++ + if foldEnd >= band || scale*int(bandEdges[foldEnd]) >= effectiveLowband+bandWidth { + break + } + } + for fold := foldStart; fold < foldEnd; fold++ { + xMask |= uint(collapseMasks[fold*channelCount]) + yMask |= uint(collapseMasks[fold*channelCount+channelCount-1]) + } + } else { + xMask = (1 << blocks) - 1 + yMask = xMask + } + + if dualStereo && band == info.allocation.intensity { + dualStereo = false + for i := scale * int(bandEdges[info.startBand]); i < bandStart; i++ { + norm[i] = 0.5 * (norm[i] + norm2[i]) + } + } + + var lowband []float32 + if effectiveLowband >= 0 { + lowband = norm[effectiveLowband:] + } + if dualStereo { + xMask = quantBandMono( + band, + x[bandStart:bandEnd], + bandWidth, + bandBits/2, + info.spread, + blocks, + info.tfChange[band], + lowband, + &remainingBits, + info.lm, + norm[bandStart:], + 0, + 1, + lowbandScratch, + xMask, + state, + ) + var lowbandY []float32 + if effectiveLowband >= 0 { + lowbandY = norm2[effectiveLowband:] + } + yMask = quantBandMono( + band, + y[bandStart:bandEnd], + bandWidth, + bandBits/2, + info.spread, + blocks, + info.tfChange[band], + lowbandY, + &remainingBits, + info.lm, + norm2[bandStart:], + 0, + 1, + lowbandScratch, + yMask, + state, + ) + } else { + xMask = quantBandStereo( + band, + x[bandStart:bandEnd], + y[bandStart:bandEnd], + bandWidth, + bandBits, + info.spread, + blocks, + info.allocation.intensity, + info.tfChange[band], + lowband, + &remainingBits, + info.lm, + 1, + lowbandScratch, + xMask|yMask, + state, + ) + yMask = xMask + } + + copy(norm[bandStart:bandEnd], x[bandStart:bandEnd]) + copy(norm2[bandStart:bandEnd], y[bandStart:bandEnd]) + collapseMasks[band*channelCount] = byte(xMask) + collapseMasks[band*channelCount+channelCount-1] = byte(yMask) + balance += info.allocation.pulses[band] + tell + updateLowband = bandBits > bandWidth<= e.mode.BandCount() { return nil, errInvalidBand @@ -229,7 +237,9 @@ func (e *Encoder) EncodeFrame( e.encodeIntraEnergyFlag(&info) var targetLogE [2][maxBands]float32 - targetLogE[0] = analysis.logBandAmp + for ch := range info.channelCount { + targetLogE[ch] = analysis.logBandAmp[ch] + } e.encodeCoarseEnergy(&info, targetLogE) e.encodeTimeFrequencyChanges(&info) @@ -245,12 +255,17 @@ func (e *Encoder) EncodeFrame( e.encodeFineEnergy(&info, info.allocation.fineQuant, targetLogE) totalBits := (int(info.totalBits) << bitResolution) - info.antiCollapseRsv - shape := normaliseBandsForEncoding(&info, analysis.mdct, analysis.logBandAmp) bandState := bandEncodeState{ rangeEncoder: &e.rangeEncoder, seed: e.rng, } - _ = quantAllBandsMono(&info, shape, totalBits, &bandState) + shape0 := normaliseBandsForEncoding(&info, analysis.mdct[0], analysis.logBandAmp[0]) + if info.channelCount == 2 { + shape1 := normaliseBandsForEncoding(&info, analysis.mdct[1], analysis.logBandAmp[1]) + _ = quantAllBandsStereo(&info, shape0, shape1, totalBits, &bandState) + } else { + _ = quantAllBandsMono(&info, shape0, totalBits, &bandState) + } bitsLeft := int(info.totalBits) - int(e.rangeEncoder.Tell()) e.finalizeFineEnergy(&info, info.allocation.fineQuant, info.allocation.finePriority, targetLogE, bitsLeft) diff --git a/internal/celt/encoder_test.go b/internal/celt/encoder_test.go index 36c82c5..b29bcf6 100644 --- a/internal/celt/encoder_test.go +++ b/internal/celt/encoder_test.go @@ -24,7 +24,7 @@ func TestEncodeFrameRoundTripMono20ms(t *testing.T) { } for i := range 3 { - data, err := encoder.EncodeFrame(pcm, frameBytes, 0, maxBands) + data, err := encoder.EncodeFrame([][]float32{pcm}, frameBytes, 0, maxBands) require.NoError(t, err) require.NotEmpty(t, data) assert.LessOrEqual(t, len(data), frameBytes, @@ -53,7 +53,7 @@ func TestEncodeFrameRoundTripMono20msTightBudget(t *testing.T) { pcm[i] = float32(math.Sin(2 * math.Pi * 440 * float64(i) / sampleRate)) } - data, err := encoder.EncodeFrame(pcm, frameBytes, 0, maxBands) + data, err := encoder.EncodeFrame([][]float32{pcm}, frameBytes, 0, maxBands) require.NoError(t, err) require.NotEmpty(t, data) assert.LessOrEqual(t, len(data), frameBytes) @@ -80,7 +80,7 @@ func TestEncodeFrameMonoPersistence(t *testing.T) { pcm[i] = float32(math.Sin(2 * math.Pi * 440 * float64(i) / sampleRate)) } - data1, err := encoder.EncodeFrame(pcm, frameBytes, 0, maxBands) + data1, err := encoder.EncodeFrame([][]float32{pcm}, frameBytes, 0, maxBands) require.NoError(t, err) out1 := make([]float32, frameSampleCount) @@ -91,7 +91,7 @@ func TestEncodeFrameMonoPersistence(t *testing.T) { out1b += float64(out1[i]) } - data2, err := encoder.EncodeFrame(pcm, frameBytes, 0, maxBands) + data2, err := encoder.EncodeFrame([][]float32{pcm}, frameBytes, 0, maxBands) require.NoError(t, err) out2 := make([]float32, frameSampleCount) @@ -116,10 +116,35 @@ func TestEncodeFrameMonoRngStability(t *testing.T) { } for range 3 { - data, err := encoder.EncodeFrame(pcm, frameBytes, 0, maxBands) + data, err := encoder.EncodeFrame([][]float32{pcm}, frameBytes, 0, maxBands) require.NoError(t, err) require.NotEmpty(t, data) _ = encoder.rangeEncoder.FinalRange() } } + +func TestEncodeFrameStereoFinalRange(t *testing.T) { + encoder := NewEncoder() + decoder := NewDecoder() + + frameSampleCount := shortBlockSampleCount << maxLM + frameBytes := 60 + + L := make([]float32, frameSampleCount) + R := make([]float32, frameSampleCount) + for i := range frameSampleCount { + L[i] = float32(math.Sin(2 * math.Pi * 440 * float64(i) / sampleRate)) + R[i] = float32(math.Sin(2 * math.Pi * 660 * float64(i) / sampleRate)) + } + + data, err := encoder.EncodeFrame([][]float32{L, R}, frameBytes, 0, maxBands) + require.NoError(t, err) + require.NotEmpty(t, data) + + out := make([]float32, frameSampleCount*2) + require.NoError(t, decoder.Decode(data, out, true, 2, frameSampleCount, 0, maxBands)) + + assert.Equal(t, encoder.FinalRange(), decoder.FinalRange(), + "range coder must be in sync after stereo encode/decode") +} From ac16983c93bb14a1840d0d6bc5af777e5fd06c70 Mon Sep 17 00:00:00 2001 From: Thomas Vilte Date: Tue, 9 Jun 2026 19:58:23 -0300 Subject: [PATCH 2/2] Add stereo encoding edge case tests --- internal/celt/encoder_test.go | 65 +++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/internal/celt/encoder_test.go b/internal/celt/encoder_test.go index b29bcf6..221f1be 100644 --- a/internal/celt/encoder_test.go +++ b/internal/celt/encoder_test.go @@ -148,3 +148,68 @@ func TestEncodeFrameStereoFinalRange(t *testing.T) { assert.Equal(t, encoder.FinalRange(), decoder.FinalRange(), "range coder must be in sync after stereo encode/decode") } + +func TestQuantBandStereoN1(t *testing.T) { + enc := NewEncoder() + enc.rangeEncoder.Init() + state := bandEncodeState{rangeEncoder: &enc.rangeEncoder} + + x := []float32{0.7} + y := []float32{-0.5} + remaining := 100 << bitResolution + + mask := quantBandStereo( + 0, x, y, 1, 10<