From fa6e75166cad10b705d8ab647078426988a4e4dc Mon Sep 17 00:00:00 2001 From: Zeng Fung Liew Date: Thu, 19 May 2022 23:16:25 -0700 Subject: [PATCH 01/15] Setup SIWT data structures --- src/mod/siwt/siwt_one_level.jl | 3 +++ src/mod/siwt/siwt_utls.jl | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 src/mod/siwt/siwt_one_level.jl create mode 100644 src/mod/siwt/siwt_utls.jl diff --git a/src/mod/siwt/siwt_one_level.jl b/src/mod/siwt/siwt_one_level.jl new file mode 100644 index 0000000..3a6462f --- /dev/null +++ b/src/mod/siwt/siwt_one_level.jl @@ -0,0 +1,3 @@ +# TODO: Function to compute one level of decomposition + +# TODO: Function to compute one level of recomposition \ No newline at end of file diff --git a/src/mod/siwt/siwt_utls.jl b/src/mod/siwt/siwt_utls.jl new file mode 100644 index 0000000..391816a --- /dev/null +++ b/src/mod/siwt/siwt_utls.jl @@ -0,0 +1,19 @@ +mutable struct ShiftInvariantWaveletTransformObject{N₁, N₂, T₁<:Integer, T₂<:AbstractFloat} + Nodes::Dict{NTuple{N₁,T}, Vector{ShiftInvariantWaveletTransformNode{N₂,T₁,T₂}}} + SignalLength::T₁ + MaxTransformLevel::T₁ + Wavelet::OrthoFilter + MinCost::Union{T₂, Nothing} + Tree::Union{Vector{NTuple{N₁,T₁}}, Nothing} +end + +struct ShiftInvariantWaveletTransformNode{N, T₁<:Integer, T₂<:AbstractFloat} + IsShiftedTransform::Bool + Depth::T₁ + IndexAtDepth::T₁ + IndexAtTree::T₁ + Cost::Union{T₂, Nothing} + Value::Array{T₂,N} +end + +# TODO: Create constructors \ No newline at end of file From e425b594ff6b486c2945b4e623cde73f1c3c3971 Mon Sep 17 00:00:00 2001 From: Zeng Fung Liew Date: Thu, 19 May 2022 23:24:15 -0700 Subject: [PATCH 02/15] Remove IndexAtTree from node data structure Does not seem to have any purpose. --- src/mod/siwt/siwt_utls.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mod/siwt/siwt_utls.jl b/src/mod/siwt/siwt_utls.jl index 391816a..1ef7c3c 100644 --- a/src/mod/siwt/siwt_utls.jl +++ b/src/mod/siwt/siwt_utls.jl @@ -11,7 +11,6 @@ struct ShiftInvariantWaveletTransformNode{N, T₁<:Integer, T₂<:AbstractFloat} IsShiftedTransform::Bool Depth::T₁ IndexAtDepth::T₁ - IndexAtTree::T₁ Cost::Union{T₂, Nothing} Value::Array{T₂,N} end From daaf1cc30ac8996a431d2731d3955ebbb4d88101 Mon Sep 17 00:00:00 2001 From: Zeng Fung Liew Date: Thu, 19 May 2022 23:32:35 -0700 Subject: [PATCH 03/15] Only run CI on PR to master branch --- .github/workflows/CI.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f617635..90c1655 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -8,6 +8,8 @@ on: - test/** - Project.toml pull_request: + branches: + - master paths: - src/** - test/** From dfa2a8ad28d8b9912ee7b6979ca72fc696eb8894 Mon Sep 17 00:00:00 2001 From: Zeng Fung Liew Date: Fri, 20 May 2022 23:45:14 -0700 Subject: [PATCH 04/15] Complete documentation & constructor for SIWT object. --- src/mod/SIWPD.jl | 3 +- src/mod/siwt/siwt_utls.jl | 123 +++++++++++++++++++++++++++++++++++--- 2 files changed, 117 insertions(+), 9 deletions(-) diff --git a/src/mod/SIWPD.jl b/src/mod/SIWPD.jl index d6423a3..d1c396f 100644 --- a/src/mod/SIWPD.jl +++ b/src/mod/SIWPD.jl @@ -8,7 +8,8 @@ using using ..Utils, - ..DWT + ..DWT, + ..BestBasis """ siwpd(x, wt[, L=maxtransformlevels(x), d=L]) diff --git a/src/mod/siwt/siwt_utls.jl b/src/mod/siwt/siwt_utls.jl index 1ef7c3c..42c8012 100644 --- a/src/mod/siwt/siwt_utls.jl +++ b/src/mod/siwt/siwt_utls.jl @@ -1,18 +1,125 @@ -mutable struct ShiftInvariantWaveletTransformObject{N₁, N₂, T₁<:Integer, T₂<:AbstractFloat} - Nodes::Dict{NTuple{N₁,T}, Vector{ShiftInvariantWaveletTransformNode{N₂,T₁,T₂}}} - SignalLength::T₁ +""" + ShiftInvariantWaveletTransformObject{N, T₁<:Integer, T₂<:AbstractFloat} + +Data structure to hold all shift-invariant wavelet transform (SIWT) information of a signal. + +# Parameters +- `Nodes::Dict{NTuple{3,T₁}, Vector{ShiftInvariantWaveletTransformNode{N,T₁,T₂}}}`: + Dictionary containing the information of each node within a tree. +- `SignalSize::Union{T₁, NTuple{N,T₁}}`: Size of the original signal. +- `MaxTransformLevel::T₁`: Maximum levels of transform set for signal. +- `Wavelet::OrthoFilter`: A discrete wavelet for transform purposes. +- `MinCost::Union{T₂, Nothing}`: The current minimum [`LogEnergyEntropyCost`](@ref) cost of + the decomposition tree. + +!!! note + `MinCost` parameter will contain the cost of the root node by default. To compute the + true minimum cost of the decomposition tree, one will need to first compute the SIWT + best basis by calling [`bestbasistree`](@ref). + +- `BestTree::Vector{NTuple{3,T₁}}`: A collection of node indices that belong in the best + basis tree. + +!!! note + `BestTree` parameter will contain all the nodes by default, ie. the full decomposition + will be the best tree. To get the SIWT best basis tree that produce the minimum cost, + one will neeed to call [`bestbasistree`](@ref). + +**See also:** [`ShiftInvariantWaveletTransformNode`](@ref) +""" +mutable struct ShiftInvariantWaveletTransformObject{N, T₁<:Integer, T₂<:AbstractFloat} + Nodes::Dict{NTuple{3,T₁}, Vector{ShiftInvariantWaveletTransformNode{N,T₁,T₂}}} + SignalSize::Union{T₁, NTuple{N,T₁}} MaxTransformLevel::T₁ Wavelet::OrthoFilter MinCost::Union{T₂, Nothing} - Tree::Union{Vector{NTuple{N₁,T₁}}, Nothing} + BestTree::Vector{NTuple{3,T₁}} end -struct ShiftInvariantWaveletTransformNode{N, T₁<:Integer, T₂<:AbstractFloat} - IsShiftedTransform::Bool +""" + ShiftInvariantWaveletTransformNode{N, T₁<:Integer, T₂<:AbstractFloat} + +Data structure to hold the index, coefficients, and cost value of an SIWT node. + +# Parameters +- `Depth::T₁`: The depth of current node. Root node has depth 0. +- `IndexAtDepth::T₁`: The node index at current depth. Index starts from 0 at each depth. +- `TransformShift::T₁`: The type of shift operated on the parent node before computing the + transform. Accepted type of shift values are: + - `0`: Coefficients of parent node are not shifted prior to transform. + - `1`: For both 1D and 2D signals, coefficients of parent node are circularly shifted to + the left by 1 index. + - `2`: For 2D signals, coefficients of parent node are circularly shifted from bottom to + top by 1 index. Not available for 1D signals. + - `3`: For 2D signals, coefficients of parent node are circularly shifted from bottom to + top and right to left by 1 index each. Not available for 1D signals. +- `Cost::T₂`: The [`LogEnergyEntropyCost`](@ref) of current node. +- `Value::Array{T₂,N}`: Coefficients of current node. +""" +mutable struct ShiftInvariantWaveletTransformNode{N, T₁<:Integer, T₂<:AbstractFloat} Depth::T₁ IndexAtDepth::T₁ - Cost::Union{T₂, Nothing} + TransformShift::T₁ + Cost::T₂ Value::Array{T₂,N} + + function ShiftInvariantWaveletTransformNode{N, T₁, T₂}(Depth, IndexAtDepth, TransformShift, Cost, Value) where {N, T₁<:Integer, T₂<:AbstractFloat} + valueDim = ndims(Value) + if valueDim == 1 + maxIndexAtDepth = 1< maxIndexAtDepth) | (TransformShift > maxTransformShift) + (throw ∘ ArgumentError)("Invalid IndexAtDepth or TransformShift for $(valueDim)D coefficients.") + end + return new(Depth, IndexAtDepth, TransformShift, Cost, Value) + end +end + +""" + ShiftInvariantWaveletTransformObject(signal, wavelet) + +Outer constructor and initialization of SIWT object. + +# Arguments +- `signal::Array{T} where T<:AbstractFloat`: Input signal. +- `wavelet::OrthoFilter`: Wavelet filter. +""" +function ShiftInvariantWaveletTransformObject(signal::Array{T}, wavelet::OrthoFilter) where T<:AbstractFloat + signalDim = ndims(signal) + signalSize = signalDim == 1 ? length(signal) : size(signal) + cost = coefcost(signal, LogEnergyEntropyCost()) + signalNode = ShiftInvariantWaveletTransformNode(0, 0, 0, cost, signal) + maxTransformLevel = 0 + index = (signalNode.Depth, signalNode.IndexAtDepth, signalNode.TransformShift) + S = typeof(signalNode.Depth) + nodes = Dict{NTuple{3,S}, Vector{ShiftInvariantWaveletTransformNode{signalDim,S,T}}}(index => signal) + tree = [index] + return ShiftInvariantWaveletTransformObject(nodes, signalSize, maxTransformLevel, wavelet, cost, tree) end -# TODO: Create constructors \ No newline at end of file +""" + ShiftInvariantWaveletTransformNode(data, depth, indexAtDepth, transformShift) + +Outer constructor of SIWT node. + +# Arguments +- `data::Array{T} where T<:AbstractFloat`: Array of coefficients. +- `depth::S where S<:Integer`: Depth of current node. +- `indexAtDepth::S where S<:Integer`: Node index at current depth. +- `transformShift::S where S<:Integer`: The type of shift operated on the parent node before + computing the transform. +""" +function ShiftInvariantWaveletTransformNode(data::Array{T}, + depth::S, + indexAtDepth::S, + transformShift::S) where {T<:AbstractFloat, S<:Integer} + cost = coefcost(data, LogEnergyEntropyCost()) + return ShiftInvariantWaveletTransformNode(depth, indexAtDepth, transformShift, cost, data) +end \ No newline at end of file From 2bb1071455f020a07e74488c02185dbf4cfe75ce Mon Sep 17 00:00:00 2001 From: Zeng Fung Liew Date: Sun, 22 May 2022 20:56:47 -0700 Subject: [PATCH 05/15] Build SIWPD decomposition on SIWT Object - Currently API is `siwpd2` - Manual testing passed for 1D signals - Files and functions need to be rearranged - Once unit tests are up and running, will delete previous implementation --- src/mod/SIWPD.jl | 56 +++++------ src/mod/siwt/siwt_one_level.jl | 56 ++++++++++- src/mod/siwt/siwt_utls.jl | 166 ++++++++++++++++++++++----------- 3 files changed, 188 insertions(+), 90 deletions(-) diff --git a/src/mod/SIWPD.jl b/src/mod/SIWPD.jl index d1c396f..1ef25dc 100644 --- a/src/mod/SIWPD.jl +++ b/src/mod/SIWPD.jl @@ -4,12 +4,17 @@ export makesiwpdtree using - Wavelets + Wavelets, + Parameters, + LinearAlgebra using ..Utils, - ..DWT, - ..BestBasis + ..DWT + +include("siwt/siwt_utls.jl") +include("siwt/siwt_one_level.jl") +include("bestbasis/bestbasis_costs.jl") """ siwpd(x, wt[, L=maxtransformlevels(x), d=L]) @@ -41,6 +46,21 @@ function siwpd(x::AbstractVector{T}, wt::OrthoFilter, return y end +function siwpd2(x::AbstractVector{T}, + wt::OrthoFilter, + L::S = maxtransformlevels(x), + d::S = L) where {T<:AbstractFloat, S<:Integer} + # Sanity check + @assert 0 ≤ L ≤ maxtransformlevels(x) + @assert 1 ≤ d ≤ L + + g, h = WT.makereverseqmfpair(wt, true) + siwtObj = ShiftInvariantWaveletTransformObject(x, wt, L, d) + rootNodeIndex = siwtObj.BestTree[1] + siwpd_subtree!(siwtObj, rootNodeIndex, h, g, d) + return siwtObj +end + #======================================================================================= Recursive calls for decomposition while d>0. If shift is 0, ie. ss=0, the main root will not be further decomposed even if d>0. @@ -87,36 +107,6 @@ function siwpd_subtree!(y::AbstractMatrix{T₁}, h::Vector{T₂}, g::Vector{T₂ return y end -# Single step of decomposition -function sidwt_step!(w₁::AbstractVector{T}, w₂::AbstractVector{T}, - v::AbstractVector{T}, - h::Vector{S}, g::Vector{S}, - s::Bool) where {T<:AbstractFloat, S<:AbstractFloat} - # Sanity check - @assert length(w₁) == length(w₂) == length(v)÷2 - @assert length(h) == length(g) - - # Setup - n = length(v) # Parent length - n₁ = length(w₁) # Child length - filtlen = length(h) # Filter length - - # One step of discrete transform - for i in 1:n₁ - k₁ = mod1(2*i-1-s, n) # Start index for low pass filtering - k₂ = 2*i-s # Start index for high pass filtering - @inbounds w₁[i] = g[end] * v[k₁] - @inbounds w₂[i] = h[1] * v[k₂] - for j in 2:filtlen - k₁ = k₁+1 |> k₁ -> k₁>n ? mod1(k₁,n) : k₁ - k₂ = k₂-1 |> k₂ -> k₂≤0 ? mod1(k₂,n) : k₂ - @inbounds w₁[i] += g[end-j+1] * v[k₁] - @inbounds w₂[i] += h[j] * v[k₂] - end - end - return w₁, w₂ -end - """ makesiwpdtree(n, L, d) diff --git a/src/mod/siwt/siwt_one_level.jl b/src/mod/siwt/siwt_one_level.jl index 3a6462f..0d219f4 100644 --- a/src/mod/siwt/siwt_one_level.jl +++ b/src/mod/siwt/siwt_one_level.jl @@ -1,3 +1,57 @@ -# TODO: Function to compute one level of decomposition +function sidwt_step!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}, + index::NTuple{3,T₁}, + h::Vector{T₃}, g::Vector{T₃}, + shiftedTransform::Bool) where + {N, T₁<:Integer, T₂<:AbstractFloat, T₃<:AbstractFloat} + nodeObj = siwtObj.Nodes[index] + nodeValue = nodeObj.Value + nodeLength = length(nodeValue) + nodeDepth, nodeIndexAtDepth, nodeTransformShift = index + + childLength = nodeLength ÷ 2 + child1Value = Vector{T₂}(undef, childLength) + child2Value = Vector{T₂}(undef, childLength) + sidwt_step!(child1Value, child2Value, nodeValue, h, g, shiftedTransform) + child1Obj = ShiftInvariantWaveletTransformNode(child1Value, nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift+(1< k₁ -> k₁>n ? mod1(k₁,n) : k₁ + k₂ = k₂-1 |> k₂ -> k₂≤0 ? mod1(k₂,n) : k₂ + @inbounds w₁[i] += g[end-j+1] * v[k₁] + @inbounds w₂[i] += h[j] * v[k₂] + end + end + return w₁, w₂ +end # TODO: Function to compute one level of recomposition \ No newline at end of file diff --git a/src/mod/siwt/siwt_utls.jl b/src/mod/siwt/siwt_utls.jl index 42c8012..36e7400 100644 --- a/src/mod/siwt/siwt_utls.jl +++ b/src/mod/siwt/siwt_utls.jl @@ -1,41 +1,3 @@ -""" - ShiftInvariantWaveletTransformObject{N, T₁<:Integer, T₂<:AbstractFloat} - -Data structure to hold all shift-invariant wavelet transform (SIWT) information of a signal. - -# Parameters -- `Nodes::Dict{NTuple{3,T₁}, Vector{ShiftInvariantWaveletTransformNode{N,T₁,T₂}}}`: - Dictionary containing the information of each node within a tree. -- `SignalSize::Union{T₁, NTuple{N,T₁}}`: Size of the original signal. -- `MaxTransformLevel::T₁`: Maximum levels of transform set for signal. -- `Wavelet::OrthoFilter`: A discrete wavelet for transform purposes. -- `MinCost::Union{T₂, Nothing}`: The current minimum [`LogEnergyEntropyCost`](@ref) cost of - the decomposition tree. - -!!! note - `MinCost` parameter will contain the cost of the root node by default. To compute the - true minimum cost of the decomposition tree, one will need to first compute the SIWT - best basis by calling [`bestbasistree`](@ref). - -- `BestTree::Vector{NTuple{3,T₁}}`: A collection of node indices that belong in the best - basis tree. - -!!! note - `BestTree` parameter will contain all the nodes by default, ie. the full decomposition - will be the best tree. To get the SIWT best basis tree that produce the minimum cost, - one will neeed to call [`bestbasistree`](@ref). - -**See also:** [`ShiftInvariantWaveletTransformNode`](@ref) -""" -mutable struct ShiftInvariantWaveletTransformObject{N, T₁<:Integer, T₂<:AbstractFloat} - Nodes::Dict{NTuple{3,T₁}, Vector{ShiftInvariantWaveletTransformNode{N,T₁,T₂}}} - SignalSize::Union{T₁, NTuple{N,T₁}} - MaxTransformLevel::T₁ - Wavelet::OrthoFilter - MinCost::Union{T₂, Nothing} - BestTree::Vector{NTuple{3,T₁}} -end - """ ShiftInvariantWaveletTransformNode{N, T₁<:Integer, T₂<:AbstractFloat} @@ -67,10 +29,11 @@ mutable struct ShiftInvariantWaveletTransformNode{N, T₁<:Integer, T₂<:Abstra valueDim = ndims(Value) if valueDim == 1 maxIndexAtDepth = 1< signal) - tree = [index] - return ShiftInvariantWaveletTransformObject(nodes, signalSize, maxTransformLevel, wavelet, cost, tree) +mutable struct ShiftInvariantWaveletTransformObject{N, T₁<:Integer, T₂<:AbstractFloat} + Nodes::Dict{NTuple{3,T₁}, ShiftInvariantWaveletTransformNode{N,T₁,T₂}} + SignalSize::Union{T₁, NTuple{N,T₁}} + MaxTransformLevel::T₁ + MaxShiftedTransformLevels::T₁ + Wavelet::OrthoFilter + MinCost::Union{T₂, Nothing} + BestTree::Vector{NTuple{3,T₁}} end """ @@ -121,5 +101,79 @@ function ShiftInvariantWaveletTransformNode(data::Array{T}, indexAtDepth::S, transformShift::S) where {T<:AbstractFloat, S<:Integer} cost = coefcost(data, LogEnergyEntropyCost()) - return ShiftInvariantWaveletTransformNode(depth, indexAtDepth, transformShift, cost, data) + N = ndims(data) + return ShiftInvariantWaveletTransformNode{N,S,T}(depth, indexAtDepth, transformShift, cost, data) +end + +""" + ShiftInvariantWaveletTransformObject(signal, wavelet) + +Outer constructor and initialization of SIWT object. + +# Arguments +- `signal::Array{T} where T<:AbstractFloat`: Input signal. +- `wavelet::OrthoFilter`: Wavelet filter. +""" +function ShiftInvariantWaveletTransformObject(signal::Array{T}, + wavelet::OrthoFilter, + maxTransformLevel::S = 0, + maxShiftedTransformLevel::S = 0) where + {T<:AbstractFloat, S<:Integer} + signalDim = ndims(signal) + signalSize = signalDim == 1 ? length(signal) : size(signal) + cost = coefcost(signal, LogEnergyEntropyCost()) + signalNode = ShiftInvariantWaveletTransformNode{signalDim,S,T}(0, 0, 0, cost, signal) + index = (signalNode.Depth, signalNode.IndexAtDepth, signalNode.TransformShift) + nodes = Dict{NTuple{3,S}, ShiftInvariantWaveletTransformNode{signalDim,S,T}}(index => signalNode) + tree = [index] + return ShiftInvariantWaveletTransformObject(nodes, signalSize, maxTransformLevel, maxShiftedTransformLevel, wavelet, cost, tree) +end + + +""" + siwpd_subtree(siwtObj, index, h, g, remainingRelativeDepth4ShiftedTransform) +""" +function siwpd_subtree!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}, + index::NTuple{3,T₁}, + h::Vector{T₃}, g::Vector{T₃}, + remainingRelativeDepth4ShiftedTransform::T₁) where + {N, T₁<:Integer, T₂<:AbstractFloat, T₃<:AbstractFloat} + treeMaxTransformLevel = siwtObj.MaxTransformLevel + nodeDepth, _, nodeTransformShift = index + @info "$index: $remainingRelativeDepth4ShiftedTransform" + + @assert 0 ≤ nodeDepth ≤ treeMaxTransformLevel + @assert 0 ≤ remainingRelativeDepth4ShiftedTransform ≤ treeMaxTransformLevel-nodeDepth + + # --- Base case --- + isLeafNode = (nodeDepth == treeMaxTransformLevel) + isShiftedTransform4NodeRequired = (remainingRelativeDepth4ShiftedTransform > 0) + isShiftedTransformNode = nodeTransformShift > 0 + isShiftedTransformLeafNode = (!isShiftedTransform4NodeRequired && isShiftedTransformNode) + if (isLeafNode || isShiftedTransformLeafNode) + return nothing + end + + # General step: + # - Decompose current node without additional shift + # - Decompose children nodes + childDepth = nodeDepth + 1 + (child1Index, child2Index) = sidwt_step!(siwtObj, index, h, g, false) + childRemainingRelativeDepth4ShiftedTransform = isShiftedTransformNode ? + remainingRelativeDepth4ShiftedTransform-1 : + min(remainingRelativeDepth4ShiftedTransform, treeMaxTransformLevel-childDepth) + siwpd_subtree!(siwtObj, child1Index, h, g, childRemainingRelativeDepth4ShiftedTransform) + siwpd_subtree!(siwtObj, child2Index, h, g, childRemainingRelativeDepth4ShiftedTransform) + + # Case: remainingRelativeDepth4ShiftedTransform > 0 + # - Decompose current node with additional shift + # - Decompose children (with additional shift) nodes + if isShiftedTransform4NodeRequired + (child1Index, child2Index) = sidwt_step!(siwtObj, index, h, g, true) + childRemainingRelativeDepth4ShiftedTransform = remainingRelativeDepth4ShiftedTransform-1 + siwpd_subtree!(siwtObj, child1Index, h, g, childRemainingRelativeDepth4ShiftedTransform) + siwpd_subtree!(siwtObj, child2Index, h, g, childRemainingRelativeDepth4ShiftedTransform) + end + + return nothing end \ No newline at end of file From af00a5995d0ab9086ddfba758dd86c9a3f9f4f5d Mon Sep 17 00:00:00 2001 From: Zeng Fung Liew Date: Tue, 24 May 2022 11:54:16 -0700 Subject: [PATCH 06/15] Added bestbasis functionality to SIWT Object, debugging required --- src/mod/SIWPD.jl | 1 + src/mod/siwt/siwt_bestbasis.jl | 88 ++++++++++++++++++++++++++++++++++ src/mod/siwt/siwt_utls.jl | 9 ++-- 3 files changed, 93 insertions(+), 5 deletions(-) create mode 100644 src/mod/siwt/siwt_bestbasis.jl diff --git a/src/mod/SIWPD.jl b/src/mod/SIWPD.jl index 1ef25dc..d1d9dd5 100644 --- a/src/mod/SIWPD.jl +++ b/src/mod/SIWPD.jl @@ -14,6 +14,7 @@ using include("siwt/siwt_utls.jl") include("siwt/siwt_one_level.jl") +include("siwt/siwt_bestbasis.jl") include("bestbasis/bestbasis_costs.jl") """ diff --git a/src/mod/siwt/siwt_bestbasis.jl b/src/mod/siwt/siwt_bestbasis.jl new file mode 100644 index 0000000..69aa559 --- /dev/null +++ b/src/mod/siwt/siwt_bestbasis.jl @@ -0,0 +1,88 @@ +function Wavelets.Threshold.bestbasistree(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}) where + {N, T₁<:Integer, T₂<:AbstractFloat} + rootNodeIndex = (0,0,0) + bestbasis_treeselection!(siwtObj, rootNodeIndex) + siwtObj.MinCost = siwtObj.Nodes[rootNodeIndex].Cost + return nothing +end + +function bestbasis_treeselection!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}, + index::NTuple{3,T₁}) where + {N, T₁<:Integer, T₂<:AbstractFloat} + # Base case: Check if tree contains desired index + if index ∉ siwtObj.BestTree + return nothing + end + + # Get the cost of current node and its children nodes + nodeDepth, nodeIndexAtDepth, nodeTransformShift = index + child1Index = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift) + child2Index = (nodeDepth+1, nodeIndexAtDepth<<1+1, nodeTransformShift) + shiftedChild1Index = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift+(1< x==index, siwtObj.BestTree)) + + # Delete children nodes + nodeDepth, nodeIndexAtDepth, nodeTransformShift = index + child1Index = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift) + child2Index = (nodeDepth+1, nodeIndexAtDepth<<1+1, nodeTransformShift) + shiftedChild1Index = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift+(1< signalNode) @@ -140,7 +140,6 @@ function siwpd_subtree!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T {N, T₁<:Integer, T₂<:AbstractFloat, T₃<:AbstractFloat} treeMaxTransformLevel = siwtObj.MaxTransformLevel nodeDepth, _, nodeTransformShift = index - @info "$index: $remainingRelativeDepth4ShiftedTransform" @assert 0 ≤ nodeDepth ≤ treeMaxTransformLevel @assert 0 ≤ remainingRelativeDepth4ShiftedTransform ≤ treeMaxTransformLevel-nodeDepth From 09631fe39820b090f1190d7e1ddc11eb0bf64e63 Mon Sep 17 00:00:00 2001 From: Zeng Fung Liew Date: Fri, 3 Jun 2022 23:35:37 -0700 Subject: [PATCH 07/15] Minor fixes and unit tests for SIWT objects/data structures --- src/mod/SIWPD.jl | 2 ++ src/mod/siwt/siwt_utls.jl | 19 ++++++++++---- test/runtests.jl | 14 +++++------ test/transforms.jl | 52 ++++++++++++++++++++++++--------------- 4 files changed, 55 insertions(+), 32 deletions(-) diff --git a/src/mod/SIWPD.jl b/src/mod/SIWPD.jl index d1d9dd5..aa4a2cf 100644 --- a/src/mod/SIWPD.jl +++ b/src/mod/SIWPD.jl @@ -1,5 +1,7 @@ module SIWPD export + ShiftInvariantWaveletTransformNode, + ShiftInvariantWaveletTransformObject, siwpd, makesiwpdtree diff --git a/src/mod/siwt/siwt_utls.jl b/src/mod/siwt/siwt_utls.jl index a96b426..14ae136 100644 --- a/src/mod/siwt/siwt_utls.jl +++ b/src/mod/siwt/siwt_utls.jl @@ -26,11 +26,14 @@ mutable struct ShiftInvariantWaveletTransformNode{N, T₁<:Integer, T₂<:Abstra Value::Array{T₂,N} function ShiftInvariantWaveletTransformNode{N, T₁, T₂}(Depth, IndexAtDepth, TransformShift, Cost, Value) where {N, T₁<:Integer, T₂<:AbstractFloat} - valueDim = ndims(Value) - if valueDim == 1 + if N ≠ ndims(Value) + (throw ∘ TypeError)("Value array is not of $N dimension.") + end + + if N == 1 maxIndexAtDepth = 1< maxIndexAtDepth) | (TransformShift > maxTransformShift) - (throw ∘ ArgumentError)("Invalid IndexAtDepth or TransformShift for $(valueDim)D coefficients.") + (throw ∘ ArgumentError)("Invalid IndexAtDepth or TransformShift for $(N)D coefficients.") end return new(Depth, IndexAtDepth, TransformShift, Cost, Value) end @@ -82,6 +85,12 @@ mutable struct ShiftInvariantWaveletTransformObject{N, T₁<:Integer, T₂<:Abst Wavelet::OrthoFilter MinCost::Union{T₂, Nothing} BestTree::Vector{NTuple{3,T₁}} + + function ShiftInvariantWaveletTransformObject{N,T₁,T₂}(nodes, signalSize, maxTransformLevel, maxShiftedTransformLevels, wt, minCost, bestTree) where {N, T₁<:Integer, T₂<:AbstractFloat} + 0 ≤ maxTransformLevel ≤ maxtransformlevels(nodes[(0,0,0)].Value) || (throw ∘ ArgumentError)("Provided MaxTransformLevels is too large.") + 0 ≤ maxShiftedTransformLevels < length(nodes[(0,0,0)].Value) || (throw ∘ ArgumentError)("Provided MaxShiftedTransformLevels is too large.") + return new(nodes, signalSize, maxTransformLevel, maxShiftedTransformLevels, wt, minCost, bestTree) + end end """ @@ -126,7 +135,7 @@ function ShiftInvariantWaveletTransformObject(signal::Array{T}, index = (signalNode.Depth, signalNode.IndexAtDepth, signalNode.TransformShift) nodes = Dict{NTuple{3,S}, ShiftInvariantWaveletTransformNode{signalDim,S,T}}(index => signalNode) tree = [index] - return ShiftInvariantWaveletTransformObject(nodes, signalSize, maxTransformLevel, maxShiftedTransformLevel, wavelet, cost, tree) + return ShiftInvariantWaveletTransformObject{signalDim,S,T}(nodes, signalSize, maxTransformLevel, maxShiftedTransformLevel, wavelet, cost, tree) end diff --git a/test/runtests.jl b/test/runtests.jl index 8395460..ee78eb5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,10 +12,10 @@ using WaveletsExt, SparseArrays -@testset "Utils" begin include("utils.jl") end -@testset "Transforms" begin include("transforms.jl") end -@testset "Wavelet Multiplication" begin include("wavemult.jl") end -@testset "Best Basis" begin include("bestbasis.jl") end -@testset "Denoising" begin include("denoising.jl") end -@testset "LDB" begin include("ldb.jl") end -@testset "Visualizations" begin include("visualizations.jl") end \ No newline at end of file +@testset verbose=true "Utils" begin include("utils.jl") end +@testset verbose=true "Transforms" begin include("transforms.jl") end +@testset verbose=true "Wavelet Multiplication" begin include("wavemult.jl") end +@testset verbose=true "Best Basis" begin include("bestbasis.jl") end +@testset verbose=true "Denoising" begin include("denoising.jl") end +@testset verbose=true "LDB" begin include("ldb.jl") end +@testset verbose=true "Visualizations" begin include("visualizations.jl") end \ No newline at end of file diff --git a/test/transforms.jl b/test/transforms.jl index 02a2c25..4a6b6ca 100644 --- a/test/transforms.jl +++ b/test/transforms.jl @@ -174,26 +174,38 @@ end @test iacwpd(acwpd(x, wt), tree) ≈ x end -@testset "SIWPD" begin - # siwpd - x = randn(4) - wt = wavelet(WT.haar) - y = siwpd(x, wt, 2, 1) - y0 = y[1,4:7] - y1 = y[2,4:7] - y2 = y[3,4:7] - y3 = y[4,4:7] - @test y0 == wpt(x, wt) - @test !all([isdefined(y1, i) for i in 1:4]) - @test y2 == wpt(circshift(x,2), wt) - @test !all([isdefined(y3, i) for i in 1:4]) - # make tree - tree = [ - trues(1), - repeat([trues(2)],2)..., - repeat([BitVector([1,0,1,0])],4)... - ] - @test makesiwpdtree(4, 2, 1) == tree +@testset verbose=true "SIWT" begin + @testset "Data structures" begin + signal = [2,3,-4,5.0]; + wt = wavelet(WT.haar); + cost = BestBasis.coefcost(signal, ShannonEntropyCost()); + @test isa(ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,0,signal), ShiftInvariantWaveletTransformNode); + @test_throws MethodError ShiftInvariantWaveletTransformNode{2,Int64,Float64}(0,0,0,0,signal); # signal dimension and N do not match + @test_throws ArgumentError ShiftInvariantWaveletTransformNode{1,Int64,Float64}(2,4,0,0,signal); # Invalid IndexAtDepth + @test_throws ArgumentError ShiftInvariantWaveletTransformNode{1,Int64,Float64}(2,0,4,0,signal); # Invalid TransformShift + @test ShiftInvariantWaveletTransformNode(signal,0,0,0).Depth == ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,cost,signal).Depth; + @test ShiftInvariantWaveletTransformNode(signal,0,0,0).IndexAtDepth == ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,cost,signal).IndexAtDepth; + @test ShiftInvariantWaveletTransformNode(signal,0,0,0).TransformShift == ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,cost,signal).TransformShift; + @test ShiftInvariantWaveletTransformNode(signal,0,0,0).Cost == ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,cost,signal).Cost; + @test ShiftInvariantWaveletTransformNode(signal,0,0,0).Value == ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,cost,signal).Value; + @test ShiftInvariantWaveletTransformNode(signal,1,0,0) != ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,cost,signal); + + maxTransformLevels = 2; + maxTransformShift = 3; + siwtObject = ShiftInvariantWaveletTransformObject(signal, wt); + @test isa(siwtObject, ShiftInvariantWaveletTransformObject); + @test isa(siwtObject.Nodes[(0,0,0)], ShiftInvariantWaveletTransformNode); + @test siwtObject.SignalSize == 4; + @test siwtObject.MaxTransformLevel == 0; + @test siwtObject.MaxShiftedTransformLevels == 0; + @test siwtObject.Wavelet == wt; + @test siwtObject.MinCost == cost; + @test siwtObject.BestTree == [(0,0,0)]; + @test_throws ArgumentError ShiftInvariantWaveletTransformObject(signal, wt, maxTransformLevels+1); + @test_throws ArgumentError ShiftInvariantWaveletTransformObject(signal, wt, -1); + @test_throws ArgumentError ShiftInvariantWaveletTransformObject(signal, wt, 0, maxTransformShift+1); + @test_throws ArgumentError ShiftInvariantWaveletTransformObject(signal, wt, 0, -1); + end end @testset "Transform All" begin From aec4a67cc1e82c8433825b6a5406807f663a16db Mon Sep 17 00:00:00 2001 From: Zeng Fung Liew Date: Sat, 4 Jun 2022 22:52:04 -0700 Subject: [PATCH 08/15] Fix bugs, add unit tests for SIWT, and commented out old code that are no longer useful --- src/WaveletsExt.jl | 2 +- src/mod/BestBasis.jl | 189 +++++++++++++------------- src/mod/SIWPD.jl | 203 +++++++++++++++------------- src/mod/bestbasis/bestbasis_tree.jl | 108 +++++++-------- src/mod/siwt/siwt_bestbasis.jl | 6 +- src/mod/siwt/siwt_one_level.jl | 7 +- src/mod/siwt/siwt_utls.jl | 16 ++- test/transforms.jl | 51 ++++++- 8 files changed, 324 insertions(+), 258 deletions(-) diff --git a/src/WaveletsExt.jl b/src/WaveletsExt.jl index 429e5dc..cffea03 100644 --- a/src/WaveletsExt.jl +++ b/src/WaveletsExt.jl @@ -4,9 +4,9 @@ module WaveletsExt include("mod/Utils.jl") include("mod/DWT.jl") -include("mod/SIWPD.jl") include("mod/ACWT.jl") include("mod/BestBasis.jl") +include("mod/SIWPD.jl") include("mod/SWT.jl") include("mod/Denoising.jl") include("mod/LDB.jl") diff --git a/src/mod/BestBasis.jl b/src/mod/BestBasis.jl index d0fc3f2..e5de635 100644 --- a/src/mod/BestBasis.jl +++ b/src/mod/BestBasis.jl @@ -17,7 +17,6 @@ export LSDB, JBB, BB, - SIBB, # best basis tree bestbasistreeall @@ -30,8 +29,7 @@ using using ..Utils, - ..DWT, - ..SIWPD + ..DWT include("bestbasis/bestbasis_costs.jl") include("bestbasis/bestbasis_tree.jl") @@ -112,70 +110,70 @@ function bestbasis_treeselection(costs::AbstractVector{T}, end # SIWPD tree selection -""" - bestbasis_treeselection(costs, tree) - -Best basis tree selection on SIWPD. - -# Arguments -- `costs::AbstractVector{Tc} where Tc<:AbstractVector{<:Union{Number, Nothing}}`: Cost of - each node. -- `tree::AbstractVector{Tt} where Tt<:BitVector`: SIWPD tree. - -!!! warning - Current implementation works but is unstable, ie. we are still working on better - syntax/more optimized computations/better data structure. -""" -function bestbasis_treeselection(costs::AbstractVector{Tc}, - tree::AbstractVector{Tt}) where - {Tc<:AbstractVector{<:Union{Number,Nothing}}, Tt<:BitVector} - - @assert length(costs) == length(tree) - bt = deepcopy(tree) - bc = deepcopy(costs) - nn = length(tree) - for i in reverse(eachindex(bt)) - if getchildindex(i,:left) > nn # current node is at bottom level - continue - end - level = floor(Int, log2(i)) - for j in eachindex(bt[i]) # iterate through all available shifts - if !bt[i][j] # node of current shift does not exist - continue - end - @assert bt[getchildindex(i,:left)][j] == - bt[getchildindex(i,:right)][j] == - bt[getchildindex(i,:left)][j+1< sum(node), bt) .<= 1) - return bt -end +# """ +# bestbasis_treeselection(costs, tree) + +# Best basis tree selection on SIWPD. + +# # Arguments +# - `costs::AbstractVector{Tc} where Tc<:AbstractVector{<:Union{Number, Nothing}}`: Cost of +# each node. +# - `tree::AbstractVector{Tt} where Tt<:BitVector`: SIWPD tree. + +# !!! warning +# Current implementation works but is unstable, ie. we are still working on better +# syntax/more optimized computations/better data structure. +# """ +# function bestbasis_treeselection(costs::AbstractVector{Tc}, +# tree::AbstractVector{Tt}) where +# {Tc<:AbstractVector{<:Union{Number,Nothing}}, Tt<:BitVector} + +# @assert length(costs) == length(tree) +# bt = deepcopy(tree) +# bc = deepcopy(costs) +# nn = length(tree) +# for i in reverse(eachindex(bt)) +# if getchildindex(i,:left) > nn # current node is at bottom level +# continue +# end +# level = floor(Int, log2(i)) +# for j in eachindex(bt[i]) # iterate through all available shifts +# if !bt[i][j] # node of current shift does not exist +# continue +# end +# @assert bt[getchildindex(i,:left)][j] == +# bt[getchildindex(i,:right)][j] == +# bt[getchildindex(i,:left)][j+1< sum(node), bt) .<= 1) +# return bt +# end # Deletes subtree due to inferior cost """ @@ -301,35 +299,36 @@ function Wavelets.Threshold.bestbasistree(X::AbstractArray{T}, method::BB) where end # SIWPD Best basis tree # TODO: find a way to compute bestbasis_tree without input d -""" - bestbasistree(y, d, method) - -Computes the best basis tree for the shift invariant wavelet packet decomposition (SIWPD). - -# Arguments -- `y::AbstractArray{T,2} where T<:Number`: A SIWPD decomposed signal. -- `d::Integer`: The number of depth computed for the decomposition. -- `method::SIBB`: The `SIBB()` method. - -# Returns -- `Vector{BitVector}`: SIWPD best basis tree. - -!!! warning - Current implementation works but is unstable, ie. we are still working on better - syntax/more optimized computations/better data structure. -""" -function Wavelets.Threshold.bestbasistree(y::AbstractArray{T,2}, - d::Integer, - method::SIBB) where T<:Number +# """ +# bestbasistree(y, d, method) + +# Computes the best basis tree for the shift invariant wavelet packet decomposition (SIWPD). + +# # Arguments +# - `y::AbstractArray{T,2} where T<:Number`: A SIWPD decomposed signal. +# - `d::Integer`: The number of depth computed for the decomposition. +# - `method::SIBB`: The `SIBB()` method. + +# # Returns +# - `Vector{BitVector}`: SIWPD best basis tree. + +# !!! warning +# Current implementation works but is unstable, ie. we are still working on better +# syntax/more optimized computations/better data structure. +# """ +# function Wavelets.Threshold.bestbasistree(y::AbstractArray{T,2}, +# d::Integer, +# method::SIBB) where T<:Number - nn = size(y,2) - L = maxtransformlevels((nn+1)÷2) - ns = size(y,1) - tree = makesiwpdtree(ns, L, d) - costs = tree_costs(y, tree, method) - besttree = bestbasis_treeselection(costs, tree) - return besttree -end +# nn = size(y,2) +# L = maxtransformlevels((nn+1)÷2) +# ns = size(y,1) +# tree = makesiwpdtree(ns, L, d) +# costs = tree_costs(y, tree, method) +# besttree = bestbasis_treeselection(costs, tree) +# return besttree +# end + # Default best basis tree search function Wavelets.Threshold.bestbasistree(X::AbstractArray{T}, method::BestBasisType = JBB()) where diff --git a/src/mod/SIWPD.jl b/src/mod/SIWPD.jl index aa4a2cf..42d3c2b 100644 --- a/src/mod/SIWPD.jl +++ b/src/mod/SIWPD.jl @@ -1,9 +1,9 @@ module SIWPD export + bestbasistree!, ShiftInvariantWaveletTransformNode, ShiftInvariantWaveletTransformObject, - siwpd, - makesiwpdtree + siwpd using Wavelets, @@ -14,42 +14,59 @@ using ..Utils, ..DWT +import ..BestBasis: coefcost, ShannonEntropyCost + include("siwt/siwt_utls.jl") include("siwt/siwt_one_level.jl") include("siwt/siwt_bestbasis.jl") -include("bestbasis/bestbasis_costs.jl") -""" - siwpd(x, wt[, L=maxtransformlevels(x), d=L]) +# """ +# siwpd(x, wt[, L=maxtransformlevels(x), d=L]) + +# Computes the Shift-Invariant Wavelet Packet Decomposition originally developed +# by Cohen, Raz & Malah on the vector `x` using the discrete wavelet filter `wt` +# for `L` levels with depth `d`. +# """ +# function siwpd(x::AbstractVector{T}, wt::OrthoFilter, +# L::S = maxtransformlevels(x), d::S = L) where {T<:AbstractFloat, S<:Integer} +# # Sanity check +# n = length(x) +# @assert 0 ≤ L ≤ maxtransformlevels(n) +# @assert 1 ≤ d ≤ L + +# # Setup +# g, h = WT.makereverseqmfpair(wt, true) +# # y = Matrix{T}(undef, (n, gettreelength(1<<(L+1)))) +# y = zeros(T, (n, gettreelength(1<<(L+1)))) +# @inbounds y[:,1] = x + +# # Decomposition +# for i in axes(y,2) +# lvl = getdepth(i, :binary) +# len = nodelength(n, lvl) +# dₗ = (0 ≤ lvl ≤ L-d) ? d : L-lvl +# siwpd_subtree!(y, h, g, i, dₗ, 0, L=lvl, len=len) +# end +# return y +# end -Computes the Shift-Invariant Wavelet Packet Decomposition originally developed -by Cohen, Raz & Malah on the vector `x` using the discrete wavelet filter `wt` -for `L` levels with depth `d`. """ -function siwpd(x::AbstractVector{T}, wt::OrthoFilter, - L::S = maxtransformlevels(x), d::S = L) where {T<:AbstractFloat, S<:Integer} - # Sanity check - n = length(x) - @assert 0 ≤ L ≤ maxtransformlevels(n) - @assert 1 ≤ d ≤ L + siwpd(x, wt[, L, d]) - # Setup - g, h = WT.makereverseqmfpair(wt, true) - # y = Matrix{T}(undef, (n, gettreelength(1<<(L+1)))) - y = zeros(T, (n, gettreelength(1<<(L+1)))) - @inbounds y[:,1] = x - - # Decomposition - for i in axes(y,2) - lvl = getdepth(i, :binary) - len = nodelength(n, lvl) - dₗ = (0 ≤ lvl ≤ L-d) ? d : L-lvl - siwpd_subtree!(y, h, g, i, dₗ, 0, L=lvl, len=len) - end - return y -end +Computes the Shift-Invariant Wavelet Packet Decomposition originally developed by Cohen, Raz +& Malah on the vector `x` using the discrete wavelet filter `wt` for `L` levels with depth +`d`. + +# Arguments +- `x::AbstractVector{T} where T<:AbstractFloat`: 1D-signal. +- `wt::OrthoFilter`: Wavelet filter. +- `L::S where S<:Integer`: (Default: `maxtransformlevels(x)`) Number of transform levels. +- `d::S where S<:Integer`: (Default: `L`) Depth of shifted transform for each node. -function siwpd2(x::AbstractVector{T}, +# Returns +- `ShiftInvariantWaveletTransformObject` containing node and tree details. +""" +function siwpd(x::AbstractVector{T}, wt::OrthoFilter, L::S = maxtransformlevels(x), d::S = L) where {T<:AbstractFloat, S<:Integer} @@ -68,69 +85,69 @@ end Recursive calls for decomposition while d>0. If shift is 0, ie. ss=0, the main root will not be further decomposed even if d>0. =======================================================================================# -function siwpd_subtree!(y::AbstractMatrix{T₁}, h::Vector{T₂}, g::Vector{T₂}, - i::S, d::S, ss::S; - L::S = getdepth(i, :binary), - len::S = nodelength(size(y,1), L)) where - {T₁<:AbstractFloat, T₂<:AbstractFloat, S<:Integer} - # Sanity check - n, k = size(y) - @assert 0 ≤ L ≤ maxtransformlevels(n) - @assert 0 ≤ ss < n - @assert 0 ≤ d ≤ getdepth(k,:binary)-L - @assert iseven(len) || len == 1 - - # --- Base case --- - d > 0 || return y - - # --- One level of decomposition --- - v_start = ss*len+1 # Start index for parent node - v_end = (ss+1)*len # End index for parent node - @inbounds v = @view y[v_start:v_end, i] # Parent node - w_start = ss*(len÷2)+1 # Start index for child node - w_end = w_start+(len÷2)-1 # End index for child node - @inbounds w₁ = @view y[w_start:w_end, getchildindex(i, :left)] # Left child node - @inbounds w₂ = @view y[w_start:w_end, getchildindex(i, :right)] # Right child node - sidwt_step!(w₁, w₂, v, h, g, false) # 1 step of non-shifted decomposition - w_start = (ss+1< 0 || return y + +# # --- One level of decomposition --- +# v_start = ss*len+1 # Start index for parent node +# v_end = (ss+1)*len # End index for parent node +# @inbounds v = @view y[v_start:v_end, i] # Parent node +# w_start = ss*(len÷2)+1 # Start index for child node +# w_end = w_start+(len÷2)-1 # End index for child node +# @inbounds w₁ = @view y[w_start:w_end, getchildindex(i, :left)] # Left child node +# @inbounds w₂ = @view y[w_start:w_end, getchildindex(i, :right)] # Right child node +# sidwt_step!(w₁, w₂, v, h, g, false) # 1 step of non-shifted decomposition +# w_start = (ss+1<0 --- - if ss > 0 # Decomposition of current shift - siwpd_subtree!(y, h, g, getchildindex(i, :left), d-1, ss, L=L+1, len=len÷2) - siwpd_subtree!(y, h, g, getchildindex(i, :right), d-1, ss, L=L+1, len=len÷2) - end - # Decomposition of shifted version - siwpd_subtree!(y, h, g, getchildindex(i, :left), d-1, ss+1<0 --- +# if ss > 0 # Decomposition of current shift +# siwpd_subtree!(y, h, g, getchildindex(i, :left), d-1, ss, L=L+1, len=len÷2) +# siwpd_subtree!(y, h, g, getchildindex(i, :right), d-1, ss, L=L+1, len=len÷2) +# end +# # Decomposition of shifted version +# siwpd_subtree!(y, h, g, getchildindex(i, :left), d-1, ss+1< 0 # - Decompose current node with additional shift @@ -179,8 +181,8 @@ function siwpd_subtree!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T if isShiftedTransform4NodeRequired (child1Index, child2Index) = sidwt_step!(siwtObj, index, h, g, true) childRemainingRelativeDepth4ShiftedTransform = remainingRelativeDepth4ShiftedTransform-1 - siwpd_subtree!(siwtObj, child1Index, h, g, childRemainingRelativeDepth4ShiftedTransform) - siwpd_subtree!(siwtObj, child2Index, h, g, childRemainingRelativeDepth4ShiftedTransform) + siwpd_subtree!(siwtObj, child1Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) + siwpd_subtree!(siwtObj, child2Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) end return nothing diff --git a/test/transforms.jl b/test/transforms.jl index 4a6b6ca..bcdf7a5 100644 --- a/test/transforms.jl +++ b/test/transforms.jl @@ -175,9 +175,9 @@ end end @testset verbose=true "SIWT" begin + signal = [2,3,-4,5.0]; + wt = wavelet(WT.haar); @testset "Data structures" begin - signal = [2,3,-4,5.0]; - wt = wavelet(WT.haar); cost = BestBasis.coefcost(signal, ShannonEntropyCost()); @test isa(ShiftInvariantWaveletTransformNode{1,Int64,Float64}(0,0,0,0,signal), ShiftInvariantWaveletTransformNode); @test_throws MethodError ShiftInvariantWaveletTransformNode{2,Int64,Float64}(0,0,0,0,signal); # signal dimension and N do not match @@ -206,6 +206,53 @@ end @test_throws ArgumentError ShiftInvariantWaveletTransformObject(signal, wt, 0, maxTransformShift+1); @test_throws ArgumentError ShiftInvariantWaveletTransformObject(signal, wt, 0, -1); end + + @testset "Transform" begin + expectedNodes = Dict{NTuple{3,Int64}, ShiftInvariantWaveletTransformNode{1,Int64,Float64}}(); + expectedNodes[(0,0,0)] = ShiftInvariantWaveletTransformNode(signal,0,0,0); + signalTransformedL1S0 = dwt(signal, wt, 1); + expectedNodes[(1,0,0)] = ShiftInvariantWaveletTransformNode(signalTransformedL1S0[1:2],1,0,0); + expectedNodes[(1,1,0)] = ShiftInvariantWaveletTransformNode(signalTransformedL1S0[3:4],1,1,0); + signalTransformedL1S1 = dwt(circshift(signal,1), wt, 1); + expectedNodes[(1,0,1)] = ShiftInvariantWaveletTransformNode(signalTransformedL1S1[1:2],1,0,1); + expectedNodes[(1,1,1)] = ShiftInvariantWaveletTransformNode(signalTransformedL1S1[3:4],1,1,1); + @test isa(siwpd(signal,wt,1,1), ShiftInvariantWaveletTransformObject) + end + + @testset "Best Basis" begin + # Best basis on original object without decomposition + siwtObject = ShiftInvariantWaveletTransformObject(signal, wt, 0, 0); + rootOnlyTree = [(0,0,0)]; + bestbasistree!(siwtObject); + @test siwtObject.BestTree == rootOnlyTree; + @test siwtObject.MinCost == siwtObject.Nodes[(0,0,0)].Cost; + + # Cost check before best basis operation + expectedNodeCost = Dict{NTuple{3,Int64}, Float64}( + (0,0,0) => 1.208, + (1,0,0) => 0.382, + (1,0,1) => 0.402, + (1,1,0) => 0.259, + (1,1,1) => 0.566 + ) + siwtObj = siwpd(signal, wt, 1); + for index in keys(expectedNodeCost) + @test expectedNodeCost[index] ≈ siwtObj.Nodes[index].Cost atol=1e-3; + end + + # Cost check after best basis operation + expectedNodeCost = Dict{NTuple{3,Int64}, Float64}( + (0,0,0) => 0.641, + (1,0,0) => 0.382, + (1,1,0) => 0.259 + ) + bestbasistree!(siwtObj) + @test (Set ∘ keys)(expectedNodeCost) == Set(siwtObj.BestTree) == (Set ∘ keys)(siwtObj.Nodes) + for index in keys(expectedNodeCost) + @test expectedNodeCost[index] ≈ siwtObj.Nodes[index].Cost atol=1e-3 + end + @test expectedNodeCost[(0,0,0)] ≈ siwtObj.MinCost atol=1e-3 + end end @testset "Transform All" begin From 827b9e368974e136b1d9090d58a5a20777880be1 Mon Sep 17 00:00:00 2001 From: Zeng Fung Liew Date: Sat, 4 Jun 2022 23:05:04 -0700 Subject: [PATCH 09/15] Remove unnecessary code from SIWT and BestBasis modules --- src/WaveletsExt.jl | 4 +- src/mod/BestBasis.jl | 119 ---------------------- src/mod/SIWPD.jl | 153 ---------------------------- src/mod/SIWT.jl | 54 ++++++++++ src/mod/bestbasis/bestbasis_tree.jl | 61 ----------- test/bestbasis.jl | 8 -- 6 files changed, 56 insertions(+), 343 deletions(-) delete mode 100644 src/mod/SIWPD.jl create mode 100644 src/mod/SIWT.jl diff --git a/src/WaveletsExt.jl b/src/WaveletsExt.jl index cffea03..3780e62 100644 --- a/src/WaveletsExt.jl +++ b/src/WaveletsExt.jl @@ -6,7 +6,7 @@ include("mod/Utils.jl") include("mod/DWT.jl") include("mod/ACWT.jl") include("mod/BestBasis.jl") -include("mod/SIWPD.jl") +include("mod/SIWT.jl") include("mod/SWT.jl") include("mod/Denoising.jl") include("mod/LDB.jl") @@ -20,7 +20,7 @@ using Reexport .Utils, .LDB, .SWT, - .SIWPD, + .SIWT, .ACWT, .Visualizations, .WaveMult diff --git a/src/mod/BestBasis.jl b/src/mod/BestBasis.jl index e5de635..dff18c1 100644 --- a/src/mod/BestBasis.jl +++ b/src/mod/BestBasis.jl @@ -109,72 +109,6 @@ function bestbasis_treeselection(costs::AbstractVector{T}, return tree end -# SIWPD tree selection -# """ -# bestbasis_treeselection(costs, tree) - -# Best basis tree selection on SIWPD. - -# # Arguments -# - `costs::AbstractVector{Tc} where Tc<:AbstractVector{<:Union{Number, Nothing}}`: Cost of -# each node. -# - `tree::AbstractVector{Tt} where Tt<:BitVector`: SIWPD tree. - -# !!! warning -# Current implementation works but is unstable, ie. we are still working on better -# syntax/more optimized computations/better data structure. -# """ -# function bestbasis_treeselection(costs::AbstractVector{Tc}, -# tree::AbstractVector{Tt}) where -# {Tc<:AbstractVector{<:Union{Number,Nothing}}, Tt<:BitVector} - -# @assert length(costs) == length(tree) -# bt = deepcopy(tree) -# bc = deepcopy(costs) -# nn = length(tree) -# for i in reverse(eachindex(bt)) -# if getchildindex(i,:left) > nn # current node is at bottom level -# continue -# end -# level = floor(Int, log2(i)) -# for j in eachindex(bt[i]) # iterate through all available shifts -# if !bt[i][j] # node of current shift does not exist -# continue -# end -# @assert bt[getchildindex(i,:left)][j] == -# bt[getchildindex(i,:right)][j] == -# bt[getchildindex(i,:left)][j+1< sum(node), bt) .<= 1) -# return bt -# end - # Deletes subtree due to inferior cost """ delete_subtree!(bt, i, tree_type) @@ -205,28 +139,6 @@ function delete_subtree!(bt::BitVector, i::Integer, tree_type::Symbol) return bt end -function delete_subtree!(bt::AbstractVector{BitVector}, i::Integer, j::Integer) - @assert 1 <= i <= length(bt) - level = floor(Int, log2(i)) - bt[i][j] = false - if (getchildindex(i,:left)) < length(bt) # current node can contain subtrees - if bt[getchildindex(i,:left)][j] # left child of current shift - delete_subtree!(bt, getchildindex(i,:left), j) - end - if bt[getchildindex(i,:right)][j] # right child of current shift - delete_subtree!(bt, getchildindex(i,:right), j) - end - if bt[getchildindex(i,:left)][j+1<0. -If shift is 0, ie. ss=0, the main root will not be further decomposed even if d>0. -=======================================================================================# -# function siwpd_subtree!(y::AbstractMatrix{T₁}, h::Vector{T₂}, g::Vector{T₂}, -# i::S, d::S, ss::S; -# L::S = getdepth(i, :binary), -# len::S = nodelength(size(y,1), L)) where -# {T₁<:AbstractFloat, T₂<:AbstractFloat, S<:Integer} -# # Sanity check -# n, k = size(y) -# @assert 0 ≤ L ≤ maxtransformlevels(n) -# @assert 0 ≤ ss < n -# @assert 0 ≤ d ≤ getdepth(k,:binary)-L -# @assert iseven(len) || len == 1 - -# # --- Base case --- -# d > 0 || return y - -# # --- One level of decomposition --- -# v_start = ss*len+1 # Start index for parent node -# v_end = (ss+1)*len # End index for parent node -# @inbounds v = @view y[v_start:v_end, i] # Parent node -# w_start = ss*(len÷2)+1 # Start index for child node -# w_end = w_start+(len÷2)-1 # End index for child node -# @inbounds w₁ = @view y[w_start:w_end, getchildindex(i, :left)] # Left child node -# @inbounds w₂ = @view y[w_start:w_end, getchildindex(i, :right)] # Right child node -# sidwt_step!(w₁, w₂, v, h, g, false) # 1 step of non-shifted decomposition -# w_start = (ss+1<0 --- -# if ss > 0 # Decomposition of current shift -# siwpd_subtree!(y, h, g, getchildindex(i, :left), d-1, ss, L=L+1, len=len÷2) -# siwpd_subtree!(y, h, g, getchildindex(i, :right), d-1, ss, L=L+1, len=len÷2) -# end -# # Decomposition of shifted version -# siwpd_subtree!(y, h, g, getchildindex(i, :left), d-1, ss+1< sum(node), bestbasistree(xw0, 4, SIBB())) -bt4 = map(node -> sum(node), bestbasistree(xw4, 4, SIBB())) -@test bt0 == bt4 - # misc @test_throws ArgumentError BestBasis.bestbasis_treeselection(randn(15), 8, :fail) @test_throws AssertionError BestBasis.bestbasis_treeselection(randn(7), 3, :fail) # true n=4 From c7f4566ded925835c876bd7d33c4db39cb24b221 Mon Sep 17 00:00:00 2001 From: Zeng Fung Liew Date: Sun, 5 Jun 2022 11:59:47 -0700 Subject: [PATCH 10/15] add check to validate best basis trees in SIWT --- src/mod/siwt/siwt_bestbasis.jl | 1 + src/mod/siwt/siwt_utls.jl | 27 +++++++++++++++++++++++++++ test/transforms.jl | 2 ++ 3 files changed, 30 insertions(+) diff --git a/src/mod/siwt/siwt_bestbasis.jl b/src/mod/siwt/siwt_bestbasis.jl index c279691..7ea4609 100644 --- a/src/mod/siwt/siwt_bestbasis.jl +++ b/src/mod/siwt/siwt_bestbasis.jl @@ -3,6 +3,7 @@ function bestbasistree!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T rootNodeIndex = (0,0,0) bestbasis_treeselection!(siwtObj, rootNodeIndex) siwtObj.MinCost = siwtObj.Nodes[rootNodeIndex].Cost + @assert isvalidtree(siwtObj) return siwtObj.BestTree end diff --git a/src/mod/siwt/siwt_utls.jl b/src/mod/siwt/siwt_utls.jl index 34f7328..68dda5d 100644 --- a/src/mod/siwt/siwt_utls.jl +++ b/src/mod/siwt/siwt_utls.jl @@ -186,4 +186,31 @@ function siwpd_subtree!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T end return nothing +end + +function Wavelets.Util.isvalidtree(siwtObj::ShiftInvariantWaveletTransformObject) + @assert (Set ∘ keys)(siwtObj.Nodes) == Set(siwtObj.BestTree) + + # Each node needs to: + # - Be a root node OR have a parent node + # - Have child nodes XOR have shifted child nodes XOR have no child nodes + nodeSet = Set(siwtObj.BestTree) + for index in nodeSet + depth, indexAtDepth, transformShift = index + + isRootNode = index == (0,0,0) + hasParentNode = (depth-1, indexAtDepth>>1, transformShift) ∈ nodeSet + + hasChildNodes = (depth+1, indexAtDepth<<1, transformShift) ∈ nodeSet && (depth+1, indexAtDepth<<1+1, transformShift) ∈ nodeSet + hasShiftedChildNodes = (depth+1, indexAtDepth<<1, transformShift+(1< Date: Sun, 5 Jun 2022 23:31:52 -0700 Subject: [PATCH 11/15] SIWT signal reconstruction functionality --- src/mod/SIWT.jl | 10 +++++++ src/mod/siwt/siwt_one_level.jl | 54 +++++++++++++++++++++++++++++++++- src/mod/siwt/siwt_utls.jl | 50 ++++++++++++++++++++++++++++++- 3 files changed, 112 insertions(+), 2 deletions(-) diff --git a/src/mod/SIWT.jl b/src/mod/SIWT.jl index 1d4645a..8a264e8 100644 --- a/src/mod/SIWT.jl +++ b/src/mod/SIWT.jl @@ -1,6 +1,7 @@ module SIWT export bestbasistree!, + isiwpd, ShiftInvariantWaveletTransformNode, ShiftInvariantWaveletTransformObject, siwpd @@ -51,4 +52,13 @@ function siwpd(x::AbstractVector{T}, return siwtObj end +function isiwpd(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}) where + {N, T₁<:Integer, T₂<:AbstractFloat} + g, h = WT.makereverseqmfpair(siwtObj.Wavelet, true) + rootNodeIndex = (0,0,0) + isiwpd_subtree!(siwtObj, rootNodeIndex, h, g) + + return siwtObj.Nodes[rootNodeIndex].Value +end + end # end module \ No newline at end of file diff --git a/src/mod/siwt/siwt_one_level.jl b/src/mod/siwt/siwt_one_level.jl index 7c399f6..84f95f5 100644 --- a/src/mod/siwt/siwt_one_level.jl +++ b/src/mod/siwt/siwt_one_level.jl @@ -55,4 +55,56 @@ function sidwt_step!(w₁::AbstractVector{T}, w₂::AbstractVector{T}, return w₁, w₂ end -# TODO: Function to compute one level of recomposition \ No newline at end of file +# TODO: Function to compute one level of recomposition +function isidwt_step!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}, + nodeIndex::NTuple{3,T₁}, + child1Index::NTuple{3,T₁}, child2Index::NTuple{3,T₁}, + h::Vector{T₃}, g::Vector{T₃}) where + {N, T₁<:Integer, T₂<:AbstractFloat, T₃<:AbstractFloat} + nodeObj = siwtObj.Nodes[nodeIndex] + child1Obj = siwtObj.Nodes[child1Index] + child2Obj = siwtObj.Nodes[child2Index] + + @assert child1Obj.TransformShift == child2Obj.TransformShift + isShiftedTransform = nodeObj.TransformShift == child1Obj.TransformShift + + nodeValue = nodeObj.Value + child1Value = child1Obj.Value + child2Value = child2Obj.Value + isidwt_step!(nodeValue, child1Value, child2Value, h, g, isShiftedTransform) + + return nothing +end + +function isidwt_step!(v::AbstractVector{T}, + w₁::AbstractVector{T}, w₂::AbstractVector{T}, + h::Array{S,1}, g::Array{S,1}, + s::Bool) where {T<:AbstractFloat, S<:AbstractFloat} + # Sanity check + @assert length(w₁) == length(w₂) == length(v)÷2 + @assert length(h) == length(g) + + # Setup + n = length(v) # Parent length + n₁ = length(w₁) # Child length + filtlen = length(h) # Filter length + + # One step of inverse discrete transform + for i in 1:n + ℓ = mod1(i-s,n) # Index of reconstructed vector + j₀ = mod1(i,2) # Pivot point to determine start index for filter + j₁ = filtlen-j₀+1 # Index for low pass filter g + j₂ = mod1(i+1,2) # Index for high pass filter h + k₁ = (i+1)>>1 # Index for approx coefs w₁ + k₂ = (i+1)>>1 # Index for detail coefs w₂ + @inbounds v[ℓ] = g[j₁] * w₁[k₁] + h[j₂] * w₂[k₂] + for j in (j₀+2):2:filtlen + j₁ = filtlen-j+1 + j₂ = j + isodd(j) - iseven(j) + k₁ = k₁-1 |> k₁ -> k₁≤0 ? mod1(k₁,n₁) : k₁ + k₂ = k₂+1 |> k₂ -> k₂>n₁ ? mod1(k₂,n₁) : k₂ + @inbounds v[ℓ] += g[j₁] * w₁[k₁] + h[j₂] * w₂[k₂] + end + end + return v +end \ No newline at end of file diff --git a/src/mod/siwt/siwt_utls.jl b/src/mod/siwt/siwt_utls.jl index 68dda5d..b3631f3 100644 --- a/src/mod/siwt/siwt_utls.jl +++ b/src/mod/siwt/siwt_utls.jl @@ -164,7 +164,7 @@ function siwpd_subtree!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T return nothing end - # General step: + # --- General step --- # - Decompose current node without additional shift # - Decompose children nodes childDepth = nodeDepth + 1 @@ -188,6 +188,54 @@ function siwpd_subtree!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T return nothing end +""" + isiwpd_subtree!(siwtObj, index, h, g) +""" +function isiwpd_subtree!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}, + index::NTuple{3,T₁}, + h::Vector{T₃}, g::Vector{T₃}) where + {N, T₁<:Integer, T₂<:AbstractFloat, T₃<:AbstractFloat} + # Check for children nodes + nodeDepth, nodeIndexAtDepth, nodeTransformShift = index + hasNonShiftedChildren = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift) ∈ siwtObj.BestTree + hasShiftedChildren = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift+(1< Date: Mon, 6 Jun 2022 22:42:22 -0700 Subject: [PATCH 12/15] Reorganize codes and update API documentation. --- docs/make.jl | 2 +- docs/src/api/siwpd.md | 11 --- docs/src/api/siwt.md | 49 ++++++++++++ src/mod/SIWT.jl | 131 ++++++++++++++++++++++++++++++ src/mod/siwt/siwt_bestbasis.jl | 47 +++++------ src/mod/siwt/siwt_one_level.jl | 73 ++++++++++++++++- src/mod/siwt/siwt_utls.jl | 142 ++++++++++++--------------------- test/transforms.jl | 8 ++ 8 files changed, 333 insertions(+), 130 deletions(-) delete mode 100644 docs/src/api/siwpd.md create mode 100644 docs/src/api/siwt.md diff --git a/docs/make.jl b/docs/make.jl index ad850fd..39231b5 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -20,7 +20,7 @@ makedocs( "DWT" => "api/dwt.md", "ACWT" => "api/acwt.md", "SWT" => "api/swt.md", - "SIWPD" => "api/siwpd.md", + "SIWT" => "api/siwt.md", "WaveMult" => "api/wavemult.md", "Best Basis" => "api/bestbasis.md", "Denoising" => "api/denoising.md", diff --git a/docs/src/api/siwpd.md b/docs/src/api/siwpd.md deleted file mode 100644 index 1cc911a..0000000 --- a/docs/src/api/siwpd.md +++ /dev/null @@ -1,11 +0,0 @@ -# Shift Invariant Wavelet Packet Decomposition - -```@index -Modules = [SIWPD] -``` - -## Public API -```@autodocs -Modules = [SIWPD] -Private = false -``` diff --git a/docs/src/api/siwt.md b/docs/src/api/siwt.md new file mode 100644 index 0000000..23a6971 --- /dev/null +++ b/docs/src/api/siwt.md @@ -0,0 +1,49 @@ +# Shift Invariant Wavelet Packet Decomposition + +```@index +Modules = [SIWT] +``` + +## Data Structures +```@docs +SIWT.ShiftInvariantWaveletTransformNode +SIWT.ShiftInvariantWaveletTransformObject +``` + +## Signal Transform and Reconstruction +### Public API +```@docs +SIWT.siwpd +SIWT.isiwpd +``` + +### Private API +```@docs +SIWT.siwpd_subtree! +SIWT.isiwpd_subtree! +``` + +## Best Basis Search +### Public API +```@docs +SIWT.bestbasistree! +``` + +### Private API +```@docs +SIWT.bestbasis_treeselection! +``` + +## Single Step Transforms +### Private API +```@docs +SIWT.sidwt_step! +SIWT.isidwt_step! +``` + +## Other Utils +### Private API +```@docs +Wavelets.Util.isvalidtree +SIWT.delete_node! +``` \ No newline at end of file diff --git a/src/mod/SIWT.jl b/src/mod/SIWT.jl index 8a264e8..7a858e4 100644 --- a/src/mod/SIWT.jl +++ b/src/mod/SIWT.jl @@ -52,6 +52,83 @@ function siwpd(x::AbstractVector{T}, return siwtObj end +""" + siwpd_subtree!(siwtObj, index, h, g, remainingRelativeDepth4ShiftedTransform[; signalNorm]) + +Runs the recursive computation of Shift-Invariant Wavelet Transform (SIWT) at each node +`index`. + +# Arguments +- `siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂} where + {N,T₁<:Integer,T₂<:AbstractFloat`: SIWT object. +- `index::NTuple{3,T₁} where T₁<:Integer`: Index of current node to be decomposed. +- `h::Vector{T₃} where T₃<:AbstractFloat`: High pass filter. +- `g::Vector{T₃} where T₃<:AbstractFloat`: Low pass filter. +- `remainingRelativeDepth4ShiftedTransform::T₁ where T₁<:Integer`: Remaining relative depth + for shifted transform. + +# Keyword Arguments +- `signalNorm::T₂ where T₂<:AbstractFloat`: (Default: `norm(siwtObj.Nodes[(0,0,0)].Value)`) + Signal Euclidean-norm. +""" +function siwpd_subtree!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}, + index::NTuple{3,T₁}, + h::Vector{T₃}, g::Vector{T₃}, + remainingRelativeDepth4ShiftedTransform::T₁; + signalNorm::T₂ = norm(siwtObj.Nodes[(0,0,0)].Value)) where + {N, T₁<:Integer, T₂<:AbstractFloat, T₃<:AbstractFloat} + treeMaxTransformLevel = siwtObj.MaxTransformLevel + nodeDepth, _, nodeTransformShift = index + + @assert 0 ≤ nodeDepth ≤ treeMaxTransformLevel + @assert 0 ≤ remainingRelativeDepth4ShiftedTransform ≤ treeMaxTransformLevel-nodeDepth + + # --- Base case --- + isLeafNode = (nodeDepth == treeMaxTransformLevel) + isShiftedTransform4NodeRequired = (remainingRelativeDepth4ShiftedTransform > 0) + isShiftedTransformNode = nodeTransformShift > 0 + isShiftedTransformLeafNode = (!isShiftedTransform4NodeRequired && isShiftedTransformNode) + if (isLeafNode || isShiftedTransformLeafNode) + return nothing + end + + # --- General step --- + # - Decompose current node without additional shift + # - Decompose children nodes + childDepth = nodeDepth + 1 + (child1Index, child2Index) = sidwt_step!(siwtObj, index, h, g, false) + childRemainingRelativeDepth4ShiftedTransform = isShiftedTransformNode ? + remainingRelativeDepth4ShiftedTransform-1 : + min(remainingRelativeDepth4ShiftedTransform, treeMaxTransformLevel-childDepth) + siwpd_subtree!(siwtObj, child1Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) + siwpd_subtree!(siwtObj, child2Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) + + # Case: remainingRelativeDepth4ShiftedTransform > 0 + # - Decompose current node with additional shift + # - Decompose children (with additional shift) nodes + if isShiftedTransform4NodeRequired + (child1Index, child2Index) = sidwt_step!(siwtObj, index, h, g, true) + childRemainingRelativeDepth4ShiftedTransform = remainingRelativeDepth4ShiftedTransform-1 + siwpd_subtree!(siwtObj, child1Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) + siwpd_subtree!(siwtObj, child2Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) + end + + return nothing +end + +""" + isiwpd(siwtObj) + +Computes the Inverse Shift-Invariant Wavelet Packet Decomposition originally developed by +Cohen, Raz & Malah. + +# Arguments +- `siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂} where + {N,T₁<:Integer,T₂<:AbstractFloat`: SIWT object. + +# Returns +- `Vector{T₂}`: Reconstructed signal. +""" function isiwpd(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}) where {N, T₁<:Integer, T₂<:AbstractFloat} g, h = WT.makereverseqmfpair(siwtObj.Wavelet, true) @@ -61,4 +138,58 @@ function isiwpd(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}) wher return siwtObj.Nodes[rootNodeIndex].Value end +""" + isiwpd_subtree!(siwtObj, index, h, g) + +Runs the recursive computation of Inverse Shift-Invariant Wavelet Transform (SIWT) at each +node `index`. + +# Arguments +- `siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂} where + {N,T₁<:Integer,T₂<:AbstractFloat`: SIWT object. +- `index::NTuple{3,T₁} where T₁<:Integer`: Index of current node to be decomposed. +- `h::Vector{T₃} where T₃<:AbstractFloat`: High pass filter. +- `g::Vector{T₃} where T₃<:AbstractFloat`: Low pass filter. +""" +function isiwpd_subtree!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}, + index::NTuple{3,T₁}, + h::Vector{T₃}, g::Vector{T₃}) where + {N, T₁<:Integer, T₂<:AbstractFloat, T₃<:AbstractFloat} + # Check for children nodes + nodeDepth, nodeIndexAtDepth, nodeTransformShift = index + hasNonShiftedChildren = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift) ∈ siwtObj.BestTree + hasShiftedChildren = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift+(1< x==index, siwtObj.BestTree)) - - # Delete children nodes - nodeDepth, nodeIndexAtDepth, nodeTransformShift = index - child1Index = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift) - child2Index = (nodeDepth+1, nodeIndexAtDepth<<1+1, nodeTransformShift) - shiftedChild1Index = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift+(1< 0) - isShiftedTransformNode = nodeTransformShift > 0 - isShiftedTransformLeafNode = (!isShiftedTransform4NodeRequired && isShiftedTransformNode) - if (isLeafNode || isShiftedTransformLeafNode) - return nothing - end - - # --- General step --- - # - Decompose current node without additional shift - # - Decompose children nodes - childDepth = nodeDepth + 1 - (child1Index, child2Index) = sidwt_step!(siwtObj, index, h, g, false) - childRemainingRelativeDepth4ShiftedTransform = isShiftedTransformNode ? - remainingRelativeDepth4ShiftedTransform-1 : - min(remainingRelativeDepth4ShiftedTransform, treeMaxTransformLevel-childDepth) - siwpd_subtree!(siwtObj, child1Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) - siwpd_subtree!(siwtObj, child2Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) - - # Case: remainingRelativeDepth4ShiftedTransform > 0 - # - Decompose current node with additional shift - # - Decompose children (with additional shift) nodes - if isShiftedTransform4NodeRequired - (child1Index, child2Index) = sidwt_step!(siwtObj, index, h, g, true) - childRemainingRelativeDepth4ShiftedTransform = remainingRelativeDepth4ShiftedTransform-1 - siwpd_subtree!(siwtObj, child1Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) - siwpd_subtree!(siwtObj, child2Index, h, g, childRemainingRelativeDepth4ShiftedTransform, signalNorm=signalNorm) - end + Wavelets.Util.isvalidtree(siwtObj) - return nothing -end +Checks if tree within SIWT object is a valid tree. -""" - isiwpd_subtree!(siwtObj, index, h, g) -""" -function isiwpd_subtree!(siwtObj::ShiftInvariantWaveletTransformObject{N,T₁,T₂}, - index::NTuple{3,T₁}, - h::Vector{T₃}, g::Vector{T₃}) where - {N, T₁<:Integer, T₂<:AbstractFloat, T₃<:AbstractFloat} - # Check for children nodes - nodeDepth, nodeIndexAtDepth, nodeTransformShift = index - hasNonShiftedChildren = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift) ∈ siwtObj.BestTree - hasShiftedChildren = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift+(1< x==index, siwtObj.BestTree)) + + # Delete children nodes + nodeDepth, nodeIndexAtDepth, nodeTransformShift = index + child1Index = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift) + child2Index = (nodeDepth+1, nodeIndexAtDepth<<1+1, nodeTransformShift) + shiftedChild1Index = (nodeDepth+1, nodeIndexAtDepth<<1, nodeTransformShift+(1< Date: Mon, 6 Jun 2022 22:55:30 -0700 Subject: [PATCH 13/15] Update manual documentation. --- docs/src/api/siwt.md | 2 +- docs/src/manual/bestbasis.md | 22 ++++++++++++---------- docs/src/manual/transforms.md | 13 +++++++------ 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/docs/src/api/siwt.md b/docs/src/api/siwt.md index 23a6971..f9fcb45 100644 --- a/docs/src/api/siwt.md +++ b/docs/src/api/siwt.md @@ -1,4 +1,4 @@ -# Shift Invariant Wavelet Packet Decomposition +# [Shift Invariant Wavelet Packet Decomposition](@id siwt_api) ```@index Modules = [SIWT] diff --git a/docs/src/manual/bestbasis.md b/docs/src/manual/bestbasis.md index cf58702..da63060 100644 --- a/docs/src/manual/bestbasis.md +++ b/docs/src/manual/bestbasis.md @@ -126,13 +126,15 @@ plot(p1, p2, p3, p4, p5, p6, layout=(3,2)) ## [Best Basis of Shift-Invariant Wavelet Packet Decomposition](@id si_bestbasis) One can think of searching for the best basis of the shift-invariant wavelet packet decomposition as a problem of finding ``\min_{b \in B} \sum_{x \in X} M_x(b)``, where ``X`` is all the possible shifted versions of an original signal ``y``. One can compute the best basis tree as follows: -```@example wt -xw = siwpd(x, wt) - -# SIBB -tree = bestbasistree(xw, 7, SIBB()); -nothing #hide -``` - -!!! warning - SIWPD is still undergoing large changes in terms of data structures and efficiency improvements. Syntax changes may occur in the next patch updates. +```@repl +x = [2,3,-4,5.0]; +wt = wavelet(WT.haar); +xwObj = siwpd(x, wt, 1, 1); +xwObj.BestTree # Original tree (all decomposed nodes) + +bestbasistree!(xwObj) +xwObj.BestTree # Best basis tree + +x̂ = isiwpd(xwObj) # Reconstruction of signal +x̂ == x +``` \ No newline at end of file diff --git a/docs/src/manual/transforms.md b/docs/src/manual/transforms.md index 1a34593..8187e8a 100644 --- a/docs/src/manual/transforms.md +++ b/docs/src/manual/transforms.md @@ -142,14 +142,15 @@ plot(plot(img, title="Original"), p0, p1, p2, layout=(2,2)) The [Shift-Invariant Wavelet Decomposition (SIWPD)](https://israelcohen.com/wp-content/uploads/2018/05/ICASSP95.pdf) is developed by Cohen et. al.. While it is also a type of redundant transform, it does not follow the same methodology as the SWT and the ACWT. Cohen's main goal for developing this algorithm was to obtain a global minimum entropy from a signal and all its shifted versions. See [its best basis implementation](@ref si_bestbasis) for more information. One can compute the SIWPD of a single signal as follows. -```@example wt -# decomposition -xw = siwpd(x, wt); -nothing # hide +```@repl +x = [2,3,-4,5.0]; +wt = wavelet(WT.haar); +xwObj = siwpd(x, wt, 1, 1); +xwObj.Nodes[(0,0,0)] # Example of looking into a node +xwObj.BestTree # Looking into what's within the tree structure ``` +For more information on the API, visit the [SIWT API page](@ref siwt_api). -!!! note - As of right now, there is not too many functions written based on the SIWPD, as it does not follow the conventional style of wavelet transforms. There is a lot of ongoing work to develop more functions catered for the SIWPD such as it's inverse transforms and group-implementations. From 446a0b5d67de88667a87d4fa770c97d826c08655 Mon Sep 17 00:00:00 2001 From: Zeng Fung Liew Date: Mon, 6 Jun 2022 23:10:51 -0700 Subject: [PATCH 14/15] Remove old code from best basis documentation --- docs/src/api/bestbasis.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/src/api/bestbasis.md b/docs/src/api/bestbasis.md index 2d323f4..dec370a 100644 --- a/docs/src/api/bestbasis.md +++ b/docs/src/api/bestbasis.md @@ -17,13 +17,11 @@ BestBasis.ShannonEntropyCost BestBasis.LogEnergyEntropyCost BestBasis.coefcost BestBasis.tree_costs -BestBasis.tree_costs(::AbstractMatrix{T}, ::AbstractVector{BitVector}, ::SIBB) where T<:Number ``` # Best Basis Tree Selection ```@docs BestBasis.bestbasis_treeselection -BestBasis.bestbasis_treeselection(::AbstractVector{Tc}, ::AbstractVector{Tt}) where {Tc<:AbstractVector{<:Union{Number,Nothing}}, Tt<:BitVector} BestBasis.delete_subtree! ``` @@ -33,8 +31,6 @@ BestBasis.BestBasisType BestBasis.LSDB BestBasis.JBB BestBasis.BB -BestBasis.SIBB Wavelets.Threshold.bestbasistree -Wavelets.Threshold.bestbasistree(::AbstractMatrix{T}, ::Integer, ::SIBB) where T<:Number BestBasis.bestbasistreeall ``` From 7473e66d3f8878e6f4b4618970ff8481ee2669f3 Mon Sep 17 00:00:00 2001 From: Zeng Fung Liew Date: Mon, 6 Jun 2022 23:42:13 -0700 Subject: [PATCH 15/15] Bug fixes on documentation --- docs/src/api/siwt.md | 2 +- docs/src/api/utils.md | 2 +- src/mod/BestBasis.jl | 3 +-- src/mod/bestbasis/bestbasis_tree.jl | 7 +++---- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/docs/src/api/siwt.md b/docs/src/api/siwt.md index f9fcb45..3fa68fc 100644 --- a/docs/src/api/siwt.md +++ b/docs/src/api/siwt.md @@ -44,6 +44,6 @@ SIWT.isidwt_step! ## Other Utils ### Private API ```@docs -Wavelets.Util.isvalidtree +Wavelets.Util.isvalidtree(::ShiftInvariantWaveletTransformObject) SIWT.delete_node! ``` \ No newline at end of file diff --git a/docs/src/api/utils.md b/docs/src/api/utils.md index 083fe6e..0cc8a18 100644 --- a/docs/src/api/utils.md +++ b/docs/src/api/utils.md @@ -18,7 +18,7 @@ Utils.finestdetailrange ## Tree traversing functions ```@docs -Wavelets.Util.isvalidtree +Wavelets.Util.isvalidtree(::AbstractMatrix,::BitVector) Wavelets.Util.maketree Utils.getchildindex Utils.getparentindex diff --git a/src/mod/BestBasis.jl b/src/mod/BestBasis.jl index dff18c1..64bd55d 100644 --- a/src/mod/BestBasis.jl +++ b/src/mod/BestBasis.jl @@ -147,8 +147,7 @@ end Extension to the best basis tree function from Wavelets.jl. Given a set of decomposed signals, returns different types of best basis trees based on the methods specified. Available methods are the joint best basis ([`JBB`](@ref)), least statistically dependent -basis ([`LSDB`](@ref)), individual regular best basis ([`BB`](@ref)), and shift-invariant -best basis ([`SIBB`](@ref)). +basis ([`LSDB`](@ref)), and individual regular best basis ([`BB`](@ref)). # Arguments - `X::AbstractArray{T} where T<:AbstractFloat`: A set of decomposed signals, of sizes diff --git a/src/mod/bestbasis/bestbasis_tree.jl b/src/mod/bestbasis/bestbasis_tree.jl index 81ad8fc..a28dda2 100644 --- a/src/mod/bestbasis/bestbasis_tree.jl +++ b/src/mod/bestbasis/bestbasis_tree.jl @@ -6,7 +6,6 @@ Abstract type for best basis. Current available types are: - [`LSDB`](@ref) - [`JBB`](@ref) - [`BB`](@ref) -- [`SIBB`](@ref) """ abstract type BestBasisType end @@ -21,7 +20,7 @@ Least Statistically Dependent Basis (LSDB). redundant. Set `redundant=true` when running LSDB with redundant wavelet transforms such as SWT or ACWT. -**See also:** [`BestBasisType`](@ref), [`JBB`](@ref), [`BB`](@ref), [`SIBB`](@ref) +**See also:** [`BestBasisType`](@ref), [`JBB`](@ref), [`BB`](@ref) """ @with_kw struct LSDB <: BestBasisType cost::LSDBCost = DifferentialEntropyCost() @@ -39,7 +38,7 @@ Joint Best Basis (JBB). redundant. Set `redundant=true` when running LSDB with redundant wavelet transforms such as SWT or ACWT. -**See also:** [`BestBasisType`](@ref), [`LSDB`](@ref), [`BB`](@ref), [`SIBB`](@ref) +**See also:** [`BestBasisType`](@ref), [`LSDB`](@ref), [`BB`](@ref) """ @with_kw struct JBB <: BestBasisType # Joint Best Basis cost::JBBCost = LoglpCost(2) @@ -57,7 +56,7 @@ Standard Best Basis (BB). redundant. Set `redundant=true` when running LSDB with redundant wavelet transforms such as SWT or ACWT. -**See also:** [`BestBasisType`](@ref), [`LSDB`](@ref), [`JBB`](@ref), [`SIBB`](@ref) +**See also:** [`BestBasisType`](@ref), [`LSDB`](@ref), [`JBB`](@ref) """ @with_kw struct BB <: BestBasisType # Individual Best Basis cost::BBCost = ShannonEntropyCost()