Skip to content

Commit

Permalink
Merge pull request #177 from omlins/padding
Browse files Browse the repository at this point in the history
Add additional handling of field arrays
  • Loading branch information
omlins authored Oct 31, 2024
2 parents 0390670 + 3b2112d commit d7b00ca
Show file tree
Hide file tree
Showing 24 changed files with 680 additions and 437 deletions.
185 changes: 92 additions & 93 deletions src/FiniteDifferences.jl

Large diffs are not rendered by default.

67 changes: 47 additions & 20 deletions src/ParallelKernel/FieldAllocators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ To see a description of a macro type `?<macroname>` (including the `@`).
module FieldAllocators

using ..Exceptions
import ..ParallelKernel: check_initialized, get_numbertype, extract_kwargvalues, split_args, clean_args, is_same, extract_tuple, extract_kwargs
import ..ParallelKernel: check_initialized, get_numbertype, get_padding, extract_kwargvalues, split_args, clean_args, is_same, extract_tuple, extract_kwargs
import ..ParallelKernel: NUMBERTYPE_NONE, FIELDTYPES


Expand Down Expand Up @@ -447,28 +447,55 @@ function _allocate(caller::Module; gridsize=nothing, fields=nothing, allocator=n
end

function _field(caller::Module, gridsize, allocator=:@zeros; eltype=nothing, sizetemplate=nothing)
eltype = determine_eltype(caller, eltype)
if (sizetemplate == :X) arraysize = :($gridsize .+ ((length($gridsize)==3) ? (-1,-2,-2) : (length($gridsize)==2) ? (-1,-2) : -1))
elseif (sizetemplate == :Y) arraysize = :($gridsize .+ ((length($gridsize)==3) ? (-2,-1,-2) : (length($gridsize)==2) ? (-2,-1) : -2))
elseif (sizetemplate == :Z) arraysize = :($gridsize .+ ((length($gridsize)==3) ? (-2,-2,-1) : (length($gridsize)==2) ? (-2,-2) : -2))
elseif (sizetemplate == :BX) arraysize = :($gridsize .+ ((length($gridsize)==3) ? (+1, 0, 0) : (length($gridsize)==2) ? (+1, 0) : +1))
elseif (sizetemplate == :BY) arraysize = :($gridsize .+ ((length($gridsize)==3) ? ( 0,+1, 0) : (length($gridsize)==2) ? ( 0,+1) : 0))
elseif (sizetemplate == :BZ) arraysize = :($gridsize .+ ((length($gridsize)==3) ? ( 0, 0,+1) : (length($gridsize)==2) ? ( 0, 0) : 0))
elseif (sizetemplate == :XX) arraysize = :($gridsize .+ ((length($gridsize)==3) ? ( 0,-2,-2) : (length($gridsize)==2) ? ( 0,-2) : 0))
elseif (sizetemplate == :YY) arraysize = :($gridsize .+ ((length($gridsize)==3) ? (-2, 0,-2) : (length($gridsize)==2) ? (-2, 0) : -2))
elseif (sizetemplate == :ZZ) arraysize = :($gridsize .+ ((length($gridsize)==3) ? (-2,-2, 0) : (length($gridsize)==2) ? (-2,-2) : -2))
elseif (sizetemplate == :XY) arraysize = :($gridsize .+ ((length($gridsize)==3) ? (-1,-1,-2) : (length($gridsize)==2) ? (-1,-1) : -1))
elseif (sizetemplate == :XZ) arraysize = :($gridsize .+ ((length($gridsize)==3) ? (-1,-2,-1) : (length($gridsize)==2) ? (-1,-2) : -1))
elseif (sizetemplate == :YZ) arraysize = :($gridsize .+ ((length($gridsize)==3) ? (-2,-1,-1) : (length($gridsize)==2) ? (-2,-1) : -2))
else arraysize = gridsize
padding = get_padding(caller)
eltype = determine_eltype(caller, eltype)
if padding
if (sizetemplate in (:X, :BX)) arraysize = :(map(+, $gridsize, (+1, 0, 0)))
elseif (sizetemplate in (:Y, :BY)) arraysize = :(map(+, $gridsize, ( 0,+1, 0)))
elseif (sizetemplate in (:Z, :BZ)) arraysize = :(map(+, $gridsize, ( 0, 0,+1)))
elseif (sizetemplate == :XY) arraysize = :(map(+, $gridsize, (+1,+1, 0)))
elseif (sizetemplate == :XZ) arraysize = :(map(+, $gridsize, (+1, 0,+1)))
elseif (sizetemplate == :YZ) arraysize = :(map(+, $gridsize, ( 0,+1,+1)))
elseif (isnothing(sizetemplate) || sizetemplate in (:XX, :YY, :ZZ)) arraysize = gridsize
else @ModuleInternalError("unexpected sizetemplate.")
end
else
if (sizetemplate == :X) arraysize = :(map(+, $gridsize, (-1,-2,-2)))
elseif (sizetemplate == :Y) arraysize = :(map(+, $gridsize, (-2,-1,-2)))
elseif (sizetemplate == :Z) arraysize = :(map(+, $gridsize, (-2,-2,-1)))
elseif (sizetemplate == :BX) arraysize = :(map(+, $gridsize, (+1, 0, 0)))
elseif (sizetemplate == :BY) arraysize = :(map(+, $gridsize, ( 0,+1, 0)))
elseif (sizetemplate == :BZ) arraysize = :(map(+, $gridsize, ( 0, 0,+1)))
elseif (sizetemplate == :XX) arraysize = :(map(+, $gridsize, ( 0,-2,-2)))
elseif (sizetemplate == :YY) arraysize = :(map(+, $gridsize, (-2, 0,-2)))
elseif (sizetemplate == :ZZ) arraysize = :(map(+, $gridsize, (-2,-2, 0)))
elseif (sizetemplate == :XY) arraysize = :(map(+, $gridsize, (-1,-1,-2)))
elseif (sizetemplate == :XZ) arraysize = :(map(+, $gridsize, (-1,-2,-1)))
elseif (sizetemplate == :YZ) arraysize = :(map(+, $gridsize, (-2,-1,-1)))
elseif isnothing(sizetemplate) arraysize = gridsize
else @ModuleInternalError("unexpected sizetemplate.")
end
end
if is_same(allocator, :@zeros) return :(ParallelStencil.ParallelKernel.@zeros($arraysize..., eltype=$eltype))
elseif is_same(allocator, :@ones) return :(ParallelStencil.ParallelKernel.@ones($arraysize..., eltype=$eltype))
elseif is_same(allocator, :@rand) return :(ParallelStencil.ParallelKernel.@rand($arraysize..., eltype=$eltype))
elseif is_same(allocator, :@falses) return :(ParallelStencil.ParallelKernel.@falses($arraysize..., eltype=$eltype))
elseif is_same(allocator, :@trues) return :(ParallelStencil.ParallelKernel.@trues($arraysize..., eltype=$eltype))

if is_same(allocator, :@zeros) arrayalloc = :(ParallelStencil.ParallelKernel.@zeros($arraysize..., eltype=$eltype))
elseif is_same(allocator, :@ones) arrayalloc = :(ParallelStencil.ParallelKernel.@ones($arraysize..., eltype=$eltype))
elseif is_same(allocator, :@rand) arrayalloc = :(ParallelStencil.ParallelKernel.@rand($arraysize..., eltype=$eltype))
elseif is_same(allocator, :@falses) arrayalloc = :(ParallelStencil.ParallelKernel.@falses($arraysize..., eltype=$eltype))
elseif is_same(allocator, :@trues) arrayalloc = :(ParallelStencil.ParallelKernel.@trues($arraysize..., eltype=$eltype))
else @ModuleInternalError("unexpected allocator macro.")
end

if padding
if (sizetemplate in (:X, :Y, :Z, :XY, :XZ, :YZ)) return :(view($arrayalloc, (:).(2, $arraysize.-1)...))
elseif (sizetemplate == :XX) return :(view($arrayalloc, (:).(map(+, $gridsize.*0, (1,2,2)), map(+, $arraysize, ( 0,-1,-1)))...))
elseif (sizetemplate == :YY) return :(view($arrayalloc, (:).(map(+, $gridsize.*0, (2,1,2)), map(+, $arraysize, (-1, 0,-1)))...))
elseif (sizetemplate == :ZZ) return :(view($arrayalloc, (:).(map(+, $gridsize.*0, (2,2,1)), map(+, $arraysize, (-1,-1, 0)))...))
elseif (isnothing(sizetemplate) || sizetemplate in (:BX, :BY, :BZ)) return :(view($arrayalloc, (:).(1, $arraysize)...))
else @ModuleInternalError("unexpected sizetemplate.")
end
else
return arrayalloc
end
end

function _vectorfield(caller::Module, gridsize, allocator=:@zeros; eltype=nothing, sizetemplate=nothing)
Expand Down
26 changes: 18 additions & 8 deletions src/ParallelKernel/init_parallel_kernel.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
"""
@init_parallel_kernel(package, numbertype)
@init_parallel_kernel(package, numbertype, inbounds=..., padding=...)
Initialize the package ParallelKernel, giving access to its main functionality. Creates a module `Data` in the module where `@init_parallel_kernel` is called from. The module `Data` contains the types as `Data.Number`, `Data.Array` and `Data.CellArray` (type `?Data` *after* calling `@init_parallel_kernel` to see the full description of the module).
# Arguments
- `package::Module`: the package used for parallelization (CUDA or AMDGPU or Metal for GPU, or Threads or Polyester for CPU).
- `numbertype::DataType`: the type of numbers used by @zeros, @ones, @rand and @fill and in all array types of module `Data` (e.g. Float32 or Float64). It is contained in `Data.Number` after @init_parallel_kernel.
- `inbounds::Bool=false`: whether to apply `@inbounds` to the kernels by default (overwritable in each kernel definition).
- `padding::Bool=false`: whether to apply padding to the fields allocated with macros from [`ParallelKernel.FieldAllocators`](@ref).
See also: [`Data`](@ref)
"""
macro init_parallel_kernel(args...)
check_already_initialized(__module__)
posargs, kwargs_expr = split_args(args)
if (length(args) > 3) @ArgumentError("too many arguments.")
if (length(args) > 4) @ArgumentError("too many arguments.")
elseif (0 < length(posargs) < 2) @ArgumentError("there must be either two or zero positional arguments.")
end
kwargs = split_kwargs(kwargs_expr)
if (length(posargs) == 2) package, numbertype_val = extract_posargs_init(__module__, posargs...)
else package, numbertype_val = extract_kwargs_init(__module__, kwargs)
end
inbounds_val = extract_kwargs_nopos(__module__, kwargs)
inbounds_val, padding_val = extract_kwargs_nopos(__module__, kwargs)
if (package == PKG_NONE) @ArgumentError("the package argument cannot be ommited.") end #TODO: this error message will disappear, once the package can be defined at runtime.
esc(init_parallel_kernel(__module__, package, numbertype_val, inbounds_val))
esc(init_parallel_kernel(__module__, package, numbertype_val, inbounds_val, padding_val))
end

function init_parallel_kernel(caller::Module, package::Symbol, numbertype::DataType, inbounds::Bool; datadoc_call=:(), parent_module::String="ParallelKernel")
function init_parallel_kernel(caller::Module, package::Symbol, numbertype::DataType, inbounds::Bool, padding::Bool; datadoc_call=:(), parent_module::String="ParallelKernel")
if package == PKG_CUDA
if (isinteractive() && !is_installed("CUDA")) @NotInstalledError("CUDA was selected as package for parallelization, but CUDA.jl is not installed. CUDA functionality is provided as an extension of $parent_module and CUDA.jl needs therefore to be installed independently (type `add CUDA` in the julia package manager).") end
indextype = INT_CUDA
Expand Down Expand Up @@ -79,6 +81,7 @@ function init_parallel_kernel(caller::Module, package::Symbol, numbertype::DataT
set_package(caller, package)
set_numbertype(caller, numbertype)
set_inbounds(caller, inbounds)
set_padding(caller, padding)
set_initialized(caller, true)
return nothing
end
Expand All @@ -88,12 +91,14 @@ macro is_initialized() is_initialized(__module__) end
macro get_package() esc(get_package(__module__)) end # NOTE: escaping is required here, to avoid that the symbol is evaluated in this module, instead of just being returned as a symbol.
macro get_numbertype() get_numbertype(__module__) end
macro get_inbounds() get_inbounds(__module__) end
macro get_padding() get_padding(__module__) end
let
global is_initialized, set_initialized, set_package, get_package, set_numbertype, get_numbertype, set_inbounds, get_inbounds, check_initialized, check_already_initialized
global is_initialized, set_initialized, set_package, get_package, set_numbertype, get_numbertype, set_inbounds, get_inbounds, set_padding, get_padding, check_initialized, check_already_initialized
_is_initialized::Dict{Module, Bool} = Dict{Module, Bool}()
package::Dict{Module, Symbol} = Dict{Module, Symbol}()
numbertype::Dict{Module, DataType} = Dict{Module, DataType}()
inbounds::Dict{Module, Bool} = Dict{Module, Bool}()
padding::Dict{Module, Bool} = Dict{Module, Bool}()
set_initialized(caller::Module, flag::Bool) = (_is_initialized[caller] = flag)
is_initialized(caller::Module) = haskey(_is_initialized, caller) && _is_initialized[caller]
set_package(caller::Module, pkg::Symbol) = (package[caller] = pkg)
Expand All @@ -102,6 +107,8 @@ let
get_numbertype(caller::Module) = numbertype[caller]
set_inbounds(caller::Module, flag::Bool) = (inbounds[caller] = flag)
get_inbounds(caller::Module) = inbounds[caller]
set_padding(caller::Module, flag::Bool) = (padding[caller] = flag)
get_padding(caller::Module) = padding[caller]
check_initialized(caller::Module) = if !is_initialized(caller) @NotInitializedError("no ParallelKernel macro or function can be called before @init_parallel_kernel in each module (missing call in $caller).") end
check_already_initialized(caller::Module) = if is_initialized(caller) @IncoherentCallError("ParallelKernel has already been initialized for the module $caller.") end
end
Expand All @@ -114,8 +121,8 @@ function extract_posargs_init(caller::Module, package, numbertype) # NOTE: this
end

function extract_kwargs_init(caller::Module, kwargs::Dict)
if (:package in keys(kwargs)) package = kwargs[:package]; check_package(package)
else package = PKG_NONE
if (:package in keys(kwargs)) package = kwargs[:package]; check_package(package)
else package = PKG_NONE
end
if (:numbertype in keys(kwargs)) numbertype_val = eval_arg(caller, kwargs[:numbertype]); check_numbertype(numbertype_val)
else numbertype_val = NUMBERTYPE_NONE
Expand All @@ -127,7 +134,10 @@ function extract_kwargs_nopos(caller::Module, kwargs::Dict)
if (:inbounds in keys(kwargs)) inbounds_val = eval_arg(caller, kwargs[:inbounds]); check_inbounds(inbounds_val)
else inbounds_val = false
end
return inbounds_val
if (:padding in keys(kwargs)) padding_val = eval_arg(caller, kwargs[:padding]); check_padding(padding_val)
else padding_val = false
end
return inbounds_val, padding_val
end

function define_import(caller::Module, package::Symbol, parent_module::String)
Expand Down
38 changes: 32 additions & 6 deletions src/ParallelKernel/kernel_language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,6 @@ Call a macro analogue to `Base.@println`, compatible with the package for parall
macro pk_println(args...) check_initialized(__module__); esc(pk_println(__module__, args...)); end


## INTERNAL MACROS

##
macro threads(args...) check_initialized(__module__); esc(threads(__module__, args...)); end


##
const FORALL_DOC = """
@∀ x ∈ X statement
Expand Down Expand Up @@ -139,6 +133,20 @@ Expand the `statement` for all `x` in `X`.
macro ∀(args...) check_initialized(__module__); checkforallargs(args...); esc(∀(__module__, args...)); end


## INTERNAL MACROS

##
macro threads(args...) check_initialized(__module__); esc(threads(__module__, args...)); end


##
macro firstindex(args...) check_initialized(__module__); checkargs_begin_end(args...); esc(_firstindex(__module__, args...)); end


##
macro lastindex(args...) check_initialized(__module__); checkargs_begin_end(args...); esc(_lastindex(__module__, args...)); end


##
macro return_value(args...) check_initialized(__module__); checksinglearg(args...); esc(return_value(args...)); end

Expand Down Expand Up @@ -166,6 +174,10 @@ function checkforallargs(args...)
if !((args[1].head == :call && args[1].args[1] in [:, :in]) || args[1].head == :(=)) @ArgumentError("the first argument must be of the form `x ∈ X, `x in X` or `x = X`.") end
end

function checkargs_begin_end(args...)
if !(2 <= length(args) <= 3) @ArgumentError("wrong number of arguments.") end
end


## FUNCTIONS FOR INDEXING AND DIMENSIONS

Expand Down Expand Up @@ -300,6 +312,20 @@ function threads(caller::Module, args...; package::Symbol=get_package(caller))
end
end

function _firstindex(caller::Module, A::Expr, dim::Expr, padding::Union{Bool, Symbol, Expr}=false)
padding = eval_arg(caller, padding)
if (padding) return :($A.indices[$dim][1])
else return :(1)
end
end

function _lastindex(caller::Module, A::Expr, dim::Expr, padding::Union{Bool, Symbol, Expr}=false)
padding = eval_arg(caller, padding)
if (padding) return :($A.indices[$dim][end])
else return :(size($A, $dim))
end
end


## CPU TARGET IMPLEMENTATIONS

Expand Down
59 changes: 57 additions & 2 deletions src/ParallelKernel/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,22 @@ end
function parallel_kernel(caller::Module, package::Symbol, numbertype::DataType, inbounds::Bool, indices::Union{Symbol,Expr}, kernel::Expr)
if (!isa(indices,Symbol) && !isa(indices.head,Symbol)) @ArgumentError("@parallel_indices: argument 'indices' must be a tuple of indices or a single index (e.g. (ix, iy, iz) or (ix, iy) or ix ).") end
indices = extract_tuple(indices)
padding = get_padding(caller)
body = get_body(kernel)
body = remove_return(body)
body = macroexpand(caller, body)
use_aliases = !all(indices .== INDICES[1:length(indices)])
if use_aliases # NOTE: we treat explicit parallel indices as aliases to the statically retrievable indices INDICES.
indices_aliases = indices
indices = [INDICES[1:length(indices)]...]
body = macroexpand(caller, body)
for i=1:length(indices_aliases)
body = substitute(body, indices_aliases[i], indices[i])
end
end
if isgpu(package) kernel = insert_device_types(caller, kernel) end
kernel = adjust_signatures(kernel, package)
body = handle_padding(body, padding) # TODO: padding can later be made configurable per kernel (to enable working with arrays as before).
body = handle_inverses(body)
body = handle_indices_and_literals(body, indices, package, numbertype)
if (inbounds) body = add_inbounds(body) end
body = add_return(body)
Expand Down Expand Up @@ -359,7 +362,7 @@ function literaltypes(type1::DataType, type2::DataType, expr::Expr)
end


## FUNCTIONS TO HANDLE SIGNATURES AND INDICES
## FUNCTIONS TO HANDLE SIGNATURES, INDICES, INVERSES AND PADDING

function adjust_signatures(kernel::Expr, package::Symbol)
int_type = kernel_int_type(package)
Expand All @@ -370,6 +373,58 @@ function adjust_signatures(kernel::Expr, package::Symbol)
return kernel
end

function handle_inverses(body::Expr)
return postwalk(body) do ex
if @capture(ex, (1 | 1.0 | 1.0f0) / x_)
return :(inv($x))
else
return ex
end
end
end

function handle_padding(body::Expr, padding::Bool)
body = substitute_indices_inn(body, padding)
if padding
body = substitute_firstlastindex(body)
body = substitute_view_accesses(body, INDICES)
end
return body
end

function substitute_indices_inn(body::Expr, padding::Bool)
for i=1:length(INDICES_INN)
index_inn = (padding) ? INDICES[i] : :($(INDICES[i]) + 1) # NOTE: expression of ixi with ix, etc.: if padding is not used, they must be shifted by 1.
body = substitute(body, INDICES_INN[i], index_inn)
end
return body
end

function substitute_firstlastindex(body::Expr)
padding = true
return postwalk(body) do ex
if @capture(ex, f_(args__))
if (f == :firstindex) return :(ParallelStencil.ParallelKernel.@firstindex($(args...), $padding))
elseif (f == :lastindex) return :(ParallelStencil.ParallelKernel.@lastindex($(args...), $padding))
else return ex
end
else
return ex
end
end
end

function substitute_view_accesses(expr::Expr, indices::NTuple{N,<:Union{Symbol,Expr}} where N)
return postwalk(expr) do ex
if is_access(ex, indices...)
@capture(ex, A_[indices_expr__]) || @ModuleInternalError("a stencil access could not be pattern matched.")
return :($A.parent[$(indices_expr...)])
else
return ex
end
end
end

function handle_indices_and_literals(body::Expr, indices::Array, package::Symbol, numbertype::DataType)
int_type = kernel_int_type(package)
ranges = [:($RANGES_VARNAME[1]), :($RANGES_VARNAME[2]), :($RANGES_VARNAME[3])]
Expand Down
Loading

0 comments on commit d7b00ca

Please sign in to comment.