Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,29 @@ julia> fft(rand(Double64, 2))
2-element Vector{Complex{Double64}}:
0.4026739024263829 + 0.0im
0.3969515892883767 + 0.0im
```
```

## Usage for low-precision FFTs

```julia
julia> using GenericFFT, BFloat16s

julia> fs = 1000.0
julia> t = 0:1/fs:1-1/fs
julia> f1, f2 = 50.0, 120.0

julia> T = Float16
julia> # see: https://www.mathworks.com/help/matlab/ref/fft.html
julia> x = T.(0.7*sin.(2π * f1 * t) .+ 0.5 * sin.(2π * f2 * t)) .+ T(0.8)
julia> X = fft(x)
julia> println("Max round-trip error: ", maximum(abs.(x - real(ifft(X)))))

julia> T = BFloat16
julia> x = T.(0.7*sin.(2π * f1 * t) .+ 0.5 * sin.(2π * f2 * t)) .+ T(0.8)
julia> X = fft(x)
julia> println("Max round-trip error: ", maximum(abs.(x - real(ifft(X)))))
```


## History

Expand Down
116 changes: 89 additions & 27 deletions src/fft.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
const AbstractFloats = Union{AbstractFloat,Complex{T} where T<:AbstractFloat}

# We use these type definitions for clarity
const RealFloats = T where T<:AbstractFloat
const ComplexFloats = Complex{T} where T<:AbstractFloat

const AbstractFloats = Union{RealFloats, ComplexFloats}

# The following implements Bluestein's algorithm, following http://www.dsprelated.com/dspbooks/mdft/Bluestein_s_FFT_Algorithm.html
# To add more types, add them in the union of the function's signature.

