Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wavelet Multiplication + Nonstandard transforms #49

Merged
merged 20 commits into from
Feb 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a"
Expand Down
4 changes: 4 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"
Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a"
WaveletsExt = "8f464e1e-25db-479f-b0a5-b7680379e03f"

[compat]
BenchmarkTools = "1.2"
Documenter = "0.27.6"
Images = "0.24, 0.25"
Plots = "1.21.3"
Expand Down
4 changes: 3 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ makedocs(
"Transforms" => "manual/transforms.md",
"Best Basis" => "manual/bestbasis.md",
"Denoising" => "manual/denoising.md",
"Local Discriminant Basis" => "manual/localdiscriminantbasis.md"
"Local Discriminant Basis" => "manual/localdiscriminantbasis.md",
"Wavelets and Fast Numerical Algorithms" => "manual/wavemult.md"
],
"API" => Any[
"DWT" => "api/dwt.md",
"ACWT" => "api/acwt.md",
"SWT" => "api/swt.md",
"SIWPD" => "api/siwpd.md",
"WaveMult" => "api/wavemult.md",
"Best Basis" => "api/bestbasis.md",
"Denoising" => "api/denoising.md",
"LDB" => "api/ldb.md",
Expand Down
30 changes: 30 additions & 0 deletions docs/src/api/wavemult.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
```@index
Modules = [WaveMult]
```

## Wavelet Multiplication
```@docs
WaveMult.nonstd_wavemult
WaveMult.std_wavemult
```

## Matrix to Sparse Format
```@docs
WaveMult.mat2sparseform_nonstd
WaveMult.mat2sparseform_std
```

## Fast Transform Algorithms
```@docs
WaveMult.ns_dwt
WaveMult.ns_idwt
WaveMult.sft
WaveMult.isft
```

## Miscellaneous
```@docs
WaveMult.dyadlength
WaveMult.ndyad
WaveMult.stretchmatrix
```
84 changes: 84 additions & 0 deletions docs/src/manual/wavemult.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# [Fast Numerical Algorithms using Wavelet Transforms](@id wavemult_manual)
This demonstration is based on the paper *[Fast wavelet transforms and numerical algorithms](https://onlinelibrary.wiley.com/doi/abs/10.1002/cpa.3160440202)* by G. Beylkin, R. Coifman, and V. Rokhlin and the lecture notes *[Wavelets and Fast Numerical ALgorithms](https://arxiv.org/abs/comp-gas/9304004)* by G. Beylkin.

Assume we have a matrix ``M \in \mathbb{R}^{n \times n}``, and a vector ``x \in \mathbb{R}^n``, there are two ways to compute ``y = Mx``:

1) Use the regular matrix multiplication, which takes ``O(n^2)`` time.
2) Transform both ``M`` and ``x`` to their standard/nonstandard forms ``\tilde M`` and ``\tilde x`` respectively. Compute the multiplication of ``\tilde y = \tilde M \tilde x``, then find the inverse of ``\tilde y``.If the matrix ``\tilde M`` is sparse, this can take ``O(n)`` time.

## Examples and Comparisons
In our following examples, we compare the methods described in the previous section. We focus on two important metrics: time taken and accuracy of the algorithm.

We start off by setting up the matrix `M`, the vector `x`, and the wavelet `wt`. As this algorithm is more efficient when ``n`` is large, we set ``n=2048``.
```@example wavemult
using Wavelets, WaveletsExt, Random, LinearAlgebra, BenchmarkTools

# Construct matrix M, vector x, and wavelet wt
function data_setup(n::Integer, random_state = 1234)
M = zeros(n,n)
for i in axes(M,1)
for j in axes(M,2)
M[i,j] = i==j ? 0 : 1/abs(i-j)
end
end

Random.seed!(random_state)
x = randn(n)

return M, x
end

M, x = data_setup(2048)
wt = wavelet(WT.haar)
nothing # hide
```

The following code block computes the regular matrix multiplication.
```@example wavemult
y₀ = M * x
nothing # hide
```

The following code block computes the nonstandard wavelet multiplication. Note that we will only be running 4 levels of wavelet transform, as running too many levels result in large computation time, without much improvement in accuracy.
```@example wavemult
L = 4
NM = mat2sparseform_nonstd(M, wt, L)
y₁ = nonstd_wavemult(NM, x, wt, L)
nothing # hide
```

The following code block computes the standard wavelet multiplication. Once again, we will only be running 4 levels of wavelet transform here.
```@example wavemult
SM = mat2sparseform_std(M, wt, L)
y₂ = std_wavemult(SM, x, wt, L)
nothing # hide
```

!!! tip "Performance Tip"
Running more levels of transform (i.e. setting large values of `L`) does not necessarily result in more accurate approximations. Setting `L` to be a reasonably large value (e.g. 3 or 4) not only reduces computational time, but could also produce better results.

