Skip to content

Commit

Permalink
add Metal extension for batched_mul
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 15, 2024
1 parent 0213868 commit 094c435
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 0 deletions.
21 changes: 21 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,27 @@ steps:
NNLIB_TEST_CPU: "false"
JULIA_NUM_THREADS: 4

- label: ":julia: Julia 1 + Metal GPU"
plugins:
- JuliaCI/julia#v1:
version: "1"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
agents:
queue: "juliaecosystem"
os: "macos"
arch: "aarch64"
timeout_in_minutes: 180
env:
NNLIB_TEST_METAL: "true"
NNLIB_TEST_CPU: "false"
JULIA_NUM_THREADS: 4

- label: "Benchmarks"
plugins:
- JuliaCI/julia#v1:
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[extensions]
NNlibAMDGPUExt = "AMDGPU"
Expand All @@ -27,6 +28,7 @@ NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
NNlibFFTWExt = "FFTW"
NNlibForwardDiffExt = "ForwardDiff"
NNlibMetalExt = "Metal"

[compat]
AMDGPU = "0.9.4, 1"
Expand All @@ -40,6 +42,7 @@ ForwardDiff = "0.10.36"
GPUArraysCore = "0.1"
KernelAbstractions = "0.9.2"
LinearAlgebra = "<0.0.1, 1"
Metal = "1.4.2"
Random = "<0.0.1, 1"
Statistics = "1"
cuDNN = "1"
Expand Down
50 changes: 50 additions & 0 deletions ext/NNlibMetalExt/NNlibMetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
module NNlibMetalExt

using Metal, NNlib
using NNlib: AbstractRNG # === Random.AbstractRNG

# Random
NNlib._rng_from_array(::MtlArray) = Metal.MPS.default_rng()

NNlib._rng_compat_array(rng::Metal.MPS.RNG, A::MtlArray) = nothing
NNlib._rng_compat_array(rng::AbstractRNG, A::MtlArray) = throw(ArgumentError(
"cannot use rng::$(typeof(rng)) with array::MtlArray, only Metal's own RNG type works"))

# Batched matrix multiplication
function NNlib._batched_gemm!(::Type{<:MtlArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C)
eltype(C) <: Complex && @warn "don't trust this on complex arrays!" transA transB
Metal.MPS.matmul!(C, A, B, α, β, transA != 'N', transB != 'N') # transA, transB, α, A, B, β, C)
end

#=
help?> Metal.MPS.matmul!
matMulMPS(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
transpose_left=false, transpose_right=false)
A MPSMatrixMultiplication kernel thay computes: c = alpha * op(a) * beta * op(b) + beta * C
This function should not typically be used. Rather, use the normal LinearAlgebra interface with
any MtlArray and it should be accelerated using Metal Performance Shaders.
=#

using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans
using Adapt
using Adapt: WrappedArray

const MetalBatchedAdjoint{T} = BatchedAdjoint{T, <: MtlArray{T}}
const MetalBatchedTranspose{T} = BatchedTranspose{T, <: MtlArray{T}}
const MetalBatchedAdjOrTrans{T} = Union{MetalBatchedAdjoint{T}, MetalBatchedTranspose{T}}
const WrappedMetalBatchedAdjOrTrans{T, N} = WrappedArray{T, N, MetalBatchedAdjOrTrans{T}, MetalBatchedAdjOrTrans{T}}

Base.print_array(io::IO, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) = Base.print_array(io, adapt(Array, b))
Base._show_nonempty(io::IO, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}, prefix::String) = Base._show_nonempty(io, adapt(Array, b), prefix)
Base.show_vector(io::IO, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}, opn, cls) = Base.show_vector(io, adapt(Array, b), opn, cls)

Base.convert(::Type{T}, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) where {T<:Array} = Base.convert(T, adapt(Array, b))
Base.Array{T, N}(b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) where {T, N} = Array{T, N}(adapt(Array, b))
Base.collect(b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) = collect(adapt(Array, b))


end # module NNlibMetalExt
82 changes: 82 additions & 0 deletions test/ext_metal/batched_mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
@testset "batched_mul" begin
using NNlib: batched_mul, batched_mul!, batched_vec,
batched_adjoint, batched_transpose

