Skip to content

Commit

Permalink
Refactored the lo and hi_indices functions to work better with halodi…
Browse files Browse the repository at this point in the history
…ms and type stability
  • Loading branch information
smillerc committed Oct 20, 2022
1 parent a1305ea commit 176ab07
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions src/halo_exchange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,56 +301,56 @@ function updateblockhalo!(A::BlockHaloArray, current_block_id::Integer, include_
end

"""Get the upper indices for an array `A` given a number of halo entries `nhalo`"""
function hi_indices(A::AbstractArray, nhalo::Integer)
hi_halo_end = last.(axes(A)) # end index of the halo region
hi_halo_start = hi_halo_end .- nhalo .+ 1 # start index of the halo region
hi_domn_end = hi_halo_start .- 1 # end index of the inner domain
hi_domn_start = hi_domn_end .- nhalo .+ 1 # start index of the inner domain
function hi_indices(A, nhalo)
hmod = 1 .* (nhalo .== 0)
hi_halo_end = last.(axes(A))
hi_halo_start = hi_halo_end .- nhalo .+ 1 .- hmod
hi_domn_end = hi_halo_start .- 1 .+ hmod
hi_domn_start = hi_domn_end .- nhalo .+ 1 .- hmod

return (hi_domn_start, hi_domn_end, hi_halo_start, hi_halo_end)
end

"""Get the upper indices for an array `A` given a number of halo entries `nhalo`. This
version properly accounts for non-halo dimensions"""
function hi_indices(A::AbstractArray, nhalo::Integer, halodims::NTuple{N,Integer}) where {N}
function hi_indices(A::AbstractArray{T,N}, nhalo, halodims) where {N, T}

indices = collect.(hi_indices(A, nhalo))

for index in indices
for dim in eachindex(index)
if !(dim in halodims)
index[dim] = last(axes(A, dim))
function f(i, halodims, nhalo)
if i in halodims
return nhalo
else
return 0
end
end
halo_tuple = ntuple(i -> f(i, halodims, nhalo), Val(N))
hi_indices(A, halo_tuple)
end
return Tuple.(indices)
end

"""Get the lower indices for an array `A` given a number of halo entries `nhalo`"""
function lo_indices(A::AbstractArray, nhalo::Integer)
lo_halo_start = first.(axes(A)) # start index of the halo region
lo_halo_end = lo_halo_start .+ nhalo .- 1 # end index of the halo region
lo_domn_start = lo_halo_end .+ 1 # start index of the inner domain
lo_domn_end = lo_domn_start .+ nhalo .- 1 # end index of the inner domain
function lo_indices(A, nhalo)
lmod = 1 .* (nhalo .== 0)
lo_halo_start = first.(axes(A))
lo_halo_end = lo_halo_start .+ nhalo .- 1 .+ lmod
lo_domn_start = lo_halo_end .+ 1 .- lmod
lo_domn_end = lo_domn_start .+ nhalo .- 1 .+ lmod

return (lo_halo_start, lo_halo_end, lo_domn_start, lo_domn_end)
end

"""Get the lower indices for an array `A` given a number of halo entries `nhalo`. This
version properly accounts for non-halo dimensions"""
function lo_indices(A::AbstractArray, nhalo::Integer, halodims::NTuple{N,Integer}) where {N}
function lo_indices(A::AbstractArray{T,N}, nhalo, halodims) where {N, T}

indices = collect.(lo_indices(A, nhalo))

for index in indices
for dim in eachindex(index)
if !(dim in halodims)
index[dim] = first(axes(A, dim))
function f(i, halodims, nhalo)
if i in halodims
return nhalo
else
return 0
end
end
halo_tuple = ntuple(i -> f(i, halodims, nhalo), Val(N))
lo_indices(A, halo_tuple)
end
return Tuple.(indices)
end

"""Get the neighbor block id's for a 1D decomposition"""
function get_neighbor_blocks_no_periodic(tile_dims::NTuple{1,Integer})
Expand Down

0 comments on commit 176ab07

Please sign in to comment.