We assume that the regular matrix multiplication produces the true results (bar some negligible machine errors), and we compare our results from the nonstandard and standard wavelet multiplications based on their relative errors (``\frac{||\hat y - y_0||_2}{||y_0||_2}``).
```@repl wavemult
relativenorm(y₁, y₀) # Comparing nonstandard wavelet multiplication with true value
relativenorm(y₂, y₀) # Comparing standard wavelet multiplication with true value
```

Lastly, we compare the computation time for each algorithm.
```@repl wavemult
@benchmark $M * $x # Benchmark regular matrix multiplication
@benchmark nonstd_wavemult($NM, $x, $wt, L) # Benchmark nonstandard wavelet multiplication
@benchmark std_wavemult($SM, $x, $wt, L) # Benchmark standard wavelet multiplication
```

As we can see above, the nonstandard and standard wavelet multiplication methods are significantly faster than the regular method and provides reasonable approximations to the true values.

!!! tip "Performance Tips"
This method should only be used when the dimensions of ``M`` and ``x`` are large. Additionally, this would be more useful when one aims to compute ``Mx_0, Mx_1, \ldots, Mx_k`` where ``M`` remains constant throughout the ``k`` computations. In this case, it is more efficient to use
```julia
NM = mat2sparseform_nonstd(M, wt)
y = nonstd_wavemult(NM, x, wt)
```
instead of
```julia
y = nonstd_wavemult(M, x, wt)
```
4 changes: 3 additions & 1 deletion src/WaveletsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ include("mod/SWT.jl")
include("mod/Denoising.jl")
include("mod/LDB.jl")
include("mod/Visualizations.jl")
include("mod/WaveMult.jl")

using Reexport
@reexport using .DWT,
Expand All @@ -21,6 +22,7 @@ using Reexport
.SWT,
.SIWPD,
.ACWT,
.Visualizations
.Visualizations,
.WaveMult

end
6 changes: 3 additions & 3 deletions src/mod/ACWT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1006,8 +1006,8 @@ end
# return (v₀ + v₁) / √2
# end

include("acwt_utils.jl")
include("acwt_one_level.jl")
include("acwt_all.jl")
include("acwt/acwt_utils.jl")
include("acwt/acwt_one_level.jl")
include("acwt/acwt_all.jl")

end # end module
4 changes: 2 additions & 2 deletions src/mod/BestBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ using
..DWT,
..SIWPD

include("bestbasis_costs.jl")
include("bestbasis_tree.jl")
include("bestbasis/bestbasis_costs.jl")
include("bestbasis/bestbasis_tree.jl")

