Skip to content

Commit

Permalink
Rewrite A[c|t]_ldiv_B! specializations for UmfpackLU-StridedVecOrMat …
Browse files Browse the repository at this point in the history
…combinations, without generalized linear indexing and meta-fu.
  • Loading branch information
Sacha0 committed Jan 15, 2017
1 parent fa31b38 commit 32b7b4a
Showing 1 changed file with 47 additions and 34 deletions.
81 changes: 47 additions & 34 deletions base/sparse/umfpack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -383,43 +383,56 @@ function nnz(lu::UmfpackLU)
end

### Solve with Factorization
for (f!, umfpack) in ((:A_ldiv_B!, :UMFPACK_A),
(:Ac_ldiv_B!, :UMFPACK_At),
(:At_ldiv_B!, :UMFPACK_Aat))
@eval begin
function $f!{T<:UMFVTypes}(x::StridedVecOrMat{T}, lu::UmfpackLU{T}, b::StridedVecOrMat{T})
n = size(x, 2)
if n != size(b, 2)
throw(DimensionMismatch("in and output arrays must have the same number of columns"))
end
for j in 1:n
solve!(view(x, :, j), lu, view(b, :, j), $umfpack)
end
return x
end
$f!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedVector{T}) = $f!(b, lu, copy(b))
$f!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedMatrix{T}) = $f!(b, lu, copy(b))

function $f!{Tb<:Complex}(x::StridedVector{Tb}, lu::UmfpackLU{Float64}, b::StridedVector{Tb})
m, n = size(x, 1), size(x, 2)
if n != size(b, 2)
throw(DimensionMismatch("in and output arrays must have the same number of columns"))
end
# TODO: Optionally let user allocate these and pass in somehow
r = similar(b, Float64, m)
i = similar(b, Float64, m)
for j in 1:n
solve!(r, lu, convert(Vector{Float64}, real(view(b, :, j))), $umfpack)
solve!(i, lu, convert(Vector{Float64}, imag(view(b, :, j))), $umfpack)

map!((t,s) -> t + im*s, view(x, :, j), r, i)
end
return x
end
$f!{Tb<:Complex}(lu::UmfpackLU{Float64}, b::StridedVector{Tb}) = $f!(b, lu, copy(b))
A_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedVecOrMat{T}) = A_ldiv_B!(b, lu, copy(b))
At_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedVecOrMat{T}) = At_ldiv_B!(b, lu, copy(b))
Ac_ldiv_B!{T<:UMFVTypes}(lu::UmfpackLU{T}, b::StridedVecOrMat{T}) = Ac_ldiv_B!(b, lu, copy(b))

A_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_A)
At_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_At)
Ac_ldiv_B!{T<:UMFVTypes}(X::StridedVecOrMat{T}, lu::UmfpackLU{T}, B::StridedVecOrMat{T}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_Aat)

A_ldiv_B!{Tb<:Complex}(lu::UmfpackLU{Float64}, b::StridedVecOrMat{Tb}) = A_ldiv_B!(b, lu, copy(b))
At_ldiv_B!{Tb<:Complex}(lu::UmfpackLU{Float64}, b::StridedVecOrMat{Tb}) = At_ldiv_B!(b, lu, copy(b))
Ac_ldiv_B!{Tb<:Complex}(lu::UmfpackLU{Float64}, b::StridedVecOrMat{Tb}) = Ac_ldiv_B!(b, lu, copy(b))

A_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_A)
At_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_At)
Ac_ldiv_B!{Tb<:Complex}(X::StridedVecOrMat{Tb}, lu::UmfpackLU{Float64}, B::StridedVecOrMat{Tb}) =
_Aq_ldiv_B!(X, lu, B, UMFPACK_Aat)

_Aq_ldiv_B!(X::StridedVecOrMat, lu::UmfpackLU, B::StridedVecOrMat, transtype) =
(_AqldivB_checkshapecompat(X, B); _AqldivB_kernel!(X, lu, B, transtype); return X)

_AqldivB_checkshapecompat(X::StridedVecOrMat, B::StridedVecOrMat) =
size(X, 2) == size(B, 2) || throw(DimensionMismatch("input and output must have same column count"))

_AqldivB_kernel!{T<:UMFVTypes}(x::StridedVector{T}, lu::UmfpackLU{T}, b::StridedVector{T}, transtype) =
solve!(x, lu, b, transtype)
_AqldivB_kernel!{T<:UMFVTypes}(X::StridedMatrix{T}, lu::UmfpackLU{T}, b::StridedMatrix{T}, transtype) =
for col in 1:size(X, 1) solve!(view(X, :, col), lu, view(B, :, col), transtype) end

function _AqldivB_kernel!{Tb<:Complex}(X::StridedVector{Tb}, lu::UmfpackLU{Float64}, B::StridedVector{Tb}, transtype)
r, i = similar(B, Float64), similar(B, Float64)
solve!(r, lu, Vector{Float64}(real(B)), transtype)
solve!(i, lu, Vector{Float64}(imag(B)), transtype)
map!(complex, X, r, i)
end
function _AqldivB_kernel!{Tb<:Complex}(X::StridedMatrix{Tb}, lu::UmfpackLU{Float64}, B::StridedMatrix{Tb}, transtype)
r = similar(B, Float64, size(B, 1))
i = similar(B, Float64, size(B, 1))
for j in 1:size(B, 2)
solve!(r, lu, Vector{Float64}(real(view(B, :, j))), transtype)
solve!(i, lu, Vector{Float64}(imag(view(B, :, j))), transtype)
map!(complex, view(X, :, j), r, i)
end
end


function getindex(lu::UmfpackLU, d::Symbol)
L,U,p,q,Rs = umf_extract(lu)
if d == :L
Expand Down

0 comments on commit 32b7b4a

Please sign in to comment.