Skip to content

Commit

Permalink
Merge pull request #153 from omlins/amdparams
Browse files Browse the repository at this point in the history
Add AMD-specific launch parameters
  • Loading branch information
omlins authored Jul 5, 2024
2 parents e1194ff + fc489a7 commit 1a944d8
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 40 deletions.
9 changes: 6 additions & 3 deletions src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,9 @@ function parallel_call_gpu(nblocks::Union{Symbol,Expr}, nthreads::Union{Symbol,E
end

function parallel_call_gpu(ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool, package::Symbol; stream::Union{Symbol,Expr}=default_stream(package), shmem::Union{Symbol,Expr,Nothing}=nothing, launch::Bool=true, configcall::Expr=kernelcall)
nthreads_x_max = determine_nthreads_x_max(package)
maxsize = :(length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)))
nthreads = :( ParallelStencil.ParallelKernel.compute_nthreads($maxsize) )
nthreads = :( ParallelStencil.ParallelKernel.compute_nthreads($maxsize; nthreads_x_max=$nthreads_x_max) )
nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) )
parallel_call_gpu(ranges, nblocks, nthreads, kernelcall, backend_kwargs_expr, async, package; stream=stream, shmem=shmem, launch=launch)
end
Expand Down Expand Up @@ -522,9 +523,9 @@ function compute_ranges(maxsize)
return (1:maxsize[1], 1:maxsize[2], 1:maxsize[3])
end

function compute_nthreads(maxsize; nthreads_max=NTHREADS_MAX, flatdim=0) # This is a heuristic, which results in (32,8,1) threads, except if maxsize[1] < 32 or maxsize[2] < 8.
function compute_nthreads(maxsize; nthreads_x_max=NTHREADS_X_MAX, nthreads_max=NTHREADS_MAX, flatdim=0) # This is a heuristic, which results in (32,8,1) threads, except if maxsize[1] < 32 or maxsize[2] < 8.
maxsize = promote_maxsize(maxsize)
nthreads_x = min(32, (flatdim==1) ? 1 : maxsize[1])
nthreads_x = min(nthreads_x_max, (flatdim==1) ? 1 : maxsize[1])
nthreads_y = min(ceil(Int,nthreads_max/nthreads_x), (flatdim==2) ? 1 : maxsize[2])
nthreads_z = min(ceil(Int,nthreads_max/(nthreads_x*nthreads_y)), (flatdim==3) ? 1 : maxsize[3])
return (nthreads_x, nthreads_y , nthreads_z)
Expand All @@ -536,6 +537,8 @@ function compute_nblocks(maxsize, nthreads)
return ceil.(Int, maxsize./nthreads)
end

determine_nthreads_x_max(package::Symbol) = (package == PKG_AMDGPU) ? NTHREADS_X_MAX_AMDGPU : NTHREADS_X_MAX


## FUNCTIONS TO CREATE KERNEL LAUNCH AND SYNCHRONIZATION CALLS