function generic_fft!(x::AbstractVector{Complex{T}}) where {T<:AbstractFloat}
if ispow2(length(x))
Expand All @@ -23,7 +20,7 @@ function generic_fft!(x::AbstractVector{Complex{T}}, region::Integer) where {T<:
end

function _generic_fft_first_dim!(x, Ipost)
Threads.@threads for I in Ipost
for I in Ipost
generic_fft!(@view x[:, I])
end
x
Expand Down Expand Up @@ -81,18 +78,24 @@ function generic_fft!(x)
end


generic_fft(x, region) = generic_fft!(copy(x), region)
# generic_fft(x, region) = generic_fft!(copy(complex(x)), region)
# generic_fft(x) = generic_fft!(copy(complex(x)))

copycomplex(A::AbstractArray{<:Complex}) = copy(A)
copycomplex(A::AbstractArray{<:Real}) = complex(A)
generic_fft(x, region) = generic_fft!(copycomplex(x), region)
generic_fft(x) = generic_fft!(copycomplex(x))

generic_fft(x) = generic_fft!(copy(x))

function generic_fft(x::AbstractVector{T}) where T<:AbstractFloats
n = length(x)
ispow2(n) && return generic_fft_pow2(x)
ks = range(zero(real(T)),stop=n-one(real(T)),length=n)
Wks = @. cispi(-T(ks^2/n))
S = promote_type(real(T), Float64)
ks = range(zero(S), stop=S(n)-one(S), length=n)
Wks = Complex{real(T)}.(cispi.(-ks.^2 ./ S(n))) # always Complex
Wksrev = @view Wks[reverse(eachindex(Wks))]
xq, wq = x.*Wks, conj!([cispi(-T(n)); Wksrev; @view Wks[2:end]])
return Wks.* @view _conv!(xq,wq)[n+1:2n]
xq, wq = complex(x).*Wks, conj!([Complex{real(T)}(cispi(-S(n))); Wksrev; @view Wks[2:end]])
return Wks .* @view _conv!(xq,wq)[n+1:2n]
end

generic_bfft(x::AbstractArray{T, N}, region) where {T <: AbstractFloats, N} = conj!(generic_fft(conj(x), region))
Expand All @@ -105,27 +108,78 @@ generic_ifft(x::AbstractArray{T, N}, region) where {T<:AbstractFloats, N} = ldiv
generic_ifft!(x::AbstractArray{T, N}, region) where {T<:AbstractFloats, N} = ldiv!(T(_regionscale(x, region)), conj!(generic_fft!(conj!(x), region)))

generic_rfft(v::AbstractVector{T}, region) where T<:AbstractFloats = generic_fft(v, region)[1:div(length(v),2)+1]

function generic_rfft(x::AbstractArray{T, N}, region) where {T<:AbstractFloats, N}
d = first(region)
if length(region) > 1
return generic_fft(generic_rfft(x, d), region[2:end])
end

nout = size(x, d) ÷ 2 + 1
sz = collect(size(x))
sz[d] = nout
out = similar(x, Complex{real(T)}, tuple(sz...))

# CartesianIndices enables iterating over slices in arbitrary dimensions
Rpre = CartesianIndices(size(x)[1:d-1])
Comment thread
dlfivefifty marked this conversation as resolved.
Rpost = CartesianIndices(size(x)[d+1:end])

for Ipost in Rpost
for Ipre in Rpre
out[Ipre, :, Ipost] .= generic_rfft(view(x, Ipre, :, Ipost), 1)
end
end
return out
end

function generic_irfft(v::AbstractVector{T}, n::Integer, region) where T<:ComplexFloats
@assert length(v) == n>>1 + 1
r = Vector{T}(undef, n)
r[1:length(v)]=v
r[length(v)+1:n]=reverse(conj(v[2:end])[1:n-length(v)])
real(generic_ifft(r, region))
return real(generic_ifft(r, region))
end

function generic_irfft(x::AbstractArray{T, N}, n::Integer, region) where {T<:ComplexFloats, N}
d = first(region)
if length(region) > 1
return generic_irfft(generic_ifft(x, region[2:end]), n, d)
end

sz = collect(size(x))
sz[d] = n
out = similar(x, real(T), tuple(sz...))

Rpre = CartesianIndices(size(x)[1:d-1])
Rpost = CartesianIndices(size(x)[d+1:end])

for Ipost in Rpost
for Ipre in Rpre
out[Ipre, :, Ipost] .= generic_irfft(view(x, Ipre, :, Ipost), n, 1)
end
end
return out
end

function generic_brfft(v::AbstractArray, n::Integer, region)
scale = n * _regionscale(v, region isa Integer ? () : region[2:end])
return generic_irfft(v, n, region) * scale
end
generic_brfft(v::AbstractArray, n::Integer, region) = generic_irfft(v, n, region)*n

function _conv!(u::AbstractVector{T}, v::AbstractVector{T}) where T<:AbstractFloats
nu = length(u)
nv = length(v)
n = nu + nv - 1
nu, nv = length(u), length(v)
n = nu + nv - 1
np2 = nextpow(2, n)
append!(u, zeros(T, np2-nu))
append!(v, zeros(T, np2-nv))
y = generic_ifft_pow2(generic_fft_pow2(u).*generic_fft_pow2(v))
#TODO This would not handle Dual/ComplexDual numbers correctly
y = T<:Real ? real(y[1:n]) : y[1:n]
S = promote_type(real(T), Float64)
uf = Complex{S}.(u)
vf = Complex{S}.(v)
y = generic_ifft_pow2(generic_fft_pow2(uf) .* generic_fft_pow2(vf))
y = T <: Real ? T.(real(y[1:n])) : T.(y[1:n])
end


# This is a Cooley-Tukey FFT algorithm inspired by many widely available algorithms including:
# c_radix2.c in the GNU Scientific Library and four1 in the Numerical Recipes in C.
# However, the trigonometric recurrence is improved for greater efficiency.
Expand Down Expand Up @@ -262,7 +316,7 @@ for P in (:DummyFFTPlan, :DummyiFFTPlan, :DummybFFTPlan, :DummyDCTPlan, :DummyiD
@eval begin
mutable struct $P{T,inplace,G} <: DummyPlan{T}
region::G # region (iterable) of dims that are transformed
pinv::DummyPlan{T}
pinv::Plan
$P{T,inplace,G}(region::G) where {T<:AbstractFloats, inplace, G} = new(region)
end
end
Expand All @@ -271,8 +325,8 @@ for P in (:DummyrFFTPlan, :DummyirFFTPlan, :DummybrFFTPlan)
@eval begin
mutable struct $P{T,inplace,G} <: DummyPlan{T}
n::Integer
region::G # region (iterable) of dims that are transformed
pinv::DummyPlan{T}
region::G
pinv::Plan
$P{T,inplace,G}(n::Integer, region::G) where {T<:AbstractFloats, inplace, G} = new(n, region)
end
end
Expand All @@ -287,8 +341,8 @@ for (Plan,iPlan) in ((:DummyFFTPlan,:DummyiFFTPlan),
end

# Specific for rfft, irfft and brfft:
plan_inv(p::DummyirFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyrFFTPlan{T,inplace,G}(p.n, p.region)
plan_inv(p::DummyrFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyirFFTPlan{T,inplace,G}(p.n, p.region)
plan_inv(p::DummyirFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyrFFTPlan{real(T),inplace,G}(p.n, p.region)
plan_inv(p::DummyrFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyirFFTPlan{Complex{T},inplace,G}(p.n, p.region)



Expand Down Expand Up @@ -331,6 +385,14 @@ end

plan_fft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{T,false,typeof(region)}(region)
plan_fft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{T,true,typeof(region)}(region)
plan_fft(x::StridedArray{T}, region; kws...) where {T <: RealFloats} =
T <: FFTW.fftwReal ? invoke(plan_fft, Tuple{AbstractArray{<:Real}, Any}, x, region; kws...) : DummyFFTPlan{Complex{T},false,typeof(region)}(region)
plan_fft!(x::StridedArray{T}, region; kws...) where {T <: RealFloats} =
T <: FFTW.fftwReal ? invoke(plan_fft!, Tuple{AbstractArray, Any}, x, region; kws...) : DummyFFTPlan{Complex{T},true,typeof(region)}(region)

# intercept fft(x) before AbstractFFTs gets a chance for any non-FFTW float type.
fft(x::StridedArray{T}) where {T<:AbstractFloats} = generic_fft(x)
fft(x::StridedArray{T}, region) where {T<:AbstractFloats} = generic_fft(x, region)

plan_bfft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{T,false,typeof(region)}(region)
plan_bfft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{T,true,typeof(region)}(region)
Expand All @@ -345,11 +407,11 @@ plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan
plan_idct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,false,typeof(region)}(region)
plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true,typeof(region)}(region)

plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{T,false,typeof(region)}(length(x), region)
plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{T,false,typeof(region)}(size(x, first(region)), region)
plan_brfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummybrFFTPlan{T,false,typeof(region)}(n, region)

# A plan for irfft is created in terms of a plan for brfft.
# plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{Complex{real(T)},false,typeof(region)}(n, region)
# Explicitly define plan_irfft to ensure correct scaling
plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{T,false,typeof(region)}(n, region)

# These don't exist for now:
# plan_rfft!(x::StridedArray{T}) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},true}()
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
Loading
Loading