Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unwrap 1.11 regression + performance improvements #576

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 75 additions & 87 deletions src/unwrap.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module Unwrap
using Random: GLOBAL_RNG, AbstractRNG
using Random: AbstractRNG, default_rng
export unwrap, unwrap!

"""
Expand Down Expand Up @@ -62,7 +62,7 @@ of an image, as each pixel is wrapped to stay within (-pi, pi].
- `circular_dims=(false, ...)`: When an element of this tuple is `true`, the
unwrapping process will consider the edges along the corresponding axis
of the array to be connected.
- `rng=GLOBAL_RNG`: Unwrapping of arrays with dimension > 1 uses a random
- `rng=default_rng()`: Unwrapping of arrays with dimension > 1 uses a random
initialization. A user can pass their own RNG through this argument.
"""
unwrap(m::AbstractArray; kwargs...) = unwrap!(similar(m), m; kwargs...)
Expand All @@ -80,7 +80,7 @@ unwrap(m::AbstractArray; kwargs...) = unwrap!(similar(m), m; kwargs...)

mutable struct Pixel{T}
periods::Int
val::T
const val::T
reliability::Float64
groupsize::Int
head::Pixel{T}
Expand All @@ -97,38 +97,39 @@ end
Pixel(v, rng) = Pixel{typeof(v)}(0, v, rand(rng), 1)
@inline Base.length(p::Pixel) = p.head.groupsize

struct Edge{N}
struct Edge{T}
reliability::Float64
periods::Int
pixel_1::CartesianIndex{N}
pixel_2::CartesianIndex{N}
pixel_1::Pixel{T}
pixel_2::Pixel{T}
end
function Edge{N}(pixel_image::AbstractArray, ind1::CartesianIndex{N}, ind2::CartesianIndex{N}, range) where N
@inbounds rel = pixel_image[ind1].reliability + pixel_image[ind2].reliability
@inbounds periods = find_period(pixel_image[ind1].val, pixel_image[ind2].val, range)
return Edge{N}(rel, periods, ind1, ind2)
function Edge{T}(p1::Pixel{T}, p2::Pixel{T}, range) where {T}
rel = p1.reliability + p2.reliability
periods = find_period(p1.val, p2.val, range)
return Edge{T}(rel, periods, p1, p2)
end
@inline Base.isless(e1::Edge, e2::Edge) = isless(e1.reliability, e2.reliability)

function unwrap_nd!(dest::AbstractArray{T, N},
src::AbstractArray{T, N};
range::Number=2*convert(T, pi),
circular_dims::NTuple{N, Bool}=ntuple(_->false, Val(N)),
rng::AbstractRNG=GLOBAL_RNG) where {T, N}
rng::AbstractRNG=default_rng()) where {T, N}

range_T = convert(T, range)

pixel_image = init_pixels(src, rng)
calculate_reliability(pixel_image, circular_dims, range_T)
edges = Edge{N}[]
edges = Edge{T}[]
num_edges = _predict_num_edges(size(src), circular_dims)
sizehint!(edges, num_edges)
for idx_dim=1:N
for idx_dim = 1:N
populate_edges!(edges, pixel_image, idx_dim, circular_dims[idx_dim], range_T)
end

sort!(edges, alg=MergeSort)
gather_pixels!(pixel_image, edges)
perm = sortperm(map(x -> x.reliability, edges); alg=MergeSort)
edges = edges[perm]
Comment on lines -130 to +131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this better? And why the map? Shouldn't the custom isless take care of that? And if not, wouldn't a by= be better by avoiding a temporary array?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it faster in empirical testing. I'm not sure why, some cache locality issues when sorting? Because Edges are 4 times the size of floats.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are also allocations elsewhere that could be reduced, although I feel that would obscure the code.

gather_pixels!(edges)
unwrap_image!(dest, pixel_image, range_T)

return dest
Expand All @@ -145,80 +146,80 @@ end
# function to broadcast
function init_pixels(wrapped_image::AbstractArray{T, N}, rng) where {T, N}
pixel_image = similar(wrapped_image, Pixel{T})
Threads.@threads for i in eachindex(wrapped_image)
@inbounds pixel_image[i] = Pixel(wrapped_image[i], rng)
Threads.@threads for i in eachindex(wrapped_image, pixel_image)
pixel_image[i] = Pixel(wrapped_image[i], rng)
end
return pixel_image
end

function gather_pixels!(pixel_image, edges)
function gather_pixels!(edges)
for edge in edges
@inbounds p1 = pixel_image[edge.pixel_1]
@inbounds p2 = pixel_image[edge.pixel_2]
merge_groups!(edge, p1, p2)
p1 = edge.pixel_1
p2 = edge.pixel_2
if is_differentgroup(p1, p2)
periods = edge.periods
merge_groups!(periods, p1, p2)
end
end
end

