Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update MKLSparse.jl to support both LP64 and ILP64 interfaces #44

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/MKLSparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,18 @@ end
INTERFACE_GNU
end

function set_threading_layer(layer::Threading = THREADING_SEQUENTIAL)
function set_threading_layer(layer::Threading)
err = @ccall libmkl_rt.MKL_Set_Threading_Layer(layer::Cint)::Cint
(err == -1) && error("MKL_Set_Threading_Layer() returned -1")
return nothing
end

function set_interface_layer(interface::Interface = INTERFACE_ILP64)
function set_interface_layer(interface::Interface)
err = @ccall libmkl_rt.MKL_Set_Interface_Layer(interface::Cint)::Cint
(err == -1) && error("MKL_Set_Interface_Layer() returned -1")
return nothing
end

function __init__()
set_interface_layer(Base.USE_BLAS64 ? INTERFACE_ILP64 : INTERFACE_LP64)
end

# Wrappers generated by Clang.jl
include("libmklsparse.jl")
Expand Down
169 changes: 119 additions & 50 deletions src/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,61 +27,130 @@ function _check_mat_mult_matvec(C, A, B, tA)
end
end

function cscmv!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, x::StridedVector{T},
β::T, y::StridedVector{T}) where {T <: BlasFloat}
_check_transa(transa)
_check_mat_mult_matvec(y, A, x, transa)
__counter[] += 1
for (fname, T) in ((:mkl_scscmv, :Float32 ),
(:mkl_dcscmv, :Float64 ),
(:mkl_ccscmv, :ComplexF32),
(:mkl_zcscmv, :ComplexF64))
@eval begin
function cscmv!(transa::Char, α::$T, matdescra::String,
A::SparseMatrixCSC{$T, Int32}, x::StridedVector{$T},
β::$T, y::StridedVector{$T})
_check_transa(transa)
_check_mat_mult_matvec(y, A, x, transa)
__counter[] += 1
$fname(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y)
return y
end

T == Float32 && (mkl_scscmv(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y))
T == Float64 && (mkl_dcscmv(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y))
T == ComplexF32 && (mkl_ccscmv(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y))
T == ComplexF64 && (mkl_zcscmv(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y))
return y
function cscmv!(transa::Char, α::$T, matdescra::String,
A::SparseMatrixCSC{$T, Int64}, x::StridedVector{$T},
β::$T, y::StridedVector{$T})
_check_transa(transa)
_check_mat_mult_matvec(y, A, x, transa)
__counter[] += 1
set_interface_layer(INTERFACE_ILP64)
$fname(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y)
set_interface_layer(INTERFACE_LP64)
return y
end
end
end

function cscmm!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, B::StridedMatrix{T},
β::T, C::StridedMatrix{T}) where {T <: BlasFloat}
_check_transa(transa)
_check_mat_mult_matvec(C, A, B, transa)
mB, nB = size(B)
mC, nC = size(C)
__counter[] += 1
T == Float32 && (mkl_scscmm(transa, A.m, nC, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, β, C, mC))
T == Float64 && (mkl_dcscmm(transa, A.m, nC, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, β, C, mC))
T == ComplexF32 && (mkl_ccscmm(transa, A.m, nC, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, β, C, mC))
T == ComplexF64 && (mkl_zcscmm(transa, A.m, nC, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, β, C, mC))
return C
for (fname, T) in ((:mkl_scscmm, :Float32 ),
(:mkl_dcscmm, :Float64 ),
(:mkl_ccscmm, :ComplexF32),
(:mkl_zcscmm, :ComplexF64))
@eval begin
function cscmm!(transa::Char, α::$T, matdescra::String,
A::SparseMatrixCSC{$T, Int32}, B::StridedMatrix{$T},
β::$T, C::StridedMatrix{$T})
_check_transa(transa)
_check_mat_mult_matvec(C, A, B, transa)
mB, nB = size(B)
mC, nC = size(C)
__counter[] += 1
$fname(transa, A.m, nC, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, β, C, mC)
return C
end

