Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Metal extension for batched_mul #614

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,27 @@ steps:
NNLIB_TEST_CPU: "false"
JULIA_NUM_THREADS: 4

- label: ":julia: Julia 1 + Metal GPU"
plugins:
- JuliaCI/julia#v1:
version: "1"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
agents:
queue: "juliaecosystem"
os: "macos"
arch: "aarch64"
timeout_in_minutes: 180
env:
NNLIB_TEST_METAL: "true"
NNLIB_TEST_CPU: "false"
JULIA_NUM_THREADS: 4

- label: "Benchmarks"
plugins:
- JuliaCI/julia#v1:
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[extensions]
NNlibAMDGPUExt = "AMDGPU"
Expand All @@ -27,6 +28,7 @@ NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
NNlibFFTWExt = "FFTW"
NNlibForwardDiffExt = "ForwardDiff"
NNlibMetalExt = "Metal"

[compat]
AMDGPU = "0.9.4, 1"
Expand All @@ -40,6 +42,7 @@ ForwardDiff = "0.10.36"
GPUArraysCore = "0.1"
KernelAbstractions = "0.9.2"
LinearAlgebra = "<0.0.1, 1"
Metal = "1.4.2"
Random = "<0.0.1, 1"
Statistics = "1"
cuDNN = "1"
Expand Down
50 changes: 50 additions & 0 deletions ext/NNlibMetalExt/NNlibMetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
module NNlibMetalExt

using Metal, NNlib
using NNlib: AbstractRNG # === Random.AbstractRNG

# Random
NNlib._rng_from_array(::MtlArray) = Metal.MPS.default_rng()

NNlib._rng_compat_array(rng::Metal.MPS.RNG, A::MtlArray) = nothing
NNlib._rng_compat_array(rng::AbstractRNG, A::MtlArray) = throw(ArgumentError(
"cannot use rng::$(typeof(rng)) with array::MtlArray, only Metal's own RNG type works"))

# Batched matrix multiplication
function NNlib._batched_gemm!(::Type{<:MtlArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C)
eltype(C) <: Complex && @warn "don't trust this on complex arrays!" transA transB
Metal.MPS.matmul!(C, A, B, α, β, transA != 'N', transB != 'N') # transA, transB, α, A, B, β, C)
end

#=

help?> Metal.MPS.matmul!
matMulMPS(a::MtlMatrix, b::MtlMatrix, c::MtlMatrix, alpha=1, beta=1,
transpose_left=false, transpose_right=false)

A MPSMatrixMultiplication kernel thay computes: c = alpha * op(a) * beta * op(b) + beta * C

This function should not typically be used. Rather, use the normal LinearAlgebra interface with
any MtlArray and it should be accelerated using Metal Performance Shaders.

=#

using NNlib: BatchedAdjoint, BatchedTranspose, BatchedAdjOrTrans
using Adapt
using Adapt: WrappedArray

const MetalBatchedAdjoint{T} = BatchedAdjoint{T, <: MtlArray{T}}
const MetalBatchedTranspose{T} = BatchedTranspose{T, <: MtlArray{T}}
const MetalBatchedAdjOrTrans{T} = Union{MetalBatchedAdjoint{T}, MetalBatchedTranspose{T}}
const WrappedMetalBatchedAdjOrTrans{T, N} = WrappedArray{T, N, MetalBatchedAdjOrTrans{T}, MetalBatchedAdjOrTrans{T}}

Base.print_array(io::IO, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) = Base.print_array(io, adapt(Array, b))
Base._show_nonempty(io::IO, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}, prefix::String) = Base._show_nonempty(io, adapt(Array, b), prefix)
Base.show_vector(io::IO, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}, opn, cls) = Base.show_vector(io, adapt(Array, b), opn, cls)

Base.convert(::Type{T}, b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) where {T<:Array} = Base.convert(T, adapt(Array, b))
Base.Array{T, N}(b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) where {T, N} = Array{T, N}(adapt(Array, b))
Base.collect(b::Union{MetalBatchedAdjOrTrans, WrappedMetalBatchedAdjOrTrans}) = collect(adapt(Array, b))


end # module NNlibMetalExt
82 changes: 82 additions & 0 deletions test/ext_metal/batched_mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
@testset "batched_mul" begin
using NNlib: batched_mul, batched_mul!, batched_vec,
batched_adjoint, batched_transpose

