Skip to content

Commit

Permalink
Merge pull request #180 from omlins/padding
Browse files Browse the repository at this point in the history
Improve index handling
  • Loading branch information
omlins authored Dec 4, 2024
2 parents d7b00ca + a06e0ea commit 2d4a27e
Show file tree
Hide file tree
Showing 18 changed files with 1,727 additions and 1,445 deletions.
15 changes: 15 additions & 0 deletions src/FieldAllocators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ To see a description of a macro type `?<macroname>` (including the `@`).
"""
module FieldAllocators
import ..ParallelKernel
import ..ParallelStencil: check_initialized
@doc replace(ParallelKernel.FieldAllocators.ALLOCATE_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro allocate(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@allocate($(args...)))); end
@doc replace(ParallelKernel.FieldAllocators.FIELD_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro Field(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@Field($(args...)))); end
@doc replace(ParallelKernel.FieldAllocators.VECTORFIELD_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro VectorField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@VectorField($(args...)))); end
Expand All @@ -46,5 +47,19 @@ module FieldAllocators
@doc replace(ParallelKernel.FieldAllocators.TENSORFIELD_COMP_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro XZField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@XZField($(args...)))); end
@doc replace(ParallelKernel.FieldAllocators.TENSORFIELD_COMP_DOC, "@init_parallel_kernel" => "@init_parallel_stencil") macro YZField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@YZField($(args...)))); end

macro IField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@IField($(args...)))); end
macro XXYField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@XXYField($(args...)))); end
macro XYYField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@XYYField($(args...)))); end
macro XYZField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@XYZField($(args...)))); end
macro XXYZField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@XXYZField($(args...)))); end
macro XYYZField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@XYYZField($(args...)))); end
macro XYZZField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@XYZZField($(args...)))); end
macro XXYYField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@XXYYField($(args...)))); end
macro XXZZField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@XXZZField($(args...)))); end
macro YYZZField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@YYZZField($(args...)))); end
macro XXYYZField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@XXYYZField($(args...)))); end
macro XYYZZField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@XYYZZField($(args...)))); end
macro XXYZZField(args...) check_initialized(__module__); esc(:(ParallelStencil.ParallelKernel.FieldAllocators.@XXYZZField($(args...)))); end

export @allocate, @Field, @VectorField, @BVectorField, @TensorField, @XField, @BXField, @YField, @BYField, @ZField, @BZField, @XXField, @YYField, @ZZField, @XYField, @XZField, @YZField
end
173 changes: 88 additions & 85 deletions src/FiniteDifferences.jl

Large diffs are not rendered by default.

172 changes: 159 additions & 13 deletions src/ParallelKernel/FieldAllocators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,126 @@ macro YZField(args...)
end


## FIELDS FOR UNIT TESTS

macro IField(args...)
check_initialized(__module__)
checksargs_field_macros(args...)
posargs, kwargs_expr = split_args(args)
eltype, = extract_kwargvalues(kwargs_expr, (:eltype,), "@IField")
posargs = clean_args(posargs)
esc(_field(__module__, posargs...; eltype=eltype, sizetemplate=:I))
end

macro XXYField(args...)
check_initialized(__module__)
checksargs_field_macros(args...)
posargs, kwargs_expr = split_args(args)
eltype, = extract_kwargvalues(kwargs_expr, (:eltype,), "@XXYField")
posargs = clean_args(posargs)
esc(_field(__module__, posargs...; eltype=eltype, sizetemplate=:XXY))
end

macro XYYField(args...)
check_initialized(__module__)
checksargs_field_macros(args...)
posargs, kwargs_expr = split_args(args)
eltype, = extract_kwargvalues(kwargs_expr, (:eltype,), "@XYYField")
posargs = clean_args(posargs)
esc(_field(__module__, posargs...; eltype=eltype, sizetemplate=:XYY))
end

macro XYZField(args...)
check_initialized(__module__)
checksargs_field_macros(args...)
posargs, kwargs_expr = split_args(args)
eltype, = extract_kwargvalues(kwargs_expr, (:eltype,), "@XYZField")
posargs = clean_args(posargs)
esc(_field(__module__, posargs...; eltype=eltype, sizetemplate=:XYZ))
end

macro XXYZField(args...)
check_initialized(__module__)
checksargs_field_macros(args...)
posargs, kwargs_expr = split_args(args)
eltype, = extract_kwargvalues(kwargs_expr, (:eltype,), "@XXYZField")
posargs = clean_args(posargs)
esc(_field(__module__, posargs...; eltype=eltype, sizetemplate=:XXYZ))
end

macro XYYZField(args...)
check_initialized(__module__)
checksargs_field_macros(args...)
posargs, kwargs_expr = split_args(args)
eltype, = extract_kwargvalues(kwargs_expr, (:eltype,), "@XYYZField")
posargs = clean_args(posargs)
esc(_field(__module__, posargs...; eltype=eltype, sizetemplate=:XYYZ))
end

macro XYZZField(args...)
check_initialized(__module__)
checksargs_field_macros(args...)
posargs, kwargs_expr = split_args(args)
eltype, = extract_kwargvalues(kwargs_expr, (:eltype,), "@XYZZField")
posargs = clean_args(posargs)
esc(_field(__module__, posargs...; eltype=eltype, sizetemplate=:XYZZ))
end

macro XXYYField(args...)
check_initialized(__module__)
checksargs_field_macros(args...)
posargs, kwargs_expr = split_args(args)
eltype, = extract_kwargvalues(kwargs_expr, (:eltype,), "@XXYYField")
posargs = clean_args(posargs)
esc(_field(__module__, posargs...; eltype=eltype, sizetemplate=:XXYY))
end

macro XXZZField(args...)
check_initialized(__module__)
checksargs_field_macros(args...)
posargs, kwargs_expr = split_args(args)
eltype, = extract_kwargvalues(kwargs_expr, (:eltype,), "@XXZZField")
posargs = clean_args(posargs)
esc(_field(__module__, posargs...; eltype=eltype, sizetemplate=:XXZZ))
end

macro YYZZField(args...)
check_initialized(__module__)
checksargs_field_macros(args...)
posargs, kwargs_expr = split_args(args)
eltype, = extract_kwargvalues(kwargs_expr, (:eltype,), "@YYZZField")
posargs = clean_args(posargs)
esc(_field(__module__, posargs...; eltype=eltype, sizetemplate=:YYZZ))
end

macro XXYYZField(args...)
check_initialized(__module__)
checksargs_field_macros(args...)
posargs, kwargs_expr = split_args(args)
eltype, = extract_kwargvalues(kwargs_expr, (:eltype,), "@XXYYZField")
posargs = clean_args(posargs)
esc(_field(__module__, posargs...; eltype=eltype, sizetemplate=:XXYYZ))
end

macro XXYZZField(args...)
check_initialized(__module__)
checksargs_field_macros(args...)
posargs, kwargs_expr = split_args(args)
eltype, = extract_kwargvalues(kwargs_expr, (:eltype,), "@XXYZZField")
posargs = clean_args(posargs)
esc(_field(__module__, posargs...; eltype=eltype, sizetemplate=:XXYZZ))
end

macro XYYZZField(args...)
check_initialized(__module__)
checksargs_field_macros(args...)
posargs, kwargs_expr = split_args(args)
eltype, = extract_kwargvalues(kwargs_expr, (:eltype,), "@XYYZZField")
posargs = clean_args(posargs)
esc(_field(__module__, posargs...; eltype=eltype, sizetemplate=:XYYZZ))
end


## ARGUMENT CHECKS

function checkargs_allocate(args...)
Expand Down Expand Up @@ -450,13 +570,14 @@ function _field(caller::Module, gridsize, allocator=:@zeros; eltype=nothing, siz
padding = get_padding(caller)
eltype = determine_eltype(caller, eltype)
if padding
if (sizetemplate in (:X, :BX)) arraysize = :(map(+, $gridsize, (+1, 0, 0)))
elseif (sizetemplate in (:Y, :BY)) arraysize = :(map(+, $gridsize, ( 0,+1, 0)))
elseif (sizetemplate in (:Z, :BZ)) arraysize = :(map(+, $gridsize, ( 0, 0,+1)))
elseif (sizetemplate == :XY) arraysize = :(map(+, $gridsize, (+1,+1, 0)))
elseif (sizetemplate == :XZ) arraysize = :(map(+, $gridsize, (+1, 0,+1)))
elseif (sizetemplate == :YZ) arraysize = :(map(+, $gridsize, ( 0,+1,+1)))
elseif (isnothing(sizetemplate) || sizetemplate in (:XX, :YY, :ZZ)) arraysize = gridsize
if (sizetemplate in (:X, :BX, :XYY, :XYYZZ)) arraysize = :(map(+, $gridsize, (+1, 0, 0)))
elseif (sizetemplate in (:Y, :BY, :XXY, :XXYZZ)) arraysize = :(map(+, $gridsize, ( 0,+1, 0)))
elseif (sizetemplate in (:Z, :BZ, :XXYYZ)) arraysize = :(map(+, $gridsize, ( 0, 0,+1)))
elseif (sizetemplate in (:XY, :XYZZ)) arraysize = :(map(+, $gridsize, (+1,+1, 0)))
elseif (sizetemplate in (:XZ, :XYYZ)) arraysize = :(map(+, $gridsize, (+1, 0,+1)))
elseif (sizetemplate in (:YZ, :XXYZ)) arraysize = :(map(+, $gridsize, ( 0,+1,+1)))
elseif (sizetemplate == :XYZ) arraysize = :(map(+, $gridsize, (+1,+1,+1)))
elseif (isnothing(sizetemplate) || sizetemplate in (:XX, :YY, :ZZ, :I, :XXYY, :XXZZ, :YYZZ)) arraysize = gridsize
else @ModuleInternalError("unexpected sizetemplate.")
end
else
Expand All @@ -472,6 +593,19 @@ function _field(caller::Module, gridsize, allocator=:@zeros; eltype=nothing, siz
elseif (sizetemplate == :XY) arraysize = :(map(+, $gridsize, (-1,-1,-2)))
elseif (sizetemplate == :XZ) arraysize = :(map(+, $gridsize, (-1,-2,-1)))
elseif (sizetemplate == :YZ) arraysize = :(map(+, $gridsize, (-2,-1,-1)))
elseif (sizetemplate == :I) arraysize = :(map(+, $gridsize, (-2,-2,-2)))
elseif (sizetemplate == :XXY) arraysize = :(map(+, $gridsize, ( 0,-1,-2)))
elseif (sizetemplate == :XYY) arraysize = :(map(+, $gridsize, (-1, 0,-2)))
elseif (sizetemplate == :XYZ) arraysize = :(map(+, $gridsize, (-1,-1,-1)))
elseif (sizetemplate == :XXYZ) arraysize = :(map(+, $gridsize, ( 0,-1,-1)))
elseif (sizetemplate == :XYYZ) arraysize = :(map(+, $gridsize, (-1, 0,-1)))
elseif (sizetemplate == :XYZZ) arraysize = :(map(+, $gridsize, (-1,-1, 0)))
elseif (sizetemplate == :XXYY) arraysize = :(map(+, $gridsize, ( 0, 0,-2)))
elseif (sizetemplate == :XXZZ) arraysize = :(map(+, $gridsize, ( 0,-2, 0)))
elseif (sizetemplate == :YYZZ) arraysize = :(map(+, $gridsize, (-2, 0, 0)))
elseif (sizetemplate == :XXYYZ) arraysize = :(map(+, $gridsize, ( 0, 0,-1)))
elseif (sizetemplate == :XXYZZ) arraysize = :(map(+, $gridsize, ( 0,-1, 0)))
elseif (sizetemplate == :XYYZZ) arraysize = :(map(+, $gridsize, (-1, 0, 0)))
elseif isnothing(sizetemplate) arraysize = gridsize
else @ModuleInternalError("unexpected sizetemplate.")
end
Expand All @@ -486,11 +620,15 @@ function _field(caller::Module, gridsize, allocator=:@zeros; eltype=nothing, siz
end

if padding
if (sizetemplate in (:X, :Y, :Z, :XY, :XZ, :YZ)) return :(view($arrayalloc, (:).(2, $arraysize.-1)...))
elseif (sizetemplate == :XX) return :(view($arrayalloc, (:).(map(+, $gridsize.*0, (1,2,2)), map(+, $arraysize, ( 0,-1,-1)))...))
elseif (sizetemplate == :YY) return :(view($arrayalloc, (:).(map(+, $gridsize.*0, (2,1,2)), map(+, $arraysize, (-1, 0,-1)))...))
elseif (sizetemplate == :ZZ) return :(view($arrayalloc, (:).(map(+, $gridsize.*0, (2,2,1)), map(+, $arraysize, (-1,-1, 0)))...))
elseif (isnothing(sizetemplate) || sizetemplate in (:BX, :BY, :BZ)) return :(view($arrayalloc, (:).(1, $arraysize)...))
subarray = :(ParallelStencil.ParallelKernel.FieldAllocators.subarray)
if (sizetemplate in (:X, :Y, :Z, :XY, :XZ, :YZ, :I, :XYZ)) return :($subarray($arrayalloc, (:).(2, $arraysize.-1)...))
elseif (sizetemplate in (:XX, :XXY, :XXYZ)) return :($subarray($arrayalloc, (:).(map(+, $gridsize.*0, (1,2,2)), map(+, $arraysize, ( 0,-1,-1)))...))
elseif (sizetemplate in (:YY, :XYY, :XYYZ)) return :($subarray($arrayalloc, (:).(map(+, $gridsize.*0, (2,1,2)), map(+, $arraysize, (-1, 0,-1)))...))
elseif (sizetemplate in (:ZZ, :XYZZ)) return :($subarray($arrayalloc, (:).(map(+, $gridsize.*0, (2,2,1)), map(+, $arraysize, (-1,-1, 0)))...))
elseif (sizetemplate in (:XXYY, :XXYYZ)) return :($subarray($arrayalloc, (:).(map(+, $gridsize.*0, (1,1,2)), map(+, $arraysize, ( 0, 0,-1)))...))
elseif (sizetemplate in (:XXZZ, :XXYZZ)) return :($subarray($arrayalloc, (:).(map(+, $gridsize.*0, (1,2,1)), map(+, $arraysize, ( 0,-1, 0)))...))
elseif (sizetemplate in (:YYZZ, :XYYZZ)) return :($subarray($arrayalloc, (:).(map(+, $gridsize.*0, (2,1,1)), map(+, $arraysize, (-1, 0, 0)))...))
elseif (isnothing(sizetemplate) || sizetemplate in (:BX, :BY, :BZ)) return :($subarray($arrayalloc, (:).(1, $arraysize)...))
else @ModuleInternalError("unexpected sizetemplate.")
end
else
Expand Down Expand Up @@ -539,10 +677,18 @@ function determine_eltype(caller::Module, eltype)
return eltype
end

function subarray(A, indices...)
B = view(A, indices...)
if B isa SubArray
return B
else
return SubArray(A, indices)
end
end

## Exports

export @allocate, @Field, @VectorField, @BVectorField, @TensorField, @XField, @BXField, @YField, @BYField, @ZField, @BZField, @XXField, @YYField, @ZZField, @XYField, @XZField, @YZField
export @allocate, @Field, @VectorField, @BVectorField, @TensorField, @XField, @BXField, @YField, @BYField, @ZField, @BZField, @XXField, @YYField, @ZZField, @XYField, @XZField, @YZField, @IField, @XXYField, @XYYField


end # Module FieldAllocators
4 changes: 2 additions & 2 deletions src/ParallelKernel/kernel_language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,14 @@ function threads(caller::Module, args...; package::Symbol=get_package(caller))
end
end

function _firstindex(caller::Module, A::Expr, dim::Expr, padding::Union{Bool, Symbol, Expr}=false)
function _firstindex(caller::Module, A::Union{Symbol, Expr}, dim::Union{Integer, Symbol, Expr}, padding::Union{Bool, Symbol, Expr}=false)
padding = eval_arg(caller, padding)
if (padding) return :($A.indices[$dim][1])
else return :(1)
end
end

function _lastindex(caller::Module, A::Expr, dim::Expr, padding::Union{Bool, Symbol, Expr}=false)
function _lastindex(caller::Module, A::Union{Symbol, Expr}, dim::Union{Integer, Symbol, Expr}, padding::Union{Bool, Symbol, Expr}=false)
padding = eval_arg(caller, padding)
if (padding) return :($A.indices[$dim][end])
else return :(size($A, $dim))
Expand Down
Loading

0 comments on commit 2d4a27e

Please sign in to comment.