Skip to content

Commit

Permalink
add type assertions to work around inference problem (fix JuliaLang#9772
Browse files Browse the repository at this point in the history
)
  • Loading branch information
stevengj committed May 29, 2015
1 parent 883bf9a commit 674f970
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 25 deletions.
45 changes: 24 additions & 21 deletions base/fft/FFTW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,21 +569,22 @@ fftwfloat{T<:fftwReal}(X::StridedArray{T}) = X
fftwfloat{T<:Real}(X::AbstractArray{T}) = copy!(Array(Float64, size(X)), X)
fftwfloat{T<:Complex}(X::AbstractArray{T}) = fftwcomplex(X)

for (f,direction) in ((:fft,:FORWARD), (:bfft,:BACKWARD))
for (f,direction) in ((:fft,FORWARD), (:bfft,BACKWARD))
plan_f = symbol("plan_",f)
plan_f! = symbol("plan_",f,"!")
idirection = -direction
@eval begin
function $plan_f{T<:fftwComplex}(X::StridedArray{T}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT)
cFFTWPlan(X, fakesimilar(flags, X, T), region,
$direction, flags,timelimit)
$direction, flags,timelimit)::cFFTWPlan{T,$direction,false}
end

function $plan_f!{T<:fftwComplex}(X::StridedArray{T}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT)
cFFTWPlan(X, X, region, $direction, flags, timelimit)
cFFTWPlan(X, X, region, $direction, flags, timelimit)::cFFTWPlan{T,$direction,true}
end
$plan_f{T<:fftwComplex}(X::StridedArray{T}; kws...) =
$plan_f(X, 1:ndims(X); kws...)
Expand All @@ -593,8 +594,8 @@ for (f,direction) in ((:fft,:FORWARD), (:bfft,:BACKWARD))
function plan_inv{T<:fftwComplex,inplace}(p::cFFTWPlan{T,$direction,inplace})
X = Array(T, p.sz)
Y = inplace ? X : fakesimilar(p.flags, X, T)
ScaledPlan(cFFTWPlan(X, Y, p.region, -$direction,
p.flags, NO_TIMELIMIT),
ScaledPlan(cFFTWPlan(X, Y, p.region, $idirection,
p.flags, NO_TIMELIMIT)::cFFTWPlan{T,$idirection,inplace},
normalization(X, p.region))
end
end
Expand All @@ -606,9 +607,9 @@ function A_mul_B!{T}(y::StridedArray{T}, p::cFFTWPlan{T}, x::StridedArray{T})
return y
end

function *{T,K}(p::cFFTWPlan{T,K,false}, x::StridedArray{T})
function *{T,K,N}(p::cFFTWPlan{T,K,false}, x::StridedArray{T,N})
assert_applicable(p, x)
y = Array(T, p.osz)
y = Array(T, p.osz)::Array{T,N}
unsafe_execute!(p, x, y)
return y
end
Expand All @@ -622,12 +623,13 @@ end
# rfft/brfft and planned variants. No in-place version for now.

for (Tr,Tc) in ((:Float32,:Complex64),(:Float64,:Complex128))
# Note: use $FORWARD and $BACKWARD below because of issue #9775
@eval begin
function plan_rfft(X::StridedArray{$Tr}, region;
flags::Integer=ESTIMATE, timelimit::Real=NO_TIMELIMIT)
osize = rfft_output_size(X, region)
Y = flags&ESTIMATE != 0 ? FakeArray($Tc,osize...) : Array($Tc,osize...)
rFFTWPlan(X, Y, region, flags, timelimit)
rFFTWPlan(X, Y, region, flags, timelimit)::rFFTWPlan{$Tr,$FORWARD,false}
end

function plan_brfft(X::StridedArray{$Tc}, d::Integer, region;
Expand All @@ -643,54 +645,55 @@ for (Tr,Tc) in ((:Float32,:Complex64),(:Float64,:Complex128))
else
Xc = copy(X)
rFFTWPlan(X, Y, region, flags, timelimit)
end
end::rFFTWPlan{$Tc,$BACKWARD,false}
end

plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,1:ndims(X);kws...)
plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,1:ndims(X);kws...)

function plan_inv(p::rFFTWPlan{$Tr,FORWARD,false})
function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false})
X = Array($Tr, p.sz)
Y = p.flags&ESTIMATE != 0 ? FakeArray($Tc,p.osz) : Array($Tc,p.osz)
ScaledPlan(rFFTWPlan(Y, X, p.region,
length(p.region)<=1 ? p.flags | PRESERVE_INPUT
: p.flags, NO_TIMELIMIT),
: p.flags, NO_TIMELIMIT)::rFFTWPlan{$Tc,$BACKWARD,false},
normalization(X, p.region))
end