function unwrap_image!(dest, pixel_image, range)
Threads.@threads for i in eachindex(dest)
@inbounds dest[i] = muladd(range, pixel_image[i].periods, pixel_image[i].val)
Threads.@threads for i in eachindex(dest, pixel_image)
p = pixel_image[i]
dest[i] = muladd(range, p.periods, p.val)
end
end

function wrap_val(val, range)
wrapped_val = val
wrapped_val += ifelse(val > range/2, -range, zero(val))
wrapped_val += ifelse(val < -range/2, range, zero(val))
wrapped_val -= ifelse(val > range / 2, range, zero(val))
wrapped_val += ifelse(val < -range / 2, range, zero(val))
return wrapped_val
end

function find_period(val_left, val_right, range)
difference = val_left - val_right
period = 0
period += ifelse(difference > range/2, -1, 0)
period += ifelse(difference < -range/2, 1, 0)
period -= (difference > range / 2)
period += (difference < -range / 2)
return period
end

function merge_groups!(edge, pixel_1, pixel_2)
if is_differentgroup(pixel_1, pixel_2)
# pixel 2 is alone in group
if is_pixelalone(pixel_2)
merge_pixels!(pixel_1, pixel_2, -edge.periods)
elseif is_pixelalone(pixel_1)
merge_pixels!(pixel_2, pixel_1, edge.periods)
function merge_groups!(periods, base, target)
# target is alone in group
if is_pixelalone(target)
periods = -periods
elseif is_pixelalone(base)
base, target = target, base
else
if is_bigger(base, target)
periods = -periods
else
if is_bigger(pixel_1, pixel_2)
merge_into_group!(pixel_1, pixel_2, -edge.periods)
else
merge_into_group!(pixel_2, pixel_1, edge.periods)
end
base, target = target, base
end
merge_into_group!(base, target, periods)
return
end
merge_pixels!(base, target, periods)
end

@inline function is_differentgroup(p1::Pixel, p2::Pixel)
return p1.head !== p2.head
end
@inline function is_pixelalone(pixel::Pixel)
return pixel.head === pixel.last
end
@inline function is_bigger(p1::Pixel, p2::Pixel)
return length(p1) length(p2)
end
@inline is_differentgroup(p1::Pixel, p2::Pixel) = p1.head !== p2.head
@inline is_pixelalone(pixel::Pixel) = pixel.head === pixel.last
@inline is_bigger(p1::Pixel, p2::Pixel) = length(p1) length(p2)

function merge_pixels!(pixel_base::Pixel, pixel_target::Pixel, periods)
pixel_base.head.groupsize += pixel_target.head.groupsize
pixel_base.head.last.next = pixel_target.head
pixel_base.head.last = pixel_target.head.last
pixel_target.head = pixel_base.head
pixel_target.periods = pixel_base.periods + periods
return nothing
end

function merge_into_group!(pixel_base::Pixel, pixel_target::Pixel, periods)
add_periods = pixel_base.periods + periods - pixel_target.periods
pixel = pixel_target.head
while pixel nothing
while !isnothing(pixel)
# merge all pixels in pixel_target's group to pixel_base's group
if pixel !== pixel_target
pixel.periods += add_periods
Expand All @@ -230,37 +231,33 @@ function merge_into_group!(pixel_base::Pixel, pixel_target::Pixel, periods)
merge_pixels!(pixel_base, pixel_target, periods)
end

function populate_edges!(edges, pixel_image::Array{T, N}, dim, connected, range) where {T, N}
size_img = collect(size(pixel_image))
size_img[dim] -= 1
idx_step = fill(0, N)
idx_step[dim] += 1
idx_step_cart = CartesianIndex{N}(NTuple{N,Int}(idx_step))
idx_size = CartesianIndex{N}(NTuple{N,Int}(size_img))
for i in CartesianIndices(idx_size)
push!(edges, Edge{N}(pixel_image, i, i+idx_step_cart, range))
function populate_edges!(edges::Vector{Edge{T}}, pixel_image::AbstractArray{Pixel{T},N}, dim, connected, range) where {T,N}
idx_step = ntuple(i -> Int(i == dim), Val(N))
idx_step_cart = CartesianIndex{N}(idx_step)
image_inds = CartesianIndices(pixel_image)
fi, li = first(image_inds), last(image_inds)
for i in fi:li-idx_step_cart
push!(edges, Edge{T}(pixel_image[i], pixel_image[i + idx_step_cart], range))
end
if connected
idx_step = fill!(idx_step, 0)
idx_step[dim] = -size_img[dim]
idx_step_cart = CartesianIndex{N}(NTuple{N,Int}(idx_step))
edge_begin = ones(Int, N)
edge_begin[dim] = size(pixel_image)[dim]
edge_begin_cart = CartesianIndex{N}(NTuple{N,Int}(edge_begin))
for i in CartesianIndices(ntuple(dim_idx -> edge_begin_cart[dim_idx]:size(pixel_image, dim_idx), N))
push!(edges, Edge{N}(pixel_image, i, i+idx_step_cart, range))
idx_step_cart *= size(pixel_image, dim) - 1
for i in fi+idx_step_cart:li
push!(edges, Edge{T}(pixel_image[i], pixel_image[i - idx_step_cart], range))
end
end
end

