Skip to content

Commit

Permalink
Fixed remaining bugs in LDB. Added tests for 2D LDB in CI.
Browse files Browse the repository at this point in the history
  • Loading branch information
zengfung committed Dec 1, 2021
1 parent 8145f68 commit 465b888
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 96 deletions.
14 changes: 7 additions & 7 deletions src/mod/LDB.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ export
# local discriminant basis
LocalDiscriminantBasis,
fit!,
fitdec!,
fit_transform,
transform,
inverse_transform,
Expand Down Expand Up @@ -95,7 +96,7 @@ following field values:
dp::DiscriminantPower = BasisDiscriminantMeasure()
top_k::Union{Integer, Nothing} = nothing
n_features::Union{Integer, Nothing} = nothing
# to be computed in fit!
# to be computed in fit! or fitdec!
sz::Union{Tuple{Vararg{T,N}} where {T<:Integer,N}, Nothing} = nothing
Γ::Union{AbstractArray{<:AbstractFloat},
AbstractArray{NamedTuple{(:coef, :weight), Tuple{S1, S2}}} where
Expand Down Expand Up @@ -183,13 +184,12 @@ function fit!(f::LocalDiscriminantBasis, X::AbstractArray{S}, y::AbstractVector{
Xw = wpdall(X, f.wt, f.max_dec_level)

# fit local discriminant basis
fit!(f, Xw, y)
fitdec!(f, Xw, y)
return nothing
end

# TODO: parameter types are same as the one above
function fit!(f::LocalDiscriminantBasis, Xw::AbstractArray{S}, y::AbstractVector{T}) where
{S<:AbstractFloat, T}
function fitdec!(f::LocalDiscriminantBasis, Xw::AbstractArray{S}, y::AbstractVector{T}) where
{S<:AbstractFloat, T}
# basic summary of data
@assert 3 ndims(Xw) 4
c = unique(y) # unique classes
Expand Down Expand Up @@ -316,7 +316,7 @@ function fit_transform(f::LocalDiscriminantBasis,
Xw = wpdall(X, f.wt, f.max_dec_level)

# fit LDB and return best features
fit!(f, Xw, y)
fitdec!(f, Xw, y)
Xw = getbasiscoefall(Xw, f.tree)
# Extract best features
Xc = Array{S,2}(undef, f.n_features, N)
Expand Down Expand Up @@ -373,7 +373,7 @@ function change_nfeatures(f::LocalDiscriminantBasis, x::AbstractArray{T,2},
# check measurements
@assert !isnothing(f.n_features)
@assert size(x,1) == f.n_features || throw(ArgumentError("f.n_features and number of rows of x do not match!"))
@assert 1 n_features f.n
@assert 1 n_features prod(f.sz)

# change number of features
if f.n_features n_features
Expand Down
18 changes: 14 additions & 4 deletions src/mod/ldb_measures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,13 @@ function discriminant_power(coefs::AbstractArray{T}, y::AbstractVector{S},
= mean(Eαᵢ, dims = N) # overall mean of each entry
pᵢ = Nᵢ / sum(Nᵢ) # proportions of each class

# TODO: This does not work when dealing with 2D signals (multiplication issue)
power = ((Eαᵢ - (Eα .* Eαᵢ)).^2 * pᵢ) ./ (Varαᵢ * pᵢ)
if N==2 # For 1D signals, can be done via matrix multiplication
power = ((Eαᵢ - (Eα .* Eαᵢ)).^2 * pᵢ) ./ (Varαᵢ * pᵢ)
else # For 2D signals, requires some manual work
pᵢ = reshape(pᵢ,1,1,:)
power = sum((Eαᵢ-(Eα.*Eαᵢ)).^2 .* pᵢ, dims=N) ./ sum(Varαᵢ.*pᵢ, dims=N)
power = reshape(power, sz...)
end
order = sortperm(vec(power), rev = true)

return (power, order)
Expand Down Expand Up @@ -411,8 +416,13 @@ function discriminant_power(coefs::AbstractArray{T}, y::AbstractVector{S},
Medα = median(Medαᵢ, dims = N) # overall mean of each entry
pᵢ = Nᵢ / sum(Nᵢ) # proportions of each class

# TODO: This does not work when dealing with 2D signals (multiplication issue)
power = ((Medαᵢ - (Medα .* Medαᵢ)).^2 * pᵢ) ./ (Madαᵢ * pᵢ)
if N==2 # For 1D signals, can be done via matrix multiplication
power = ((Medαᵢ - (Medα.*Medαᵢ)).^2 *pᵢ) ./ (Madαᵢ * pᵢ)
else # For 2D signals, requires some manual work
pᵢ = reshape(pᵢ,1,1,:)
power = sum((Medαᵢ-(Medα.*Medαᵢ)).^2 .* pᵢ, dims=N) ./ sum(Madαᵢ.*pᵢ, dims=N)
power = reshape(power, sz...)
end
order = sortperm(vec(power), rev = true)

return (power, order)
Expand Down
261 changes: 176 additions & 85 deletions test/ldb.jl
Original file line number Diff line number Diff line change
@@ -1,86 +1,177 @@
X, y = generateclassdata(ClassData(:tri, 5, 5, 5))
wt = wavelet(WT.haar)
@testset "1D LDB" begin
X, y = generateclassdata(ClassData(:tri, 5, 5, 5))
wt = wavelet(WT.haar)

# AsymmetricRelativeEntropy + TimeFrequency + BasisDiscriminantMeasure
f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (32, 15)

# SymmetricRelativeEntropy + TimeFrequency + FishersClassSeparability
f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=SymmetricRelativeEntropy(),
dp=FishersClassSeparability(), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (32, 15)

# LpDistance + TimeFrequency + RobustFishersClassSeparability
f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=LpDistance(),
dp=RobustFishersClassSeparability(), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (32, 15)

# HellingerDistance + ProbabilityDensity + BasisDiscriminantMeasure
f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=HellingerDistance(),
en=ProbabilityDensity(), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (32, 15)

# EarthMoverDistance + Signatures(equal weight) + BasisDiscriminantMeasure
f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=EarthMoverDistance(),
en=Signatures(:equal), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (32, 15)

# EarthMoverDistance + Signatures(pdf weight) + BasisDiscriminantMeasure
f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=EarthMoverDistance(),
en=Signatures(:pdf), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (32, 15)

# change number of features
@test_nowarn change_nfeatures(f, Xc, 5)
x = change_nfeatures(f, Xc, 5)
@test size(x) == (5, 15)
@test_logs (:warn, "Proposed n_features larger than currently saved n_features. Results will be less accurate since inverse_transform and transform is involved.") change_nfeatures(f, Xc, 10)
@test_throws ArgumentError change_nfeatures(f, Xc, 10)
end

# AsymmetricRelativeEntropy + TimeFrequency + BasisDiscriminantMeasure
f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (32, 15)

# SymmetricRelativeEntropy + TimeFrequency + FishersClassSeparability
f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=SymmetricRelativeEntropy(),
dp=FishersClassSeparability(), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (32, 15)

# LpDistance + TimeFrequency + RobustFishersClassSeparability
f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=LpDistance(),
dp=RobustFishersClassSeparability(), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (32, 15)

# HellingerDistance + ProbabilityDensity + BasisDiscriminantMeasure
f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=HellingerDistance(),
en=ProbabilityDensity(), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (32, 15)

# EarthMoverDistance + Signatures(equal weight) + BasisDiscriminantMeasure
f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=EarthMoverDistance(),
en=Signatures(:equal), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (32, 15)

# EarthMoverDistance + Signatures(pdf weight) + BasisDiscriminantMeasure
f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=EarthMoverDistance(),
en=Signatures(:pdf), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (32, 15)

# change number of features
@test_nowarn change_nfeatures(f, Xc, 5)
x = change_nfeatures(f, Xc, 5)
@test size(x) == (5, 15)
@test_logs (:warn, "Proposed n_features larger than currently saved n_features. Results will be less accurate since inverse_transform and transform is involved.") change_nfeatures(f, Xc, 10)
@test_throws ArgumentError change_nfeatures(f, Xc, 10)
@testset "2D LDB" begin
X = cat(rand(Normal(0,1), 8,8,5), rand(Normal(1,1), 8,8,5), rand(Normal(2,1), 8,8,5), dims=3)
y = [repeat([1], 5); repeat([2],5); repeat([3],5)]

# AsymmetricRelativeEntropy + TimeFrequency + BasisDiscriminantMeasure
f = LocalDiscriminantBasis(max_dec_level=2, top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (8, 8, 15)

# SymmetricRelativeEntropy + TimeFrequency + FishersClassSeparability
f = LocalDiscriminantBasis(max_dec_level=2, dm=SymmetricRelativeEntropy(),
dp=FishersClassSeparability(), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (8, 8, 15)

# LpDistance + TimeFrequency + RobustFishersClassSeparability
f = LocalDiscriminantBasis(max_dec_level=2, dm=LpDistance(),
dp=RobustFishersClassSeparability(), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (8, 8, 15)

# HellingerDistance + ProbabilityDensity + BasisDiscriminantMeasure
f = LocalDiscriminantBasis(max_dec_level=2, dm=HellingerDistance(),
en=ProbabilityDensity(), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (8, 8, 15)

# EarthMoverDistance + Signatures(equal weight) + BasisDiscriminantMeasure
f = LocalDiscriminantBasis(max_dec_level=2, dm=EarthMoverDistance(),
en=Signatures(:equal), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (8, 8, 15)

# EarthMoverDistance + Signatures(pdf weight) + BasisDiscriminantMeasure
f = LocalDiscriminantBasis(max_dec_level=2, dm=EarthMoverDistance(),
en=Signatures(:pdf), top_k=5, n_features=5)
@test typeof(f) == LocalDiscriminantBasis
@test_nowarn fit_transform(f, X, y)
@test_nowarn fit!(f, X, y)
@test_nowarn transform(f, X)
Xc = transform(f, X)
@test size(Xc) == (5,15)
@test_nowarn inverse_transform(f, Xc)
= inverse_transform(f, Xc)
@test size(X̂) == (8, 8, 15)

# change number of features
@test_nowarn change_nfeatures(f, Xc, 5)
x = change_nfeatures(f, Xc, 5)
@test size(x) == (5, 15)
@test_logs (:warn, "Proposed n_features larger than currently saved n_features. Results will be less accurate since inverse_transform and transform is involved.") change_nfeatures(f, Xc, 10)
@test_throws ArgumentError change_nfeatures(f, Xc, 10)
end

0 comments on commit 465b888

Please sign in to comment.