diff --git a/src/unwrap.jl b/src/unwrap.jl index 2212080b..8776f2b5 100644 --- a/src/unwrap.jl +++ b/src/unwrap.jl @@ -1,5 +1,5 @@ module Unwrap -using Random: GLOBAL_RNG, AbstractRNG +using Random: AbstractRNG, default_rng export unwrap, unwrap! """ @@ -62,7 +62,7 @@ of an image, as each pixel is wrapped to stay within (-pi, pi]. - `circular_dims=(false, ...)`: When an element of this tuple is `true`, the unwrapping process will consider the edges along the corresponding axis of the array to be connected. -- `rng=GLOBAL_RNG`: Unwrapping of arrays with dimension > 1 uses a random +- `rng=default_rng()`: Unwrapping of arrays with dimension > 1 uses a random initialization. A user can pass their own RNG through this argument. """ unwrap(m::AbstractArray; kwargs...) = unwrap!(similar(m), m; kwargs...) @@ -80,7 +80,7 @@ unwrap(m::AbstractArray; kwargs...) = unwrap!(similar(m), m; kwargs...) mutable struct Pixel{T} periods::Int - val::T + const val::T reliability::Float64 groupsize::Int head::Pixel{T} @@ -97,16 +97,16 @@ end Pixel(v, rng) = Pixel{typeof(v)}(0, v, rand(rng), 1) @inline Base.length(p::Pixel) = p.head.groupsize -struct Edge{N} +struct Edge{T} reliability::Float64 periods::Int - pixel_1::CartesianIndex{N} - pixel_2::CartesianIndex{N} + pixel_1::Pixel{T} + pixel_2::Pixel{T} end -function Edge{N}(pixel_image::AbstractArray, ind1::CartesianIndex{N}, ind2::CartesianIndex{N}, range) where N - @inbounds rel = pixel_image[ind1].reliability + pixel_image[ind2].reliability - @inbounds periods = find_period(pixel_image[ind1].val, pixel_image[ind2].val, range) - return Edge{N}(rel, periods, ind1, ind2) +function Edge{T}(p1::Pixel{T}, p2::Pixel{T}, range) where {T} + rel = p1.reliability + p2.reliability + periods = find_period(p1.val, p2.val, range) + return Edge{T}(rel, periods, p1, p2) end @inline Base.isless(e1::Edge, e2::Edge) = isless(e1.reliability, e2.reliability) @@ -114,21 +114,22 @@ function unwrap_nd!(dest::AbstractArray{T, N}, src::AbstractArray{T, N}; range::Number=2*convert(T, pi), circular_dims::NTuple{N, Bool}=ntuple(_->false, Val(N)), - rng::AbstractRNG=GLOBAL_RNG) where {T, N} + rng::AbstractRNG=default_rng()) where {T, N} range_T = convert(T, range) pixel_image = init_pixels(src, rng) calculate_reliability(pixel_image, circular_dims, range_T) - edges = Edge{N}[] + edges = Edge{T}[] num_edges = _predict_num_edges(size(src), circular_dims) sizehint!(edges, num_edges) - for idx_dim=1:N + for idx_dim = 1:N populate_edges!(edges, pixel_image, idx_dim, circular_dims[idx_dim], range_T) end - sort!(edges, alg=MergeSort) - gather_pixels!(pixel_image, edges) + perm = sortperm(map(x -> x.reliability, edges); alg=MergeSort) + edges = edges[perm] + gather_pixels!(edges) unwrap_image!(dest, pixel_image, range_T) return dest @@ -145,67 +146,66 @@ end # function to broadcast function init_pixels(wrapped_image::AbstractArray{T, N}, rng) where {T, N} pixel_image = similar(wrapped_image, Pixel{T}) - Threads.@threads for i in eachindex(wrapped_image) - @inbounds pixel_image[i] = Pixel(wrapped_image[i], rng) + Threads.@threads for i in eachindex(wrapped_image, pixel_image) + pixel_image[i] = Pixel(wrapped_image[i], rng) end return pixel_image end -function gather_pixels!(pixel_image, edges) +function gather_pixels!(edges) for edge in edges - @inbounds p1 = pixel_image[edge.pixel_1] - @inbounds p2 = pixel_image[edge.pixel_2] - merge_groups!(edge, p1, p2) + p1 = edge.pixel_1 + p2 = edge.pixel_2 + if is_differentgroup(p1, p2) + periods = edge.periods + merge_groups!(periods, p1, p2) + end end end function unwrap_image!(dest, pixel_image, range) - Threads.@threads for i in eachindex(dest) - @inbounds dest[i] = muladd(range, pixel_image[i].periods, pixel_image[i].val) + Threads.@threads for i in eachindex(dest, pixel_image) + p = pixel_image[i] + dest[i] = muladd(range, p.periods, p.val) end end function wrap_val(val, range) wrapped_val = val - wrapped_val += ifelse(val > range/2, -range, zero(val)) - wrapped_val += ifelse(val < -range/2, range, zero(val)) + wrapped_val -= ifelse(val > range / 2, range, zero(val)) + wrapped_val += ifelse(val < -range / 2, range, zero(val)) return wrapped_val end function find_period(val_left, val_right, range) difference = val_left - val_right period = 0 - period += ifelse(difference > range/2, -1, 0) - period += ifelse(difference < -range/2, 1, 0) + period -= (difference > range / 2) + period += (difference < -range / 2) return period end -function merge_groups!(edge, pixel_1, pixel_2) - if is_differentgroup(pixel_1, pixel_2) - # pixel 2 is alone in group - if is_pixelalone(pixel_2) - merge_pixels!(pixel_1, pixel_2, -edge.periods) - elseif is_pixelalone(pixel_1) - merge_pixels!(pixel_2, pixel_1, edge.periods) +function merge_groups!(periods, base, target) + # target is alone in group + if is_pixelalone(target) + periods = -periods + elseif is_pixelalone(base) + base, target = target, base + else + if is_bigger(base, target) + periods = -periods else - if is_bigger(pixel_1, pixel_2) - merge_into_group!(pixel_1, pixel_2, -edge.periods) - else - merge_into_group!(pixel_2, pixel_1, edge.periods) - end + base, target = target, base end + merge_into_group!(base, target, periods) + return end + merge_pixels!(base, target, periods) end -@inline function is_differentgroup(p1::Pixel, p2::Pixel) - return p1.head !== p2.head -end -@inline function is_pixelalone(pixel::Pixel) - return pixel.head === pixel.last -end -@inline function is_bigger(p1::Pixel, p2::Pixel) - return length(p1) ≥ length(p2) -end +@inline is_differentgroup(p1::Pixel, p2::Pixel) = p1.head !== p2.head +@inline is_pixelalone(pixel::Pixel) = pixel.head === pixel.last +@inline is_bigger(p1::Pixel, p2::Pixel) = length(p1) ≥ length(p2) function merge_pixels!(pixel_base::Pixel, pixel_target::Pixel, periods) pixel_base.head.groupsize += pixel_target.head.groupsize @@ -213,12 +213,13 @@ function merge_pixels!(pixel_base::Pixel, pixel_target::Pixel, periods) pixel_base.head.last = pixel_target.head.last pixel_target.head = pixel_base.head pixel_target.periods = pixel_base.periods + periods + return nothing end function merge_into_group!(pixel_base::Pixel, pixel_target::Pixel, periods) add_periods = pixel_base.periods + periods - pixel_target.periods pixel = pixel_target.head - while pixel ≠ nothing + while !isnothing(pixel) # merge all pixels in pixel_target's group to pixel_base's group if pixel !== pixel_target pixel.periods += add_periods @@ -230,25 +231,18 @@ function merge_into_group!(pixel_base::Pixel, pixel_target::Pixel, periods) merge_pixels!(pixel_base, pixel_target, periods) end -function populate_edges!(edges, pixel_image::Array{T, N}, dim, connected, range) where {T, N} - size_img = collect(size(pixel_image)) - size_img[dim] -= 1 - idx_step = fill(0, N) - idx_step[dim] += 1 - idx_step_cart = CartesianIndex{N}(NTuple{N,Int}(idx_step)) - idx_size = CartesianIndex{N}(NTuple{N,Int}(size_img)) - for i in CartesianIndices(idx_size) - push!(edges, Edge{N}(pixel_image, i, i+idx_step_cart, range)) +function populate_edges!(edges::Vector{Edge{T}}, pixel_image::AbstractArray{Pixel{T},N}, dim, connected, range) where {T,N} + idx_step = ntuple(i -> Int(i == dim), Val(N)) + idx_step_cart = CartesianIndex{N}(idx_step) + image_inds = CartesianIndices(pixel_image) + fi, li = first(image_inds), last(image_inds) + for i in fi:li-idx_step_cart + push!(edges, Edge{T}(pixel_image[i], pixel_image[i + idx_step_cart], range)) end if connected - idx_step = fill!(idx_step, 0) - idx_step[dim] = -size_img[dim] - idx_step_cart = CartesianIndex{N}(NTuple{N,Int}(idx_step)) - edge_begin = ones(Int, N) - edge_begin[dim] = size(pixel_image)[dim] - edge_begin_cart = CartesianIndex{N}(NTuple{N,Int}(edge_begin)) - for i in CartesianIndices(ntuple(dim_idx -> edge_begin_cart[dim_idx]:size(pixel_image, dim_idx), N)) - push!(edges, Edge{N}(pixel_image, i, i+idx_step_cart, range)) + idx_step_cart *= size(pixel_image, dim) - 1 + for i in fi+idx_step_cart:li + push!(edges, Edge{T}(pixel_image[i], pixel_image[i - idx_step_cart], range)) end end end @@ -256,11 +250,14 @@ end function calculate_reliability(pixel_image::AbstractArray{T, N}, circular_dims, range) where {T, N} # get the shifted pixel indices in CartesinanIndex form # This gets all the nearest neighbors (CartesionIndex{N}() = one(CartesianIndex{N})) - pixel_shifts = CartesianIndices(ntuple(i -> -1:1, N)) + one_cart = oneunit(CartesianIndex{N}) + pixel_shifts = -one_cart:one_cart + image_inds = CartesianIndices(pixel_image) + fi, li = first(image_inds) + one_cart, last(image_inds) - one_cart size_img = size(pixel_image) # inner loop - for i in CartesianIndices(ntuple(dim -> 2:(size(pixel_image, dim)-1), N)) - @inbounds pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts, range) + for i in fi:li + pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts, range) end if !(true in circular_dims) @@ -276,54 +273,45 @@ function calculate_reliability(pixel_image::AbstractArray{T, N}, circular_dims, for (idx_ps, ps) in enumerate(pixel_shifts_border) # if the pixel shift goes out of bounds, we make the shift wrap if ps[idx_dim] == 1 - fill!(new_ps, 0) new_ps[idx_dim] = -size_img[idx_dim]+1 pixel_shifts_border[idx_ps] = CartesianIndex{N}(NTuple{N,Int}(new_ps)) + new_ps[idx_dim] = 0 end end - border_range = get_border_range(size_img, idx_dim, size_img[idx_dim]) + border_range = get_border_range(fi:li, idx_dim, li[idx_dim] + 1) for i in CartesianIndices(border_range) - @inbounds pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts_border, range) + pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts_border, range) end # second border pixel_shifts_border = copyto!(pixel_shifts_border, pixel_shifts) for (idx_ps, ps) in enumerate(pixel_shifts_border) # if the pixel shift goes out of bounds, we make the shift wrap, this time to the other side if ps[idx_dim] == -1 - fill!(new_ps, 0) new_ps[idx_dim] = size_img[idx_dim]-1 pixel_shifts_border[idx_ps] = CartesianIndex{N}(NTuple{N,Int}(new_ps)) + new_ps[idx_dim] = 0 end end - border_range = get_border_range(size_img, idx_dim, 1) + border_range = get_border_range(fi:li, idx_dim, fi[idx_dim] - 1) for i in CartesianIndices(border_range) - @inbounds pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts_border, range) + pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts_border, range) end end end end -function get_border_range(size_img::NTuple{N, T}, border_dim, border_idx) where {N, T} - border_range = [2:(size_img[dim]-1) for dim=1:N] +function get_border_range(C::CartesianIndices{N}, border_dim, border_idx) where {N} + border_range = [C.indices[dim] for dim=1:N] border_range[border_dim] = border_idx:border_idx return NTuple{N,UnitRange{Int}}(border_range) end function calculate_pixel_reliability(pixel_image::AbstractArray{Pixel{T},N}, pixel_index, pixel_shifts, range) where {T,N} pix_val = pixel_image[pixel_index].val - rel_contrib(shift) = @inbounds wrap_val(pixel_image[pixel_index+shift].val - pix_val, range)^2 + rel_contrib(shift) = wrap_val(pixel_image[pixel_index+shift].val - pix_val, range)^2 # for N=3, pixel_shifts[14] is null shift, can avoid if manually unrolling loop sum_val = sum(rel_contrib, pixel_shifts) return sum_val end -# specialized pixel reliability calculations for different N -@inbounds function calculate_pixel_reliability(pixel_image::AbstractArray{Pixel{T}, 2}, pixel_index, pixel_shifts, range) where T - D1 = wrap_val(pixel_image[pixel_index+pixel_shifts[2]].val - pixel_image[pixel_index].val, range) - D2 = wrap_val(pixel_image[pixel_index+pixel_shifts[4]].val - pixel_image[pixel_index].val, range) - H = wrap_val(pixel_image[pixel_index+pixel_shifts[6]].val - pixel_image[pixel_index].val, range) - V = wrap_val(pixel_image[pixel_index+pixel_shifts[8]].val - pixel_image[pixel_index].val, range) - return H*H + V*V + D1*D1 + D2*D2 -end - end diff --git a/test/unwrap.jl b/test/unwrap.jl index 5d775fc7..b6ed068a 100644 --- a/test/unwrap.jl +++ b/test/unwrap.jl @@ -1,4 +1,5 @@ using DSP, Test +using Random: MersenneTwister @testset "Unwrap 1D" begin @test unwrap([0.1, 0.2, 0.3, 0.4]) ≈ [0.1, 0.2, 0.3, 0.4]