Skip to content

Commit

Permalink
Merge pull request #118 from omlins/nested-functions
Browse files Browse the repository at this point in the history
Support early returns in nested functions
  • Loading branch information
omlins authored Sep 8, 2023
2 parents 5f8306e + ba3d11b commit 5b9f625
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 21 deletions.
23 changes: 23 additions & 0 deletions src/ParallelKernel/kernel_language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,24 @@ Call a macro analogue to `Base.@println`, compatible with the package for parall
macro pk_println(args...) check_initialized(); esc(pk_println(args...)); end


##
macro return_value(args...) check_initialized(); checksinglearg(args...); esc(return_value(args...)); end


##
macro return_nothing(args...) check_initialized(); checknoargs(args...); esc(return_nothing(args...)); end


## ARGUMENT CHECKS

function checknoargs(args...)
if (length(args) != 0) @ArgumentError("no arguments allowed.") end
end

function checksinglearg(args...)
if (length(args) != 1) @ArgumentError("wrong number of arguments.") end
end

function checkargs_sharedMem(args...)
if !(2 <= length(args) <= 3) @ArgumentError("wrong number of arguments.") end
end
Expand Down Expand Up @@ -177,6 +189,17 @@ function pk_println(args...; package::Symbol=get_package())
end


## FUNCTION FOR DIVERSE TASKS

function return_value(value)
return :(return $value)
end

function return_nothing()
return :(return)
end


## CPU TARGET IMPLEMENTATIONS