function cscmm!(transa::Char, α::$T, matdescra::String,
A::SparseMatrixCSC{$T, Int64}, B::StridedMatrix{$T},
β::$T, C::StridedMatrix{$T})
_check_transa(transa)
_check_mat_mult_matvec(C, A, B, transa)
mB, nB = size(B)
mC, nC = size(C)
__counter[] += 1
set_interface_layer(INTERFACE_ILP64)
$fname(transa, A.m, nC, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, β, C, mC)
set_interface_layer(INTERFACE_LP64)
return C
end
end
end

function cscsv!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, x::StridedVector{T},
y::StridedVector{T}) where {T <: BlasFloat}
n = checksquare(A)
_check_transa(transa)
_check_mat_mult_matvec(y, A, x, transa)
__counter[] += 1
T == Float32 && (mkl_scscsv(transa, A.m, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, y))
T == Float64 && (mkl_dcscsv(transa, A.m, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, y))
T == ComplexF32 && (mkl_ccscsv(transa, A.m, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, y))
T == ComplexF64 && (mkl_zcscsv(transa, A.m, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, y))
return y
for (fname, T) in ((:mkl_scscsv, :Float32 ),
(:mkl_dcscsv, :Float64 ),
(:mkl_ccscsv, :ComplexF32),
(:mkl_zcscsv, :ComplexF64))
@eval begin
function cscsv!(transa::Char, α::$T, matdescra::String,
A::SparseMatrixCSC{$T, Int32}, x::StridedVector{$T},
y::StridedVector{$T})
n = checksquare(A)
_check_transa(transa)
_check_mat_mult_matvec(y, A, x, transa)
__counter[] += 1
$fname(transa, A.m, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, y)
return y
end

function cscsv!(transa::Char, α::$T, matdescra::String,
A::SparseMatrixCSC{$T, Int64}, x::StridedVector{$T},
y::StridedVector{$T})
n = checksquare(A)
_check_transa(transa)
_check_mat_mult_matvec(y, A, x, transa)
__counter[] += 1
set_interface_layer(INTERFACE_ILP64)
$fname(transa, A.m, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, y)
set_interface_layer(INTERFACE_LP64)
return y
end
end
end

function cscsm!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, B::StridedMatrix{T},
C::StridedMatrix{T}) where {T <: BlasFloat}
mB, nB = size(B)
mC, nC = size(C)
n = checksquare(A)
_check_transa(transa)
_check_mat_mult_matvec(C, A, B, transa)
__counter[] += 1
T == Float32 && (mkl_scscsm(transa, A.n, nC, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, C, mC))
T == Float64 && (mkl_dcscsm(transa, A.n, nC, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, C, mC))
T == ComplexF32 && (mkl_ccscsm(transa, A.n, nC, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, C, mC))
T == ComplexF64 && (mkl_zcscsm(transa, A.n, nC, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, C, mC))
return C
for (fname, T) in ((:mkl_scscsm, :Float32 ),
(:mkl_dcscsm, :Float64 ),
(:mkl_ccscsm, :ComplexF32),
(:mkl_zcscsm, :ComplexF64))
@eval begin
function cscsm!(transa::Char, α::$T, matdescra::String,
A::SparseMatrixCSC{$T, Int32}, B::StridedMatrix{$T},
C::StridedMatrix{$T})
mB, nB = size(B)
mC, nC = size(C)
n = checksquare(A)
_check_transa(transa)
_check_mat_mult_matvec(C, A, B, transa)
__counter[] += 1
$fname(transa, A.n, nC, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, C, mC)
return C
end

