Skip to content

Commit

Permalink
Wavelet Multiplication + Nonstandard transforms (#49)
Browse files Browse the repository at this point in the history
* Reorganize source files

Organize source code of the same purpose into a subdirectory. Eg:
```
src/ACWT.jl                     src/ACWT.jl
src/acwt_all.jl         ->      src/acwt/acwt_all.jl
src/acwt_one_level.jl           src/acwt/acwt_one_level.jl
src/acwt_utils.jl               src/acwt/acwt_utils.jl
```

* Add nonstandard transform functionalities.

* Update function documentation for nonstandard transforms

- Update function documentations
- Add `ns_idwt`, `nonstd_wavemult` to be exported from module.

* Added standard form wavelet multiplication
- Standard form wavelet multiplication.
- Documentation for `WaveMult` module.

* Add API documentation page into docs/make.jl

* Fix broadcasting of logical operators in v1.6

* Bug fixes in WaveMult module
- Added parentheses to condition evaluation in `stretchmatrix` function. This improves the
  condition evaluation process that was previously buggy and produces the wrong logical expressions.
- Added a line in `mat2sparseform_nonstd` function. Previous implementation did not produce
  the right answer.

* Added documentation to demonstrate the usage of WaveMult module.

* Add additional assert for stretchmatrix.

* Add unit tests and fix bugs that arose from it.

* Syntax change to deal with v1.6 bug
In both `mat2sparse_nonstd` and `mat2sparse_std` functions, the line
```
maxcolnorm = (maximum ∘ mapslices)(norm, Mw, dims = 1)
```
is changed to
```
maxcolnorm = mapslices(norm, Mw, dims = 1) |> maximum
```

* Change @test_warn to @test_logs for unit tests in v1.6
  • Loading branch information
zengfung authored Feb 5, 2022
1 parent f0ac240 commit 228188e
Show file tree
Hide file tree
Showing 34 changed files with 888 additions and 19 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a"
Expand Down
4 changes: 4 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"
Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a"
WaveletsExt = "8f464e1e-25db-479f-b0a5-b7680379e03f"

[compat]
BenchmarkTools = "1.2"
Documenter = "0.27.6"
Images = "0.24, 0.25"
Plots = "1.21.3"
Expand Down
4 changes: 3 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ makedocs(
"Transforms" => "manual/transforms.md",
"Best Basis" => "manual/bestbasis.md",
"Denoising" => "manual/denoising.md",
"Local Discriminant Basis" => "manual/localdiscriminantbasis.md"
"Local Discriminant Basis" => "manual/localdiscriminantbasis.md",
"Wavelets and Fast Numerical Algorithms" => "manual/wavemult.md"
],
"API" => Any[
"DWT" => "api/dwt.md",
"ACWT" => "api/acwt.md",
"SWT" => "api/swt.md",
"SIWPD" => "api/siwpd.md",
"WaveMult" => "api/wavemult.md",
"Best Basis" => "api/bestbasis.md",
"Denoising" => "api/denoising.md",
"LDB" => "api/ldb.md",
Expand Down
30 changes: 30 additions & 0 deletions docs/src/api/wavemult.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
```@index
Modules = [WaveMult]
```

## Wavelet Multiplication
```@docs
WaveMult.nonstd_wavemult
WaveMult.std_wavemult
```

## Matrix to Sparse Format
```@docs
WaveMult.mat2sparseform_nonstd
WaveMult.mat2sparseform_std
```

## Fast Transform Algorithms
```@docs
WaveMult.ns_dwt
WaveMult.ns_idwt
WaveMult.sft
WaveMult.isft
```

## Miscellaneous
```@docs
WaveMult.dyadlength
WaveMult.ndyad
WaveMult.stretchmatrix
```
84 changes: 84 additions & 0 deletions docs/src/manual/wavemult.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# [Fast Numerical Algorithms using Wavelet Transforms](@id wavemult_manual)
This demonstration is based on the paper *[Fast wavelet transforms and numerical algorithms](https://onlinelibrary.wiley.com/doi/abs/10.1002/cpa.3160440202)* by G. Beylkin, R. Coifman, and V. Rokhlin and the lecture notes *[Wavelets and Fast Numerical ALgorithms](https://arxiv.org/abs/comp-gas/9304004)* by G. Beylkin.

Assume we have a matrix ``M \in \mathbb{R}^{n \times n}``, and a vector ``x \in \mathbb{R}^n``, there are two ways to compute ``y = Mx``:

1) Use the regular matrix multiplication, which takes ``O(n^2)`` time.
2) Transform both ``M`` and ``x`` to their standard/nonstandard forms ``\tilde M`` and ``\tilde x`` respectively. Compute the multiplication of ``\tilde y = \tilde M \tilde x``, then find the inverse of ``\tilde y``.If the matrix ``\tilde M`` is sparse, this can take ``O(n)`` time.

## Examples and Comparisons
In our following examples, we compare the methods described in the previous section. We focus on two important metrics: time taken and accuracy of the algorithm.

We start off by setting up the matrix `M`, the vector `x`, and the wavelet `wt`. As this algorithm is more efficient when ``n`` is large, we set ``n=2048``.
```@example wavemult
using Wavelets, WaveletsExt, Random, LinearAlgebra, BenchmarkTools
# Construct matrix M, vector x, and wavelet wt
function data_setup(n::Integer, random_state = 1234)
M = zeros(n,n)
for i in axes(M,1)
for j in axes(M,2)
M[i,j] = i==j ? 0 : 1/abs(i-j)
end
end
Random.seed!(random_state)
x = randn(n)
return M, x
end
M, x = data_setup(2048)
wt = wavelet(WT.haar)
nothing # hide
```

The following code block computes the regular matrix multiplication.
```@example wavemult
y₀ = M * x
nothing # hide
```

The following code block computes the nonstandard wavelet multiplication. Note that we will only be running 4 levels of wavelet transform, as running too many levels result in large computation time, without much improvement in accuracy.
```@example wavemult
L = 4
NM = mat2sparseform_nonstd(M, wt, L)
y₁ = nonstd_wavemult(NM, x, wt, L)
nothing # hide
```

The following code block computes the standard wavelet multiplication. Once again, we will only be running 4 levels of wavelet transform here.
```@example wavemult
SM = mat2sparseform_std(M, wt, L)
y₂ = std_wavemult(SM, x, wt, L)
nothing # hide
```

!!! tip "Performance Tip"
Running more levels of transform (i.e. setting large values of `L`) does not necessarily result in more accurate approximations. Setting `L` to be a reasonably large value (e.g. 3 or 4) not only reduces computational time, but could also produce better results.

We assume that the regular matrix multiplication produces the true results (bar some negligible machine errors), and we compare our results from the nonstandard and standard wavelet multiplications based on their relative errors (``\frac{||\hat y - y_0||_2}{||y_0||_2}``).
```@repl wavemult
relativenorm(y₁, y₀) # Comparing nonstandard wavelet multiplication with true value
relativenorm(y₂, y₀) # Comparing standard wavelet multiplication with true value
```

Lastly, we compare the computation time for each algorithm.
```@repl wavemult
@benchmark $M * $x # Benchmark regular matrix multiplication
@benchmark nonstd_wavemult($NM, $x, $wt, L) # Benchmark nonstandard wavelet multiplication
@benchmark std_wavemult($SM, $x, $wt, L) # Benchmark standard wavelet multiplication
```

As we can see above, the nonstandard and standard wavelet multiplication methods are significantly faster than the regular method and provides reasonable approximations to the true values.

!!! tip "Performance Tips"
This method should only be used when the dimensions of ``M`` and ``x`` are large. Additionally, this would be more useful when one aims to compute ``Mx_0, Mx_1, \ldots, Mx_k`` where ``M`` remains constant throughout the ``k`` computations. In this case, it is more efficient to use
```julia
NM = mat2sparseform_nonstd(M, wt)
y = nonstd_wavemult(NM, x, wt)
```
instead of
```julia
y = nonstd_wavemult(M, x, wt)
```
4 changes: 3 additions & 1 deletion src/WaveletsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ include("mod/SWT.jl")
include("mod/Denoising.jl")
include("mod/LDB.jl")
include("mod/Visualizations.jl")
include("mod/WaveMult.jl")

using Reexport
@reexport using .DWT,
Expand All @@ -21,6 +22,7 @@ using Reexport
.SWT,
.SIWPD,
.ACWT,
.Visualizations
.Visualizations,
.WaveMult

end
6 changes: 3 additions & 3 deletions src/mod/ACWT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1006,8 +1006,8 @@ end
# return (v₀ + v₁) / √2
# end

include("acwt_utils.jl")
include("acwt_one_level.jl")
include("acwt_all.jl")
include("acwt/acwt_utils.jl")
include("acwt/acwt_one_level.jl")
include("acwt/acwt_all.jl")

end # end module
4 changes: 2 additions & 2 deletions src/mod/BestBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ using
..DWT,
..SIWPD

include("bestbasis_costs.jl")
include("bestbasis_tree.jl")
include("bestbasis/bestbasis_costs.jl")
include("bestbasis/bestbasis_tree.jl")

## ----- BEST TREE SELECTION -----
# Tree selection for 1D signals
Expand Down
4 changes: 2 additions & 2 deletions src/mod/DWT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ function Wavelets.Transforms.iwpt!(x̂::AbstractArray{T,2},
return
end

include("dwt_one_level.jl")
include("dwt_all.jl")
include("dwt/dwt_one_level.jl")
include("dwt/dwt_all.jl")

end # end module
4 changes: 2 additions & 2 deletions src/mod/LDB.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ using ..Utils,
..DWT
import ..BestBasis: bestbasis_treeselection

include("ldb_energymap.jl")
include("ldb_measures.jl")
include("ldb/ldb_energymap.jl")
include("ldb/ldb_measures.jl")

## LOCAL DISCRIMINANT BASIS
"""
Expand Down
4 changes: 2 additions & 2 deletions src/mod/SWT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ function iswpd!(x::AbstractMatrix{T},
return x
end

include("swt_one_level.jl")
include("swt_all.jl")
include("swt/swt_one_level.jl")
include("swt/swt_all.jl")

end # end module
6 changes: 3 additions & 3 deletions src/mod/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,8 @@ function getcolrange(n::Integer, idx::T) where T<:Integer
end
end

include("utils_tree.jl")
include("utils_metrics.jl")
include("utils_dataset.jl")
include("utils/utils_tree.jl")
include("utils/utils_metrics.jl")
include("utils/utils_dataset.jl")

end # end module
19 changes: 19 additions & 0 deletions src/mod/WaveMult.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module WaveMult
export mat2sparseform_std,
mat2sparseform_nonstd,
ns_dwt,
ns_idwt,
std_wavemult,
nonstd_wavemult

using Wavelets,
LinearAlgebra,
SparseArrays

import ..DWT: dwt_step!, idwt_step!

include("wavemult/utils.jl")
include("wavemult/mat2sparse.jl")
include("wavemult/wavemult.jl")
include("wavemult/transforms.jl")
end # module
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
100 changes: 100 additions & 0 deletions src/mod/wavemult/mat2sparse.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
mat2sparseform_nonstd(M, wt, [L], [ϵ])
Transform the matrix `M` into the wavelet basis. Then, it is stretched into its nonstandard
form. Elements exceeding `ϵ * maximum column norm` are set to zero. The resulting output
sparse matrix, and can be used as input to `nonstd_wavemult`.
# Arguments
- `M::AbstractMatrix{T} where T<:AbstractFloat`: `n` by `n` matrix (`n` dyadic) to be put in
Sparse Nonstandard form.
- `wt::OrthoFilter`: Wavelet filter.
- `L::Integer`: (Default: `maxtransformlevels(M)`) Number of decomposition levels.
- `ϵ::T where T<:AbstractFloat`: (Default: `1e-4`) Truncation Criterion.
# Returns
- `NM::SparseMatrixCSC{T, Integer}`: Sparse nonstandard form of matrix of size 2n x 2n.
# Examples
```jldoctest
julia> using Wavelets, WaveletsExt; import Random: seed!
julia> seed!(1234); M = randn(4,4); wt = wavelet(WT.haar);
julia> mat2sparseform_nonstd(M, wt)
8×8 SparseArrays.SparseMatrixCSC{Float64, Int64} with 16 stored entries:
1.88685 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ 0.363656 ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 2.49634 -1.08139 ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ -1.0187 0.539411
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1.68141 0.0351839
⋅ ⋅ ⋅ ⋅ -1.39713 -1.21352 0.552745 0.427717
⋅ ⋅ ⋅ ⋅ -1.05882 0.16666 -0.124156 -0.218902
```
**See also:** [`mat2sparseform_std`](@ref), [`stretchmatrix`](@ref)
"""
function mat2sparseform_nonstd(M::AbstractMatrix{T},
wt::OrthoFilter,
L::Integer = maxtransformlevels(M),
ϵ::T = 1e-4) where T<:AbstractFloat
@assert size(M,1) == size(M,2)
n = size(M, 1)
Mw = dwt(M, wt, L) # 2D wavelet transform
# Find maximum column norm of Mw
maxcolnorm = mapslices(norm, Mw, dims = 1) |> maximum
nilMw = Mw .* (abs.(Mw) .> ϵ * maxcolnorm) # "Remove" values close to zero
nz_vals = nilMw[nilMw .≠ 0] # Extract all nonzero values
nz_indx = findall(!iszero, nilMw) # Extract indices of nonzero values
i = getindex.(nz_indx, 1)
j = getindex.(nz_indx, 2)
ie, je = stretchmatrix(i, j, n, L) # Stretch matrix Mw
NM = sparse(ie, je, nz_vals, 2*n, 2*n) # Convert to sparse matrix
return NM
end

"""
mat2sparseform_std(M, wt, [L], [ϵ])
Transform the matrix `M` into the standard form. Then, elements exceeding `ϵ * maximum
column norm` are set to zero. The resulting output sparse matrix, and can be used as input
to `std_wavemult`.
# Arguments
- `M::AbstractMatrix{T} where T<:AbstractFloat`: Matrix to be put in Sparse Standard form.
- `wt::OrthoFilter`: Wavelet filter.
- `L::Integer`: (Default: `maxtransformlevels(M)`) Number of decomposition levels.
- `ϵ::T where T<:AbstractFloat`: (Default: `1e-4`) Truncation Criterion.
# Returns
- `SM::SparseMatrixCSC{T, Integer}`: Sparse standard form of matrix of size ``n \times n``.
# Examples
```jldoctest
julia> using Wavelets, WaveletsExt; import Random: seed!
julia> seed!(1234); M = randn(4,4); wt = wavelet(WT.haar);
julia> mat2sparseform_std(M, wt)
4×4 SparseArrays.SparseMatrixCSC{Float64, Int64} with 16 stored entries:
1.88685 0.363656 0.468602 0.4063
2.49634 -1.08139 1.90927 -0.356542
-1.84601 0.129829 0.552745 0.427717
-0.630852 0.866545 -0.124156 -0.218902
```
**See also:** [`mat2sparseform_nonstd`](@ref), [`sft`](@ref)
"""
function mat2sparseform_std(M::AbstractMatrix{T},
wt::OrthoFilter,
L::Integer = maxtransformlevels(M),
ϵ::T = 1e-4) where T<:AbstractFloat
@assert size(M,1) == size(M,2)
Mw = sft(M, wt, L) # Transform to standard form
# Find maximum column norm of Mw
maxcolnorm = mapslices(norm, Mw, dims = 1) |> maximum
nilMw = Mw .* (abs.(Mw) .> ϵ * maxcolnorm) # Remove values close to zero
SM = sparse(nilMw) # Convert to sparse matrix
return SM
end
Loading

2 comments on commit 228188e

@zengfung
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: "Tag with name v0.1.17 already exists and points to a different commit"

Please sign in to comment.