Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support early returns in nested functions #118

Merged
merged 6 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/ParallelKernel/kernel_language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,24 @@ Call a macro analogue to `Base.@println`, compatible with the package for parall
macro pk_println(args...) check_initialized(); esc(pk_println(args...)); end


##
macro return_value(args...) check_initialized(); checksinglearg(args...); esc(return_value(args...)); end


##
macro return_nothing(args...) check_initialized(); checknoargs(args...); esc(return_nothing(args...)); end


## ARGUMENT CHECKS

function checknoargs(args...)
if (length(args) != 0) @ArgumentError("no arguments allowed.") end
end

function checksinglearg(args...)
if (length(args) != 1) @ArgumentError("wrong number of arguments.") end
end

function checkargs_sharedMem(args...)
if !(2 <= length(args) <= 3) @ArgumentError("wrong number of arguments.") end
end
Expand Down Expand Up @@ -177,6 +189,17 @@ function pk_println(args...; package::Symbol=get_package())
end


## FUNCTION FOR DIVERSE TASKS

function return_value(value)
return :(return $value)
end

function return_nothing()
return :(return)
end


## CPU TARGET IMPLEMENTATIONS

macro gridDim_cpu() esc(:(ParallelStencil.ParallelKernel.Dim3($(RANGELENGTHS_VARNAMES[1]), $(RANGELENGTHS_VARNAMES[2]), $(RANGELENGTHS_VARNAMES[3])))) end
Expand Down
2 changes: 1 addition & 1 deletion src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ function parallel_kernel(caller::Module, package::Symbol, numbertype::DataType,
if (!isa(indices,Symbol) && !isa(indices.head,Symbol)) @ArgumentError("@parallel_indices: argument 'indices' must be a tuple of indices or a single index (e.g. (ix, iy, iz) or (ix, iy) or ix ).") end
indices = extract_tuple(indices)
body = get_body(kernel)
body = remove_returns(body)
body = remove_return(body)
use_aliases = !all(indices .== INDICES[1:length(indices)])
if use_aliases # NOTE: we treat explicit parallel indices as aliases to the statically retrievable indices INDICES.
indices_aliases = indices
Expand Down
14 changes: 7 additions & 7 deletions src/ParallelKernel/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,36 +142,36 @@ function push_to_signature!(kernel::Expr, arg::Expr)
return kernel
end

function remove_returns(body::Expr)
function remove_return(body::Expr)
if !(body.args[end] in [:(return), :(return nothing), :(nothing)])
@ArgumentError("invalid kernel in @parallel kernel definition: the last statement must be a `return nothing` statement ('return' or 'return nothing' or 'nothing') as required for any GPU kernels.")
end
body = make_nested_returns_implicit(body)
body = disguise_nested_returns(body)
remainder = copy(body)
remainder.args = body.args[1:end-2]
if inexpr_walk(remainder, :return) @ArgumentError("invalid kernel in @parallel kernel definition: only one return statement is allowed in the kernel (exception: nested function definitions) and it must return nothing and be the last statement (required to ensure equal behaviour with different packages for parallellization).") end
return remainder
end

function make_nested_returns_implicit(body::Expr)
function disguise_nested_returns(body::Expr)
return postwalk(body) do ex
if isdef(ex)
f_elems = splitdef(ex)
body = f_elems[:body]
f_elems[:body] = make_returns_implicit(body)
f_elems[:body] = disguise_returns(body)
return combinedef(f_elems)
else
return ex
end
end
end

function make_returns_implicit(body::Expr)
function disguise_returns(body::Expr)
return postwalk(body) do ex
if @capture(ex, return x_)
return x
return :(ParallelStencil.ParallelKernel.@return_value($x))
elseif @capture(ex, return)
return nothing
return :(ParallelStencil.ParallelKernel.@return_nothing)
else
return ex
end
Expand Down
9 changes: 0 additions & 9 deletions src/kernel_language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ macro memopt(args...) check_initialized(); checkargs_memopt(args...); esc(memopt
macro shortif(args...) check_initialized(); checktwoargs(args...); esc(shortif(args...)); end


##
macro return_nothing(args...) check_initialized(); checknoargs(args...); esc(return_nothing(args...)); end


## ARGUMENT CHECKS

function checknoargs(args...)
Expand Down Expand Up @@ -487,11 +483,6 @@ function shortif(else_val, if_expr; package::Symbol=get_package())
end


function return_nothing()
return :(return)
end


## FUNCTIONS FOR SHARED MEMORY ALLOCATION


Expand Down
6 changes: 3 additions & 3 deletions src/parallel.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import .ParallelKernel: get_name, set_name, get_body, set_body!, add_return, remove_returns, extract_kwargs, split_parallel_args, extract_tuple, substitute, literaltypes, push_to_signature!, add_loop, add_threadids, promote_maxsize
import .ParallelKernel: get_name, set_name, get_body, set_body!, add_return, remove_return, extract_kwargs, split_parallel_args, extract_tuple, substitute, literaltypes, push_to_signature!, add_loop, add_threadids, promote_maxsize

const PARALLEL_DOC = """
@parallel kernel
Expand Down Expand Up @@ -185,7 +185,7 @@ function parallel_indices_memopt(metadata_module::Module, metadata_function::Exp
if (!isa(indices,Symbol) && !isa(indices.head,Symbol)) @ArgumentError("@parallel_indices: argument 'indices' must be a tuple of indices, a single index or a variable followed by the splat operator representing a tuple of indices (e.g. (ix, iy, iz) or (ix, iy) or ix or I...).") end
if (!isa(optvars,Symbol) && !isa(optvars.head,Symbol)) @ArgumentError("@parallel_indices: argument 'optvars' must be a tuple of optvars or a single optvar (e.g. (A, B, C) or A ).") end
body = get_body(kernel)
body = remove_returns(body)
body = remove_return(body)
body = add_memopt(metadata_module, is_parallel_kernel, caller, package, body, indices, optvars, loopdim, loopsize, optranges, useshmemhalos, optimize_halo_read)
body = add_return(body)
set_body!(kernel, body)
Expand All @@ -198,7 +198,7 @@ function parallel_kernel(metadata_module::Module, metadata_function::Expr, calle
memopt = haskey(kwargs, :memopt) ? kwargs.memopt : get_memopt()
indices = get_indices_expr(ndims).args
body = get_body(kernel)
body = remove_returns(body)
body = remove_return(body)
validate_body(body)
kernelargs = splitarg.(extract_kernel_args(kernel)[1])
argvars = (arg[1] for arg in kernelargs)
Expand Down
2 changes: 1 addition & 1 deletion src/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ elseif ENABLE_AMDGPU
using AMDGPU
end
import MacroTools: @capture, postwalk, splitarg # NOTE: inexpr_walk used instead of MacroTools.inexpr
import .ParallelKernel: eval_arg, split_args, split_kwargs, extract_posargs_init, extract_kernel_args, is_kernel, is_call, gensym_world, isgpu, @isgpu, substitute, inexpr_walk, cast, @ranges, @rangelengths
import .ParallelKernel: eval_arg, split_args, split_kwargs, extract_posargs_init, extract_kernel_args, is_kernel, is_call, gensym_world, isgpu, @isgpu, substitute, inexpr_walk, cast, @ranges, @rangelengths, @return_value, @return_nothing
import .ParallelKernel: PKG_CUDA, PKG_AMDGPU, PKG_THREADS, PKG_NONE, NUMBERTYPE_NONE, SUPPORTED_NUMBERTYPES, SUPPORTED_PACKAGES, ERRMSG_UNSUPPORTED_PACKAGE, INT_CUDA, INT_AMDGPU, INT_THREADS, INDICES, PKNumber, RANGES_VARNAME, RANGES_TYPE, RANGELENGTH_XYZ_TYPE, RANGELENGTHS_VARNAMES, THREADIDS_VARNAMES, GENSYM_SEPARATOR, AD_SUPPORTED_ANNOTATIONS
import .ParallelKernel: @require, @symbols, symbols, longnameof, @prettyexpand, @prettystring, prettystring, @gorgeousexpand, @gorgeousstring, gorgeousstring

Expand Down