diff --git a/Project.toml b/Project.toml new file mode 100644 index 0000000..123dc9b --- /dev/null +++ b/Project.toml @@ -0,0 +1,13 @@ +name = "MKLSparse" +uuid = "0c723cd3-b8cd-5d40-b370-ba682dde9aae" +version = "1.0.0" + +[deps] +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[extras] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Test"] \ No newline at end of file diff --git a/src/BLAS/level_2_3/matmul.jl b/src/BLAS/level_2_3/matmul.jl index 62ae9e1..6cd6c75 100644 --- a/src/BLAS/level_2_3/matmul.jl +++ b/src/BLAS/level_2_3/matmul.jl @@ -24,8 +24,7 @@ for (tchar, ttype) in (('N', :()), ('T', :Transpose)) AT = tchar == 'N' ? :(SparseMatrixCSC{$T,BlasInt}) : :($ttype{$T,SparseMatrixCSC{$T,BlasInt}}) @eval begin - function mul!(α::$T, adjA::$AT, - B::$mat{$T}, β::$T, C::$mat{$T}) + function mul!(C::$mat{$T}, adjA::$AT, B::$mat{$T}, α::$T, β::$T) A = _unwrap_adj(adjA) if isa(B, AbstractVector) return cscmv!($tchar, α, matdescra(A), A, B, β, C) @@ -34,7 +33,7 @@ for (tchar, ttype) in (('N', :()), end end - mul!(C::$mat{$T}, adjA::$AT, B::$mat{$T}) = mul!(one($T), adjA, B, zero($T), C) + mul!(C::$mat{$T}, adjA::$AT, B::$mat{$T}) = mul!(C, adjA, B, one($T), zero($T)) function (*)(adjA::$AT, B::$mat{$T}) A = _unwrap_adj(adjA) @@ -51,8 +50,7 @@ for (tchar, ttype) in (('N', :()), :($w{$T,SparseMatrixCSC{$T,BlasInt}}) : :($ttype{$T,$w{$T,SparseMatrixCSC{$T,BlasInt}}}) @eval begin - function mul!(α::$T, adjA::$AT, - B::$mat{$T}, β::$T, C::$mat{$T}) + function mul!(C::$mat{$T}, adjA::$AT, B::$mat{$T}, α::$T, β::$T) A = _unwrap_adj(adjA) if isa(B,AbstractVector) return cscmv!($tchar, α, matdescra(A), _get_data(A), B, β, C) @@ -61,7 +59,7 @@ for (tchar, ttype) in (('N', :()), end end - mul!(C::$mat{$T}, adjA::$AT, B::$mat{$T}) = mul!(one($T), adjA, B, zero($T), C) + mul!(C::$mat{$T}, adjA::$AT, B::$mat{$T}) = mul!(C, adjA, B, one($T), zero($T)) function (*)(adjA::$AT, B::$mat{$T}) A = _unwrap_adj(adjA) diff --git a/test/test_BLAS.jl b/test/test_BLAS.jl index 64b6d35..ddfff55 100644 --- a/test/test_BLAS.jl +++ b/test/test_BLAS.jl @@ -63,6 +63,9 @@ end @test_blas (maximum(abs.(mul!(similar(c), a, c) - Array(a)*c)) < 100*eps()) @test_blas (maximum(abs.(mul!(similar(b), transpose(a), b) - transpose(Array(a))*b)) < 100*eps()) @test_blas (maximum(abs.(mul!(similar(c), transpose(a), c) - transpose(Array(a))*c)) < 100*eps()) + @test_blas (maximum(abs.(mul!(copy(b), a, b, α, β) - (α*(Array(a)*b) + β*b))) < 100*eps()) + @test_blas (maximum(abs.(mul!(copy(b), transpose(a), b, α, β) - (α*(transpose(Array(a))*b) + β*b))) < 100*eps()) + @test_blas (maximum(abs.(mul!(copy(c), transpose(a), c, α, β) - (α*(transpose(Array(a))*c) + β*c))) < 100*eps()) c = randn(6) + im*randn(6) @test_throws DimensionMismatch transpose(a)*c