Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Wrap BLIS #431

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
[weakdeps]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
blis_jll = "6136c539-28a5-5bf0-87cc-b183200dce32"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
Expand All @@ -44,6 +45,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"

[extensions]
LinearSolveBandedMatricesExt = "BandedMatrices"
LinearSolveBLISExt = "blis_jll"
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveEnzymeExt = "Enzyme"
Expand All @@ -58,6 +60,7 @@ LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
[compat]
ArrayInterface = "7.4.11"
BandedMatrices = "1"
blis_jll = "0.9.0"
BlockDiagonals = "0.1"
ConcreteStructs = "0.2"
DocStringExtensions = "0.9"
Expand Down
248 changes: 248 additions & 0 deletions ext/LinearSolveBLISExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
module LinearSolveBLISExt

using Libdl
using blis_jll
using LinearAlgebra
using LinearSolve

using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
@blasfunc, chkargsok
using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCache, SciMLBase

const global libblis = blis_jll.blis

function getrf!(A::AbstractMatrix{<:ComplexF64};

Check warning on line 15 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L15

Added line #L15 was not covered by tests
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))

Check warning on line 25 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L19-L25

Added lines #L19 - L25 were not covered by tests
end
ccall((@blasfunc(zgetrf_), libblis), Cvoid,

Check warning on line 27 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L27

Added line #L27 was not covered by tests
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type

Check warning on line 32 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L31-L32

Added lines #L31 - L32 were not covered by tests
end

function getrf!(A::AbstractMatrix{<:ComplexF32};

Check warning on line 35 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L35

Added line #L35 was not covered by tests
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))

Check warning on line 45 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L39-L45

Added lines #L39 - L45 were not covered by tests
end
ccall((@blasfunc(cgetrf_), libblis), Cvoid,

Check warning on line 47 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L47

Added line #L47 was not covered by tests
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type

Check warning on line 52 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L51-L52

Added lines #L51 - L52 were not covered by tests
end

function getrf!(A::AbstractMatrix{<:Float64};

Check warning on line 55 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L55

Added line #L55 was not covered by tests
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))

Check warning on line 65 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L59-L65

Added lines #L59 - L65 were not covered by tests
end
ccall((@blasfunc(dgetrf_), libblis), Cvoid,

Check warning on line 67 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L67

Added line #L67 was not covered by tests
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type

Check warning on line 72 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L71-L72

Added lines #L71 - L72 were not covered by tests
end

function getrf!(A::AbstractMatrix{<:Float32};

Check warning on line 75 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L75

Added line #L75 was not covered by tests
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))

Check warning on line 85 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L79-L85

Added lines #L79 - L85 were not covered by tests
end
ccall((@blasfunc(sgetrf_), libblis), Cvoid,

Check warning on line 87 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L87

Added line #L87 was not covered by tests
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type

Check warning on line 92 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L91-L92

Added lines #L91 - L92 were not covered by tests
end

function getrs!(trans::AbstractChar,

Check warning on line 95 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L95

Added line #L95 was not covered by tests
A::AbstractMatrix{<:ComplexF64},
ipiv::AbstractVector{BlasInt},
B::AbstractVecOrMat{<:ComplexF64};
info = Ref{BlasInt}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))

Check warning on line 105 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L100-L105

Added lines #L100 - L105 were not covered by tests
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))

Check warning on line 108 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L107-L108

Added lines #L107 - L108 were not covered by tests
end
nrhs = size(B, 2)
ccall(("zgetrs_", libblis), Cvoid,

Check warning on line 111 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L110-L111

Added lines #L110 - L111 were not covered by tests
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B

Check warning on line 117 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L116-L117

Added lines #L116 - L117 were not covered by tests
end

function getrs!(trans::AbstractChar,

Check warning on line 120 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L120

Added line #L120 was not covered by tests
A::AbstractMatrix{<:ComplexF32},
ipiv::AbstractVector{BlasInt},
B::AbstractVecOrMat{<:ComplexF32};
info = Ref{BlasInt}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))

Check warning on line 130 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L125-L130

Added lines #L125 - L130 were not covered by tests
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))

Check warning on line 133 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L132-L133