macro gridDim_cpu() esc(:(ParallelStencil.ParallelKernel.Dim3($(RANGELENGTHS_VARNAMES[1]), $(RANGELENGTHS_VARNAMES[2]), $(RANGELENGTHS_VARNAMES[3])))) end
Expand Down
2 changes: 1 addition & 1 deletion src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ function parallel_kernel(caller::Module, package::Symbol, numbertype::DataType,
if (!isa(indices,Symbol) && !isa(indices.head,Symbol)) @ArgumentError("@parallel_indices: argument 'indices' must be a tuple of indices or a single index (e.g. (ix, iy, iz) or (ix, iy) or ix ).") end
indices = extract_tuple(indices)
body = get_body(kernel)
body = remove_returns(body)
body = remove_return(body)
use_aliases = !all(indices .== INDICES[1:length(indices)])
if use_aliases # NOTE: we treat explicit parallel indices as aliases to the statically retrievable indices INDICES.
indices_aliases = indices
Expand Down
14 changes: 7 additions & 7 deletions src/ParallelKernel/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,36 +142,36 @@ function push_to_signature!(kernel::Expr, arg::Expr)
return kernel
end

function remove_returns(body::Expr)
function remove_return(body::Expr)
if !(body.args[end] in [:(return), :(return nothing), :(nothing)])
@ArgumentError("invalid kernel in @parallel kernel definition: the last statement must be a `return nothing` statement ('return' or 'return nothing' or 'nothing') as required for any GPU kernels.")
end
body = make_nested_returns_implicit(body)
body = disguise_nested_returns(body)
remainder = copy(body)
remainder.args = body.args[1:end-2]
if inexpr_walk(remainder, :return) @ArgumentError("invalid kernel in @parallel kernel definition: only one return statement is allowed in the kernel (exception: nested function definitions) and it must return nothing and be the last statement (required to ensure equal behaviour with different packages for parallellization).") end
return remainder
end

function make_nested_returns_implicit(body::Expr)
function disguise_nested_returns(body::Expr)
return postwalk(body) do ex
if isdef(ex)
f_elems = splitdef(ex)
body = f_elems[:body]
f_elems[:body] = make_returns_implicit(body)
f_elems[:body] = disguise_returns(body)
return combinedef(f_elems)
else
return ex
end
end
end

function make_returns_implicit(body::Expr)
function disguise_returns(body::Expr)
return postwalk(body) do ex
if @capture(ex, return x_)
return x
return :(ParallelStencil.ParallelKernel.@return_value($x))
elseif @capture(ex, return)
return nothing
return :(ParallelStencil.ParallelKernel.@return_nothing)
else
return ex
end
Expand Down
9 changes: 0 additions & 9 deletions src/kernel_language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ macro memopt(args...) check_initialized(); checkargs_memopt(args...); esc(memopt
macro shortif(args...) check_initialized(); checktwoargs(args...); esc(shortif(args...)); end


##
macro return_nothing(args...) check_initialized(); checknoargs(args...); esc(return_nothing(args...)); end


## ARGUMENT CHECKS

function checknoargs(args...)
Expand Down Expand Up @@ -487,11 +483,6 @@ function shortif(else_val, if_expr; package::Symbol=get_package())
end


function return_nothing()
return :(return)
end


## FUNCTIONS FOR SHARED MEMORY ALLOCATION


Expand Down
6 changes: 3 additions & 3 deletions src/parallel.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import .ParallelKernel: get_name, set_name, get_body, set_body!, add_return, remove_returns, extract_kwargs, split_parallel_args, extract_tuple, substitute, literaltypes, push_to_signature!, add_loop, add_threadids, promote_maxsize
import .ParallelKernel: get_name, set_name, get_body, set_body!, add_return, remove_return, extract_kwargs, split_parallel_args, extract_tuple, substitute, literaltypes, push_to_signature!, add_loop, add_threadids, promote_maxsize

const PARALLEL_DOC = """
@parallel kernel
Expand Down Expand Up @@ -185,7 +185,7 @@ function parallel_indices_memopt(metadata_module::Module, metadata_function::Exp
if (!isa(indices,Symbol) && !isa(indices.head,Symbol)) @ArgumentError("@parallel_indices: argument 'indices' must be a tuple of indices, a single index or a variable followed by the splat operator representing a tuple of indices (e.g. (ix, iy, iz) or (ix, iy) or ix or I...).") end
if (!isa(optvars,Symbol) && !isa(optvars.head,Symbol)) @ArgumentError("@parallel_indices: argument 'optvars' must be a tuple of optvars or a single optvar (e.g. (A, B, C) or A ).") end
body = get_body(kernel)
body = remove_returns(body)
body = remove_return(body)
body = add_memopt(metadata_module, is_parallel_kernel, caller, package, body, indices, optvars, loopdim, loopsize, optranges, useshmemhalos, optimize_halo_read)
body = add_return(body)
set_body!(kernel, body)
Expand All @@ -198,7 +198,7 @@ function parallel_kernel(metadata_module::Module, metadata_function::Expr, calle
memopt = haskey(kwargs, :memopt) ? kwargs.memopt : get_memopt()
indices = get_indices_expr(ndims).args
body = get_body(kernel)
body = remove_returns(body)
body = remove_return(body)
validate_body(body)
kernelargs = splitarg.(extract_kernel_args(kernel)[1])
argvars = (arg[1] for arg in kernelargs)
Expand Down
2 changes: 1 addition & 1 deletion src/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ elseif ENABLE_AMDGPU
using AMDGPU
end
import MacroTools: @capture, postwalk, splitarg # NOTE: inexpr_walk used instead of MacroTools.inexpr
import .ParallelKernel: eval_arg, split_args, split_kwargs, extract_posargs_init, extract_kernel_args, is_kernel, is_call, gensym_world, isgpu, @isgpu, substitute, inexpr_walk, cast, @ranges, @rangelengths
import .ParallelKernel: eval_arg, split_args, split_kwargs, extract_posargs_init, extract_kernel_args, is_kernel, is_call, gensym_world, isgpu, @isgpu, substitute, inexpr_walk, cast, @ranges, @rangelengths, @return_value, @return_nothing
import .ParallelKernel: PKG_CUDA, PKG_AMDGPU, PKG_THREADS, PKG_NONE, NUMBERTYPE_NONE, SUPPORTED_NUMBERTYPES, SUPPORTED_PACKAGES, ERRMSG_UNSUPPORTED_PACKAGE, INT_CUDA, INT_AMDGPU, INT_THREADS, INDICES, PKNumber, RANGES_VARNAME, RANGES_TYPE, RANGELENGTH_XYZ_TYPE, RANGELENGTHS_VARNAMES, THREADIDS_VARNAMES, GENSYM_SEPARATOR, AD_SUPPORTED_ANNOTATIONS
import .ParallelKernel: @require, @symbols, symbols, longnameof, @prettyexpand, @prettystring, prettystring, @gorgeousexpand, @gorgeousstring, gorgeousstring

Expand Down

0 comments on commit 5b9f625

Please sign in to comment.