Skip to content

Commit

Permalink
add type assertions to work around inference problem (fix JuliaLang#9772
Browse files Browse the repository at this point in the history
)
  • Loading branch information
stevengj committed Jan 14, 2015
1 parent 5059e08 commit 7da59d7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
34 changes: 18 additions & 16 deletions base/fft/FFTW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -570,18 +570,19 @@ fftwfloat{T<:Complex}(X::AbstractArray{T}) = fftwcomplex(X)
for (f,direction) in ((:fft,:FORWARD), (:bfft,:BACKWARD))
plan_f = symbol(string("plan_",f))
plan_f! = symbol(string("plan_",f,"!"))
idirection = direction == :FORWARD ? :BACKWARD : :FORWARD
@eval begin
function $plan_f{T<:fftwComplex}(X::StridedArray{T}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT)
cFFTWPlan(X, fakesimilar(flags, X, T), region,
$direction, flags,timelimit)
$direction, flags,timelimit)::cFFTWPlan{T,$direction,false}
end

function $plan_f!{T<:fftwComplex}(X::StridedArray{T}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT)
cFFTWPlan(X, X, region, $direction, flags, timelimit)
cFFTWPlan(X, X, region, $direction, flags, timelimit)::cFFTWPlan{T,$direction,true}
end
$plan_f{T<:fftwComplex}(X::StridedArray{T}; kws...) =
$plan_f(X, 1:ndims(X); kws...)
Expand All @@ -591,8 +592,8 @@ for (f,direction) in ((:fft,:FORWARD), (:bfft,:BACKWARD))
function plan_inv{T<:fftwComplex,inplace}(p::cFFTWPlan{T,$direction,inplace})
X = Array(T, p.sz)
Y = inplace ? X : fakesimilar(p.flags, X, T)
ScaledPlan(cFFTWPlan(X, Y, p.region, -$direction,
p.flags, NO_TIMELIMIT),
ScaledPlan(cFFTWPlan(X, Y, p.region, $idirection,
p.flags, NO_TIMELIMIT)::cFFTWPlan{T,$idirection,inplace},
normalization(X, p.region))
end
end
Expand All @@ -604,9 +605,9 @@ function A_mul_B!{T}(y::StridedArray{T}, p::cFFTWPlan{T}, x::StridedArray{T})
return y
end

function *{T,K}(p::cFFTWPlan{T,K,false}, x::StridedArray{T})
function *{T,K,N}(p::cFFTWPlan{T,K,false}, x::StridedArray{T,N})
assert_applicable(p, x)
y = Array(T, p.osz)
y = Array(T, p.osz)::Array{T,N}
unsafe_execute!(p, x, y)
return y
end
Expand All @@ -625,7 +626,7 @@ for (Tr,Tc) in ((:Float32,:Complex64),(:Float64,:Complex128))
flags::Integer=ESTIMATE, timelimit::Real=NO_TIMELIMIT)
osize = rfft_output_size(X, region)
Y = flags&ESTIMATE != 0 ? FakeArray($Tc,osize...) : Array($Tc,osize...)
rFFTWPlan(X, Y, region, flags, timelimit)
rFFTWPlan(X, Y, region, flags, timelimit)::rFFTWPlan{$Tr,FORWARD,false}
end

function plan_brfft(X::StridedArray{$Tc}, d::Integer, region;
Expand All @@ -641,7 +642,7 @@ for (Tr,Tc) in ((:Float32,:Complex64),(:Float64,:Complex128))
else
Xc = copy(X)
rFFTWPlan(X, Y, region, flags, timelimit)
end
end::rFFTWPlan{$Tc,BACKWARD,false}
end

plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,1:ndims(X);kws...)
Expand All @@ -652,14 +653,14 @@ for (Tr,Tc) in ((:Float32,:Complex64),(:Float64,:Complex128))
Y = p.flags&ESTIMATE != 0 ? FakeArray($Tc,p.osz) : Array($Tc,p.osz)
ScaledPlan(rFFTWPlan(Y, X, p.region,
length(p.region)<=1 ? p.flags | PRESERVE_INPUT
: p.flags, NO_TIMELIMIT),
: p.flags, NO_TIMELIMIT)::rFFTWPlan{$Tc,BACKWARD,false},
normalization(X, p.region))
end

function plan_inv(p::rFFTWPlan{$Tc,BACKWARD,false})
X = Array($Tc, p.sz)
Y = p.flags&ESTIMATE != 0 ? FakeArray($Tr,p.osz) : Array($Tr,p.osz)
ScaledPlan(rFFTWPlan(Y, X, p.region, p.flags, NO_TIMELIMIT),
ScaledPlan(rFFTWPlan(Y, X, p.region, p.flags, NO_TIMELIMIT)::rFFTWPlan{$Tr,FORWARD,false},
normalization(Y, p.region))
end

Expand All @@ -674,21 +675,22 @@ for (Tr,Tc) in ((:Float32,:Complex64),(:Float64,:Complex128))
return y
end

function *(p::rFFTWPlan{$Tr,FORWARD,false}, x::StridedArray{$Tr})
function *{N}(p::rFFTWPlan{$Tr,FORWARD,false}, x::StridedArray{$Tr,N})
assert_applicable(p, x)
y = Array($Tc, p.osz)
y = Array($Tc, p.osz)::Array{$Tc,N}
unsafe_execute!(p, x, y)
return y
end

function *(p::rFFTWPlan{$Tc,BACKWARD,false}, x::StridedArray{$Tc})
y = Array($Tr, p.osz)
function *{N}(p::rFFTWPlan{$Tc,BACKWARD,false}, x::StridedArray{$Tc,N})
if p.flags & PRESERVE_INPUT != 0
assert_applicable(p, x)
y = Array($Tr, p.osz)::Array{$Tr,N}
unsafe_execute!(p, x, y)
else # need to make a copy to avoid overwriting x
xc = copy(x)
assert_applicable(p, xc)
y = Array($Tr, p.osz)::Array{$Tr,N}
unsafe_execute!(p, xc, y)
end
return y
Expand Down Expand Up @@ -758,9 +760,9 @@ function A_mul_B!{T}(y::StridedArray{T}, p::r2rFFTWPlan{T}, x::StridedArray{T})
return y
end

function *{T,K}(p::r2rFFTWPlan{T,K,false}, x::StridedArray{T})
function *{T,K,N}(p::r2rFFTWPlan{T,K,false}, x::StridedArray{T,N})
assert_applicable(p, x)
y = Array(T, p.osz)
y = Array(T, p.osz)::Array{T,N}
unsafe_execute!(p, x, y)
return y
end
Expand Down
13 changes: 13 additions & 0 deletions test/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,16 @@ for T in (Complex128, Complex{BigFloat})
end
end
end

# issue #9772
for x in (randn(10),randn(10,12))
z = complex(x)
y = rfft(x)
@inferred rfft(x)
@inferred brfft(x,18)
@inferred brfft(y,10)
for f in (fft,bfft,fft_)
@inferred f(x)
@inferred f(z)
end
end

0 comments on commit 7da59d7

Please sign in to comment.