diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f617635..90c1655 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -8,6 +8,8 @@ on: - test/** - Project.toml pull_request: + branches: + - master paths: - src/** - test/** diff --git a/docs/make.jl b/docs/make.jl index ad850fd..39231b5 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -20,7 +20,7 @@ makedocs( "DWT" => "api/dwt.md", "ACWT" => "api/acwt.md", "SWT" => "api/swt.md", - "SIWPD" => "api/siwpd.md", + "SIWT" => "api/siwt.md", "WaveMult" => "api/wavemult.md", "Best Basis" => "api/bestbasis.md", "Denoising" => "api/denoising.md", diff --git a/docs/src/api/bestbasis.md b/docs/src/api/bestbasis.md index 2d323f4..dec370a 100644 --- a/docs/src/api/bestbasis.md +++ b/docs/src/api/bestbasis.md @@ -17,13 +17,11 @@ BestBasis.ShannonEntropyCost BestBasis.LogEnergyEntropyCost BestBasis.coefcost BestBasis.tree_costs -BestBasis.tree_costs(::AbstractMatrix{T}, ::AbstractVector{BitVector}, ::SIBB) where T<:Number ``` # Best Basis Tree Selection ```@docs BestBasis.bestbasis_treeselection -BestBasis.bestbasis_treeselection(::AbstractVector{Tc}, ::AbstractVector{Tt}) where {Tc<:AbstractVector{<:Union{Number,Nothing}}, Tt<:BitVector} BestBasis.delete_subtree! ``` @@ -33,8 +31,6 @@ BestBasis.BestBasisType BestBasis.LSDB BestBasis.JBB BestBasis.BB -BestBasis.SIBB Wavelets.Threshold.bestbasistree -Wavelets.Threshold.bestbasistree(::AbstractMatrix{T}, ::Integer, ::SIBB) where T<:Number BestBasis.bestbasistreeall ``` diff --git a/docs/src/api/siwpd.md b/docs/src/api/siwpd.md deleted file mode 100644 index 1cc911a..0000000 --- a/docs/src/api/siwpd.md +++ /dev/null @@ -1,11 +0,0 @@ -# Shift Invariant Wavelet Packet Decomposition - -```@index -Modules = [SIWPD] -``` - -## Public API -```@autodocs -Modules = [SIWPD] -Private = false -``` diff --git a/docs/src/api/siwt.md b/docs/src/api/siwt.md new file mode 100644 index 0000000..3fa68fc --- /dev/null +++ b/docs/src/api/siwt.md @@ -0,0 +1,49 @@ +# [Shift Invariant Wavelet Packet Decomposition](@id siwt_api) + +```@index +Modules = [SIWT] +``` + +## Data Structures +```@docs +SIWT.ShiftInvariantWaveletTransformNode +SIWT.ShiftInvariantWaveletTransformObject +``` + +## Signal Transform and Reconstruction +### Public API +```@docs +SIWT.siwpd +SIWT.isiwpd +``` + +### Private API +```@docs +SIWT.siwpd_subtree! +SIWT.isiwpd_subtree! +``` + +## Best Basis Search +### Public API +```@docs +SIWT.bestbasistree! +``` + +### Private API +```@docs +SIWT.bestbasis_treeselection! +``` + +## Single Step Transforms +### Private API +```@docs +SIWT.sidwt_step! +SIWT.isidwt_step! +``` + +## Other Utils +### Private API +```@docs +Wavelets.Util.isvalidtree(::ShiftInvariantWaveletTransformObject) +SIWT.delete_node! +``` \ No newline at end of file diff --git a/docs/src/api/utils.md b/docs/src/api/utils.md index 083fe6e..0cc8a18 100644 --- a/docs/src/api/utils.md +++ b/docs/src/api/utils.md @@ -18,7 +18,7 @@ Utils.finestdetailrange ## Tree traversing functions ```@docs -Wavelets.Util.isvalidtree +Wavelets.Util.isvalidtree(::AbstractMatrix,::BitVector) Wavelets.Util.maketree Utils.getchildindex Utils.getparentindex diff --git a/docs/src/manual/bestbasis.md b/docs/src/manual/bestbasis.md index cf58702..da63060 100644 --- a/docs/src/manual/bestbasis.md +++ b/docs/src/manual/bestbasis.md @@ -126,13 +126,15 @@ plot(p1, p2, p3, p4, p5, p6, layout=(3,2)) ## [Best Basis of Shift-Invariant Wavelet Packet Decomposition](@id si_bestbasis) One can think of searching for the best basis of the shift-invariant wavelet packet decomposition as a problem of finding ``\min_{b \in B} \sum_{x \in X} M_x(b)``, where ``X`` is all the possible shifted versions of an original signal ``y``. One can compute the best basis tree as follows: -```@example wt -xw = siwpd(x, wt) - -# SIBB -tree = bestbasistree(xw, 7, SIBB()); -nothing #hide -``` - -!!! warning - SIWPD is still undergoing large changes in terms of data structures and efficiency improvements. Syntax changes may occur in the next patch updates. +```@repl +x = [2,3,-4,5.0]; +wt = wavelet(WT.haar); +xwObj = siwpd(x, wt, 1, 1); +xwObj.BestTree # Original tree (all decomposed nodes) + +bestbasistree!(xwObj) +xwObj.BestTree # Best basis tree + +x̂ = isiwpd(xwObj) # Reconstruction of signal +x̂ == x +``` \ No newline at end of file diff --git a/docs/src/manual/transforms.md b/docs/src/manual/transforms.md index 1a34593..8187e8a 100644 --- a/docs/src/manual/transforms.md +++ b/docs/src/manual/transforms.md @@ -142,14 +142,15 @@ plot(plot(img, title="Original"), p0, p1, p2, layout=(2,2)) The [Shift-Invariant Wavelet Decomposition (SIWPD)](https://israelcohen.com/wp-content/uploads/2018/05/ICASSP95.pdf) is developed by Cohen et. al.. While it is also a type of redundant transform, it does not follow the same methodology as the SWT and the ACWT. Cohen's main goal for developing this algorithm was to obtain a global minimum entropy from a signal and all its shifted versions. See [its best basis implementation](@ref si_bestbasis) for more information. One can compute the SIWPD of a single signal as follows. -```@example wt -# decomposition -xw = siwpd(x, wt); -nothing # hide +```@repl +x = [2,3,-4,5.0]; +wt = wavelet(WT.haar); +xwObj = siwpd(x, wt, 1, 1); +xwObj.Nodes[(0,0,0)] # Example of looking into a node +xwObj.BestTree # Looking into what's within the tree structure ``` +For more information on the API, visit the [SIWT API page](@ref siwt_api). -!!! note - As of right now, there is not too many functions written based on the SIWPD, as it does not follow the conventional style of wavelet transforms. There is a lot of ongoing work to develop more functions catered for the SIWPD such as it's inverse transforms and group-implementations. diff --git a/src/WaveletsExt.jl b/src/WaveletsExt.jl index 429e5dc..3780e62 100644 --- a/src/WaveletsExt.jl +++ b/src/WaveletsExt.jl @@ -4,9 +4,9 @@ module WaveletsExt include("mod/Utils.jl") include("mod/DWT.jl") -include("mod/SIWPD.jl") include("mod/ACWT.jl") include("mod/BestBasis.jl") +include("mod/SIWT.jl") include("mod/SWT.jl") include("mod/Denoising.jl") include("mod/LDB.jl") @@ -20,7 +20,7 @@ using Reexport .Utils, .LDB, .SWT, - .SIWPD, + .SIWT, .ACWT, .Visualizations, .WaveMult diff --git a/src/mod/BestBasis.jl b/src/mod/BestBasis.jl index d0fc3f2..64bd55d 100644 --- a/src/mod/BestBasis.jl +++ b/src/mod/BestBasis.jl @@ -17,7 +17,6 @@ export LSDB, JBB, BB, - SIBB, # best basis tree bestbasistreeall @@ -30,8 +29,7 @@ using using ..Utils, - ..DWT, - ..SIWPD + ..DWT include("bestbasis/bestbasis_costs.jl") include("bestbasis/bestbasis_tree.jl") @@ -111,72 +109,6 @@ function bestbasis_treeselection(costs::AbstractVector{T}, return tree end -# SIWPD tree selection -""" - bestbasis_treeselection(costs, tree) - -Best basis tree selection on SIWPD. - -# Arguments -- `costs::AbstractVector{Tc} where Tc<:AbstractVector{<:Union{Number, Nothing}}`: Cost of - each node. -- `tree::AbstractVector{Tt} where Tt<:BitVector`: SIWPD tree. - -!!! warning - Current implementation works but is unstable, ie. we are still working on better - syntax/more optimized computations/better data structure. -""" -function bestbasis_treeselection(costs::AbstractVector{Tc}, - tree::AbstractVector{Tt}) where - {Tc<:AbstractVector{<:Union{Number,Nothing}}, Tt<:BitVector} - - @assert length(costs) == length(tree) - bt = deepcopy(tree) - bc = deepcopy(costs) - nn = length(tree) - for i in reverse(eachindex(bt)) - if getchildindex(i,:left) > nn # current node is at bottom level - continue - end - level = floor(Int, log2(i)) - for j in eachindex(bt[i]) # iterate through all available shifts - if !bt[i][j] # node of current shift does not exist - continue - end - @assert bt[getchildindex(i,:left)][j] == - bt[getchildindex(i,:right)][j] == - bt[getchildindex(i,:left)][j+1< sum(node), bt) .<= 1) - return bt -end - # Deletes subtree due to inferior cost """ delete_subtree!(bt, i, tree_type) @@ -207,28 +139,6 @@ function delete_subtree!(bt::BitVector, i::Integer, tree_type::Symbol) return bt end -function delete_subtree!(bt::AbstractVector{BitVector}, i::Integer, j::Integer) - @assert 1 <= i <= length(bt) - level = floor(Int, log2(i)) - bt[i][j] = false - if (getchildindex(i,:left)) < length(bt) # current node can contain subtrees - if bt[getchildindex(i,:left)][j] # left child of current shift - delete_subtree!(bt, getchildindex(i,:left), j) - end - if bt[getchildindex(i,:right)][j] # right child of current shift - delete_subtree!(bt, getchildindex(i,:right), j) - end - if bt[getchildindex(i,:left)][j+1<0. -If shift is 0, ie. ss=0, the main root will not be further decomposed even if d>0. -=======================================================================================# -function siwpd_subtree!(y::AbstractMatrix{T₁}, h::Vector{T₂}, g::Vector{T₂}, - i::S, d::S, ss::S; - L::S = getdepth(i, :binary), - len::S = nodelength(size(y,1), L)) where - {T₁<:AbstractFloat, T₂<:AbstractFloat, S<:Integer} - # Sanity check - n, k = size(y) - @assert 0 ≤ L ≤ maxtransformlevels(n) - @assert 0 ≤ ss < n - @assert 0 ≤ d ≤ getdepth(k,:binary)-L - @assert iseven(len) || len == 1 - - # --- Base case --- - d > 0 || return y - - # --- One level of decomposition --- - v_start = ss*len+1 # Start index for parent node - v_end = (ss+1)*len # End index for parent node - @inbounds v = @view y[v_start:v_end, i] # Parent node - w_start = ss*(len÷2)+1 # Start index for child node - w_end = w_start+(len÷2)-1 # End index for child node - @inbounds w₁ = @view y[w_start:w_end, getchildindex(i, :left)] # Left child node - @inbounds w₂ = @view y[w_start:w_end, getchildindex(i, :right)] # Right child node - sidwt_step!(w₁, w₂, v, h, g, false) # 1 step of non-shifted decomposition - w_start = (ss+1<0 --- - if ss > 0 # Decomposition of current shift - siwpd_subtree!(y, h, g, getchildindex(i, :left), d-1, ss, L=L+1, len=len÷2) - siwpd_subtree!(y, h, g, getchildindex(i, :right), d-1, ss, L=L+1, len=len÷2) - end - # Decomposition of shifted version - siwpd_subtree!(y, h, g, getchildindex(i, :left), d-1, ss+1< k₁ -> k₁>n ? mod1(k₁,n) : k₁ - k₂ = k₂-1 |> k₂ -> k₂≤0 ? mod1(k₂,n) : k₂ - @inbounds w₁[i] += g[end-j+1] * v[k₁] - @inbounds w₂[i] += h[j] * v[k₂] - end - end - return w₁, w₂ -end - -""" - makesiwpdtree(n, L, d) - -Returns the multi-level, multi-depth binary tree corresponding to the Shift- -Invariant Wavelet Packet Decomposition. -""" -function makesiwpdtree(n::Integer, L::Integer, d::Integer) - @assert 0 ≤ L ≤ maxtransformlevels(n) - @assert 1 ≤ d ≤ L - - tree = Vector{BitVector}(undef, 2^(L+1)-1) - for i in eachindex(tree) - level = floor(Int, log2(i)) - len = nodelength(n, level) - nshift = n ÷ len # number of possible shifts for subspace Ω(i,j) - node = falses(nshift) - k = ceil(Int, nshift / 1< 0) + isShiftedTransformNode = nodeTransformShift > 0 + isShiftedTransformLeafNode = (!isShiftedTransform4NodeRequired && isShiftedTransformNode) + if (isLeafNode || isShiftedTransformLeafNode) + return nothing + end + + # --- General step --- + # - Decompose current node without additional shift + # - Decompose children nodes + childDepth = nodeDepth + 1 + (child1Index, child2Index) = sidwt_step!(siwtObj, index, h, g, false) + childRemainingRelativeDepth4ShiftedTransform = isShiftedTransformNode ? + remainingRelativeDepth4ShiftedTransform-1 : + min(remainingRelativeDepth4ShiftedTransform, treeMaxTransformLevel-childDepth) + siwpd_subtree!(siwtObj, child1Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) + siwpd_subtree!(siwtObj, child2Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) + + # Case: remainingRelativeDepth4ShiftedTransform > 0 + # - Decompose current node with additional shift + # - Decompose children (with additional shift) nodes + if isShiftedTransform4NodeRequired + (child1Index, child2Index) = sidwt_step!(siwtObj, index, h, g, true) + childRemainingRelativeDepth4ShiftedTransform = remainingRelativeDepth4ShiftedTransform-1 + siwpd_subtree!(siwtObj, child1Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) + siwpd_subtree!(siwtObj, child2Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) + end + + return nothing +end + +""" + isiwpd(siwtObj) + +Computes the Inverse Shift-Invariant Wavelet Packet Decomposition originally developed by +Cohen, Raz & Malah. + +# Arguments +- `siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂} where + {N,T₁<:Integer,T₂<:AbstractFloat`: SIWT object. + +# Returns +- `Vector{T₂}`: Reconstructed signal. +""" +function isiwpd(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}) where + {N, T₁<:Integer, T₂<:AbstractFloat} + g, h = WT.makereverseqmfpair(siwtObj.Wavelet, true) + rootNodeIndex = (0,0,0) + isiwpd_subtree!(siwtObj, rootNodeIndex, h, g) + + return siwtObj.Nodes[rootNodeIndex].Value +end + +""" + isiwpd_subtree!(siwtObj, index, h, g) + +Runs the recursive computation of Inverse Shift-Invariant Wavelet Transform (SIWT) at each +node `index`. + +# Arguments +- `siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂} where + {N,T₁<:Integer,T₂<:AbstractFloat`: SIWT object. +- `index::NTuple{3,T₁} where T₁<:Integer`: Index of current node to be decomposed. +- `h::Vector{T₃} where T₃<:AbstractFloat`: High pass filter. +- `g::Vector{T₃} where T₃<:AbstractFloat`: Low pass filter. +""" +function isiwpd_subtree!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}, + index::NTuple{3,T₁}, + h::Vector{T₃}, g::Vector{T₃}) where + {N, T₁<:Integer, T₂<:AbstractFloat, T₃<:AbstractFloat} + # Check for children nodes + nodeDepth, nodeIndexAtDepth, nodeTransformShift = index + hasNonShiftedChildren = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift) ∈ siwtObj.BestTree + hasShiftedChildren = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift+(1< k₁ -> k₁>n ? mod1(k₁,n) : k₁ + k₂ = k₂-1 |> k₂ -> k₂≤0 ? mod1(k₂,n) : k₂ + @inbounds w₁[i] += g[end-j+1] * v[k₁] + @inbounds w₂[i] += h[j] * v[k₂] + end + end + return w₁, w₂ +end + +""" + isidwt_step!(siwtObj, index, child1Index, child2Index, h, g) + +Computes one step of the inverse SIWT decomposition on the node `index`. + +# Arguments +- `siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂} where + {N,T₁<:Integer,T₂<:AbstractFloat`: SIWT object. +- `index::NTuple{3,T₁} where T₁<:Integer`: Index of current node to be decomposed. +- `child1Index::NTuple{T,T₁} where T₁<:Integer`: Index of child 1. +- `child2Index::NTuple{T,T₁} where T₁<:Integer`: Index of child 2. +- `h::Vector{T₃} where T₃<:AbstractFloat`: High pass filter. +- `g::Vector{T₃} where T₃<:AbstractFloat`: Low pass filter. +""" +function isidwt_step!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}, + nodeIndex::NTuple{3,T₁}, + child1Index::NTuple{3,T₁}, child2Index::NTuple{3,T₁}, + h::Vector{T₃}, g::Vector{T₃}) where + {N, T₁<:Integer, T₂<:AbstractFloat, T₃<:AbstractFloat} + nodeObj = siwtObj.Nodes[nodeIndex] + child1Obj = siwtObj.Nodes[child1Index] + child2Obj = siwtObj.Nodes[child2Index] + + @assert child1Obj.TransformShift == child2Obj.TransformShift + isShiftedTransform = nodeObj.TransformShift == child1Obj.TransformShift + + nodeValue = nodeObj.Value + child1Value = child1Obj.Value + child2Value = child2Obj.Value + isidwt_step!(nodeValue, child1Value, child2Value, h, g, isShiftedTransform) + + return nothing +end + +""" + isidwt_step!(v, w₁, w₂, h, g, s) + +Computes one step of the inverse SIWT decomposition on the node `w₁` and `w₂`. + +# Arguments +- `v::AbstractVector{T} where T<:AbstractFloat`: Vector allocation for reconstructed coefficients. +- `w₁::AbstractVector{T} where T<:AbstractFloat`: Vector allocation for output from low pass + filter. +- `w₂::AbstractVector{T} where T<:AbstractFloat`: Vector allocation for output from high pass + filter. +- `h::Vector{S} where S<:AbstractFloat`: High pass filter. +- `g::Vector{S} where S<:AbstractFloat`: Low pass filter. +- `s::Bool`: Whether a shifted inverse transform should be performed. + +# Returns +- `v::Vector{T}`: Reconstructed coefficients. +""" +function isidwt_step!(v::AbstractVector{T}, + w₁::AbstractVector{T}, w₂::AbstractVector{T}, + h::Array{S,1}, g::Array{S,1}, + s::Bool) where {T<:AbstractFloat, S<:AbstractFloat} + # Sanity check + @assert length(w₁) == length(w₂) == length(v)÷2 + @assert length(h) == length(g) + + # Setup + n = length(v) # Parent length + n₁ = length(w₁) # Child length + filtlen = length(h) # Filter length + + # One step of inverse discrete transform + for i in 1:n + ℓ = mod1(i-s,n) # Index of reconstructed vector + j₀ = mod1(i,2) # Pivot point to determine start index for filter + j₁ = filtlen-j₀+1 # Index for low pass filter g + j₂ = mod1(i+1,2) # Index for high pass filter h + k₁ = (i+1)>>1 # Index for approx coefs w₁ + k₂ = (i+1)>>1 # Index for detail coefs w₂ + @inbounds v[ℓ] = g[j₁] * w₁[k₁] + h[j₂] * w₂[k₂] + for j in (j₀+2):2:filtlen + j₁ = filtlen-j+1 + j₂ = j + isodd(j) - iseven(j) + k₁ = k₁-1 |> k₁ -> k₁≤0 ? mod1(k₁,n₁) : k₁ + k₂ = k₂+1 |> k₂ -> k₂>n₁ ? mod1(k₂,n₁) : k₂ + @inbounds v[ℓ] += g[j₁] * w₁[k₁] + h[j₂] * w₂[k₂] + end + end + return v +end \ No newline at end of file diff --git a/src/mod/siwt/siwt_utls.jl b/src/mod/siwt/siwt_utls.jl new file mode 100644 index 0000000..90f8f66 --- /dev/null +++ b/src/mod/siwt/siwt_utls.jl @@ -0,0 +1,224 @@ +""" + ShiftInvariantWaveletTransformNode{N, T₁<:Integer, T₂<:AbstractFloat} + +Data structure to hold the index, coefficients, and cost value of an SIWT node. + +# Parameters +- `Depth::T₁`: The depth of current node. Root node has depth 0. +- `IndexAtDepth::T₁`: The node index at current depth. Index starts from 0 at each depth. +- `TransformShift::T₁`: The type of shift operated on the parent node before computing the + transform. Accepted type of shift values are: + - `0`: Coefficients of parent node are not shifted prior to transform. + - `1`: For both 1D and 2D signals, coefficients of parent node are circularly shifted to + the left by 1 index. + - `2`: For 2D signals, coefficients of parent node are circularly shifted from bottom to + top by 1 index. Not available for 1D signals. + - `3`: For 2D signals, coefficients of parent node are circularly shifted from bottom to + top and right to left by 1 index each. Not available for 1D signals. +- `Cost::T₂`: The [`ShannonEntropyCost`](@ref) of current node. +- `Value::Array{T₂,N}`: Coefficients of current node. +""" +mutable struct ShiftInvariantWaveletTransformNode{N, T₁<:Integer, T₂<:AbstractFloat} + Depth::T₁ + IndexAtDepth::T₁ + TransformShift::T₁ + Cost::T₂ + Value::Array{T₂,N} + + function ShiftInvariantWaveletTransformNode{N, T₁, T₂}(Depth, IndexAtDepth, TransformShift, Cost, Value) where {N, T₁<:Integer, T₂<:AbstractFloat} + if N ≠ ndims(Value) + (throw ∘ TypeError)("Value array is not of $N dimension.") + end + + if N == 1 + maxIndexAtDepth = 1< maxIndexAtDepth) | (TransformShift > maxTransformShift) + (throw ∘ ArgumentError)("Invalid IndexAtDepth or TransformShift for $(N)D coefficients.") + end + return new(Depth, IndexAtDepth, TransformShift, Cost, Value) + end +end + +""" + ShiftInvariantWaveletTransformObject{N, T₁<:Integer, T₂<:AbstractFloat} + +Data structure to hold all shift-invariant wavelet transform (SIWT) information of a signal. + +# Parameters +- `Nodes::Dict{NTuple{3,T₁}, ShiftInvariantWaveletTransformNode{N,T₁,T₂}}`: + Dictionary containing the information of each node within a tree. +- `SignalSize::Union{T₁, NTuple{N,T₁}}`: Size of the original signal. +- `MaxTransformLevel::T₁`: Maximum levels of transform set for signal. +- `Wavelet::OrthoFilter`: A discrete wavelet for transform purposes. +- `MinCost::Union{T₂, Nothing}`: The current minimum [`ShannonEntropyCost`](@ref) cost of + the decomposition tree. + +!!! note + `MinCost` parameter will contain the cost of the root node by default. To compute the + true minimum cost of the decomposition tree, one will need to first compute the SIWT + best basis by calling [`bestbasistree`](@ref). + +- `BestTree::Vector{NTuple{3,T₁}}`: A collection of node indices that belong in the best + basis tree. + +!!! note + `BestTree` parameter will contain all the nodes by default, ie. the full decomposition + will be the best tree. To get the SIWT best basis tree that produce the minimum cost, + one will neeed to call [`bestbasistree`](@ref). + +**See also:** [`ShiftInvariantWaveletTransformNode`](@ref) +""" +mutable struct ShiftInvariantWaveletTransformObject{N, T₁<:Integer, T₂<:AbstractFloat} + Nodes::Dict{NTuple{3,T₁}, ShiftInvariantWaveletTransformNode{N,T₁,T₂}} + SignalSize::Union{T₁, NTuple{N,T₁}} + MaxTransformLevel::T₁ + MaxShiftedTransformLevels::T₁ + Wavelet::OrthoFilter + MinCost::Union{T₂, Nothing} + BestTree::Vector{NTuple{3,T₁}} + + function ShiftInvariantWaveletTransformObject{N,T₁,T₂}(nodes, signalSize, maxTransformLevel, maxShiftedTransformLevels, wt, minCost, bestTree) where {N, T₁<:Integer, T₂<:AbstractFloat} + 0 ≤ maxTransformLevel ≤ maxtransformlevels(nodes[(0,0,0)].Value) || (throw ∘ ArgumentError)("Provided MaxTransformLevels is too large.") + 0 ≤ maxShiftedTransformLevels < length(nodes[(0,0,0)].Value) || (throw ∘ ArgumentError)("Provided MaxShiftedTransformLevels is too large.") + return new(nodes, signalSize, maxTransformLevel, maxShiftedTransformLevels, wt, minCost, bestTree) + end +end + +""" + ShiftInvariantWaveletTransformNode(data, depth, indexAtDepth, transformShift) + +Outer constructor of SIWT node. + +# Arguments +- `data::Array{T} where T<:AbstractFloat`: Array of coefficients. +- `depth::S where S<:Integer`: Depth of current node. +- `indexAtDepth::S where S<:Integer`: Node index at current depth. +- `transformShift::S where S<:Integer`: The type of shift operated on the parent node before + computing the transform. +- `nrm::T where T<:AbstractFloat`: (Default: `norm(data)`) Norm of the signal. +""" +function ShiftInvariantWaveletTransformNode(data::Array{T}, + depth::S, + indexAtDepth::S, + transformShift::S, + nrm::T = norm(data)) where {T<:AbstractFloat, S<:Integer} + cost = coefcost(data, ShannonEntropyCost(), nrm) + N = ndims(data) + return ShiftInvariantWaveletTransformNode{N,S,T}(depth, indexAtDepth, transformShift, cost, data) +end + +""" + ShiftInvariantWaveletTransformObject(signal, wavelet) + +Outer constructor and initialization of SIWT object. + +# Arguments +- `signal::Array{T} where T<:AbstractFloat`: Input signal. +- `wavelet::OrthoFilter`: Wavelet filter. +- `maxTransformLevel::S where S<:Integer`: (Default: `0`) Max transform level. +- `maxShiftedTransformLevel::S where S<:Integer`: (Default: `0`) Max shifted transform + levels. +""" +function ShiftInvariantWaveletTransformObject(signal::Array{T}, + wavelet::OrthoFilter, + maxTransformLevel::S = 0, + maxShiftedTransformLevel::S = 0) where + {T<:AbstractFloat, S<:Integer} + signalDim = ndims(signal) + signalSize = signalDim == 1 ? length(signal) : size(signal) + cost = coefcost(signal, ShannonEntropyCost()) + signalNode = ShiftInvariantWaveletTransformNode{signalDim,S,T}(0, 0, 0, cost, signal) + index = (signalNode.Depth, signalNode.IndexAtDepth, signalNode.TransformShift) + nodes = Dict{NTuple{3,S}, ShiftInvariantWaveletTransformNode{signalDim,S,T}}(index => signalNode) + tree = [index] + return ShiftInvariantWaveletTransformObject{signalDim,S,T}(nodes, signalSize, maxTransformLevel, maxShiftedTransformLevel, wavelet, cost, tree) +end + +""" + Wavelets.Util.isvalidtree(siwtObj) + +Checks if tree within SIWT object is a valid tree. + +# Arguments +- `siwtObj::ShiftInvariantWaveletTransformObject`: SIWT object. + +# Returns +- `bool`: `true` if tree is valid, `false` otherwise. + +!!! note + `isvalidtree` checks the criteria that each node has a parent (except root node) and one + set of children. A node can have either non-shifted or shifted children, but not both. + Using this function on a decomposed `siwtObj` prior to its best basis search will return + `false`. +""" +function Wavelets.Util.isvalidtree(siwtObj::ShiftInvariantWaveletTransformObject) + @assert (Set ∘ keys)(siwtObj.Nodes) == Set(siwtObj.BestTree) + + # Each node needs to: + # - Be a root node OR have a parent node + # - Have child nodes XOR have shifted child nodes XOR have no child nodes + nodeSet = Set(siwtObj.BestTree) + for index in nodeSet + depth, indexAtDepth, transformShift = index + + isRootNode = index == (0,0,0) + hasParentNode = (depth-1, indexAtDepth>>1, transformShift) ∈ nodeSet + + hasChildNodes = (depth+1, indexAtDepth<<1, transformShift) ∈ nodeSet && (depth+1, indexAtDepth<<1+1, transformShift) ∈ nodeSet + hasShiftedChildNodes = (depth+1, indexAtDepth<<1, transformShift+(1< x==index, siwtObj.BestTree)) + + # Delete children nodes + nodeDepth, nodeIndexAtDepth, nodeTransformShift = index + child1Index = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift) + child2Index = (nodeDepth+1, nodeIndexAtDepth<<1+1, nodeTransformShift) + shiftedChild1Index = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift+(1< sum(node), bestbasistree(xw0, 4, SIBB())) -bt4 = map(node -> sum(node), bestbasistree(xw4, 4, SIBB())) -@test bt0 == bt4 - # misc @test_throws ArgumentError BestBasis.bestbasis_treeselection(randn(15), 8, :fail) @test_throws AssertionError BestBasis.bestbasis_treeselection(randn(7), 3, :fail) # true n=4 diff --git a/test/runtests.jl b/test/runtests.jl index 8395460..ee78eb5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,10 +12,10 @@ using WaveletsExt, SparseArrays -@testset "Utils" begin include("utils.jl") end -@testset "Transforms" begin include("transforms.jl") end -@testset "Wavelet Multiplication" begin include("wavemult.jl") end -@testset "Best Basis" begin include("bestbasis.jl") end -@testset "Denoising" begin include("denoising.jl") end -@testset "LDB" begin include("ldb.jl") end -@testset "Visualizations" begin include("visualizations.jl") end \ No newline at end of file +@testset verbose=true "Utils" begin include("utils.jl") end +@testset verbose=true "Transforms" begin include("transforms.jl") end +@testset verbose=true "Wavelet Multiplication" begin include("wavemult.jl") end +@testset verbose=true "Best Basis" begin include("bestbasis.jl") end +@testset verbose=true "Denoising" begin include("denoising.jl") end +@testset verbose=true "LDB" begin include("ldb.jl") end +@testset verbose=true "Visualizations" begin include("visualizations.jl") end \ No newline at end of file diff --git a/test/transforms.jl b/test/transforms.jl index 02a2c25..d5c3333 100644 --- a/test/transforms.jl +++ b/test/transforms.jl @@ -174,26 +174,95 @@ end @test iacwpd(acwpd(x, wt), tree) ≈ x end -@testset "SIWPD" begin - # siwpd - x = randn(4) - wt = wavelet(WT.haar) - y = siwpd(x, wt, 2, 1) - y0 = y[1,4:7] - y1 = y[2,4:7] - y2 = y[3,4:7] - y3 = y[4,4:7] - @test y0 == wpt(x, wt) - @test !all([isdefined(y1, i) for i in 1:4]) - @test y2 == wpt(circshift(x,2), wt) - @test !all([isdefined(y3, i) for i in 1:4]) - # make tree - tree = [ - trues(1), - repeat([trues(2)],2)..., - repeat([BitVector([1,0,1,0])],4)... - ] - @test makesiwpdtree(4, 2, 1) == tree +@testset verbose=true "SIWT" begin + signal = [2,3,-4,5.0]; + wt = wavelet(WT.haar); + @testset "Data structures" begin + cost = BestBasis.coefcost(signal, ShannonEntropyCost()); + @test isa(ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,0,signal), ShiftInvariantWaveletTransformNode); + @test_throws MethodError ShiftInvariantWaveletTransformNode{2,Int64,Float64}(0,0,0,0,signal); # signal dimension and N do not match + @test_throws ArgumentError ShiftInvariantWaveletTransformNode{1,Int64,Float64}(2,4,0,0,signal); # Invalid IndexAtDepth + @test_throws ArgumentError ShiftInvariantWaveletTransformNode{1,Int64,Float64}(2,0,4,0,signal); # Invalid TransformShift + @test ShiftInvariantWaveletTransformNode(signal,0,0,0).Depth == ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,cost,signal).Depth; + @test ShiftInvariantWaveletTransformNode(signal,0,0,0).IndexAtDepth == ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,cost,signal).IndexAtDepth; + @test ShiftInvariantWaveletTransformNode(signal,0,0,0).TransformShift == ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,cost,signal).TransformShift; + @test ShiftInvariantWaveletTransformNode(signal,0,0,0).Cost == ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,cost,signal).Cost; + @test ShiftInvariantWaveletTransformNode(signal,0,0,0).Value == ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,cost,signal).Value; + @test ShiftInvariantWaveletTransformNode(signal,1,0,0) != ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,cost,signal); + + maxTransformLevels = 2; + maxTransformShift = 3; + siwtObject = ShiftInvariantWaveletTransformObject(signal, wt); + @test isa(siwtObject, ShiftInvariantWaveletTransformObject); + @test isa(siwtObject.Nodes[(0,0,0)], ShiftInvariantWaveletTransformNode); + @test siwtObject.SignalSize == 4; + @test siwtObject.MaxTransformLevel == 0; + @test siwtObject.MaxShiftedTransformLevels == 0; + @test siwtObject.Wavelet == wt; + @test siwtObject.MinCost == cost; + @test siwtObject.BestTree == [(0,0,0)]; + @test_throws ArgumentError ShiftInvariantWaveletTransformObject(signal, wt, maxTransformLevels+1); + @test_throws ArgumentError ShiftInvariantWaveletTransformObject(signal, wt, -1); + @test_throws ArgumentError ShiftInvariantWaveletTransformObject(signal, wt, 0, maxTransformShift+1); + @test_throws ArgumentError ShiftInvariantWaveletTransformObject(signal, wt, 0, -1); + end + + @testset "Transform" begin + expectedNodes = Dict{NTuple{3,Int64}, ShiftInvariantWaveletTransformNode{1,Int64,Float64}}(); + expectedNodes[(0,0,0)] = ShiftInvariantWaveletTransformNode(signal,0,0,0); + signalTransformedL1S0 = dwt(signal, wt, 1); + expectedNodes[(1,0,0)] = ShiftInvariantWaveletTransformNode(signalTransformedL1S0[1:2],1,0,0); + expectedNodes[(1,1,0)] = ShiftInvariantWaveletTransformNode(signalTransformedL1S0[3:4],1,1,0); + signalTransformedL1S1 = dwt(circshift(signal,1), wt, 1); + expectedNodes[(1,0,1)] = ShiftInvariantWaveletTransformNode(signalTransformedL1S1[1:2],1,0,1); + expectedNodes[(1,1,1)] = ShiftInvariantWaveletTransformNode(signalTransformedL1S1[3:4],1,1,1); + @test isa(siwpd(signal,wt,1,1), ShiftInvariantWaveletTransformObject) + end + + @testset "Best Basis" begin + # Best basis on original object without decomposition + siwtObject = ShiftInvariantWaveletTransformObject(signal, wt, 0, 0); + rootOnlyTree = [(0,0,0)]; + bestbasistree!(siwtObject); + @test siwtObject.BestTree == rootOnlyTree; + @test siwtObject.MinCost == siwtObject.Nodes[(0,0,0)].Cost; + @test isvalidtree(siwtObject); + + # Cost check before best basis operation + expectedNodeCost = Dict{NTuple{3,Int64}, Float64}( + (0,0,0) => 1.208, + (1,0,0) => 0.382, + (1,0,1) => 0.402, + (1,1,0) => 0.259, + (1,1,1) => 0.566 + ) + siwtObj = siwpd(signal, wt, 1); + for index in keys(expectedNodeCost) + @test expectedNodeCost[index] ≈ siwtObj.Nodes[index].Cost atol=1e-3; + end + + # Cost check after best basis operation + expectedNodeCost = Dict{NTuple{3,Int64}, Float64}( + (0,0,0) => 0.641, + (1,0,0) => 0.382, + (1,1,0) => 0.259 + ) + bestbasistree!(siwtObj) + @test (Set ∘ keys)(expectedNodeCost) == Set(siwtObj.BestTree) == (Set ∘ keys)(siwtObj.Nodes) + for index in keys(expectedNodeCost) + @test expectedNodeCost[index] ≈ siwtObj.Nodes[index].Cost atol=1e-3 + end + @test expectedNodeCost[(0,0,0)] ≈ siwtObj.MinCost atol=1e-3 + @test isvalidtree(siwtObj) + end + + @testset "Signal Reconstruction" begin + siwtObj = siwpd(signal, wt) + bestbasistree!(siwtObj) + @test isa(isiwpd(siwtObj), Vector) + reconstructedSignal = isiwpd(siwtObj) + @test reconstructedSignal ≈ signal + end end @testset "Transform All" begin