diff --git a/src/mod/LDB.jl b/src/mod/LDB.jl index 410b8fc..13061c4 100644 --- a/src/mod/LDB.jl +++ b/src/mod/LDB.jl @@ -26,6 +26,7 @@ export # local discriminant basis LocalDiscriminantBasis, fit!, + fitdec!, fit_transform, transform, inverse_transform, @@ -95,7 +96,7 @@ following field values: dp::DiscriminantPower = BasisDiscriminantMeasure() top_k::Union{Integer, Nothing} = nothing n_features::Union{Integer, Nothing} = nothing - # to be computed in fit! + # to be computed in fit! or fitdec! sz::Union{Tuple{Vararg{T,N}} where {T<:Integer,N}, Nothing} = nothing Γ::Union{AbstractArray{<:AbstractFloat}, AbstractArray{NamedTuple{(:coef, :weight), Tuple{S1, S2}}} where @@ -183,13 +184,12 @@ function fit!(f::LocalDiscriminantBasis, X::AbstractArray{S}, y::AbstractVector{ Xw = wpdall(X, f.wt, f.max_dec_level) # fit local discriminant basis - fit!(f, Xw, y) + fitdec!(f, Xw, y) return nothing end -# TODO: parameter types are same as the one above -function fit!(f::LocalDiscriminantBasis, Xw::AbstractArray{S}, y::AbstractVector{T}) where - {S<:AbstractFloat, T} +function fitdec!(f::LocalDiscriminantBasis, Xw::AbstractArray{S}, y::AbstractVector{T}) where + {S<:AbstractFloat, T} # basic summary of data @assert 3 ≤ ndims(Xw) ≤ 4 c = unique(y) # unique classes @@ -316,7 +316,7 @@ function fit_transform(f::LocalDiscriminantBasis, Xw = wpdall(X, f.wt, f.max_dec_level) # fit LDB and return best features - fit!(f, Xw, y) + fitdec!(f, Xw, y) Xw = getbasiscoefall(Xw, f.tree) # Extract best features Xc = Array{S,2}(undef, f.n_features, N) @@ -373,7 +373,7 @@ function change_nfeatures(f::LocalDiscriminantBasis, x::AbstractArray{T,2}, # check measurements @assert !isnothing(f.n_features) @assert size(x,1) == f.n_features || throw(ArgumentError("f.n_features and number of rows of x do not match!")) - @assert 1 ≤ n_features ≤ f.n + @assert 1 ≤ n_features ≤ prod(f.sz) # change number of features if f.n_features ≥ n_features diff --git a/src/mod/ldb_measures.jl b/src/mod/ldb_measures.jl index 729cccc..9d5edd8 100644 --- a/src/mod/ldb_measures.jl +++ b/src/mod/ldb_measures.jl @@ -376,8 +376,13 @@ function discriminant_power(coefs::AbstractArray{T}, y::AbstractVector{S}, Eα = mean(Eαᵢ, dims = N) # overall mean of each entry pᵢ = Nᵢ / sum(Nᵢ) # proportions of each class - # TODO: This does not work when dealing with 2D signals (multiplication issue) - power = ((Eαᵢ - (Eα .* Eαᵢ)).^2 * pᵢ) ./ (Varαᵢ * pᵢ) + if N==2 # For 1D signals, can be done via matrix multiplication + power = ((Eαᵢ - (Eα .* Eαᵢ)).^2 * pᵢ) ./ (Varαᵢ * pᵢ) + else # For 2D signals, requires some manual work + pᵢ = reshape(pᵢ,1,1,:) + power = sum((Eαᵢ-(Eα.*Eαᵢ)).^2 .* pᵢ, dims=N) ./ sum(Varαᵢ.*pᵢ, dims=N) + power = reshape(power, sz...) + end order = sortperm(vec(power), rev = true) return (power, order) @@ -411,8 +416,13 @@ function discriminant_power(coefs::AbstractArray{T}, y::AbstractVector{S}, Medα = median(Medαᵢ, dims = N) # overall mean of each entry pᵢ = Nᵢ / sum(Nᵢ) # proportions of each class - # TODO: This does not work when dealing with 2D signals (multiplication issue) - power = ((Medαᵢ - (Medα .* Medαᵢ)).^2 * pᵢ) ./ (Madαᵢ * pᵢ) + if N==2 # For 1D signals, can be done via matrix multiplication + power = ((Medαᵢ - (Medα.*Medαᵢ)).^2 *pᵢ) ./ (Madαᵢ * pᵢ) + else # For 2D signals, requires some manual work + pᵢ = reshape(pᵢ,1,1,:) + power = sum((Medαᵢ-(Medα.*Medαᵢ)).^2 .* pᵢ, dims=N) ./ sum(Madαᵢ.*pᵢ, dims=N) + power = reshape(power, sz...) + end order = sortperm(vec(power), rev = true) return (power, order) diff --git a/test/ldb.jl b/test/ldb.jl index 27e1909..8e05650 100644 --- a/test/ldb.jl +++ b/test/ldb.jl @@ -1,86 +1,177 @@ -X, y = generateclassdata(ClassData(:tri, 5, 5, 5)) -wt = wavelet(WT.haar) +@testset "1D LDB" begin + X, y = generateclassdata(ClassData(:tri, 5, 5, 5)) + wt = wavelet(WT.haar) + + # AsymmetricRelativeEntropy + TimeFrequency + BasisDiscriminantMeasure + f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, top_k=5, n_features=5) + @test typeof(f) == LocalDiscriminantBasis + @test_nowarn fit_transform(f, X, y) + @test_nowarn fit!(f, X, y) + @test_nowarn transform(f, X) + Xc = transform(f, X) + @test size(Xc) == (5,15) + @test_nowarn inverse_transform(f, Xc) + X̂ = inverse_transform(f, Xc) + @test size(X̂) == (32, 15) + + # SymmetricRelativeEntropy + TimeFrequency + FishersClassSeparability + f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=SymmetricRelativeEntropy(), + dp=FishersClassSeparability(), top_k=5, n_features=5) + @test typeof(f) == LocalDiscriminantBasis + @test_nowarn fit_transform(f, X, y) + @test_nowarn fit!(f, X, y) + @test_nowarn transform(f, X) + Xc = transform(f, X) + @test size(Xc) == (5,15) + @test_nowarn inverse_transform(f, Xc) + X̂ = inverse_transform(f, Xc) + @test size(X̂) == (32, 15) + + # LpDistance + TimeFrequency + RobustFishersClassSeparability + f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=LpDistance(), + dp=RobustFishersClassSeparability(), top_k=5, n_features=5) + @test typeof(f) == LocalDiscriminantBasis + @test_nowarn fit_transform(f, X, y) + @test_nowarn fit!(f, X, y) + @test_nowarn transform(f, X) + Xc = transform(f, X) + @test size(Xc) == (5,15) + @test_nowarn inverse_transform(f, Xc) + X̂ = inverse_transform(f, Xc) + @test size(X̂) == (32, 15) + + # HellingerDistance + ProbabilityDensity + BasisDiscriminantMeasure + f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=HellingerDistance(), + en=ProbabilityDensity(), top_k=5, n_features=5) + @test typeof(f) == LocalDiscriminantBasis + @test_nowarn fit_transform(f, X, y) + @test_nowarn fit!(f, X, y) + @test_nowarn transform(f, X) + Xc = transform(f, X) + @test size(Xc) == (5,15) + @test_nowarn inverse_transform(f, Xc) + X̂ = inverse_transform(f, Xc) + @test size(X̂) == (32, 15) + + # EarthMoverDistance + Signatures(equal weight) + BasisDiscriminantMeasure + f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=EarthMoverDistance(), + en=Signatures(:equal), top_k=5, n_features=5) + @test typeof(f) == LocalDiscriminantBasis + @test_nowarn fit_transform(f, X, y) + @test_nowarn fit!(f, X, y) + @test_nowarn transform(f, X) + Xc = transform(f, X) + @test size(Xc) == (5,15) + @test_nowarn inverse_transform(f, Xc) + X̂ = inverse_transform(f, Xc) + @test size(X̂) == (32, 15) + + # EarthMoverDistance + Signatures(pdf weight) + BasisDiscriminantMeasure + f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=EarthMoverDistance(), + en=Signatures(:pdf), top_k=5, n_features=5) + @test typeof(f) == LocalDiscriminantBasis + @test_nowarn fit_transform(f, X, y) + @test_nowarn fit!(f, X, y) + @test_nowarn transform(f, X) + Xc = transform(f, X) + @test size(Xc) == (5,15) + @test_nowarn inverse_transform(f, Xc) + X̂ = inverse_transform(f, Xc) + @test size(X̂) == (32, 15) + + # change number of features + @test_nowarn change_nfeatures(f, Xc, 5) + x = change_nfeatures(f, Xc, 5) + @test size(x) == (5, 15) + @test_logs (:warn, "Proposed n_features larger than currently saved n_features. Results will be less accurate since inverse_transform and transform is involved.") change_nfeatures(f, Xc, 10) + @test_throws ArgumentError change_nfeatures(f, Xc, 10) +end -# AsymmetricRelativeEntropy + TimeFrequency + BasisDiscriminantMeasure -f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, top_k=5, n_features=5) -@test typeof(f) == LocalDiscriminantBasis -@test_nowarn fit_transform(f, X, y) -@test_nowarn fit!(f, X, y) -@test_nowarn transform(f, X) -Xc = transform(f, X) -@test size(Xc) == (5,15) -@test_nowarn inverse_transform(f, Xc) -X̂ = inverse_transform(f, Xc) -@test size(X̂) == (32, 15) - -# SymmetricRelativeEntropy + TimeFrequency + FishersClassSeparability -f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=SymmetricRelativeEntropy(), - dp=FishersClassSeparability(), top_k=5, n_features=5) -@test typeof(f) == LocalDiscriminantBasis -@test_nowarn fit_transform(f, X, y) -@test_nowarn fit!(f, X, y) -@test_nowarn transform(f, X) -Xc = transform(f, X) -@test size(Xc) == (5,15) -@test_nowarn inverse_transform(f, Xc) -X̂ = inverse_transform(f, Xc) -@test size(X̂) == (32, 15) - -# LpDistance + TimeFrequency + RobustFishersClassSeparability -f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=LpDistance(), - dp=RobustFishersClassSeparability(), top_k=5, n_features=5) -@test typeof(f) == LocalDiscriminantBasis -@test_nowarn fit_transform(f, X, y) -@test_nowarn fit!(f, X, y) -@test_nowarn transform(f, X) -Xc = transform(f, X) -@test size(Xc) == (5,15) -@test_nowarn inverse_transform(f, Xc) -X̂ = inverse_transform(f, Xc) -@test size(X̂) == (32, 15) - -# HellingerDistance + ProbabilityDensity + BasisDiscriminantMeasure -f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=HellingerDistance(), - en=ProbabilityDensity(), top_k=5, n_features=5) -@test typeof(f) == LocalDiscriminantBasis -@test_nowarn fit_transform(f, X, y) -@test_nowarn fit!(f, X, y) -@test_nowarn transform(f, X) -Xc = transform(f, X) -@test size(Xc) == (5,15) -@test_nowarn inverse_transform(f, Xc) -X̂ = inverse_transform(f, Xc) -@test size(X̂) == (32, 15) - -# EarthMoverDistance + Signatures(equal weight) + BasisDiscriminantMeasure -f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=EarthMoverDistance(), - en=Signatures(:equal), top_k=5, n_features=5) -@test typeof(f) == LocalDiscriminantBasis -@test_nowarn fit_transform(f, X, y) -@test_nowarn fit!(f, X, y) -@test_nowarn transform(f, X) -Xc = transform(f, X) -@test size(Xc) == (5,15) -@test_nowarn inverse_transform(f, Xc) -X̂ = inverse_transform(f, Xc) -@test size(X̂) == (32, 15) - -# EarthMoverDistance + Signatures(pdf weight) + BasisDiscriminantMeasure -f = LocalDiscriminantBasis(wt=wt, max_dec_level=4, dm=EarthMoverDistance(), - en=Signatures(:pdf), top_k=5, n_features=5) -@test typeof(f) == LocalDiscriminantBasis -@test_nowarn fit_transform(f, X, y) -@test_nowarn fit!(f, X, y) -@test_nowarn transform(f, X) -Xc = transform(f, X) -@test size(Xc) == (5,15) -@test_nowarn inverse_transform(f, Xc) -X̂ = inverse_transform(f, Xc) -@test size(X̂) == (32, 15) - -# change number of features -@test_nowarn change_nfeatures(f, Xc, 5) -x = change_nfeatures(f, Xc, 5) -@test size(x) == (5, 15) -@test_logs (:warn, "Proposed n_features larger than currently saved n_features. Results will be less accurate since inverse_transform and transform is involved.") change_nfeatures(f, Xc, 10) -@test_throws ArgumentError change_nfeatures(f, Xc, 10) +@testset "2D LDB" begin + X = cat(rand(Normal(0,1), 8,8,5), rand(Normal(1,1), 8,8,5), rand(Normal(2,1), 8,8,5), dims=3) + y = [repeat([1], 5); repeat([2],5); repeat([3],5)] + + # AsymmetricRelativeEntropy + TimeFrequency + BasisDiscriminantMeasure + f = LocalDiscriminantBasis(max_dec_level=2, top_k=5, n_features=5) + @test typeof(f) == LocalDiscriminantBasis + @test_nowarn fit_transform(f, X, y) + @test_nowarn fit!(f, X, y) + @test_nowarn transform(f, X) + Xc = transform(f, X) + @test size(Xc) == (5,15) + @test_nowarn inverse_transform(f, Xc) + X̂ = inverse_transform(f, Xc) + @test size(X̂) == (8, 8, 15) + + # SymmetricRelativeEntropy + TimeFrequency + FishersClassSeparability + f = LocalDiscriminantBasis(max_dec_level=2, dm=SymmetricRelativeEntropy(), + dp=FishersClassSeparability(), top_k=5, n_features=5) + @test typeof(f) == LocalDiscriminantBasis + @test_nowarn fit_transform(f, X, y) + @test_nowarn fit!(f, X, y) + @test_nowarn transform(f, X) + Xc = transform(f, X) + @test size(Xc) == (5,15) + @test_nowarn inverse_transform(f, Xc) + X̂ = inverse_transform(f, Xc) + @test size(X̂) == (8, 8, 15) + + # LpDistance + TimeFrequency + RobustFishersClassSeparability + f = LocalDiscriminantBasis(max_dec_level=2, dm=LpDistance(), + dp=RobustFishersClassSeparability(), top_k=5, n_features=5) + @test typeof(f) == LocalDiscriminantBasis + @test_nowarn fit_transform(f, X, y) + @test_nowarn fit!(f, X, y) + @test_nowarn transform(f, X) + Xc = transform(f, X) + @test size(Xc) == (5,15) + @test_nowarn inverse_transform(f, Xc) + X̂ = inverse_transform(f, Xc) + @test size(X̂) == (8, 8, 15) + + # HellingerDistance + ProbabilityDensity + BasisDiscriminantMeasure + f = LocalDiscriminantBasis(max_dec_level=2, dm=HellingerDistance(), + en=ProbabilityDensity(), top_k=5, n_features=5) + @test typeof(f) == LocalDiscriminantBasis + @test_nowarn fit_transform(f, X, y) + @test_nowarn fit!(f, X, y) + @test_nowarn transform(f, X) + Xc = transform(f, X) + @test size(Xc) == (5,15) + @test_nowarn inverse_transform(f, Xc) + X̂ = inverse_transform(f, Xc) + @test size(X̂) == (8, 8, 15) + + # EarthMoverDistance + Signatures(equal weight) + BasisDiscriminantMeasure + f = LocalDiscriminantBasis(max_dec_level=2, dm=EarthMoverDistance(), + en=Signatures(:equal), top_k=5, n_features=5) + @test typeof(f) == LocalDiscriminantBasis + @test_nowarn fit_transform(f, X, y) + @test_nowarn fit!(f, X, y) + @test_nowarn transform(f, X) + Xc = transform(f, X) + @test size(Xc) == (5,15) + @test_nowarn inverse_transform(f, Xc) + X̂ = inverse_transform(f, Xc) + @test size(X̂) == (8, 8, 15) + + # EarthMoverDistance + Signatures(pdf weight) + BasisDiscriminantMeasure + f = LocalDiscriminantBasis(max_dec_level=2, dm=EarthMoverDistance(), + en=Signatures(:pdf), top_k=5, n_features=5) + @test typeof(f) == LocalDiscriminantBasis + @test_nowarn fit_transform(f, X, y) + @test_nowarn fit!(f, X, y) + @test_nowarn transform(f, X) + Xc = transform(f, X) + @test size(Xc) == (5,15) + @test_nowarn inverse_transform(f, Xc) + X̂ = inverse_transform(f, Xc) + @test size(X̂) == (8, 8, 15) + + # change number of features + @test_nowarn change_nfeatures(f, Xc, 5) + x = change_nfeatures(f, Xc, 5) + @test size(x) == (5, 15) + @test_logs (:warn, "Proposed n_features larger than currently saved n_features. Results will be less accurate since inverse_transform and transform is involved.") change_nfeatures(f, Xc, 10) + @test_throws ArgumentError change_nfeatures(f, Xc, 10) +end \ No newline at end of file