function plan_inv(p::rFFTWPlan{$Tc,BACKWARD,false})
function plan_inv(p::rFFTWPlan{$Tc,$BACKWARD,false})
X = Array($Tc, p.sz)
Y = p.flags&ESTIMATE != 0 ? FakeArray($Tr,p.osz) : Array($Tr,p.osz)
ScaledPlan(rFFTWPlan(Y, X, p.region, p.flags, NO_TIMELIMIT),
ScaledPlan(rFFTWPlan(Y, X, p.region, p.flags, NO_TIMELIMIT)::rFFTWPlan{$Tr,$FORWARD,false},
normalization(Y, p.region))
end

function A_mul_B!(y::StridedArray{$Tc}, p::rFFTWPlan{$Tr,FORWARD}, x::StridedArray{$Tr})
function A_mul_B!(y::StridedArray{$Tc}, p::rFFTWPlan{$Tr,$FORWARD}, x::StridedArray{$Tr})
assert_applicable(p, x, y)
unsafe_execute!(p, x, y)
return y
end
function A_mul_B!(y::StridedArray{$Tr}, p::rFFTWPlan{$Tc,BACKWARD}, x::StridedArray{$Tc})
function A_mul_B!(y::StridedArray{$Tr}, p::rFFTWPlan{$Tc,$BACKWARD}, x::StridedArray{$Tc})
assert_applicable(p, x, y)
unsafe_execute!(p, x, y) # note: may overwrite x as well as y!
return y
end

function *(p::rFFTWPlan{$Tr,FORWARD,false}, x::StridedArray{$Tr})
function *{N}(p::rFFTWPlan{$Tr,$FORWARD,false}, x::StridedArray{$Tr,N})
assert_applicable(p, x)
y = Array($Tc, p.osz)
y = Array($Tc, p.osz)::Array{$Tc,N}
unsafe_execute!(p, x, y)
return y
end

function *(p::rFFTWPlan{$Tc,BACKWARD,false}, x::StridedArray{$Tc})
y = Array($Tr, p.osz)
function *{N}(p::rFFTWPlan{$Tc,$BACKWARD,false}, x::StridedArray{$Tc,N})
if p.flags & PRESERVE_INPUT != 0
assert_applicable(p, x)
y = Array($Tr, p.osz)::Array{$Tr,N}
unsafe_execute!(p, x, y)
else # need to make a copy to avoid overwriting x
xc = copy(x)
assert_applicable(p, xc)
y = Array($Tr, p.osz)::Array{$Tr,N}
unsafe_execute!(p, xc, y)
end
return y
Expand Down Expand Up @@ -760,9 +763,9 @@ function A_mul_B!{T}(y::StridedArray{T}, p::r2rFFTWPlan{T}, x::StridedArray{T})
return y
end

function *{T,K}(p::r2rFFTWPlan{T,K,false}, x::StridedArray{T})
function *{T,K,N}(p::r2rFFTWPlan{T,K,false}, x::StridedArray{T,N})
assert_applicable(p, x)
y = Array(T, p.osz)
y = Array(T, p.osz)::Array{T,N}
unsafe_execute!(p, x, y)
return y
end
Expand Down
4 changes: 2 additions & 2 deletions base/fft/ctfft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ function CTPlan{Tr<:FloatingPoint}(::Type{Complex{Tr}}, forward::Bool, n::Int)
end

plan_fft{Tr<:FloatingPoint}(x::AbstractVector{Complex{Tr}}) =
CTPlan(Complex{Tr}, true, length(x))
CTPlan(Complex{Tr}, true, length(x))::CTPlan{Complex{Tr},true}
plan_bfft{Tr<:FloatingPoint}(x::AbstractVector{Complex{Tr}}) =
CTPlan(Complex{Tr}, false, length(x))
CTPlan(Complex{Tr}, false, length(x))::CTPlan{Complex{Tr},false}

function applystep{T}(p::CTPlan{T},
x::AbstractArray{T}, x0, xs,
Expand Down
4 changes: 2 additions & 2 deletions base/fft/fftn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ function MultiDimPlan{T<:Complex}(::Type{T}, forward::Bool, region, sz)
end

plan_fft{Tr<:FloatingPoint}(x::AbstractArray{Complex{Tr}}, region) =
MultiDimPlan(Complex{Tr}, true, region, size(x))
MultiDimPlan(Complex{Tr}, true, region, size(x))::MultiDimPlan{Complex{Tr}, true}
plan_bfft{Tr<:FloatingPoint}(x::AbstractArray{Complex{Tr}}, region) =
MultiDimPlan(Complex{Tr}, false, region, size(x))
MultiDimPlan(Complex{Tr}, false, region, size(x))::MultiDimPlan{Complex{Tr}, false}

# recursive execution of a MultiDim plan, starting at dimension d, for
# strided arrays (so that we can use linear indexing):
Expand Down
15 changes: 15 additions & 0 deletions test/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,18 @@ for T in (Complex128, Complex{BigFloat})
end
end
end

# issue #9772
for x in (randn(10),randn(10,12))
z = complex(x)
y = rfft(x)
@inferred rfft(x)
@inferred brfft(x,18)
@inferred brfft(y,10)
for f in (fft,plan_fft,bfft,plan_bfft,fft_)
@inferred f(x)
@inferred f(z)
end
# note: inference doesn't work for plan_fft_ since the
# algorithm steps are included in the CTPlan type
end

0 comments on commit 674f970

Please sign in to comment.