diff --git a/encoder.go b/encoder.go new file mode 100644 index 0000000..9a2571d --- /dev/null +++ b/encoder.go @@ -0,0 +1,206 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package opus + +import ( + "encoding/binary" + "fmt" + + "github.com/pion/opus/internal/celt" +) + +const ( + defaultBitrate = 24000 + minBitrate = 6000 + maxBitrate = 510000 + frame20msNS = 20000000 +) + +// celtOnlyFullband20msConfig is the TOC config number (bits 3..7) for +// CELT-only, fullband, 20 ms frames per RFC 6716 Table 2. The mono/stereo bit +// is separate (bit 2 of the TOC) and not part of this constant. +const celtOnlyFullband20msConfig = 31 + +// Encoder encodes PCM into Opus packets. +type Encoder struct { + celtEncoder celt.Encoder + sampleRate int + channels int + bitrate int + complexity int +} + +// EncoderOption configures an Encoder during construction. +// +// Options are applied in the order they are passed to NewEncoder. Each option +// returns an error if the requested value is unsupported by the current +// encoder slice, so callers can detect unsupported configurations at +// construction time rather than at first encode. +type EncoderOption func(*Encoder) error + +// WithSampleRate sets the input sample rate in Hz. The current encoder only +// supports 48 kHz (the CELT internal rate). +func WithSampleRate(rate int) EncoderOption { + return func(e *Encoder) error { + if rate != celtSampleRate { + return errInvalidSampleRate + } + e.sampleRate = rate + + return nil + } +} + +// WithChannels sets the channel count. The current encoder only supports +// mono (1 channel); stereo is planned in a follow-up PR. +func WithChannels(channels int) EncoderOption { + return func(e *Encoder) error { + if channels != 1 { + return errInvalidChannelCount + } + e.channels = channels + + return nil + } +} + +// WithBitrate sets the target bitrate in bits per second. Valid range is +// 6000 to 510000. +func WithBitrate(bps int) EncoderOption { + return func(e *Encoder) error { + if bps < minBitrate || bps > maxBitrate { + return fmt.Errorf("%w: %d", errBitrateOutOfRange, bps) + } + e.bitrate = bps + + return nil + } +} + +// WithComplexity sets the encoder complexity on the standard Opus 0..10 +// scale. The current CELT encoder does not vary behavior by complexity, but +// the public API accepts the value for future expansion. +func WithComplexity(complexity int) EncoderOption { + return func(e *Encoder) error { + if complexity < 0 || complexity > 10 { + return fmt.Errorf("%w: %d", errInvalidComplexity, complexity) + } + e.complexity = complexity + + return nil + } +} + +// NewEncoder creates a new Opus encoder with the supplied options. +// +// Defaults: 48 kHz, mono, 24 kbit/s, complexity 0. Pass options to override +// any of these. The current API surface only supports 48 kHz mono 20 ms +// CELT-only packets; stereo, transient detection, and SILK encoding will land +// in follow-up PRs. +func NewEncoder(opts ...EncoderOption) (*Encoder, error) { + encoder := &Encoder{ + celtEncoder: celt.NewEncoder(), + sampleRate: celtSampleRate, + channels: 1, + bitrate: defaultBitrate, + complexity: 0, + } + + for _, opt := range opts { + if err := opt(encoder); err != nil { + return nil, err + } + } + + return encoder, nil +} + +// SetBitrate updates the target bitrate in bits per second. +func (e *Encoder) SetBitrate(bps int) error { + return WithBitrate(bps)(e) +} + +// SetComplexity updates the encoder complexity on the standard Opus 0..10 +// scale. +func (e *Encoder) SetComplexity(complexity int) error { + return WithComplexity(complexity)(e) +} + +// Encode encodes S16LE PCM into a single Opus packet. +// +// The input must contain exactly one 20 ms mono 48 kHz frame. +func (e *Encoder) Encode(in []byte, out []byte) (int, error) { + if len(in)%2 != 0 { + return 0, fmt.Errorf("%w: s16le length %d not a multiple of 2", errInvalidInputLength, len(in)) + } + + expectedSamples := e.frameSampleCount() * e.channels + if len(in)/2 != expectedSamples { + return 0, fmt.Errorf("%w: got %d samples, want %d", errInvalidFrameSize, len(in)/2, expectedSamples) + } + + pcm := make([]float32, len(in)/2) + for i := range pcm { + sample := int16(binary.LittleEndian.Uint16(in[i*2:])) //nolint:gosec // G115: little-endian s16 round-trip. + pcm[i] = float32(sample) / 32768 + } + + return e.EncodeFloat32(pcm, out) +} + +// EncodeFloat32 encodes float PCM into a single Opus packet. +// +// The input must contain exactly one 20 ms mono 48 kHz frame. +func (e *Encoder) EncodeFloat32(in []float32, out []byte) (int, error) { + if e.sampleRate != celtSampleRate { + return 0, errInvalidSampleRate + } + + if e.channels != 1 { + return 0, errInvalidChannelCount + } + frameSamples := e.frameSampleCount() + if len(in) != frameSamples*e.channels { + return 0, fmt.Errorf("%w: got %d samples, want %d", errInvalidFrameSize, len(in), frameSamples*e.channels) + } + + frameBytes := e.frameBytes() + if frameBytes <= 0 || frameBytes > maxOpusFrameSize { + return 0, fmt.Errorf("%w: %d", errInvalidFrameByteBudget, frameBytes) + } + if len(out) < frameBytes+1 { + return 0, errOutBufferTooSmall + } + + payload, err := e.celtEncoder.EncodeFrame(in, frameBytes, 0, e.celtEncoder.Mode().BandCount()) + if err != nil { + return 0, err + } + if len(payload) > maxOpusFrameSize { + return 0, fmt.Errorf("%w: frame size %d exceeds %d", errMalformedPacket, len(payload), maxOpusFrameSize) + } + if len(out) < len(payload)+1 { + return 0, errOutBufferTooSmall + } + + out[0] = byte(e.tocHeader()) + copy(out[1:], payload) + + return 1 + len(payload), nil +} + +func (e *Encoder) tocHeader() tableOfContentsHeader { + header := byte(celtOnlyFullband20msConfig << 3) + header |= byte(frameCodeOneFrame) + + return tableOfContentsHeader(header) +} + +func (e *Encoder) frameBytes() int { + return int(int64(e.bitrate) * frame20msNS / 1000000000 / 8) +} + +func (e *Encoder) frameSampleCount() int { + return int(int64(celtSampleRate) * frame20msNS / 1000000000) +} diff --git a/encoder_test.go b/encoder_test.go new file mode 100644 index 0000000..cdf4f00 --- /dev/null +++ b/encoder_test.go @@ -0,0 +1,181 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package opus + +import ( + "encoding/binary" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const encoderTestFrameSampleCount = 960 + +func TestNewEncoder(t *testing.T) { + encoder, err := NewEncoder() + require.NoError(t, err) + + assert.Equal(t, 48000, encoder.sampleRate) + assert.Equal(t, 1, encoder.channels) + assert.Equal(t, defaultBitrate, encoder.bitrate) + + _, err = NewEncoder(WithSampleRate(16000)) + assert.ErrorIs(t, err, errInvalidSampleRate) + + _, err = NewEncoder(WithChannels(2)) + assert.ErrorIs(t, err, errInvalidChannelCount) +} + +func TestNewEncoderOptions(t *testing.T) { + encoder, err := NewEncoder( + WithSampleRate(48000), + WithChannels(1), + WithBitrate(64000), + WithComplexity(5), + ) + require.NoError(t, err) + + assert.Equal(t, 64000, encoder.bitrate) + assert.Equal(t, 5, encoder.complexity) + + _, err = NewEncoder(WithBitrate(1000)) + assert.ErrorIs(t, err, errBitrateOutOfRange) + + _, err = NewEncoder(WithComplexity(11)) + assert.ErrorIs(t, err, errInvalidComplexity) +} + +func TestEncodeFloat32RoundTrip(t *testing.T) { + encoder, err := NewEncoder() + require.NoError(t, err) + + decoder, err := NewDecoderWithOutput(48000, 1) + require.NoError(t, err) + + pcm := testEncoderSineFloat32() + packet := make([]byte, 256) + + n, err := encoder.EncodeFloat32(pcm, packet) + require.NoError(t, err) + require.Positive(t, n) + + assert.Equal(t, byte(celtOnlyFullband20msConfig<<3)|byte(frameCodeOneFrame), packet[0]) + + out := make([]float32, encoderTestFrameSampleCount) + bandwidth, isStereo, err := decoder.DecodeFloat32(packet[:n], out) + require.NoError(t, err) + + assert.Equal(t, BandwidthFullband, bandwidth) + assert.False(t, isStereo) + assert.Greater(t, vectorEnergyFloat32(out), 1e-6) + + // Output amplitude must stay in a sane range. Opus is perceptual so some + // overshoot above the input peak is expected, but a sample reaching ±2 + // indicates a gain or scaling defect in the analysis/synthesis pair. + for i, sample := range out { + require.InDelta(t, 0, sample, 2.0, "decoded sample %d out of sane amplitude range", i) + } +} + +func TestEncodeS16LERoundTrip(t *testing.T) { + encoder, err := NewEncoder() + require.NoError(t, err) + + decoder, err := NewDecoderWithOutput(48000, 1) + require.NoError(t, err) + + pcm := testEncoderSineS16LE() + packet := make([]byte, 256) + + n, err := encoder.Encode(pcm, packet) + require.NoError(t, err) + require.Positive(t, n) + + out := make([]float32, encoderTestFrameSampleCount) + _, _, err = decoder.DecodeFloat32(packet[:n], out) + require.NoError(t, err) + + assert.Greater(t, vectorEnergyFloat32(out), 1e-6) +} + +func TestEncodeRejectsInvalidS16LEInputLength(t *testing.T) { + encoder, err := NewEncoder() + require.NoError(t, err) + + _, err = encoder.Encode(make([]byte, 3), make([]byte, 64)) + assert.ErrorIs(t, err, errInvalidInputLength) +} + +func TestEncodeFloat32RejectsInvalidFrameSize(t *testing.T) { + encoder, err := NewEncoder() + require.NoError(t, err) + + _, err = encoder.EncodeFloat32(make([]float32, encoderTestFrameSampleCount-1), make([]byte, 64)) + assert.ErrorIs(t, err, errInvalidFrameSize) +} + +func TestEncodeRejectsSmallOutputBuffer(t *testing.T) { + encoder, err := NewEncoder() + require.NoError(t, err) + + pcm := testEncoderSineFloat32() + packet := make([]byte, 8) + + _, err = encoder.EncodeFloat32(pcm, packet) + assert.ErrorIs(t, err, errOutBufferTooSmall) +} + +func TestSetBitrate(t *testing.T) { + encoder, err := NewEncoder() + require.NoError(t, err) + + require.NoError(t, encoder.SetBitrate(32000)) + assert.Equal(t, 32000, encoder.bitrate) + + assert.ErrorIs(t, encoder.SetBitrate(1000), errBitrateOutOfRange) + assert.ErrorIs(t, encoder.SetBitrate(999999), errBitrateOutOfRange) +} + +func TestSetComplexity(t *testing.T) { + encoder, err := NewEncoder() + require.NoError(t, err) + + require.NoError(t, encoder.SetComplexity(10)) + assert.Equal(t, 10, encoder.complexity) + + assert.ErrorIs(t, encoder.SetComplexity(-1), errInvalidComplexity) + assert.ErrorIs(t, encoder.SetComplexity(11), errInvalidComplexity) +} + +func testEncoderSineFloat32() []float32 { + pcm := make([]float32, encoderTestFrameSampleCount) + for i := range pcm { + pcm[i] = float32(math.Sin(2 * math.Pi * 440 * float64(i) / 48000)) + } + + return pcm +} + +func testEncoderSineS16LE() []byte { + pcm := make([]byte, encoderTestFrameSampleCount*2) + for i := range encoderTestFrameSampleCount { + // math.Round breaks gosec's constant-folding so the int16() conversion + // is analyzed against a runtime float, not a constant expression. + sample := int16(math.Round(math.Sin(2*math.Pi*440*float64(i)/48000) * 16000)) + binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample)) //nolint:gosec // G115: little-endian s16 round-trip. + } + + return pcm +} + +func vectorEnergyFloat32(x []float32) float64 { + var e float64 + for _, v := range x { + e += float64(v * v) + } + + return math.Sqrt(e) +} diff --git a/errors.go b/errors.go index fe2c0f7..e75ad4e 100644 --- a/errors.go +++ b/errors.go @@ -19,4 +19,14 @@ var ( errOutBufferTooSmall = errors.New("out isn't large enough") errMalformedPacket = errors.New("malformed packet") + + errBitrateOutOfRange = errors.New("bitrate out of range") + + errInvalidComplexity = errors.New("invalid complexity") + + errInvalidInputLength = errors.New("invalid input length") + + errInvalidFrameSize = errors.New("invalid frame size") + + errInvalidFrameByteBudget = errors.New("invalid frame byte budget") ) diff --git a/internal/celt/mdct.go b/internal/celt/mdct.go index 60e0cce..cfa17fa 100644 --- a/internal/celt/mdct.go +++ b/internal/celt/mdct.go @@ -68,11 +68,9 @@ func forwardMDCT(time []float32) []float32 { for i := 0; i < overlap/2; i++ { windowValue := celtWindow(i) - if windowValue == 0 { - continue - } - // inverseMDCT: out[i] = -celtWindow(i) * deshuffled[overlap/2-1-i] - deshuffled[overlap/2-1-i] = -time[i] / windowValue + // Apply analysis window: multiply (not divide) to mirror the synthesis + // windowing in inverseMDCT for TDAC. + deshuffled[overlap/2-1-i] = -time[i] * windowValue } for i := overlap / 2; i < n4; i++ { @@ -85,11 +83,7 @@ func forwardMDCT(time []float32) []float32 { for i := 0; i < overlap/2; i++ { windowValue := celtWindow(i) - if windowValue == 0 { - continue - } - // inverseMDCT: out[n2+overlap-1-i] = celtWindow(i) * deshuffled[n2-overlap/2+i] - deshuffled[n2-overlap/2+i] = time[n2+overlap-1-i] / windowValue + deshuffled[n2-overlap/2+i] = time[n2+overlap-1-i] * windowValue } postRotated := make([]float32, n2) diff --git a/internal/celt/mdct_test.go b/internal/celt/mdct_test.go index 5d4b575..fb93c17 100644 --- a/internal/celt/mdct_test.go +++ b/internal/celt/mdct_test.go @@ -24,7 +24,15 @@ func TestForwardComplexDFTRoundTrip(t *testing.T) { assertComplexSliceClose(t, input, recovered, 1e-4) } -func TestForwardMDCTInvertsInverseMDCT(t *testing.T) { +// TestForwardMDCTPreservesPlainRegion verifies the analysis-side MDCT +// produces a frequency representation that, when inverted, recovers the +// unwindowed (middle) portion of the input time signal. +// +// This is the practical TDAC property: the plain region between the two +// windowed overlap halves must roundtrip exactly. The windowed edges are not +// reconstructed by a single frame — overlap-add with neighboring frames +// completes them in the full pipeline. +func TestForwardMDCTPreservesPlainRegion(t *testing.T) { testCases := []int{ shortBlockSampleCount, shortBlockSampleCount << 1, @@ -33,14 +41,24 @@ func TestForwardMDCTInvertsInverseMDCT(t *testing.T) { } for _, frameSampleCount := range testCases { t.Run(frameSampleCountName(frameSampleCount), func(t *testing.T) { - freq := make([]float32, frameSampleCount) - for i := range freq { - freq[i] = float32(math.Sin(0.013*float64(i)) + 0.25*math.Cos(0.037*float64(i))) + time := make([]float32, frameSampleCount+shortBlockSampleCount) + for i := range time { + time[i] = float32(math.Sin(2*math.Pi*7*float64(i)/float64(frameSampleCount)) + + 0.25*math.Cos(2*math.Pi*23*float64(i)/float64(frameSampleCount))) } - time := inverseMDCT(freq) - recovered := forwardMDCT(time) - require.Len(t, recovered, len(freq)) - assertFloat32SliceClose(t, freq, recovered, 1e-3) + freq := forwardMDCT(time) + require.Len(t, freq, frameSampleCount) + recovered := inverseMDCT(freq) + require.Len(t, recovered, frameSampleCount+shortBlockSampleCount) + // The plain region (where neither forward nor inverse apply the + // window) must roundtrip exactly. The forward MDCT copies + // time[overlap..n+overlap-overlap] verbatim; sample edges within + // overlap/2 of either boundary go through windowing and are only + // reconstructed via TDAC overlap-add with neighboring frames. + plainStart := shortBlockSampleCount + plainEnd := frameSampleCount + assertFloat32SliceClose(t, time[plainStart:plainEnd], + recovered[plainStart:plainEnd], 1e-3) }) } }