function cscsm!(transa::Char, α::$T, matdescra::String,
A::SparseMatrixCSC{$T, Int64}, B::StridedMatrix{$T},
C::StridedMatrix{$T})
mB, nB = size(B)
mC, nC = size(C)
n = checksquare(A)
_check_transa(transa)
_check_mat_mult_matvec(C, A, B, transa)
__counter[] += 1
set_interface_layer(INTERFACE_ILP64)
$fname(transa, A.n, nC, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), B, mB, C, mC)
set_interface_layer(INTERFACE_LP64)
return C
end
end
end
2 changes: 1 addition & 1 deletion src/matdescra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ matdescra(A::UnitLowerTriangular) = "TLUF"
matdescra(A::UnitUpperTriangular) = "TUUF"
matdescra(A::Symmetric) = string('S', A.uplo, 'N', 'F')
matdescra(A::Hermitian) = string('H', A.uplo, 'N', 'F')
matdescra(A::SparseMatrixCSC) = "GUUF"
matdescra(A::SparseMatrixCSC) = "GFNF"
22 changes: 9 additions & 13 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,18 @@ _get_data(A::UpperTriangular) = triu(A.data)
_get_data(A::UnitLowerTriangular) = tril(A.data)
_get_data(A::UnitUpperTriangular) = triu(A.data)
_get_data(A::Symmetric) = A.data
_get_data(A::Hermitian) = A.data

_unwrap_adj(x::Union{Adjoint,Transpose}) = parent(x)
_unwrap_adj(x) = x

const SparseMatrices{T} = Union{SparseMatrixCSC{T,BlasInt},
Symmetric{T,SparseMatrixCSC{T,BlasInt}},
LowerTriangular{T, SparseMatrixCSC{T,BlasInt}},
UnitLowerTriangular{T, SparseMatrixCSC{T,BlasInt}},
UpperTriangular{T, SparseMatrixCSC{T,BlasInt}},
UnitUpperTriangular{T, SparseMatrixCSC{T,BlasInt}}}

for T in [Complex{Float32}, Complex{Float64}, Float32, Float64]
for T in [Float32, Float64, Complex{Float32}, Complex{Float64}]
for INT in [Int32, Int64]
for mat in (:StridedVector, :StridedMatrix)
for (tchar, ttype) in (('N', :()),
('C', :Adjoint),
('T', :Transpose))
AT = tchar == 'N' ? :(SparseMatrixCSC{$T,BlasInt}) : :($ttype{$T,SparseMatrixCSC{$T,BlasInt}})
AT = tchar == 'N' ? :(SparseMatrixCSC{$T,$INT}) : :($ttype{$T,SparseMatrixCSC{$T,$INT}})
@eval begin
function mul!(C::$mat{$T}, adjA::$AT, B::$mat{$T}, α::Number, β::Number)
A = _unwrap_adj(adjA)
Expand All @@ -45,10 +40,10 @@ for (tchar, ttype) in (('N', :()),
end
end

for w in (:Symmetric, :LowerTriangular, :UnitLowerTriangular, :UpperTriangular, :UnitUpperTriangular)
for w in (:Symmetric, :Hermitian, :LowerTriangular, :UnitLowerTriangular, :UpperTriangular, :UnitUpperTriangular)
AT = tchar == 'N' ?
:($w{$T,SparseMatrixCSC{$T,BlasInt}}) :
:($ttype{$T,$w{$T,SparseMatrixCSC{$T,BlasInt}}})
:($w{$T,SparseMatrixCSC{$T,$INT}}) :
:($ttype{$T,$w{$T,SparseMatrixCSC{$T,$INT}}})
@eval begin
function mul!(C::$mat{$T}, adjA::$AT, B::$mat{$T}, α::Number, β::Number)
A = _unwrap_adj(adjA)
Expand All @@ -71,7 +66,7 @@ for (tchar, ttype) in (('N', :()),
end
end

if w != :Symmetric
if w != :Symmetric && w != :Hermitian
@eval begin
function ldiv!(α::Number, adjA::$AT,
B::$mat{$T}, C::$mat{$T})
Expand All @@ -98,4 +93,5 @@ for (tchar, ttype) in (('N', :()),
end
end
end # mat
end # INT
end # T
Loading
Loading