Added lines #L132 - L133 were not covered by tests
end
nrhs = size(B, 2)
ccall(("cgetrs_", libblis), Cvoid,

Check warning on line 136 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L135-L136

Added lines #L135 - L136 were not covered by tests
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B

Check warning on line 142 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L141-L142

Added lines #L141 - L142 were not covered by tests
end

function getrs!(trans::AbstractChar,

Check warning on line 145 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L145

Added line #L145 was not covered by tests
A::AbstractMatrix{<:Float64},
ipiv::AbstractVector{BlasInt},
B::AbstractVecOrMat{<:Float64};
info = Ref{BlasInt}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))

Check warning on line 155 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L150-L155

Added lines #L150 - L155 were not covered by tests
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))

Check warning on line 158 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L157-L158

Added lines #L157 - L158 were not covered by tests
end
nrhs = size(B, 2)
ccall(("dgetrs_", libblis), Cvoid,

Check warning on line 161 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L160-L161

Added lines #L160 - L161 were not covered by tests
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B

Check warning on line 167 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L166-L167

Added lines #L166 - L167 were not covered by tests
end

function getrs!(trans::AbstractChar,

Check warning on line 170 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L170

Added line #L170 was not covered by tests
A::AbstractMatrix{<:Float32},
ipiv::AbstractVector{BlasInt},
B::AbstractVecOrMat{<:Float32};
info = Ref{BlasInt}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))

Check warning on line 180 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L175-L180

Added lines #L175 - L180 were not covered by tests
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))

Check warning on line 183 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L182-L183

Added lines #L182 - L183 were not covered by tests
end
nrhs = size(B, 2)
ccall(("sgetrs_", libblis), Cvoid,

Check warning on line 186 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L185-L186

Added lines #L185 - L186 were not covered by tests
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B

Check warning on line 192 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L191-L192

Added lines #L191 - L192 were not covered by tests
end

default_alias_A(::BLISLUFactorization, ::Any, ::Any) = false
default_alias_b(::BLISLUFactorization, ::Any, ::Any) = false

Check warning on line 196 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L195-L196

Added lines #L195 - L196 were not covered by tests

const PREALLOCATED_BLIS_LU = begin
A = rand(0, 0)
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}()
end

function LinearSolve.init_cacheval(alg::BLISLUFactorization, A, b, u, Pl, Pr,

Check warning on line 203 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L203

Added line #L203 was not covered by tests
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
PREALLOCATED_BLIS_LU

Check warning on line 206 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L206

Added line #L206 was not covered by tests
end

function LinearSolve.init_cacheval(alg::BLISLUFactorization, A::AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}}, b, u, Pl, Pr,

Check warning on line 209 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L209

Added line #L209 was not covered by tests
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
A = rand(eltype(A), 0, 0)
ArrayInterface.lu_instance(A), Ref{BlasInt}()

Check warning on line 213 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L212-L213

Added lines #L212 - L213 were not covered by tests
end

function SciMLBase.solve!(cache::LinearCache, alg::BLISLUFactorization;

Check warning on line 216 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L216

Added line #L216 was not covered by tests
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :BLISLUFactorization)
res = getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
fact = LU(res[1:3]...), res[4]
cache.cacheval = fact
cache.isfresh = false

Check warning on line 225 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L218-L225

Added lines #L218 - L225 were not covered by tests
end

y = ldiv!(cache.u, @get_cacheval(cache, :BLISLUFactorization)[1], cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)

Check warning on line 229 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L228-L229

Added lines #L228 - L229 were not covered by tests

#=
A, info = @get_cacheval(cache, :BLISLUFactorization)
LinearAlgebra.require_one_based_indexing(cache.u, cache.b)
m, n = size(A, 1), size(A, 2)
if m > n
Bc = copy(cache.b)
getrs!('N', A.factors, A.ipiv, Bc; info)
return copyto!(cache.u, 1, Bc, 1, n)
else
copyto!(cache.u, cache.b)
getrs!('N', A.factors, A.ipiv, cache.u; info)
end

SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
=#
end

end
2 changes: 2 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,5 @@ A wrapper over Apple's Metal GPU library. Direct calls to Metal in a way that pr
to avoid allocations and automatically offloads to the GPU.
"""
struct MetalLUFactorization <: AbstractFactorization end

struct BLISLUFactorization <: AbstractFactorization end
Loading