diff --git a/internal/celt/synthesis.go b/internal/celt/synthesis.go index 69b93dc..0016a03 100644 --- a/internal/celt/synthesis.go +++ b/internal/celt/synthesis.go @@ -57,8 +57,14 @@ type inverseTransformPlan struct { sine float32 rotateCos []float32 rotateSinQuarter []float32 - fftCos []float32 - fftSin []float32 + fftTwiddles []complex32 + fftBitrev []int + fftFactors []fftFactor +} + +type fftFactor struct { + radix int + size int } var inverseTransformPlans = [maxLM + 1]inverseTransformPlan{ //nolint:gochecknoglobals @@ -549,83 +555,196 @@ func inverseComplexDFTInto(in []complex32, out []complex32, work []complex32, pl return } - inverseComplexFFTRecursive(in, 1, out, work, len(in), plan) + for i, value := range in { + out[plan.fftBitrev[i]] = value + } + + for stage := len(plan.fftFactors) - 1; stage >= 0; stage-- { + factor := plan.fftFactors[stage] + fstride := plan.n4 / (factor.radix * factor.size) + switch factor.radix { + case 2: + inverseFFTButterfly2(out, factor.size, fstride, plan) + case 3: + inverseFFTButterfly3(out, factor.size, fstride, plan) + case 4: + inverseFFTButterfly4(out, factor.size, fstride, plan) + case 5: + inverseFFTButterfly5(out, factor.size, fstride, plan) + default: + inverseFFTButterflyGeneric(out, work, factor.radix, factor.size, fstride, plan) + } + } } -func inverseComplexFFTRecursive( - in []complex32, - stride int, - out []complex32, - work []complex32, - n int, - plan *inverseTransformPlan, -) { - if n <= 1 { - if n == 1 { - out[0] = in[0] +func inverseFFTButterfly2(out []complex32, subtransformLength, fstride int, plan *inverseTransformPlan) { + groupSize := 2 * subtransformLength + for group := 0; group < fstride; group++ { + base := group * groupSize + for frequencyIndex := 0; frequencyIndex < subtransformLength; frequencyIndex++ { + out0 := out[base+frequencyIndex] + out1 := out[base+subtransformLength+frequencyIndex] + twiddle := plan.fftTwiddles[frequencyIndex*fstride] + out1 = multiplyComplex32(out1, twiddle) + out[base+frequencyIndex] = addComplex32(out0, out1) + out[base+subtransformLength+frequencyIndex] = subtractComplex32(out0, out1) } + } +} - return +func inverseFFTButterfly3(out []complex32, subtransformLength, fstride int, plan *inverseTransformPlan) { + groupSize := 3 * subtransformLength + epi3 := plan.fftTwiddles[fstride*subtransformLength] + for group := 0; group < fstride; group++ { + base := group * groupSize + for frequencyIndex := 0; frequencyIndex < subtransformLength; frequencyIndex++ { + out0 := out[base+frequencyIndex] + out1 := multiplyComplex32(out[base+subtransformLength+frequencyIndex], plan.fftTwiddles[frequencyIndex*fstride]) + out2 := multiplyComplex32(out[base+2*subtransformLength+frequencyIndex], plan.fftTwiddles[2*frequencyIndex*fstride]) + sum := addComplex32(out1, out2) + diff := multiplyComplex32ByFloat32(subtractComplex32(out1, out2), epi3.i) + mid := subtractComplex32(out0, multiplyComplex32ByFloat32(sum, 0.5)) + out[base+frequencyIndex] = addComplex32(out0, sum) + out[base+subtransformLength+frequencyIndex] = complex32{r: mid.r - diff.i, i: mid.i + diff.r} + out[base+2*subtransformLength+frequencyIndex] = complex32{r: mid.r + diff.i, i: mid.i - diff.r} + } } +} - radix := fftRadix(n) - subtransformLength := n / radix - for subtransform := range radix { - inverseComplexFFTRecursive( - in[subtransform*stride:], - stride*radix, - work[subtransform*subtransformLength:(subtransform+1)*subtransformLength], - out[subtransform*subtransformLength:(subtransform+1)*subtransformLength], - subtransformLength, - plan, - ) +func inverseFFTButterfly4(out []complex32, subtransformLength, fstride int, plan *inverseTransformPlan) { + groupSize := 4 * subtransformLength + for group := 0; group < fstride; group++ { + base := group * groupSize + for frequencyIndex := 0; frequencyIndex < subtransformLength; frequencyIndex++ { + out0 := out[base+frequencyIndex] + out1 := multiplyComplex32(out[base+subtransformLength+frequencyIndex], plan.fftTwiddles[frequencyIndex*fstride]) + out2 := multiplyComplex32(out[base+2*subtransformLength+frequencyIndex], plan.fftTwiddles[2*frequencyIndex*fstride]) + out3 := multiplyComplex32(out[base+3*subtransformLength+frequencyIndex], plan.fftTwiddles[3*frequencyIndex*fstride]) + evenDifference := subtractComplex32(out0, out2) + evenSum := addComplex32(out0, out2) + oddSum := addComplex32(out1, out3) + oddDifference := subtractComplex32(out1, out3) + out[base+frequencyIndex] = addComplex32(evenSum, oddSum) + out[base+2*subtransformLength+frequencyIndex] = subtractComplex32(evenSum, oddSum) + out[base+subtransformLength+frequencyIndex] = complex32{r: evenDifference.r - oddDifference.i, i: evenDifference.i + oddDifference.r} + out[base+3*subtransformLength+frequencyIndex] = complex32{r: evenDifference.r + oddDifference.i, i: evenDifference.i - oddDifference.r} + } } +} - for k := range subtransformLength { - for frequencyGroup := range radix { - sum := complex32{} - for subtransform := range radix { - value := work[subtransform*subtransformLength+k] - twiddle := plan.fftTwiddle(subtransform*(k+frequencyGroup*subtransformLength), n) - sum.r += value.r*twiddle.r - value.i*twiddle.i - sum.i += value.r*twiddle.i + value.i*twiddle.r +func inverseFFTButterfly5(out []complex32, subtransformLength, fstride int, plan *inverseTransformPlan) { + groupSize := 5 * subtransformLength + ya := plan.fftTwiddles[fstride*subtransformLength] + yb := plan.fftTwiddles[2*fstride*subtransformLength] + for group := 0; group < fstride; group++ { + base := group * groupSize + for frequencyIndex := 0; frequencyIndex < subtransformLength; frequencyIndex++ { + out0 := out[base+frequencyIndex] + out1 := multiplyComplex32(out[base+subtransformLength+frequencyIndex], plan.fftTwiddles[frequencyIndex*fstride]) + out2 := multiplyComplex32(out[base+2*subtransformLength+frequencyIndex], plan.fftTwiddles[2*frequencyIndex*fstride]) + out3 := multiplyComplex32(out[base+3*subtransformLength+frequencyIndex], plan.fftTwiddles[3*frequencyIndex*fstride]) + out4 := multiplyComplex32(out[base+4*subtransformLength+frequencyIndex], plan.fftTwiddles[4*frequencyIndex*fstride]) + sum14 := addComplex32(out1, out4) + diff14 := subtractComplex32(out1, out4) + sum23 := addComplex32(out2, out3) + diff23 := subtractComplex32(out2, out3) + out[base+frequencyIndex] = addComplex32(out0, addComplex32(sum14, sum23)) + + base14 := addComplex32( + out0, + addComplex32( + multiplyComplex32ByFloat32(sum14, ya.r), + multiplyComplex32ByFloat32(sum23, yb.r), + ), + ) + rotate14 := complex32{ + r: -diff14.i*ya.i - diff23.i*yb.i, + i: diff14.r*ya.i + diff23.r*yb.i, + } + out[base+subtransformLength+frequencyIndex] = addComplex32(base14, rotate14) + out[base+4*subtransformLength+frequencyIndex] = subtractComplex32(base14, rotate14) + + base23 := addComplex32( + out0, + addComplex32( + multiplyComplex32ByFloat32(sum14, yb.r), + multiplyComplex32ByFloat32(sum23, ya.r), + ), + ) + rotate23 := complex32{ + r: -diff14.i*yb.i + diff23.i*ya.i, + i: diff14.r*yb.i - diff23.r*ya.i, } - out[k+frequencyGroup*subtransformLength] = sum + out[base+2*subtransformLength+frequencyIndex] = addComplex32(base23, rotate23) + out[base+3*subtransformLength+frequencyIndex] = subtractComplex32(base23, rotate23) } } } -func fftRadix(n int) int { - switch { - case n%2 == 0: - return 2 - case n%3 == 0: - return 3 - case n%5 == 0: - return 5 - default: - return n +func inverseFFTButterflyGeneric( + out []complex32, + work []complex32, + radix int, + subtransformLength int, + fstride int, + plan *inverseTransformPlan, +) { + groupSize := radix * subtransformLength + for group := 0; group < fstride; group++ { + base := group * groupSize + for frequencyIndex := 0; frequencyIndex < subtransformLength; frequencyIndex++ { + for frequencyGroup := 0; frequencyGroup < radix; frequencyGroup++ { + sum := complex32{} + for subtransform := 0; subtransform < radix; subtransform++ { + value := out[base+subtransform*subtransformLength+frequencyIndex] + twiddle := plan.fftTwiddles[subtransform*(frequencyIndex+frequencyGroup*subtransformLength)*fstride%plan.n4] + sum.r += value.r*twiddle.r - value.i*twiddle.i + sum.i += value.r*twiddle.i + value.i*twiddle.r + } + work[frequencyGroup] = sum + } + for frequencyGroup := 0; frequencyGroup < radix; frequencyGroup++ { + out[base+frequencyGroup*subtransformLength+frequencyIndex] = work[frequencyGroup] + } + } } } +func addComplex32(a, b complex32) complex32 { + return complex32{r: a.r + b.r, i: a.i + b.i} +} + +func subtractComplex32(a, b complex32) complex32 { + return complex32{r: a.r - b.r, i: a.i - b.i} +} + +func multiplyComplex32(a, b complex32) complex32 { + return complex32{r: a.r*b.r - a.i*b.i, i: a.r*b.i + a.i*b.r} +} + +func multiplyComplex32ByFloat32(value complex32, factor float32) complex32 { + return complex32{r: value.r * factor, i: value.i * factor} +} + func newInverseTransformPlan(frameSampleCount int) inverseTransformPlan { n := 2 * frameSampleCount n4 := n >> 2 + fftFactors := newFFTFactors(n4) plan := inverseTransformPlan{ frameSampleCount: frameSampleCount, n4: n4, sine: float32(2 * math.Pi * 0.125 / float64(n)), rotateCos: make([]float32, n4), rotateSinQuarter: make([]float32, n4), - fftCos: make([]float32, n4), - fftSin: make([]float32, n4), + fftTwiddles: make([]complex32, n4), + fftBitrev: newFFTBitrev(n4, fftFactors), + fftFactors: fftFactors, } for i := range n4 { plan.rotateCos[i] = float32(math.Cos(2 * math.Pi * float64(i) / float64(n))) plan.rotateSinQuarter[i] = float32(math.Cos(2 * math.Pi * float64(n4-i) / float64(n))) angle := 2 * math.Pi * float64(i) / float64(n4) - plan.fftCos[i] = float32(math.Cos(angle)) - plan.fftSin[i] = float32(math.Sin(angle)) + plan.fftTwiddles[i] = complex32{r: float32(math.Cos(angle)), i: float32(math.Sin(angle))} } return plan @@ -643,14 +762,53 @@ func inverseTransformPlanForFrameSampleCount(frameSampleCount int) *inverseTrans return &plan } -func (p *inverseTransformPlan) fftTwiddle(index, transformSize int) complex32 { - index *= p.n4 / transformSize - index %= p.n4 - if index < 0 { - index += p.n4 +func newFFTFactors(n int) []fftFactor { + factors := make([]fftFactor, 0, 5) + radix := 4 + for { + for n%radix != 0 { + switch radix { + case 4: + radix = 2 + case 2: + radix = 3 + default: + radix += 2 + } + if radix*radix > n { + radix = n + } + } + n /= radix + factors = append(factors, fftFactor{radix: radix, size: n}) + if n == 1 { + return factors + } + } +} + +func newFFTBitrev(n int, factors []fftFactor) []int { + bitrev := make([]int, n) + fillFFTBitrev(bitrev, 0, 1, 0, factors, 0) + + return bitrev +} + +func fillFFTBitrev(bitrev []int, outputIndex, fstride, fftOutput int, factors []fftFactor, factorIndex int) { + factor := factors[factorIndex] + if factor.size == 1 { + for j := 0; j < factor.radix; j++ { + bitrev[outputIndex+j*fstride] = fftOutput + j + } + + return } - return complex32{r: p.fftCos[index], i: p.fftSin[index]} + for j := 0; j < factor.radix; j++ { + fillFFTBitrev(bitrev, outputIndex, fstride*factor.radix, fftOutput, factors, factorIndex+1) + outputIndex += fstride + fftOutput += factor.size + } } func celtWindow(i int) float32 {