function calculate_reliability(pixel_image::AbstractArray{T, N}, circular_dims, range) where {T, N}
# get the shifted pixel indices in CartesinanIndex form
# This gets all the nearest neighbors (CartesionIndex{N}() = one(CartesianIndex{N}))
pixel_shifts = CartesianIndices(ntuple(i -> -1:1, N))
one_cart = oneunit(CartesianIndex{N})
pixel_shifts = -one_cart:one_cart
image_inds = CartesianIndices(pixel_image)
fi, li = first(image_inds) + one_cart, last(image_inds) - one_cart
size_img = size(pixel_image)
# inner loop
for i in CartesianIndices(ntuple(dim -> 2:(size(pixel_image, dim)-1), N))
@inbounds pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts, range)
for i in fi:li
pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts, range)
end

if !(true in circular_dims)
Expand All @@ -276,54 +273,45 @@ function calculate_reliability(pixel_image::AbstractArray{T, N}, circular_dims,
for (idx_ps, ps) in enumerate(pixel_shifts_border)
# if the pixel shift goes out of bounds, we make the shift wrap
if ps[idx_dim] == 1
fill!(new_ps, 0)
new_ps[idx_dim] = -size_img[idx_dim]+1
pixel_shifts_border[idx_ps] = CartesianIndex{N}(NTuple{N,Int}(new_ps))
new_ps[idx_dim] = 0
end
end
border_range = get_border_range(size_img, idx_dim, size_img[idx_dim])
border_range = get_border_range(fi:li, idx_dim, li[idx_dim] + 1)
for i in CartesianIndices(border_range)
@inbounds pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts_border, range)
pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts_border, range)
end
# second border
pixel_shifts_border = copyto!(pixel_shifts_border, pixel_shifts)
for (idx_ps, ps) in enumerate(pixel_shifts_border)
# if the pixel shift goes out of bounds, we make the shift wrap, this time to the other side
if ps[idx_dim] == -1
fill!(new_ps, 0)
new_ps[idx_dim] = size_img[idx_dim]-1
pixel_shifts_border[idx_ps] = CartesianIndex{N}(NTuple{N,Int}(new_ps))
new_ps[idx_dim] = 0
end
end
border_range = get_border_range(size_img, idx_dim, 1)
border_range = get_border_range(fi:li, idx_dim, fi[idx_dim] - 1)
for i in CartesianIndices(border_range)
@inbounds pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts_border, range)
pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts_border, range)
end
end
end
end

function get_border_range(size_img::NTuple{N, T}, border_dim, border_idx) where {N, T}
border_range = [2:(size_img[dim]-1) for dim=1:N]
function get_border_range(C::CartesianIndices{N}, border_dim, border_idx) where {N}
border_range = [C.indices[dim] for dim=1:N]
border_range[border_dim] = border_idx:border_idx
return NTuple{N,UnitRange{Int}}(border_range)
end

function calculate_pixel_reliability(pixel_image::AbstractArray{Pixel{T},N}, pixel_index, pixel_shifts, range) where {T,N}
pix_val = pixel_image[pixel_index].val
rel_contrib(shift) = @inbounds wrap_val(pixel_image[pixel_index+shift].val - pix_val, range)^2
rel_contrib(shift) = wrap_val(pixel_image[pixel_index+shift].val - pix_val, range)^2
# for N=3, pixel_shifts[14] is null shift, can avoid if manually unrolling loop
sum_val = sum(rel_contrib, pixel_shifts)
return sum_val
end

# specialized pixel reliability calculations for different N
@inbounds function calculate_pixel_reliability(pixel_image::AbstractArray{Pixel{T}, 2}, pixel_index, pixel_shifts, range) where T
D1 = wrap_val(pixel_image[pixel_index+pixel_shifts[2]].val - pixel_image[pixel_index].val, range)
D2 = wrap_val(pixel_image[pixel_index+pixel_shifts[4]].val - pixel_image[pixel_index].val, range)
H = wrap_val(pixel_image[pixel_index+pixel_shifts[6]].val - pixel_image[pixel_index].val, range)
V = wrap_val(pixel_image[pixel_index+pixel_shifts[8]].val - pixel_image[pixel_index].val, range)
return H*H + V*V + D1*D1 + D2*D2
end

end
1 change: 1 addition & 0 deletions test/unwrap.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using DSP, Test
using Random: MersenneTwister

@testset "Unwrap 1D" begin
@test unwrap([0.1, 0.2, 0.3, 0.4]) [0.1, 0.2, 0.3, 0.4]
Expand Down
Loading