A = randn(Float32, 3,3,2);
B = randn(Float32, 3,3,2);

C = batched_mul(A, B)
@test MtlArray(C) batched_mul(MtlArray(A), MtlArray(B))

Ct = batched_mul(batched_transpose(A), B)
@test MtlArray(Ct) batched_mul(batched_transpose(MtlArray(A)), MtlArray(B))

Ca = batched_mul(A, batched_adjoint(B))
@test MtlArray(Ca) batched_mul(MtlArray(A), batched_adjoint(MtlArray(B)))

# 5-arg batched_mul!
C .= pi
batched_mul!(C, A, B, 2f0, 3f0)
gpuCpi = MtlArray(similar(C)) .= pi
@test MtlArray(C) batched_mul!(gpuCpi, MtlArray(A), MtlArray(B), 2f0, 3f0)

# PermutedDimsArray
@test MtlArray(Ct) batched_mul(PermutedDimsArray(MtlArray(A), (2,1,3)), MtlArray(B))

D = permutedims(B, (1,3,2))
Cp = batched_mul(batched_adjoint(A), B)
@test_broken MtlArray(Cp) batched_mul(batched_adjoint(MtlArray(A)), PermutedDimsArray(MtlArray(D), (1,3,2)))

# Methods which reshape
M = randn(Float32, 3,3)

Cm = batched_mul(A, M)
@test MtlArray(Cm) batched_mul(MtlArray(A), MtlArray(M))

Cv = batched_vec(permutedims(A,(3,1,2)), M)
@test_broken MtlArray(Cv) batched_vec(PermutedDimsArray(MtlArray(A),(3,1,2)), MtlArray(M))
end

function print_array_strs(x)
str = sprint((io, x)->show(io, MIME"text/plain"(), x), x)
return @view split(str, '\n')[2:end]
end

@testset "BatchedAdjOrTrans" begin
x = rand(Float32, 3, 4, 2)
y = MtlArray(x)

bax = batched_adjoint(x)
btx = batched_transpose(x)
bay = batched_adjoint(y)
bty = batched_transpose(y)

@test sprint(show, bax) == sprint(show, bay)
@test sprint(show, btx) == sprint(show, bty)

@test print_array_strs(bax) == print_array_strs(bay)
@test print_array_strs(btx) == print_array_strs(bty)

@test Array(bax) == Array(bay)
@test collect(bax) == collect(bay)
@test Array(btx) == Array(bty)
@test collect(btx) == collect(bty)

for shape in (:, (12, 2))
rbax = reshape(bax, shape)
rbtx = reshape(btx, shape)
rbay = reshape(bay, shape)
rbty = reshape(bty, shape)

@test sprint(show, rbax) == sprint(show, rbay)
@test sprint(show, rbtx) == sprint(show, rbty)

@test print_array_strs(rbax) == print_array_strs(rbay)
@test print_array_strs(rbtx) == print_array_strs(rbty)

@test Array(rbax) == Array(rbay)
@test collect(rbax) == collect(rbay)
@test Array(rbtx) == Array(rbty)
@test collect(rbtx) == collect(rbty)
end
end
6 changes: 6 additions & 0 deletions test/ext_metal/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

Metal.allowscalar(false)

@testset "Batched multiplication" begin
include("batched_mul.jl")
end
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursiv

# ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests
# ENV["NNLIB_TEST_AMDGPU"] = "true" # uncomment to run AMDGPU tests
# ENV["NNLIB_TEST_METAL"] = "true" # uncomment to run Metal tests
# ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests

const rng = StableRNG(123)
Expand Down Expand Up @@ -174,4 +175,21 @@ end
else
@info "Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them."
end

if get(ENV, "NNLIB_TEST_METAL", "false") == "true"
Pkg.add(["Metal"])

using Metal
if Metal.functional()
@testset "Metal" begin
# nnlib_testsuite(CUDABackend; skip_tests=Set(("Scatter", "Gather")))

include("ext_metal/runtests.jl")
end
else
@info "Metal.jl package is not functional. Skipping Metal tests."
end
else
@info "Skipping Metal tests, set NNLIB_TEST_METAL=true to run them"
end
end

0 comments on commit 094c435

Please sign in to comment.