Expand Down
2 changes: 2 additions & 0 deletions src/ParallelKernel/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ const INT_CUDA = Int64 # NOTE: unsigned integers are not yet
const INT_AMDGPU = Int64 # NOTE: ...
const INT_POLYESTER = Int64 # NOTE: ...
const INT_THREADS = Int64 # NOTE: ...
const NTHREADS_X_MAX = 32
const NTHREADS_X_MAX_AMDGPU = 64
const NTHREADS_MAX = 256
const INDICES = (gensym_world("ix", @__MODULE__), gensym_world("iy", @__MODULE__), gensym_world("iz", @__MODULE__))
const RANGES_VARNAME = gensym_world("ranges", @__MODULE__)
Expand Down
31 changes: 19 additions & 12 deletions src/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ function parallel(source::LineNumberNode, caller::Module, args::Union{Symbol,Exp
if (length(posargs) > 1) @ArgumentError("maximum one positional argument (ranges) is allowed in a @parallel memopt=true call.") end
parallel_call_memopt(caller, posargs..., kernelarg, backend_kwargs_expr, async; kwargs...)
else
ParallelKernel.parallel(caller, posargs..., backend_kwargs_expr..., configcall_kwarg_expr, kernelarg; package=package)
ParallelKernel.parallel(caller, posargs..., backend_kwargs_expr..., configcall_kwarg_expr, kernelarg; package=package, async=async)
end
end
end
Expand Down Expand Up @@ -321,6 +321,9 @@ end

function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall)
if haskey(backend_kwargs_expr, :shmem) @KeywordArgumentError("@parallel <kernelcall>: keyword `shmem` is not allowed when memopt=true is set.") end
package = get_package(caller)
nthreads_x_max = ParallelKernel.determine_nthreads_x_max(package)
nthreads_max_memopt = determine_nthreads_max_memopt(package)
configcall_kwarg_expr = :(configcall=$configcall)
metadata_call = create_metadata_call(configcall)
metadata_module = metadata_call
Expand All @@ -331,7 +334,7 @@ function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernel
loopsize = :($(metadata_module).loopsize)
loopsizes = :(($loopdim==3) ? (1, 1, $loopsize) : ($loopdim==2) ? (1, $loopsize, 1) : ($loopsize, 1, 1))
maxsize = :(cld.(length.(ParallelStencil.ParallelKernel.promote_ranges($ranges)), $loopsizes))
nthreads = :( ParallelStencil.compute_nthreads_memopt($maxsize, $loopdim, $stencilranges) )
nthreads = :( ParallelStencil.compute_nthreads_memopt($nthreads_x_max, $nthreads_max_memopt, $maxsize, $loopdim, $stencilranges) )
nblocks = :( ParallelStencil.ParallelKernel.compute_nblocks($maxsize, $nthreads) )
numbertype = get_numbertype(caller) # not :(eltype($(optvars)[1])) # TODO: see how to obtain number type properly for each array: the type of the call call arguments corresponding to the optimization variables should be checked
dim1 = :(($loopdim==3) ? 1 : ($loopdim==2) ? 1 : 2) # TODO: to be determined if that is what is desired for loopdim 1 and 2.
Expand All @@ -344,11 +347,14 @@ function parallel_call_memopt(caller::Module, ranges::Union{Symbol,Expr}, kernel
end

function parallel_call_memopt(caller::Module, kernelcall::Expr, backend_kwargs_expr::Array, async::Bool; memopt::Bool=false, configcall::Expr=kernelcall)
metadata_call = create_metadata_call(configcall)
metadata_module = metadata_call
loopdim = :($(metadata_module).loopdim)
is_parallel_kernel = :($(metadata_module).is_parallel_kernel)
ranges = :( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...)))
package = get_package(caller)
nthreads_x_max = ParallelKernel.determine_nthreads_x_max(package)
nthreads_max_memopt = determine_nthreads_max_memopt(package)
metadata_call = create_metadata_call(configcall)
metadata_module = metadata_call
loopdim = :($(metadata_module).loopdim)
is_parallel_kernel = :($(metadata_module).is_parallel_kernel)
ranges = :( ($is_parallel_kernel) ? ParallelStencil.get_ranges_memopt($nthreads_x_max, $nthreads_max_memopt, $loopdim, $(configcall.args[2:end]...)) : ParallelStencil.ParallelKernel.get_ranges($(configcall.args[2:end]...)))
parallel_call_memopt(caller, ranges, kernelcall, backend_kwargs_expr, async; memopt=memopt, configcall=configcall)
end

Expand All @@ -362,15 +368,16 @@ end

## FUNCTIONS TO DETERMINE OPTIMIZATION PARAMETERS

determine_nthreads_max_memopt(package::Symbol) = (package == PKG_AMDGPU) ? NTHREADS_MAX_MEMOPT_AMDGPU : NTHREADS_MAX_MEMOPT_CUDA
determine_loopdim(indices::Union{Symbol,Expr}) = isa(indices,Expr) && (length(indices.args)==3) ? 3 : LOOPDIM_NONE # TODO: currently only loopdim=3 is supported.
compute_loopsize() = LOOPSIZE
compute_loopsize() = LOOPSIZE


## FUNCTIONS TO COMPUTE NTHREADS, NBLOCKS

function compute_nthreads_memopt(maxsize, loopdim, stencilranges) # This is a heuristic, which results typcially in (32,4,1) threads for a 3-D case.
function compute_nthreads_memopt(nthreads_x_max, nthreads_max_memopt, maxsize, loopdim, stencilranges) # This is a heuristic, which results typcially in (32,4,1) threads for a 3-D case.
maxsize = promote_maxsize(maxsize)
nthreads = ParallelKernel.compute_nthreads(maxsize; nthreads_max=NTHREADS_MAX_LOOPOPT, flatdim=loopdim)
nthreads = ParallelKernel.compute_nthreads(maxsize; nthreads_x_max=nthreads_x_max, nthreads_max=nthreads_max_memopt, flatdim=loopdim)
for stencilranges_A in values(stencilranges)
haloextensions = ((length(stencilranges_A[1])-1)*(loopdim!=1), (length(stencilranges_A[2])-1)*(loopdim!=2), (length(stencilranges_A[3])-1)*(loopdim!=3))
if (2*prod(nthreads) < prod(nthreads .+ haloextensions)) @ArgumentError("@parallel <kernelcall>: the automatic determination of nthreads is not possible for this case. Please specify `nthreads` and `nblocks`.") end # NOTE: this is a simple heuristic to compute compare the number of threads to the total number of cells including halo.
Expand All @@ -380,10 +387,10 @@ function compute_nthreads_memopt(maxsize, loopdim, stencilranges) # This is a he
return nthreads
end

