diff --git a/lib/mps/linalg.jl b/lib/mps/linalg.jl index a140dfa74..6b9921def 100644 --- a/lib/mps/linalg.jl +++ b/lib/mps/linalg.jl @@ -192,30 +192,43 @@ function LinearAlgebra.transpose!(B::MtlMatrix{T}, A::MtlMatrix{T}) where {T} commit!(cmdbuf) + wait_completed(cmdbuf) + return B end +function LinearAlgebra.:(\)(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} + C = deepcopy(B) + LinearAlgebra.ldiv!(A, C) + return C +end + + function LinearAlgebra.ldiv!(A::LU{T, <:MtlMatrix{T}, <:MtlVector{UInt32}}, B::MtlVecOrMat{T}) where {T<:MtlFloat} M,N = size(B,1), size(B,2) dev = current_device() queue = global_queue(dev) - Bt = reshape(B, (N,M)) + At = similar(A.factors) + Bt = similar(B, (N,M)) P = reshape((A.ipiv .- UInt32(1)), (1,M)) - X = similar(B) + X = similar(B, (N,M)) + + transpose!(At, A.factors) + transpose!(Bt, B) - mps_a = MPSMatrix(A.factors) + mps_a = MPSMatrix(At) mps_b = MPSMatrix(Bt) mps_p = MPSMatrix(P) mps_x = MPSMatrix(X) MTLCommandBuffer(queue) do cmdbuf - kernel = MPSMatrixSolveLU(dev, true, M, N) + kernel = MPSMatrixSolveLU(dev, false, M, N) encode!(cmdbuf, kernel, mps_a, mps_b, mps_p, mps_x) end - Bt .= X + transpose!(B, X) return B end @@ -225,20 +238,24 @@ function LinearAlgebra.ldiv!(A::UpperTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrM dev = current_device() queue = global_queue(dev) - Ad = MtlMatrix(A; storage=Private) - Bt = reshape(B, (N,M)) - X = similar(B) + Ad = MtlMatrix(A') + Br = similar(B, (M,M)) + X = similar(Br) + + transpose!(Br, B) mps_a = MPSMatrix(Ad) - mps_b = MPSMatrix(Bt) + mps_b = MPSMatrix(Br) mps_x = MPSMatrix(X) - MTLCommandBuffer(queue) do cmdbuf - kernel = MPSMatrixSolveTriangular(dev, false, false, false, false, M, N, 1.0) + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, N, M, 1.0) encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) end - Bt .= X + wait_completed(buf) + + copy!(B, X) return B end @@ -248,20 +265,23 @@ function LinearAlgebra.ldiv!(A::UnitUpperTriangular{T, <:MtlMatrix{T}}, B::MtlVe dev = current_device() queue = global_queue(dev) - Ad = MtlMatrix(A; storage=Private) - Bt = reshape(B, (N,M)) - X = similar(B) + Ad = MtlMatrix(A) + Br = reshape(B, (M,N)) + X = similar(Br) mps_a = MPSMatrix(Ad) - mps_b = MPSMatrix(Bt) + mps_b = MPSMatrix(Br) mps_x = MPSMatrix(X) - MTLCommandBuffer(queue) do cmdbuf - kernel = MPSMatrixSolveTriangular(dev, false, false, false, true, M, N, 1.0) + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, true, false, false, true, M, N, 1.0) encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) end - Bt .= X + wait_completed(buf) + + copy!(Br, X) return B end @@ -271,20 +291,23 @@ function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:MtlMatrix{T}}, B::MtlVecOrM dev = current_device() queue = global_queue(dev) - Ad = MtlMatrix(A; storage=Private) - Bt = reshape(B, (N,M)) - X = similar(B) + Ad = MtlMatrix(A) + Br = reshape(B, (M,N)) + X = similar(Br) mps_a = MPSMatrix(Ad) - mps_b = MPSMatrix(Bt) + mps_b = MPSMatrix(Br) mps_x = MPSMatrix(X) - MTLCommandBuffer(queue) do cmdbuf - kernel = MPSMatrixSolveTriangular(dev, false, true, false, false, M, N, 1.0) + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, true, true, false, false, M, N, 1.0) encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) end - Bt .= X + wait_completed(buf) + + copy!(Br, X) return B end @@ -294,19 +317,22 @@ function LinearAlgebra.ldiv!(A::UnitLowerTriangular{T, <:MtlMatrix{T}}, B::MtlVe dev = current_device() queue = global_queue(dev) - A = MtlMatrix(A; storage=Private) - Bt = reshape(B, (N,M)) - X = similar(B) + Ad = MtlMatrix(A) + Br = reshape(B, (M,N)) + X = similar(Br) - mps_a = MPSMatrix(A) - mps_b = MPSMatrix(Bt) + mps_a = MPSMatrix(Ad) + mps_b = MPSMatrix(Br) mps_x = MPSMatrix(X) - MTLCommandBuffer(queue) do cmdbuf - kernel = MPSMatrixSolveTriangular(dev, false, true, false, true, M, N, 1.0) + + buf = MTLCommandBuffer(queue) do cmdbuf + kernel = MPSMatrixSolveTriangular(dev, true, true, false, true, M, N, 1.0) encode!(cmdbuf, kernel, mps_a, mps_b, mps_x) end - Bt .= X + wait_completed(buf) + + copy!(Br, X) return B end \ No newline at end of file diff --git a/test/mps.jl b/test/mps.jl index d88ba4cc3..f8b1688a7 100644 --- a/test/mps.jl +++ b/test/mps.jl @@ -58,4 +58,39 @@ end @test_throws SingularException lu(A) end +@testset "solves" begin + b = MtlVector(rand(Float32, 1024)) + B = MtlMatrix(rand(Float32, 1024, 1024)) + + A = MtlMatrix(rand(Float32, 1024, 512)) + x = lu(A) \ b + @test A * x ≈ b + X = lu(A) \ B + @test A * X ≈ B + + A = UpperTriangular(MtlMatrix(rand(Float32, 1024, 1024))) + x = A \ b + @test A * x ≈ b + X = A \ B + @test A * X ≈ B + + A = UnitUpperTriangular(MtlMatrix(rand(Float32, 1024, 1024))) + x = A \ b + @test A * x ≈ b + X = A \ B + @test A * X ≈ B + + A = LowerTriangular(MtlMatrix(rand(Float32, 1024, 1024))) + x = A \ b + @test A * x ≈ b + X = A \ B + @test A * X ≈ B + + A = UnitLowerTriangular(MtlMatrix(rand(Float32, 1024, 1024))) + x = A \ b + @test A * x ≈ b + X = A \ B + @test A * X ≈ B +end + end \ No newline at end of file