diff --git a/internal/celt/bands.go b/internal/celt/bands.go index f7d741e..6cc5540 100644 --- a/internal/celt/bands.go +++ b/internal/celt/bands.go @@ -9,6 +9,7 @@ import ( "math/bits" "github.com/pion/opus/internal/rangecoding" + "github.com/pion/opus/internal/slicetools" ) const ( @@ -28,6 +29,10 @@ type bandDecodeState struct { seed uint32 pulseScratch []int tmpScratch []float32 + normScratch []float32 + lowScratch []float32 + maskScratch []byte + cwrsRows map[cwrsRowKey][]uint32 } // quantAllBands drives RFC 6716 Section 4.3.4 shape decoding across the coded @@ -44,10 +49,10 @@ func quantAllBands(info *frameSideInfo, x []float32, y []float32, totalBits int, } scale := 1 << info.lm frameBins := scale * int(bandEdges[maxBands]) - norm := make([]float32, channelCount*frameBins) + norm := slicetools.ResizeZero(&state.normScratch, channelCount*frameBins) norm2 := norm[frameBins:] - lowbandScratch := make([]float32, scale*int(bandEdges[maxBands]-bandEdges[maxBands-1])) - collapseMasks := make([]byte, channelCount*maxBands) + lowbandScratch := slicetools.Resize(&state.lowScratch, scale*int(bandEdges[maxBands]-bandEdges[maxBands-1])) + collapseMasks := slicetools.ResizeZero(&state.maskScratch, channelCount*maxBands) lowbandOffset := 0 updateLowband := true @@ -714,7 +719,7 @@ func haar1(x []float32, n0 int, stride int) { } func deinterleaveHadamard(x []float32, n0 int, stride int, hadamard bool, state *bandDecodeState) { - tmp := state.floatScratch(n0 * stride) + tmp := slicetools.Resize(&state.tmpScratch, n0*stride) if hadamard { ordery := orderyTable[stride-2:] for i := range stride { @@ -733,7 +738,7 @@ func deinterleaveHadamard(x []float32, n0 int, stride int, hadamard bool, state } func interleaveHadamard(x []float32, n0 int, stride int, hadamard bool, state *bandDecodeState) { - tmp := state.floatScratch(n0 * stride) + tmp := slicetools.Resize(&state.tmpScratch, n0*stride) if hadamard { ordery := orderyTable[stride-2:] for i := range stride { @@ -751,25 +756,6 @@ func interleaveHadamard(x []float32, n0 int, stride int, hadamard bool, state *b copy(x, tmp) } -func (s *bandDecodeState) intScratch(n int) []int { - if cap(s.pulseScratch) < n { - s.pulseScratch = make([]int, n) - } - s.pulseScratch = s.pulseScratch[:n] - clear(s.pulseScratch) - - return s.pulseScratch -} - -func (s *bandDecodeState) floatScratch(n int) []float32 { - if cap(s.tmpScratch) < n { - s.tmpScratch = make([]float32, n) - } - s.tmpScratch = s.tmpScratch[:n] - - return s.tmpScratch -} - func bitInterleave(fill uint) uint { table := [...]uint{0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3} diff --git a/internal/celt/cwrs.go b/internal/celt/cwrs.go index 2a83875..990c9ff 100644 --- a/internal/celt/cwrs.go +++ b/internal/celt/cwrs.go @@ -6,30 +6,211 @@ package celt import "github.com/pion/opus/internal/rangecoding" +// The static CELT pulse cache tops out at getPulses(40) == 128. +const cwrsMaxPulseCount = 128 + +type cwrsRowKey uint32 + // decodePulses implements the RFC 6716 Section 4.3.4.2 CWRS index decode for // the PVQ pulse vector. The row buffer stores one recurrence row of V(N,K). -func decodePulses(y []int, n, k int, rangeDecoder *rangecoding.Decoder) { - for i := range n { - y[i] = 0 - } +func decodePulses( + y []int, + n, + k int, + rangeDecoder *rangecoding.Decoder, + cwrsRows map[cwrsRowKey][]uint32, +) { if k <= 0 { + clear(y[:n]) + return } - u := cwrsUrow(n, k) + switch n { + case 2: + index, _ := rangeDecoder.DecodeUniform(cwrsCodewordCount2(k)) + cwrsDecode2(y, k, index) + + return + case 3: + index, _ := rangeDecoder.DecodeUniform(cwrsCodewordCount3(k)) + cwrsDecode3(y, k, index) + + return + case 4: + index, _ := rangeDecoder.DecodeUniform(cwrsCodewordCount4(k)) + cwrsDecode4(y, k, index) + + return + } + + var row [cwrsMaxPulseCount + 2]uint32 + var u []uint32 + if k > cwrsMaxPulseCount { + u = make([]uint32, k+2) + } else { + u = row[:k+2] + } + if cwrsRows == nil { + cwrsUrowInto(u, n) + } else { + copy(u, cachedCWRSRow(cwrsRows, n, k)) + } total := u[k] + u[k+1] index, _ := rangeDecoder.DecodeUniform(total) cwrsDecode(y, n, k, index, u) } +func cachedCWRSRow(cwrsRows map[cwrsRowKey][]uint32, n, k int) []uint32 { + key := cwrsRowKey(cwrsCodebookUint32(n)<<8 | cwrsCodebookUint32(k)) + row := cwrsRows[key] + if row == nil { + row = cwrsUrow(n, k) + cwrsRows[key] = row + } + + return row +} + // cwrsUrow initializes the recurrence row needed to count PVQ codewords for a // vector of n dimensions and up to k pulses. func cwrsUrow(n, k int) []uint32 { row := make([]uint32, k+2) + cwrsUrowInto(row, n) + + return row +} + +func cwrsCodewordCount2(k int) uint32 { + return 4 * cwrsCodebookUint32(k) +} + +func cwrsCodewordCount3(k int) uint32 { + pulses := cwrsCodebookUint32(k) + + return 2 * (2*pulses*pulses + 1) +} + +func cwrsCodewordCount4(k int) uint32 { + pulses := cwrsCodebookUint32(k) + + return ((pulses*pulses + 2) * pulses / 3) << 3 +} + +func cwrsDecode1(y []int, k int, index uint32) { + if len(y) == 0 { + return + } + if index == 0 { + y[0] = k + + return + } + + y[0] = -k +} + +func cwrsDecode2(y []int, k int, index uint32) { + p := cwrsU2(k + 1) + signMask := 0 + if index >= p { + index -= p + signMask = -1 + } + yj := k + k = int((index + 1) >> 1) + if k != 0 { + index -= cwrsU2(k) + } + yj -= k + y[0] = (yj + signMask) ^ signMask + cwrsDecode1(y[1:], k, index) +} + +func cwrsDecode3(y []int, k int, index uint32) { + p := cwrsU3(k + 1) + signMask := 0 + if index >= p { + index -= p + signMask = -1 + } + yj := k + if index != 0 { + k = int((isqrt32(2*index-1) + 1) >> 1) + } else { + k = 0 + } + if k != 0 { + index -= cwrsU3(k) + } + yj -= k + y[0] = (yj + signMask) ^ signMask + cwrsDecode2(y[1:], k, index) +} + +func cwrsDecode4(y []int, k int, index uint32) { + p := cwrsU4(k + 1) + signMask := 0 + if index >= p { + index -= p + signMask = -1 + } + yj := k + low := 0 + high := k + for { + k = (low + high) >> 1 + if k != 0 { + p = cwrsU4(k) + } else { + p = 0 + } + switch { + case p < index: + if k >= high { + goto decoded + } + low = k + 1 + case p > index: + high = k - 1 + default: + goto decoded + } + } + +decoded: + index -= p + yj -= k + y[0] = (yj + signMask) ^ signMask + cwrsDecode3(y[1:], k, index) +} + +func cwrsU2(k int) uint32 { + return cwrsCodebookUint32((k << 1) - 1) +} + +func cwrsU3(k int) uint32 { + pulses := cwrsCodebookUint32(k) + + return (2*pulses-2)*pulses + 1 +} + +func cwrsU4(k int) uint32 { + pulses := cwrsCodebookUint32(k) + + return (2*pulses*((2*pulses-3)*pulses+4) - 3) / 3 +} + +// CELT codebook dimensions and pulse counts are small non-negative values. +func cwrsCodebookUint32(value int) uint32 { + return uint32(value) //nolint:gosec +} + +func cwrsUrowInto(row []uint32, n int) { if n == 0 { row[0] = 1 - return row + return } row[0] = 0 if len(row) > 1 { @@ -40,7 +221,7 @@ func cwrsUrow(n, k int) []uint32 { row[i] = 1 } - return row + return } for pulses := 2; pulses < len(row); pulses++ { row[pulses] = uint32((pulses << 1) - 1) @@ -48,8 +229,6 @@ func cwrsUrow(n, k int) []uint32 { for rowIndex := 2; rowIndex < n; rowIndex++ { cwrsNextRow(row[1:], 1) } - - return row } // cwrsNextRow advances the V(N,K) recurrence by one dimension. @@ -66,11 +245,16 @@ func cwrsNextRow(u []uint32, value0 uint32) { // cwrsDecode walks the recurrence row to recover signs and pulse magnitudes // from the uniformly decoded codeword index. func cwrsDecode(y []int, n, k int, index uint32, u []uint32) { - for j := range n { + genericLength := n + if n > 4 { + genericLength -= 4 + } + for j := range genericLength { p := u[k+1] - negative := index >= p - if negative { + signMask := 0 + if index >= p { index -= p + signMask = -1 } yj := k @@ -81,25 +265,25 @@ func cwrsDecode(y []int, n, k int, index uint32, u []uint32) { } index -= p yj -= k - if negative { - y[j] = -yj - } else { - y[j] = yj - } + y[j] = (yj + signMask) ^ signMask cwrsPreviousRow(u, k+2, 0) } + if genericLength < n { + cwrsDecode4(y[genericLength:], k, index) + } } // cwrsPreviousRow rewinds the recurrence after one coefficient has been // decoded, matching the row update used by the reference CWRS decoder. func cwrsPreviousRow(u []uint32, n int, value0 uint32) { + u = u[:n] value := value0 - for j := 1; j < n; j++ { + for j := 1; j < len(u); j++ { next := u[j] - u[j-1] - value u[j-1] = value value = next } - u[n-1] = value + u[len(u)-1] = value } // encodePulses writes a CWRS index for the PVQ pulse vector y to the range diff --git a/internal/celt/cwrs_test.go b/internal/celt/cwrs_test.go index 815f81d..064edbc 100644 --- a/internal/celt/cwrs_test.go +++ b/internal/celt/cwrs_test.go @@ -22,7 +22,7 @@ func TestCWRSRows(t *testing.T) { func TestCWRSDecode(t *testing.T) { y := []int{99, 99, 99} - decodePulses(y, len(y), 0, nil) + decodePulses(y, len(y), 0, nil, nil) assert.Equal(t, []int{0, 0, 0}, y) row := cwrsUrow(3, 2) @@ -30,10 +30,70 @@ func TestCWRSDecode(t *testing.T) { assert.Equal(t, []int{2, 0, 0}, y) decoder := rangeDecoderWithCDFSymbol(0, cwrsUrow(3, 2)[2]+cwrsUrow(3, 2)[3]) - decodePulses(y, len(y), 2, &decoder) + decodePulses(y, len(y), 2, &decoder, nil) assert.Equal(t, []int{2, 0, 0}, y) } +func TestCWRSDirectDecodeMatchesGeneric(t *testing.T) { + for n := 2; n <= 4; n++ { + for k := 1; k <= 4; k++ { + row := cwrsUrow(n, k) + total := row[k] + row[k+1] + for index := range total { + expected := make([]int, n) + cwrsDecode(expected, n, k, index, append([]uint32(nil), row...)) + + decoder := rangeDecoderWithCDFSymbol(index, total) + got := make([]int, n) + decodePulses(got, n, k, &decoder, nil) + assert.Equalf(t, expected, got, "n=%d k=%d index=%d", n, k, index) + } + } + } +} + +func TestCWRSTailDecodeMatchesGeneric(t *testing.T) { + for n := 5; n <= 10; n++ { + for k := 1; k <= 4; k++ { + row := cwrsUrow(n, k) + total := row[k] + row[k+1] + for index := range total { + expected := make([]int, n) + cwrsDecodeGenericForTest(expected, n, k, index, append([]uint32(nil), row...)) + + got := make([]int, n) + cwrsDecode(got, n, k, index, append([]uint32(nil), row...)) + assert.Equalf(t, expected, got, "n=%d k=%d index=%d", n, k, index) + } + } + } +} + +func cwrsDecodeGenericForTest(vector []int, dimension, pulseCount int, index uint32, row []uint32) { + for vectorIndex := range dimension { + p := row[pulseCount+1] + negative := index >= p + if negative { + index -= p + } + + value := pulseCount + p = row[pulseCount] + for p > index { + pulseCount-- + p = row[pulseCount] + } + index -= p + value -= pulseCount + if negative { + vector[vectorIndex] = -value + } else { + vector[vectorIndex] = value + } + cwrsPreviousRow(row, pulseCount+2, 0) + } +} + func TestCWRSEncodeZeroPulses(t *testing.T) { assert.Equal(t, uint32(0), cwrsEncode([]int{0, 0, 0}, 3, 0)) } diff --git a/internal/celt/decoder.go b/internal/celt/decoder.go index 004a629..3122d23 100644 --- a/internal/celt/decoder.go +++ b/internal/celt/decoder.go @@ -20,6 +20,7 @@ type Decoder struct { rng uint32 lossCount int scratch *decoderScratch + cwrsRows map[cwrsRowKey][]uint32 } // NewDecoder creates a CELT decoder with the static Opus 48 kHz mode. @@ -227,6 +228,12 @@ func (d *Decoder) decode( state := bandDecodeState{ rangeDecoder: &d.rangeDecoder, seed: d.rng, + pulseScratch: scratch.bandPulses[:], + tmpScratch: scratch.bandTmp[:], + normScratch: scratch.bandNorm[:], + lowScratch: scratch.bandLow[:], + maskScratch: scratch.collapseMasks[:], + cwrsRows: d.cwrsRowCache(), } totalBits := (int(info.totalBits) << bitResolution) - info.antiCollapseRsv collapseMasks := quantAllBands(&info, x, y, totalBits, &state) @@ -253,6 +260,16 @@ func (d *Decoder) decode( return nil } +func (d *Decoder) cwrsRowCache() map[cwrsRowKey][]uint32 { + // CWRS rows only depend on static codebook dimensions and pulse counts, so + // retaining them across Reset avoids rebuilding the same immutable rows. + if d.cwrsRows == nil { + d.cwrsRows = make(map[cwrsRowKey][]uint32) + } + + return d.cwrsRows +} + func (d *Decoder) decodeLostFrame(info *frameSideInfo, out []float32) { clear(out) decay := float32(1.5) diff --git a/internal/celt/encode_bands.go b/internal/celt/encode_bands.go index 20ed944..e191e71 100644 --- a/internal/celt/encode_bands.go +++ b/internal/celt/encode_bands.go @@ -8,6 +8,7 @@ import ( "math" "github.com/pion/opus/internal/rangecoding" + "github.com/pion/opus/internal/slicetools" ) type bandEncodeState struct { @@ -17,12 +18,7 @@ type bandEncodeState struct { } func (s *bandEncodeState) floatScratch(n int) []float32 { - if cap(s.tmpScratch) < n { - s.tmpScratch = make([]float32, n) - } - s.tmpScratch = s.tmpScratch[:n] - - return s.tmpScratch + return slicetools.Resize(&s.tmpScratch, n) } func normaliseBandsForEncoding( diff --git a/internal/celt/pvq.go b/internal/celt/pvq.go index 6284d11..cd24faf 100644 --- a/internal/celt/pvq.go +++ b/internal/celt/pvq.go @@ -8,6 +8,7 @@ import ( "math" "github.com/pion/opus/internal/rangecoding" + "github.com/pion/opus/internal/slicetools" ) const ( @@ -30,17 +31,14 @@ func algUnquant( gain float32, state *bandDecodeState, ) uint { - iy := state.intScratch(n) - decodePulses(iy, n, k, rangeDecoder) + iy := slicetools.Resize(&state.pulseScratch, n) + decodePulses(iy, n, k, rangeDecoder, state.cwrsRows) - energy := 0 - for i := range n { - energy += iy[i] * iy[i] - } + energy, collapseMask := pulseEnergyAndCollapseMask(iy, n, blocks) normaliseResidual(iy, x, n, energy, gain) expRotation(x, n, -1, blocks, k, spread) - return extractCollapseMask(iy, n, blocks) + return collapseMask } // normaliseResidual maps integer PVQ pulses back to a floating-point unit @@ -79,6 +77,29 @@ func extractCollapseMask(iy []int, n int, blocks int) uint { return mask } +func pulseEnergyAndCollapseMask(iy []int, n int, blocks int) (energy int, mask uint) { + if blocks <= 1 { + for i := range n { + energy += iy[i] * iy[i] + } + + return energy, 1 + } + + blockSize := n / blocks + for block := range blocks { + for i := range blockSize { + pulse := iy[block*blockSize+i] + energy += pulse * pulse + if pulse != 0 { + mask |= 1 << block + } + } + } + + return energy, mask +} + // renormaliseVector restores unit energy after lowband folding or noise fill. func renormaliseVector(x []float32, n int, gain float32) { energy := float32(1e-27) @@ -202,16 +223,31 @@ func algQuant( } func expRotation1(x []float32, length int, stride int, c float32, s float32) { - for i := 0; i < length-stride; i++ { - x1 := x[i] - x2 := x[i+stride] - x[i+stride] = c*x2 + s*x1 - x[i] = c*x1 - s*x2 - } - for i := length - 2*stride - 1; i >= 0; i-- { - x1 := x[i] - x2 := x[i+stride] - x[i+stride] = c*x2 + s*x1 - x[i] = c*x1 - s*x2 + if length <= stride { + return + } + + lower := x[:length-stride] + upper := x[stride:length] + for i := range lower { + x1 := lower[i] + x2 := upper[i] + upper[i] = c*x2 + s*x1 + lower[i] = c*x1 - s*x2 + } + + backwardLength := len(lower) - stride + if backwardLength <= 0 { + return + } + backwardLower := lower[:backwardLength] + backwardUpper := upper[:backwardLength] + // slices.Backward adds iterator overhead in this hot loop. + //nolint:modernize + for i := backwardLength - 1; i >= 0; i-- { + x1 := backwardLower[i] + x2 := backwardUpper[i] + backwardUpper[i] = c*x2 + s*x1 + backwardLower[i] = c*x1 - s*x2 } } diff --git a/internal/celt/synthesis.go b/internal/celt/synthesis.go index 0016a03..31d3612 100644 --- a/internal/celt/synthesis.go +++ b/internal/celt/synthesis.go @@ -28,10 +28,15 @@ type complex32 struct { } type decoderScratch struct { - x [maxFrameSampleCount]float32 - y [maxFrameSampleCount]float32 - channels [2]channelScratch - postfilter [2][postfilterHistorySampleCount + maxFrameSampleCount]float32 + x [maxFrameSampleCount]float32 + y [maxFrameSampleCount]float32 + bandNorm [2 * maxFrameSampleCount]float32 + bandLow [maxFrameSampleCount]float32 + bandTmp [maxFrameSampleCount]float32 + bandPulses [maxFrameSampleCount]int + collapseMasks [2 * maxBands]byte + channels [2]channelScratch + postfilter [2][postfilterHistorySampleCount + maxFrameSampleCount]float32 } type channelScratch struct { @@ -43,12 +48,11 @@ type channelScratch struct { } type mdctScratch struct { - preRotated [maxFrameSampleCount / 2]complex32 - fftOut [maxFrameSampleCount / 2]complex32 - fftWork [maxFrameSampleCount / 2]complex32 - postRotated [maxFrameSampleCount]float32 - deshuffled [maxFrameSampleCount]float32 - out [maxFrameSampleCount + shortBlockSampleCount]float32 + preRotated [maxFrameSampleCount / 2]complex32 + fftOut [maxFrameSampleCount / 2]complex32 + fftWork [maxFrameSampleCount / 2]complex32 + deshuffled [maxFrameSampleCount]float32 + out [maxFrameSampleCount + shortBlockSampleCount]float32 } type inverseTransformPlan struct { @@ -133,7 +137,7 @@ func (d *Decoder) log2Amp(info *frameSideInfo) [2][maxBands]float32 { for channel := range info.channelCount { for band := info.startBand; band < info.endBand; band++ { lg := minFloat32(32, d.previousLogE[channel][band]+energyMeans[band]) - energy[channel][band] = float32(math.Pow(2, float64(lg))) + energy[channel][band] = float32(math.Exp2(float64(lg))) } } @@ -191,17 +195,15 @@ func (d *Decoder) denormaliseAndSynthesize( // antiCollapse implements RFC 6716 Section 4.3.5 by injecting low-energy // noise into transient short blocks that received no PVQ pulses. func (d *Decoder) antiCollapse(info *frameSideInfo, x []float32, y []float32, collapseMasks []byte, seed uint32) { - channels := [][]float32{x} - if info.channelCount == 2 { - channels = append(channels, y) - } + channels := [2][]float32{x, y} for band := info.startBand; band < info.endBand; band++ { n0 := int(bandEdges[band+1] - bandEdges[band]) n := n0 << info.lm depth := (1 + info.allocation.pulses[band]) / n threshold := 0.5 * math.Pow(2, -0.125*float64(depth)) sqrtInv := 1 / math.Sqrt(float64(n)) - for channel, spectrum := range channels { + for channel := range info.channelCount { + spectrum := channels[channel] prev1 := d.previousLogE1[channel][band] prev2 := d.previousLogE2[channel][band] if info.channelCount == 1 { @@ -361,30 +363,39 @@ func combFilter(buf []float32, start int, period0 int, period1 int, n int, gain0 g11 := gain1 * gains[tapset1][1] g12 := gain1 * gains[tapset1][2] overlap := min(shortBlockSampleCount, n) + output := buf[start : start+n] + previous0 := buf[start-period0 : start-period0+overlap] + previous0Minus1 := buf[start-period0-1 : start-period0-1+overlap] + previous0Plus1 := buf[start-period0+1 : start-period0+1+overlap] + previous0Minus2 := buf[start-period0-2 : start-period0-2+overlap] + previous0Plus2 := buf[start-period0+2 : start-period0+2+overlap] + previous1 := buf[start-period1 : start-period1+n] + previous1Minus1 := buf[start-period1-1 : start-period1-1+n] + previous1Plus1 := buf[start-period1+1 : start-period1+1+n] + previous1Minus2 := buf[start-period1-2 : start-period1-2+n] + previous1Plus2 := buf[start-period1+2 : start-period1+2+n] for i := 0; i < overlap; i++ { window := celtWindow(i) fade := window * window - index := start + i - buf[index] = buf[index] + - (1-fade)*g00*buf[index-period0] + - (1-fade)*g01*buf[index-period0-1] + - (1-fade)*g01*buf[index-period0+1] + - (1-fade)*g02*buf[index-period0-2] + - (1-fade)*g02*buf[index-period0+2] + - fade*g10*buf[index-period1] + - fade*g11*buf[index-period1-1] + - fade*g11*buf[index-period1+1] + - fade*g12*buf[index-period1-2] + - fade*g12*buf[index-period1+2] + output[i] = output[i] + + (1-fade)*g00*previous0[i] + + (1-fade)*g01*previous0Minus1[i] + + (1-fade)*g01*previous0Plus1[i] + + (1-fade)*g02*previous0Minus2[i] + + (1-fade)*g02*previous0Plus2[i] + + fade*g10*previous1[i] + + fade*g11*previous1Minus1[i] + + fade*g11*previous1Plus1[i] + + fade*g12*previous1Minus2[i] + + fade*g12*previous1Plus2[i] } for i := overlap; i < n; i++ { - index := start + i - buf[index] = buf[index] + - g10*buf[index-period1] + - g11*buf[index-period1-1] + - g11*buf[index-period1+1] + - g12*buf[index-period1-2] + - g12*buf[index-period1+2] + output[i] = output[i] + + g10*previous1[i] + + g11*previous1Minus1[i] + + g11*previous1Plus1[i] + + g12*previous1Minus2[i] + + g12*previous1Plus2[i] } } @@ -486,9 +497,10 @@ func inverseMDCTWithScratch(freq []float32, scratch *mdctScratch) []float32 { fftOut := scratch.fftOut[:n4] inverseComplexDFTInto(preRotated, fftOut, scratch.fftWork[:n4], plan) - postRotated := scratch.postRotated[:n2] + deshuffled := scratch.deshuffled[:n2] // Rotate back out of the complex domain and restore the packed even/odd - // ordering expected by the time-domain mirror step. + // ordering expected by the time-domain mirror step. Write directly into + // that order to avoid an intermediate buffer and a second pass. for i, value := range fftOut { re := value.r im := value.i @@ -496,14 +508,8 @@ func inverseMDCTWithScratch(freq []float32, scratch *mdctScratch) []float32 { sineQuarter := plan.rotateSinQuarter[i] yr := re*cosine - im*sineQuarter yi := im*cosine + re*sineQuarter - postRotated[2*i] = yr - yi*plan.sine - postRotated[2*i+1] = yi + yr*plan.sine - } - - deshuffled := scratch.deshuffled[:n2] - for i := range n4 { - deshuffled[2*i] = -postRotated[2*i] - deshuffled[2*i+1] = postRotated[n2-1-2*i] + deshuffled[2*i] = -(yr - yi*plan.sine) + deshuffled[n2-1-2*i] = yi + yr*plan.sine } overlap := shortBlockSampleCount @@ -558,7 +564,6 @@ func inverseComplexDFTInto(in []complex32, out []complex32, work []complex32, pl for i, value := range in { out[plan.fftBitrev[i]] = value } - for stage := len(plan.fftFactors) - 1; stage >= 0; stage-- { factor := plan.fftFactors[stage] fstride := plan.n4 / (factor.radix * factor.size) diff --git a/internal/rangecoding/decoder.go b/internal/rangecoding/decoder.go index 4cecd63..c4f20ba 100644 --- a/internal/rangecoding/decoder.go +++ b/internal/rangecoding/decoder.go @@ -287,6 +287,18 @@ func (r *Decoder) getBit() uint32 { } func (r *Decoder) getBits(n int) uint32 { + if n > 0 && n <= 8 && r.bitsRead+uint(n) <= uint(len(r.data))*8 { + byteIndex := r.bitsRead / 8 + offset := r.bitsRead % 8 + combined := uint32(r.data[byteIndex]) << 8 + if byteIndex+1 < uint(len(r.data)) { + combined |= uint32(r.data[byteIndex+1]) + } + r.bitsRead += uint(n) + + return combined >> (16 - offset - uint(n)) & ((1 << n) - 1) + } + bits := uint32(0) for i := range n { diff --git a/internal/rangecoding/decoder_test.go b/internal/rangecoding/decoder_test.go index 61c89a8..47148fd 100644 --- a/internal/rangecoding/decoder_test.go +++ b/internal/rangecoding/decoder_test.go @@ -201,6 +201,36 @@ func TestDecoderInitEmptyInput(t *testing.T) { }) } +func TestGetBitsCrossesByteBoundary(t *testing.T) { + decoder := &Decoder{ + data: []byte{0b10101100, 0b01110010}, + bitsRead: 5, + } + + assert.Equal(t, uint32(0b10001110), decoder.getBits(8)) + assert.Equal(t, uint(13), decoder.bitsRead) +} + +func TestGetBitsMatchesBitByBit(t *testing.T) { + data := []byte{0b10101100, 0b01110010, 0b11010001} + for offset := range len(data) * 8 { + for width := 1; width <= 8 && offset+width <= len(data)*8; width++ { + fast := &Decoder{data: data, bitsRead: uint(offset)} + bitByBit := &Decoder{data: data, bitsRead: uint(offset)} + want := uint32(0) + for i := range width { + if i != 0 { + want <<= 1 + } + want |= bitByBit.getBit() + } + + assert.Equalf(t, want, fast.getBits(width), "offset=%d width=%d", offset, width) + assert.Equalf(t, bitByBit.bitsRead, fast.bitsRead, "offset=%d width=%d", offset, width) + } + } +} + func TestDecodeRawBits(t *testing.T) { t.Run("reads bits from the end of the frame in LSB-first order", func(t *testing.T) { decoder := &Decoder{data: []byte{0xB2}} diff --git a/internal/silk/decoder.go b/internal/silk/decoder.go index b2d4300..852fbb0 100644 --- a/internal/silk/decoder.go +++ b/internal/silk/decoder.go @@ -8,6 +8,7 @@ import ( "slices" "github.com/pion/opus/internal/rangecoding" + "github.com/pion/opus/internal/slicetools" ) // Decoder maintains the state needed to decode a stream @@ -140,7 +141,7 @@ func (d *Decoder) decodeHeaderBitsInto( if dst == nil { voiceActivityDetected = make([]bool, frameCount) } else { - voiceActivityDetected = resizeZero(dst, frameCount) + voiceActivityDetected = slicetools.ResizeZero(dst, frameCount) } for i := range frameCount { voiceActivityDetected[i] = d.rangeDecoder.DecodeSymbolLogP(1) == 1 @@ -150,21 +151,6 @@ func (d *Decoder) decodeHeaderBitsInto( return } -func resize[T any](buffer *[]T, size int) []T { - if cap(*buffer) < size { - *buffer = make([]T, size) - } - - return (*buffer)[:size] -} - -func resizeZero[T any](buffer *[]T, size int) []T { - out := resize(buffer, size) - clear(out) - - return out -} - // decodeLowBitrateRedundancyFlags expands RFC 6716 Section 4.2.4's global // LBRR-present bit into one flag per SILK frame. func (d *Decoder) decodeLowBitrateRedundancyFlags(frameCount int, present bool) []bool { @@ -176,7 +162,7 @@ func (d *Decoder) decodeLowBitrateRedundancyFlagsInto(dst *[]bool, frameCount in if dst == nil { flags = make([]bool, frameCount) } else { - flags = resizeZero(dst, frameCount) + flags = slicetools.ResizeZero(dst, frameCount) } if !present { return flags @@ -306,7 +292,7 @@ func (d *Decoder) decodeSubframeQuantizations( isFirstSilkFrameInOpusFrame bool, ) (gainQ16 []float32) { var logGain, deltaGainIndex, gainIndex int32 - gainQ16 = resizeZero(&d.gainQ16, subframeCount) + gainQ16 = slicetools.ResizeZero(&d.gainQ16, subframeCount) for subframeIndex := range subframeCount { // The subframe gains are either coded independently, or relative to the @@ -444,7 +430,7 @@ func (d *Decoder) normalizeLineSpectralFrequencyStageTwo( codebook = codebookNormalizedLSFStageTwoIndexNarrowbandOrMediumband } - I2 := resizeZero(&d.i2, len(codebook[0])) + I2 := slicetools.ResizeZero(&d.i2, len(codebook[0])) for i := range I2 { // the decoder reads a symbol using the PDF corresponding // to I1 from either Table 17 or Table 18 and subtracts 4 from the @@ -489,7 +475,7 @@ func (d *Decoder) normalizeLineSpectralFrequencyStageTwo( } // stage-2 residual - resQ10 = resizeZero(&d.resQ10, len(I2)) + resQ10 = slicetools.ResizeZero(&d.resQ10, len(I2)) // Let d_LPC be the order of the codebook, i.e., 10 for NB and MB, and 16 for WB dLPC = len(I2) @@ -549,9 +535,9 @@ func (d *Decoder) normalizeLineSpectralFrequencyCoefficients( resQ10 []int16, stageOneIndex uint32, ) (nlsfQ15 []int16) { - nlsfQ15 = resizeZero(&d.nlsfQ15, dLPC) - w2Q18 := resizeZero(&d.w2Q18, dLPC) - wQ9 := resizeZero(&d.wQ9, dLPC) + nlsfQ15 = slicetools.ResizeZero(&d.nlsfQ15, dLPC) + w2Q18 := slicetools.ResizeZero(&d.w2Q18, dLPC) + wQ9 := slicetools.ResizeZero(&d.wQ9, dLPC) cb1Q8 := codebookNormalizedLSFStageOneNarrowbandOrMediumband if bandwidth == BandwidthWideband { @@ -788,7 +774,7 @@ func (d *Decoder) normalizeLSFInterpolation(n2Q15 []int16, nanoseconds int) (n1Q return nil, wQ2 } - n1Q15 = resizeZero(&d.n1Q15, len(n2Q15)) + n1Q15 = slicetools.ResizeZero(&d.n1Q15, len(n2Q15)) for k := range n1Q15 { interpolated := int32(wQ2) * (int32(n2Q15[k]) - int32(d.n0Q15[k])) >> 2 //nolint:gosec // G602 n1Q15[k] = int16(int32(d.n0Q15[k]) + interpolated) //nolint:gosec // G115 @@ -815,7 +801,7 @@ func (d *Decoder) generateAQ12(q15 []int16, bandwidth Bandwidth, aQ12 [][]float3 } func (d *Decoder) convertNormalizedLSFsToLPCCoefficients(n1Q15 []int16, bandwidth Bandwidth) (a32Q17 []int32) { - cQ17 := resizeZero(&d.cQ17, len(n1Q15)) + cQ17 := slicetools.ResizeZero(&d.cQ17, len(n1Q15)) cosQ12 := q12CosineTableForLSFConverion ordering := lsfOrderingForPolynomialEvaluationNarrowbandAndMediumband @@ -845,8 +831,8 @@ func (d *Decoder) convertNormalizedLSFsToLPCCoefficients(n1Q15 []int16, bandwidt (cosQ12[i+1]-cosQ12[i])*f + 4) >> 3 } - pQ16 := resizeZero(&d.pQ16, (len(n1Q15)/2)+1) - qQ16 := resizeZero(&d.qQ16, (len(n1Q15)/2)+1) + pQ16 := slicetools.ResizeZero(&d.pQ16, (len(n1Q15)/2)+1) + qQ16 := slicetools.ResizeZero(&d.qQ16, (len(n1Q15)/2)+1) // Given the list of cosine values compute the coefficients of P and Q, // described here via a simple recurrence. Let p_Q16[k][j] and q_Q16[k][j] @@ -914,7 +900,7 @@ func (d *Decoder) convertNormalizedLSFsToLPCCoefficients(n1Q15 []int16, bandwidt // // https://datatracker.ietf.org/doc/html/rfc6716#section-4.2.7.5.6 - a32Q17 = resizeZero(&d.a32Q17, len(n1Q15)) + a32Q17 = slicetools.ResizeZero(&d.a32Q17, len(n1Q15)) for k := range d2 { a32Q17[k] = -(qQ16[k+1] - qQ16[k]) - (pQ16[k+1] + pQ16[k]) a32Q17[dLPC-k-1] = (qQ16[k+1] - qQ16[k]) - (pQ16[k+1] + pQ16[k]) @@ -997,8 +983,8 @@ func (d *Decoder) decodeRatelevel(voiceActivityDetected bool) uint32 { // // https://datatracker.ietf.org/doc/html/rfc6716#section-4.2.7.8.2 func (d *Decoder) decodePulseAndLSBCounts(shellblocks int, rateLevel uint32) (pulsecounts []uint8, lsbcounts []uint8) { - pulsecounts = resizeZero(&d.pulsecounts, shellblocks) - lsbcounts = resizeZero(&d.lsbcounts, shellblocks) + pulsecounts = slicetools.ResizeZero(&d.pulsecounts, shellblocks) + lsbcounts = slicetools.ResizeZero(&d.lsbcounts, shellblocks) for i := range shellblocks { pulsecounts[i] = uint8(d.rangeDecoder.DecodeSymbolWithICDF(icdfPulseCount[rateLevel])) //nolint:gosec // g115 @@ -1043,7 +1029,7 @@ func (d *Decoder) decodePulseAndLSBCounts(shellblocks int, rateLevel uint32) (pu // // https://datatracker.ietf.org/doc/html/rfc6716#section-4.2.7.8.3 func (d *Decoder) decodePulseLocation(pulsecounts []uint8) (eRaw []int32) { - eRaw = resizeZero(&d.eRaw, len(pulsecounts)*pulsecountLargestPartitionSize) + eRaw = slicetools.ResizeZero(&d.eRaw, len(pulsecounts)*pulsecountLargestPartitionSize) for i := range pulsecounts { // This process skips partitions without any pulses, i.e., where // the initial pulse count from Section 4.2.7.8.2 was zero, or where the @@ -1316,7 +1302,7 @@ func (d *Decoder) decodeExcitation( // and with the corresponding sign decoded in Section 4.2.7.8.5. d.decodeExcitationSign(eRaw, signalType, quantizationOffsetType, pulsecounts) - eQ23 = resizeZero(&d.eQ23, len(eRaw)) + eQ23 = slicetools.ResizeZero(&d.eQ23, len(eRaw)) for i := range eRaw { // Additionally, let seed be the current pseudorandom seed, which is initialized to the // value decoded from Section 4.2.7.7 for the first sample in the current SILK frame, and @@ -1471,7 +1457,7 @@ func (d *Decoder) limitLPCFilterPredictionGainInto(a32Q17 []int32, slot int) (aQ // a32_Q12[n] = (a32_Q17[n] + 16) >> 5 // // https://datatracker.ietf.org/doc/html/rfc6716#section-4.2.7.5.8 - aQ12Int := resizeZero(&d.aQ12Int[slot], len(a32Q17)) + aQ12Int := slicetools.ResizeZero(&d.aQ12Int[slot], len(a32Q17)) for n := range a32Q17 { aQ12Int[n] = int16((a32Q17[n] + 16) >> 5) //nolint:gosec // G115 } @@ -1493,7 +1479,7 @@ func (d *Decoder) limitLPCFilterPredictionGainInto(a32Q17 []int32, slot int) (aQ } } - aQ12 = resizeZero(&d.aQ12Coefficients[slot], len(aQ12Int)) + aQ12 = slicetools.ResizeZero(&d.aQ12Coefficients[slot], len(aQ12Int)) for n := range aQ12Int { aQ12[n] = float32(aQ12Int[n]) } @@ -1721,7 +1707,7 @@ func (d *Decoder) decodePitchLags( // // pitch_lags[k] = clamp(lag_min, lag + lag_cb[contour_index][k], // lag_max) - pitchLags = resizeZero(&d.pitchLags, subframeCount(nanoseconds)) + pitchLags = slicetools.ResizeZero(&d.pitchLags, subframeCount(nanoseconds)) for i := range pitchLags { pitchLags[i] = int(clamp( int32(lagMin), //nolint:gosec @@ -1784,8 +1770,8 @@ func (d *Decoder) decodeLTPFilterCoefficients(signalType frameSignalType, subfra return bQ7 } - bQ7 = resize(&d.bQ7, subframeCount) - bQ7Data := resizeZero(&d.bQ7Data, subframeCount*5) + bQ7 = slicetools.Resize(&d.bQ7, subframeCount) + bQ7Data := slicetools.ResizeZero(&d.bQ7Data, subframeCount*5) for i := range bQ7 { start := i * 5 bQ7[i] = bQ7Data[start : start+5] @@ -2138,7 +2124,7 @@ func (d *Decoder) silkFrameReconstruction( // let lpc[i] be the result of LPC synthesis from the last d_LPC samples of the // previous subframe or zeros in the first subframe for this channel - lpc := resizeZero(&d.lpc, n*subframeCount) + lpc := slicetools.ResizeZero(&d.lpc, n*subframeCount) // For unvoiced frames (see Section 4.2.7.3), the LPC residual for i // such that j <= i < (j + n) is simply a normalized copy of the @@ -2147,8 +2133,8 @@ func (d *Decoder) silkFrameReconstruction( // e_Q23[i] // res[i] = --------- // 2.0**23 - res := resizeZero(&d.res, len(eQ23)) - resLag := resizeZero(&d.resLag, int(lagMax)+2) + res := slicetools.ResizeZero(&d.res, len(eQ23)) + resLag := slicetools.ResizeZero(&d.resLag, int(lagMax)+2) for i := range res { res[i] = float32(eQ23[i]) / 8388608.0 } diff --git a/internal/slicetools/slices.go b/internal/slicetools/slices.go new file mode 100644 index 0000000..7c59607 --- /dev/null +++ b/internal/slicetools/slices.go @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package slicetools provides shared helpers for reusing scratch slices. +package slicetools + +// Resize returns a slice of size elements, reusing buffer's allocation when +// possible. +func Resize[T any](buffer *[]T, size int) []T { + if cap(*buffer) < size { + *buffer = make([]T, size) + } + + return (*buffer)[:size] +} + +// ResizeZero returns a zeroed slice of size elements, reusing buffer's +// allocation when possible. +func ResizeZero[T any](buffer *[]T, size int) []T { + out := Resize(buffer, size) + clear(out) + + return out +} diff --git a/internal/slicetools/slices_test.go b/internal/slicetools/slices_test.go new file mode 100644 index 0000000..f0164fb --- /dev/null +++ b/internal/slicetools/slices_test.go @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package slicetools + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestResize(t *testing.T) { + buffer := []int{1, 2, 3} + + assert.Equal(t, []int{1, 2}, Resize(&buffer, 2)) + assert.Equal(t, []int{0, 0, 0, 0}, Resize(&buffer, 4)) +} + +func TestResizeZero(t *testing.T) { + buffer := []int{1, 2, 3} + + assert.Equal(t, []int{0, 0}, ResizeZero(&buffer, 2)) +}