Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,10 @@ func WithSampleRate(rate int) EncoderOption {
}
}

// WithChannels sets the channel count. The current encoder only supports
// mono (1 channel); stereo is planned in a follow-up PR.
// WithChannels sets the channel count (1 for mono, 2 for stereo).
func WithChannels(channels int) EncoderOption {
return func(e *Encoder) error {
if channels != 1 {
if channels < 1 || channels > 2 {
return errInvalidChannelCount
}
e.channels = channels
Expand Down Expand Up @@ -95,8 +94,8 @@ func WithComplexity(complexity int) EncoderOption {
// NewEncoder creates a new Opus encoder with the supplied options.
//
// Defaults: 48 kHz, mono, 24 kbit/s, complexity 0. Pass options to override
// any of these. The current API surface only supports 48 kHz mono 20 ms
// CELT-only packets; stereo, transient detection, and SILK encoding will land
// any of these. The current implementation supports 48 kHz, 1 or 2 channels,
// 20 ms CELT-only packets. Transient detection and SILK encoding will land
// in follow-up PRs.
func NewEncoder(opts ...EncoderOption) (*Encoder, error) {
encoder := &Encoder{
Expand Down Expand Up @@ -151,20 +150,19 @@ func (e *Encoder) Encode(in []byte, out []byte) (int, error) {

// EncodeFloat32 encodes float PCM into a single Opus packet.
//
// The input must contain exactly one 20 ms mono 48 kHz frame.
// The input must contain one 20 ms 48 kHz frame.
func (e *Encoder) EncodeFloat32(in []float32, out []byte) (int, error) {
if e.sampleRate != celtSampleRate {
return 0, errInvalidSampleRate
}

if e.channels != 1 {
return 0, errInvalidChannelCount
}
frameSamples := e.frameSampleCount()
if len(in) != frameSamples*e.channels {
return 0, fmt.Errorf("%w: got %d samples, want %d", errInvalidFrameSize, len(in), frameSamples*e.channels)
}

channels := splitChannels(in, e.channels, frameSamples)

frameBytes := e.frameBytes()
if frameBytes <= 0 || frameBytes > maxOpusFrameSize {
return 0, fmt.Errorf("%w: %d", errInvalidFrameByteBudget, frameBytes)
Expand All @@ -173,7 +171,7 @@ func (e *Encoder) EncodeFloat32(in []float32, out []byte) (int, error) {
return 0, errOutBufferTooSmall
}

payload, err := e.celtEncoder.EncodeFrame(in, frameBytes, 0, e.celtEncoder.Mode().BandCount())
payload, err := e.celtEncoder.EncodeFrame(channels, frameBytes, 0, e.celtEncoder.Mode().BandCount())
if err != nil {
return 0, err
}
Expand All @@ -193,10 +191,33 @@ func (e *Encoder) EncodeFloat32(in []float32, out []byte) (int, error) {
func (e *Encoder) tocHeader() tableOfContentsHeader {
header := byte(celtOnlyFullband20msConfig << 3)
header |= byte(frameCodeOneFrame)
if e.channels == 2 {
header |= 1 << 2
}

return tableOfContentsHeader(header)
}

// splitChannels splits interleaved PCM into per-channel slices.
// For mono, it returns the input directly without allocation.
func splitChannels(in []float32, numChannels, frameSamples int) [][]float32 {
ch := make([][]float32, numChannels)
if numChannels == 1 {
ch[0] = in

return ch
}

for c := range numChannels {
ch[c] = make([]float32, frameSamples)
for i := range frameSamples {
ch[c][i] = in[i*numChannels+c]
}
}

return ch
}

func (e *Encoder) frameBytes() int {
return int(int64(e.bitrate) * frame20msNS / 1000000000 / 8)
}
Expand Down
134 changes: 132 additions & 2 deletions encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ func TestNewEncoder(t *testing.T) {
_, err = NewEncoder(WithSampleRate(16000))
assert.ErrorIs(t, err, errInvalidSampleRate)

_, err = NewEncoder(WithChannels(2))
assert.ErrorIs(t, err, errInvalidChannelCount)
encoder, err = NewEncoder(WithChannels(2))
require.NoError(t, err)
assert.Equal(t, 2, encoder.channels)
}

func TestNewEncoderOptions(t *testing.T) {
Expand Down Expand Up @@ -101,6 +102,98 @@ func TestEncodeS16LERoundTrip(t *testing.T) {
assert.Greater(t, vectorEnergyFloat32(out), 1e-6)
}

func TestEncodeFloat32StereoRoundTrip(t *testing.T) {
encoder, err := NewEncoder(WithChannels(2))
require.NoError(t, err)

decoder, err := NewDecoderWithOutput(48000, 2)
require.NoError(t, err)

pcm := testEncoderStereoSineFloat32()
packet := make([]byte, 256)

n, err := encoder.EncodeFloat32(pcm, packet)
require.NoError(t, err)
require.Positive(t, n)

assert.Equal(t, byte(celtOnlyFullband20msConfig<<3)|byte(frameCodeOneFrame)|(1<<2), packet[0])

out := make([]float32, encoderTestFrameSampleCount*2)
bandwidth, isStereo, err := decoder.DecodeFloat32(packet[:n], out)
require.NoError(t, err)

assert.Equal(t, BandwidthFullband, bandwidth)
assert.True(t, isStereo)
assert.Greater(t, vectorEnergyFloat32(out), 1e-6)

L := make([]float32, encoderTestFrameSampleCount)
R := make([]float32, encoderTestFrameSampleCount)
for i := range encoderTestFrameSampleCount {
L[i] = out[i*2]
R[i] = out[i*2+1]
}
L440 := freqEnergy(L, 440)
L660 := freqEnergy(L, 660)
R440 := freqEnergy(R, 440)
R660 := freqEnergy(R, 660)
assert.Greater(t, L440, L660*1.5, "L channel: 440 Hz should dominate over 660 Hz")
assert.Greater(t, R660, R440*1.5, "R channel: 660 Hz should dominate over 440 Hz")
}

func TestEncodeS16LEStereoRoundTrip(t *testing.T) {
encoder, err := NewEncoder(WithChannels(2))
require.NoError(t, err)

decoder, err := NewDecoderWithOutput(48000, 2)
require.NoError(t, err)

pcm := testEncoderStereoSineS16LE()
packet := make([]byte, 256)

n, err := encoder.Encode(pcm, packet)
require.NoError(t, err)
require.Positive(t, n)

out := make([]float32, encoderTestFrameSampleCount*2)
_, isStereo, err := decoder.DecodeFloat32(packet[:n], out)
require.NoError(t, err)

assert.True(t, isStereo)
assert.Greater(t, vectorEnergyFloat32(out), 1e-6)
}

func TestStereoMultiFramePersistence(t *testing.T) {
encoder, err := NewEncoder(WithChannels(2))
require.NoError(t, err)

decoder, err := NewDecoderWithOutput(48000, 2)
require.NoError(t, err)

pcm := testEncoderStereoSineFloat32()
packet := make([]byte, 256)
out := make([]float32, encoderTestFrameSampleCount*2)

const frames = 10
energies := make([]float64, frames)
for i := range frames {
n, encErr := encoder.EncodeFloat32(pcm, packet)
require.NoError(t, encErr, "frame %d encode failed", i)
require.Positive(t, n)

_, _, decErr := decoder.DecodeFloat32(packet[:n], out)
require.NoError(t, decErr, "frame %d decode failed", i)

energies[i] = vectorEnergyFloat32(out)
assert.Greater(t, energies[i], 1e-6, "frame %d should have non-zero energy", i)
}

for i := 1; i < frames; i++ {
ratio := energies[i] / energies[0]
assert.InDelta(t, 1.0, ratio, 0.75,
"frame %d energy ratio %.3f deviates too far from frame 0", i, ratio)
}
}

func TestEncodeRejectsInvalidS16LEInputLength(t *testing.T) {
encoder, err := NewEncoder()
require.NoError(t, err)
Expand Down Expand Up @@ -159,6 +252,18 @@ func testEncoderSineFloat32() []float32 {
return pcm
}

func testEncoderStereoSineFloat32() []float32 {
pcm := make([]float32, encoderTestFrameSampleCount*2)
for i := range encoderTestFrameSampleCount {
left := float32(math.Sin(2 * math.Pi * 440 * float64(i) / 48000))
right := float32(math.Sin(2 * math.Pi * 660 * float64(i) / 48000))
pcm[i*2] = left
pcm[i*2+1] = right
}

return pcm
}

func testEncoderSineS16LE() []byte {
pcm := make([]byte, encoderTestFrameSampleCount*2)
for i := range encoderTestFrameSampleCount {
Expand All @@ -171,6 +276,18 @@ func testEncoderSineS16LE() []byte {
return pcm
}

func testEncoderStereoSineS16LE() []byte {
pcm := make([]byte, encoderTestFrameSampleCount*4) // 2 channels × 2 bytes each
for i := range encoderTestFrameSampleCount {
left := int16(math.Round(math.Sin(2*math.Pi*440*float64(i)/48000) * 16000))
right := int16(math.Round(math.Sin(2*math.Pi*660*float64(i)/48000) * 16000))
binary.LittleEndian.PutUint16(pcm[i*4:], uint16(left)) //nolint:gosec // G115
binary.LittleEndian.PutUint16(pcm[i*4+2:], uint16(right)) //nolint:gosec // G115
}

return pcm
}

func vectorEnergyFloat32(x []float32) float64 {
var e float64
for _, v := range x {
Expand All @@ -179,3 +296,16 @@ func vectorEnergyFloat32(x []float32) float64 {

return math.Sqrt(e)
}

// freqEnergy returns the DFT magnitude at freq Hz over a 48 kHz signal.
// It is phase-invariant so it survives the CELT analysis/synthesis delay.
func freqEnergy(samples []float32, freq float64) float64 {
var re, im float64
for i, s := range samples {
angle := 2 * math.Pi * freq * float64(i) / 48000
re += float64(s) * math.Cos(angle)
im += float64(s) * math.Sin(angle)
}

return math.Sqrt(re*re+im*im) / float64(len(samples))
}
75 changes: 35 additions & 40 deletions internal/celt/analysis.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,68 +10,63 @@ import (
const preemphasisCoefficient = 0.85000610

type analysisState struct {
prevPCM []float32
preemphasisMem float32
prevPCM [2][]float32
preemphasisMem [2]float32
}

type analysisResult struct {
info frameSideInfo
preemphasized []float32
mdct []float32
logBandAmp [maxBands]float32
info frameSideInfo
mdct [2][]float32
logBandAmp [2][maxBands]float32
}

func newAnalysisState() analysisState {
return analysisState{
prevPCM: make([]float32, shortBlockSampleCount),
prevPCM: [2][]float32{
make([]float32, shortBlockSampleCount),
make([]float32, shortBlockSampleCount),
},
}
}

// analyzeFrame prepares the mono CELT encoder input: applies pre-emphasis,
// extends the frame with the previous overlap for the MDCT window, runs the
// forward MDCT, and returns per-band log amplitude for coarse energy coding.
func analyzeFrame(
mode *Mode,
frame []float32,
startBand int,
endBand int,
state *analysisState,
) (analysisResult, error) {
lm, err := mode.LMForFrameSampleCount(len(frame))
// analyzeFrame applies pre-emphasis, builds the MDCT overlap window, runs the
// forward MDCT, and returns per-band log amplitude for each input channel.
func analyzeFrame(mode *Mode, pcm [][]float32, startBand, endBand int, state *analysisState) (analysisResult, error) {
lm, err := mode.LMForFrameSampleCount(len(pcm[0]))
if err != nil {
return analysisResult{}, err
}

result := analysisResult{
res := analysisResult{
info: frameSideInfo{
lm: lm,
startBand: startBand,
endBand: endBand,
channelCount: 1,
transient: false,
shortBlockCount: 0,
intraEnergy: false,
spread: defaultSpreadDecision,
allocationTrim: defaultAllocationTrim,
lm: lm,
startBand: startBand,
endBand: endBand,
channelCount: len(pcm),
transient: false,
spread: defaultSpreadDecision,
allocationTrim: defaultAllocationTrim,
},
preemphasized: make([]float32, len(frame)),
}

applyPreemphasis(frame, result.preemphasized, &state.preemphasisMem)
for ch := range pcm {
pre := make([]float32, len(pcm[ch]))
applyPreemphasis(pcm[ch], pre, &state.preemphasisMem[ch])

mdctInput := make([]float32, shortBlockSampleCount+len(frame))
copy(mdctInput, state.prevPCM)
copy(mdctInput[shortBlockSampleCount:], result.preemphasized)
mdctInput := make([]float32, shortBlockSampleCount+len(pre))
copy(mdctInput, state.prevPCM[ch])
copy(mdctInput[shortBlockSampleCount:], pre)

result.mdct = forwardMDCT(mdctInput)
if result.mdct == nil {
return analysisResult{}, errInvalidFrameSize
}
res.mdct[ch] = forwardMDCT(mdctInput)
if res.mdct[ch] == nil {
return analysisResult{}, errInvalidFrameSize
}

result.logBandAmp = computeBandLogAmp(result.mdct, lm, startBand, endBand)
copy(state.prevPCM, result.preemphasized[len(result.preemphasized)-shortBlockSampleCount:])
res.logBandAmp[ch] = computeBandLogAmp(res.mdct[ch], lm, startBand, endBand)
copy(state.prevPCM[ch], pre[len(pre)-shortBlockSampleCount:])
}

return result, nil
return res, nil
}

func applyPreemphasis(in []float32, out []float32, mem *float32) {
Expand Down
Loading
Loading