function get_ranges_memopt(loopdim, args...)
function get_ranges_memopt(nthreads_x_max, nthreads_max_memopt, loopdim, args...)
ranges = ParallelKernel.get_ranges(args...)
maxsize = length.(ranges)
nthreads = ParallelKernel.compute_nthreads(maxsize; nthreads_max=NTHREADS_MAX_LOOPOPT, flatdim=loopdim)
nthreads = ParallelKernel.compute_nthreads(maxsize; nthreads_x_max=nthreads_x_max, nthreads_max=nthreads_max_memopt, flatdim=loopdim)
# TODO: the following code reduces performance from ~482 GB/s to ~478 GB/s
rests = maxsize .% nthreads
ranges_adjustment = ( (rests[1] != 0) ? (nthreads[1] - rests[1]) : 0,
Expand Down
33 changes: 17 additions & 16 deletions src/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,23 @@ Return an expression that evaluates to `true` if the indices generated by @paral
This macro is not intended for explicit manual usage. Calls to it are automatically added by @parallel where required.
"""

const SUPPORTED_NDIMS = [1, 2, 3]
const NDIMS_NONE = 0
const ERRMSG_KERNEL_UNSUPPORTED = "unsupported kernel statements in @parallel kernel definition: @parallel is only applicable to kernels that contain exclusively array assignments using macros from FiniteDifferences{1|2|3}D or from another compatible computation submodule. @parallel_indices supports any kind of statements in the kernels."
const ERRMSG_CHECK_NDIMS = "ndims must be evaluatable at parse time (e.g. literal or constant) and has to be one of the following Integers: $(join(SUPPORTED_NDIMS,", "))"
const ERRMSG_CHECK_MEMOPT = "memopt must be evaluatable at parse time (e.g. literal or constant) and has to be of type Bool."
const PSNumber = PKNumber
const LOOPSIZE = 16
const LOOPDIM_NONE = 0
const NTHREADS_MAX_LOOPOPT = 128
const USE_SHMEMHALO_DEFAULT = true
const USE_SHMEMHALO_1D_DEFAULT = true
const USE_FULLRANGE_DEFAULT = (false, false, true)
const FULLRANGE_THRESHOLD = 1
const NOEXPR = :(begin end)
const MOD_METADATA = :__metadata__ # gensym_world("__metadata__", @__MODULE__) # # TODO: name mangling should be used here later, or if there is any sense to leave it like that then at check whether it's available must be done before creating it
const META_FUNCTION_PREFIX = string(gensym_world("META", @__MODULE__))
const SUPPORTED_NDIMS = [1, 2, 3]
const NDIMS_NONE = 0
const ERRMSG_KERNEL_UNSUPPORTED = "unsupported kernel statements in @parallel kernel definition: @parallel is only applicable to kernels that contain exclusively array assignments using macros from FiniteDifferences{1|2|3}D or from another compatible computation submodule. @parallel_indices supports any kind of statements in the kernels."
const ERRMSG_CHECK_NDIMS = "ndims must be evaluatable at parse time (e.g. literal or constant) and has to be one of the following Integers: $(join(SUPPORTED_NDIMS,", "))"
const ERRMSG_CHECK_MEMOPT = "memopt must be evaluatable at parse time (e.g. literal or constant) and has to be of type Bool."
const PSNumber = PKNumber
const LOOPSIZE = 16
const LOOPDIM_NONE = 0
const NTHREADS_MAX_MEMOPT_CUDA = 128
const NTHREADS_MAX_MEMOPT_AMDGPU = 256
const USE_SHMEMHALO_DEFAULT = true
const USE_SHMEMHALO_1D_DEFAULT = true
const USE_FULLRANGE_DEFAULT = (false, false, true)
const FULLRANGE_THRESHOLD = 1
const NOEXPR = :(begin end)
const MOD_METADATA = :__metadata__ # gensym_world("__metadata__", @__MODULE__) # # TODO: name mangling should be used here later, or if there is any sense to leave it like that then at check whether it's available must be done before creating it
const META_FUNCTION_PREFIX = string(gensym_world("META", @__MODULE__))


## FUNCTIONS TO DEAL WITH KERNEL DEFINITIONS
Expand Down
Loading

0 comments on commit 1a944d8

Please sign in to comment.