A = randn(Float32, 3,3,2);
B = randn(Float32, 3,3,2);

C = batched_mul(A, B)
@test MtlArray(C) ≈ batched_mul(MtlArray(A), MtlArray(B))

Ct = batched_mul(batched_transpose(A), B)
@test MtlArray(Ct) ≈ batched_mul(batched_transpose(MtlArray(A)), MtlArray(B))

Ca = batched_mul(A, batched_adjoint(B))
@test MtlArray(Ca) ≈ batched_mul(MtlArray(A), batched_adjoint(MtlArray(B)))

# 5-arg batched_mul!
C .= pi
batched_mul!(C, A, B, 2f0, 3f0)
gpuCpi = MtlArray(similar(C)) .= pi
@test MtlArray(C) ≈ batched_mul!(gpuCpi, MtlArray(A), MtlArray(B), 2f0, 3f0)

# PermutedDimsArray
@test MtlArray(Ct) ≈ batched_mul(PermutedDimsArray(MtlArray(A), (2,1,3)), MtlArray(B))

D = permutedims(B, (1,3,2))
Cp = batched_mul(batched_adjoint(A), B)
@test_broken MtlArray(Cp) ≈ batched_mul(batched_adjoint(MtlArray(A)), PermutedDimsArray(MtlArray(D), (1,3,2)))

# Methods which reshape
M = randn(Float32, 3,3)

Cm = batched_mul(A, M)
@test MtlArray(Cm) ≈ batched_mul(MtlArray(A), MtlArray(M))

Cv = batched_vec(permutedims(A,(3,1,2)), M)
@test_broken MtlArray(Cv) ≈ batched_vec(PermutedDimsArray(MtlArray(A),(3,1,2)), MtlArray(M))
end

function print_array_strs(x)
str = sprint((io, x)->show(io, MIME"text/plain"(), x), x)
return @view split(str, '\n')[2:end]
end

@testset "BatchedAdjOrTrans" begin
x = rand(Float32, 3, 4, 2)
y = MtlArray(x)

bax = batched_adjoint(x)
btx = batched_transpose(x)
bay = batched_adjoint(y)
bty = batched_transpose(y)

@test sprint(show, bax) == sprint(show, bay)
@test sprint(show, btx) == sprint(show, bty)

@test print_array_strs(bax) == print_array_strs(bay)
@test print_array_strs(btx) == print_array_strs(bty)

@test Array(bax) == Array(bay)
@test collect(bax) == collect(bay)
@test Array(btx) == Array(bty)
@test collect(btx) == collect(bty)

for shape in (:, (12, 2))
rbax = reshape(bax, shape)
rbtx = reshape(btx, shape)
rbay = reshape(bay, shape)
rbty = reshape(bty, shape)

@test sprint(show, rbax) == sprint(show, rbay)
@test sprint(show, rbtx) == sprint(show, rbty)

@test print_array_strs(rbax) == print_array_strs(rbay)
@test print_array_strs(rbtx) == print_array_strs(rbty)

@test Array(rbax) == Array(rbay)
@test collect(rbax) == collect(rbay)
@test Array(rbtx) == Array(rbty)
@test collect(rbtx) == collect(rbty)
end
end
6 changes: 6 additions & 0 deletions test/ext_metal/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

Metal.allowscalar(false)

@testset "Batched multiplication" begin
include("batched_mul.jl")
end
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursiv

# ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests
# ENV["NNLIB_TEST_AMDGPU"] = "true" # uncomment to run AMDGPU tests
# ENV["NNLIB_TEST_METAL"] = "true" # uncomment to run Metal tests
# ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests

const rng = StableRNG(123)
Expand Down Expand Up @@ -174,4 +175,21 @@ end
else
@info "Skipping AMDGPU tests, set NNLIB_TEST_AMDGPU=true to run them."
end

if get(ENV, "NNLIB_TEST_METAL", "false") == "true"
Pkg.add(["Metal"])

using Metal
if Metal.functional()
@testset "Metal" begin
# nnlib_testsuite(CUDABackend; skip_tests=Set(("Scatter", "Gather")))

include("ext_metal/runtests.jl")
end
else
@info "Metal.jl package is not functional. Skipping Metal tests."
end
else
@info "Skipping Metal tests, set NNLIB_TEST_METAL=true to run them"
end
end
Loading