Skip to content

Commit

Permalink
Update to use Metal.device() instead of Metal.c u rrent_device() (the…
Browse files Browse the repository at this point in the history
… latter is deprecated)

Also add tests there were TODO
  • Loading branch information
GiackAloZ committed Oct 30, 2024
1 parent 7826b4d commit 7179816
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/ParallelKernel/MetalExt/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ let
metalqueues = Array{MTL.MTLCommandQueue}(undef, 0)

function get_priority_metalstream(id::Integer)
while (id > length(priority_metalqueues)) push!(priority_metalqueues, MTL.MTLCommandQueue(Metal.current_device())) end # No priority setting available in Metal queues.
while (id > length(priority_metalqueues)) push!(priority_metalqueues, MTL.MTLCommandQueue(Metal.device())) end # No priority setting available in Metal queues.
return priority_metalqueues[id]
end

function get_metalstream(id::Integer)
while (id > length(metalqueues)) push!(metalqueues, MTL.MTLCommandQueue(Metal.current_device())) end
while (id > length(metalqueues)) push!(metalqueues, MTL.MTLCommandQueue(Metal.device())) end
return metalqueues[id]
end
end
2 changes: 1 addition & 1 deletion src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ end
function default_stream(package)
if (package == PKG_CUDA) return :(CUDA.stream()) # Use the default stream of the task.
elseif (package == PKG_AMDGPU) return :(AMDGPU.stream()) # Use the default stream of the task.
elseif (package == PKG_METAL) return :(Metal.global_queue(Metal.current_device())) # Use the default queue of the task.
elseif (package == PKG_METAL) return :(Metal.global_queue(Metal.device())) # Use the default queue of the task.
else @ModuleInternalError("unsupported GPU package (obtained: $package).")
end
end
4 changes: 2 additions & 2 deletions test/ParallelKernel/test_kernel_language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ eval(:(
@test @prettystring(1, @threadIdx()) == "Metal.thread_position_in_threadgroup_3d()"
@test @prettystring(1, @sync_threads()) == "Metal.threadgroup_barrier(; flag = Metal.MemoryFlagThreadGroup)"
@test @prettystring(1, @sharedMem($precision, (2,3))) == "ParallelStencil.ParallelKernel.@sharedMem_metal $(nameof($precision)) (2, 3)"
# @test @prettystring(1, @pk_show()) == "Metal.@mtlshow"
# @test @prettystring(1, @pk_println()) == "Metal.@mtlprintln"
# @test @prettystring(1, @pk_show()) == "Metal.@mtlshow" #TODO: not yet supported for Metal
# @test @prettystring(1, @pk_println()) == "Metal.@mtlprintln" #TODO: not yet supported for Metal
elseif @iscpu($package)
@test @prettystring(1, @gridDim()) == "ParallelStencil.ParallelKernel.@gridDim_cpu"
@test @prettystring(1, @blockIdx()) == "ParallelStencil.ParallelKernel.@blockIdx_cpu"
Expand Down
12 changes: 11 additions & 1 deletion test/ParallelKernel/test_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,17 @@ eval(:(
call = @prettystring(1, @parallel nblocks nthreads stream=mystream f(A))
@test occursin("AMDGPU.@roc gridsize = nblocks groupsize = nthreads stream = mystream f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[3])))", call)
elseif $package == $PKG_METAL
## TODO
call = @prettystring(1, @parallel f(A))
@test occursin("Metal.@metal groups = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A))); nthreads_x_max = 32) queue = Metal.global_queue(Metal.device()) f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))", call)
@test occursin("Metal.synchronize(Metal.global_queue(Metal.device()))", call)
call = @prettystring(1, @parallel ranges f(A))
@test occursin("Metal.@metal groups = ParallelStencil.ParallelKernel.compute_nblocks(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)), ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32)) threads = ParallelStencil.ParallelKernel.compute_nthreads(length.(ParallelStencil.ParallelKernel.promote_ranges(ranges)); nthreads_x_max = 32) queue = Metal.global_queue(Metal.device()) f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call)
call = @prettystring(1, @parallel nblocks nthreads f(A))
@test occursin("Metal.@metal groups = nblocks threads = nthreads queue = Metal.global_queue(Metal.device()) f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[3])))", call)
call = @prettystring(1, @parallel ranges nblocks nthreads f(A))
@test occursin("Metal.@metal groups = nblocks threads = nthreads queue = Metal.global_queue(Metal.device()) f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))", call)
call = @prettystring(1, @parallel nblocks nthreads stream=mystream f(A))
@test occursin("Metal.@metal groups = nblocks threads = nthreads queue = mystream f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.compute_ranges(nblocks .* nthreads)))[3])))", call)
elseif @iscpu($package)
@test @prettystring(1, @parallel f(A)) == "f(A, ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ParallelStencil.ParallelKernel.get_ranges(A)))[3])))"
@test @prettystring(1, @parallel ranges f(A)) == "f(A, ParallelStencil.ParallelKernel.promote_ranges(ranges), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[1])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[2])), (Int64)(length((ParallelStencil.ParallelKernel.promote_ranges(ranges))[3])))"
Expand Down

0 comments on commit 7179816

Please sign in to comment.