## ----- BEST TREE SELECTION -----
# Tree selection for 1D signals
Expand Down
4 changes: 2 additions & 2 deletions src/mod/DWT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ function Wavelets.Transforms.iwpt!(x̂::AbstractArray{T,2},
return x̂
end

include("dwt_one_level.jl")
include("dwt_all.jl")
include("dwt/dwt_one_level.jl")
include("dwt/dwt_all.jl")

end # end module
4 changes: 2 additions & 2 deletions src/mod/LDB.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ using ..Utils,
..DWT
import ..BestBasis: bestbasis_treeselection

include("ldb_energymap.jl")
include("ldb_measures.jl")
include("ldb/ldb_energymap.jl")
include("ldb/ldb_measures.jl")

## LOCAL DISCRIMINANT BASIS
"""
Expand Down
4 changes: 2 additions & 2 deletions src/mod/SWT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ function iswpd!(x::AbstractMatrix{T},
return x
end

include("swt_one_level.jl")
include("swt_all.jl")
include("swt/swt_one_level.jl")
include("swt/swt_all.jl")

end # end module
6 changes: 3 additions & 3 deletions src/mod/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,8 @@ function getcolrange(n::Integer, idx::T) where T<:Integer
end
end

include("utils_tree.jl")
include("utils_metrics.jl")
include("utils_dataset.jl")
include("utils/utils_tree.jl")
include("utils/utils_metrics.jl")
include("utils/utils_dataset.jl")

end # end module
19 changes: 19 additions & 0 deletions src/mod/WaveMult.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module WaveMult
export mat2sparseform_std,
mat2sparseform_nonstd,
ns_dwt,
ns_idwt,
std_wavemult,
nonstd_wavemult

using Wavelets,
LinearAlgebra,
SparseArrays

import ..DWT: dwt_step!, idwt_step!

include("wavemult/utils.jl")
include("wavemult/mat2sparse.jl")
include("wavemult/wavemult.jl")
include("wavemult/transforms.jl")
end # module
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
100 changes: 100 additions & 0 deletions src/mod/wavemult/mat2sparse.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
mat2sparseform_nonstd(M, wt, [L], [ϵ])

Transform the matrix `M` into the wavelet basis. Then, it is stretched into its nonstandard
form. Elements exceeding `ϵ * maximum column norm` are set to zero. The resulting output
sparse matrix, and can be used as input to `nonstd_wavemult`.

# Arguments
- `M::AbstractMatrix{T} where T<:AbstractFloat`: `n` by `n` matrix (`n` dyadic) to be put in
Sparse Nonstandard form.
- `wt::OrthoFilter`: Wavelet filter.
- `L::Integer`: (Default: `maxtransformlevels(M)`) Number of decomposition levels.
- `ϵ::T where T<:AbstractFloat`: (Default: `1e-4`) Truncation Criterion.

# Returns
- `NM::SparseMatrixCSC{T, Integer}`: Sparse nonstandard form of matrix of size 2n x 2n.

# Examples
```jldoctest
julia> using Wavelets, WaveletsExt; import Random: seed!

julia> seed!(1234); M = randn(4,4); wt = wavelet(WT.haar);

julia> mat2sparseform_nonstd(M, wt)
8×8 SparseArrays.SparseMatrixCSC{Float64, Int64} with 16 stored entries:
1.88685 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ 0.363656 ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 2.49634 -1.08139 ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ -1.0187 0.539411
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1.68141 0.0351839
⋅ ⋅ ⋅ ⋅ -1.39713 -1.21352 0.552745 0.427717
⋅ ⋅ ⋅ ⋅ -1.05882 0.16666 -0.124156 -0.218902
```

**See also:** [`mat2sparseform_std`](@ref), [`stretchmatrix`](@ref)
"""
function mat2sparseform_nonstd(M::AbstractMatrix{T},
wt::OrthoFilter,
L::Integer = maxtransformlevels(M),
ϵ::T = 1e-4) where T<:AbstractFloat
@assert size(M,1) == size(M,2)
n = size(M, 1)
Mw = dwt(M, wt, L) # 2D wavelet transform
# Find maximum column norm of Mw
maxcolnorm = mapslices(norm, Mw, dims = 1) |> maximum
nilMw = Mw .* (abs.(Mw) .> ϵ * maxcolnorm) # "Remove" values close to zero
nz_vals = nilMw[nilMw .≠ 0] # Extract all nonzero values
nz_indx = findall(!iszero, nilMw) # Extract indices of nonzero values
i = getindex.(nz_indx, 1)
j = getindex.(nz_indx, 2)
ie, je = stretchmatrix(i, j, n, L) # Stretch matrix Mw
NM = sparse(ie, je, nz_vals, 2*n, 2*n) # Convert to sparse matrix
return NM
end

"""
mat2sparseform_std(M, wt, [L], [ϵ])

Transform the matrix `M` into the standard form. Then, elements exceeding `ϵ * maximum
column norm` are set to zero. The resulting output sparse matrix, and can be used as input
to `std_wavemult`.

# Arguments
- `M::AbstractMatrix{T} where T<:AbstractFloat`: Matrix to be put in Sparse Standard form.
- `wt::OrthoFilter`: Wavelet filter.
- `L::Integer`: (Default: `maxtransformlevels(M)`) Number of decomposition levels.
- `ϵ::T where T<:AbstractFloat`: (Default: `1e-4`) Truncation Criterion.

# Returns
- `SM::SparseMatrixCSC{T, Integer}`: Sparse standard form of matrix of size ``n \times n``.

# Examples
```jldoctest
julia> using Wavelets, WaveletsExt; import Random: seed!

julia> seed!(1234); M = randn(4,4); wt = wavelet(WT.haar);

julia> mat2sparseform_std(M, wt)
4×4 SparseArrays.SparseMatrixCSC{Float64, Int64} with 16 stored entries:
1.88685 0.363656 0.468602 0.4063
2.49634 -1.08139 1.90927 -0.356542
-1.84601 0.129829 0.552745 0.427717
-0.630852 0.866545 -0.124156 -0.218902
```

**See also:** [`mat2sparseform_nonstd`](@ref), [`sft`](@ref)
"""
function mat2sparseform_std(M::AbstractMatrix{T},
wt::OrthoFilter,
L::Integer = maxtransformlevels(M),
ϵ::T = 1e-4) where T<:AbstractFloat
@assert size(M,1) == size(M,2)
Mw = sft(M, wt, L) # Transform to standard form
# Find maximum column norm of Mw
maxcolnorm = mapslices(norm, Mw, dims = 1) |> maximum
nilMw = Mw .* (abs.(Mw) .> ϵ * maxcolnorm) # Remove values close to zero
SM = sparse(nilMw) # Convert to sparse matrix
return SM
end
Loading