From 6f32c3f5752cc595558015e24a35786c59c7fa21 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 27 Jun 2023 07:41:32 +0330 Subject: [PATCH 1/8] have mode in dists --- src/base.jl | 3 +-- src/core_cond_icnf.jl | 39 +++++++++++++++++++++++++-------------- src/core_icnf.jl | 21 +++++++++++++++------ 3 files changed, 41 insertions(+), 22 deletions(-) diff --git a/src/base.jl b/src/base.jl index b77087b3..3baaa63e 100644 --- a/src/base.jl +++ b/src/base.jl @@ -3,8 +3,7 @@ export construct function construct( aicnf::Type{<:AbstractFlows}, nn, - nvars::Integer, - ; + nvars::Integer; data_type::Type{<:AbstractFloat} = Float32, array_type::Type{<:AbstractArray} = Array, compute_mode::Type{<:ComputeMode} = ADVectorMode, diff --git a/src/core_cond_icnf.jl b/src/core_cond_icnf.jl index 4c4066fb..fb03b6ae 100644 --- a/src/core_cond_icnf.jl +++ b/src/core_cond_icnf.jl @@ -31,8 +31,7 @@ end function CondICNFModel( m::AbstractCondICNF{T, AT, CM}, - loss::Function = loss, - ; + loss::Function = loss; optimizers::AbstractVector = [Optimisers.Adam()], n_epochs::Integer = 300, adtype::ADTypes.AbstractADType = AutoZygote(), @@ -195,13 +194,33 @@ struct CondICNFDist <: ICNFDistribution ys::AbstractVecOrMat{<:Real} ps::Any st::Any + mode::Mode +end + +function CondICNFDist( + m::AbstractCondICNF, + ys::AbstractVecOrMat{<:Real}, + ps::Any, + st::Any; + mode::Mode = TestMode(), +) + CondICNFDist(m, ys, ps, st, mode) +end + +function CondICNFDist( + mach::Machine{<:CondICNFModel}, + ys::AbstractVecOrMat{<:Real}; + mode::Mode = TestMode(), +) + (ps, st) = MLJBase.fitted_params(mach) + CondICNFDist(mach.model.m, ys, ps, st, mode) end Base.length(d::CondICNFDist) = d.m.nvars Base.eltype(d::CondICNFDist) = typeof(d.m).parameters[1] function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real}) if d.m isa AbstractCondICNF{<:AbstractFloat, <:AbstractArray, <:VectorMode} - first(inference(d.m, TestMode(), x, d.ys, d.ps, d.st)) + first(inference(d.m, d.mode, x, d.ys, d.ps, d.st)) elseif d.m isa AbstractCondICNF{<:AbstractFloat, <:AbstractArray, <:MatrixMode} first(Distributions._logpdf(d, hcat(x))) else @@ -212,14 +231,14 @@ function Distributions._logpdf(d::CondICNFDist, A::AbstractMatrix{<:Real}) if d.m isa AbstractCondICNF{<:AbstractFloat, <:AbstractArray, <:VectorMode} broadcast(x -> Distributions._logpdf(d, x), eachcol(A)) elseif d.m isa AbstractCondICNF{<:AbstractFloat, <:AbstractArray, <:MatrixMode} - first(inference(d.m, TestMode(), A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) + first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) else error("Not Implemented") end end function Distributions._rand!(rng::AbstractRNG, d::CondICNFDist, x::AbstractVector{<:Real}) if d.m isa AbstractCondICNF{<:AbstractFloat, <:AbstractArray, <:VectorMode} - x .= generate(d.m, TestMode(), d.ys, d.ps, d.st; rng) + x .= generate(d.m, d.mode, d.ys, d.ps, d.st; rng) elseif d.m isa AbstractCondICNF{<:AbstractFloat, <:AbstractArray, <:MatrixMode} x .= Distributions._rand!(rng, d, hcat(x)) else @@ -230,15 +249,7 @@ function Distributions._rand!(rng::AbstractRNG, d::CondICNFDist, A::AbstractMatr if d.m isa AbstractCondICNF{<:AbstractFloat, <:AbstractArray, <:VectorMode} A .= hcat(broadcast(x -> Distributions._rand!(rng, d, x), eachcol(A))...) elseif d.m isa AbstractCondICNF{<:AbstractFloat, <:AbstractArray, <:MatrixMode} - A .= generate( - d.m, - TestMode(), - d.ys[:, begin:size(A, 2)], - d.ps, - d.st, - size(A, 2); - rng, - ) + A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2); rng) else error("Not Implemented") end diff --git a/src/core_icnf.jl b/src/core_icnf.jl index e678d27c..ffdb180e 100644 --- a/src/core_icnf.jl +++ b/src/core_icnf.jl @@ -31,8 +31,7 @@ end function ICNFModel( m::AbstractICNF{T, AT, CM}, - loss::Function = loss, - ; + loss::Function = loss; optimizers::AbstractVector = [Optimisers.Adam()], n_epochs::Integer = 300, adtype::ADTypes.AbstractADType = AutoZygote(), @@ -185,13 +184,23 @@ struct ICNFDist <: ICNFDistribution m::AbstractICNF ps::Any st::Any + mode::Mode +end + +function ICNFDist(m::AbstractICNF, ps::Any, st::Any; mode::Mode = TestMode()) + ICNFDist(m, ps, st, mode) +end + +function ICNFDist(mach::Machine{<:ICNFModel}; mode::Mode = TestMode()) + (ps, st) = MLJBase.fitted_params(mach) + ICNFDist(mach.model.m, ps, st, mode) end Base.length(d::ICNFDist) = d.m.nvars Base.eltype(d::ICNFDist) = typeof(d.m).parameters[1] function Distributions._logpdf(d::ICNFDist, x::AbstractVector{<:Real}) if d.m isa AbstractICNF{<:AbstractFloat, <:AbstractArray, <:VectorMode} - first(inference(d.m, TestMode(), x, d.ps, d.st)) + first(inference(d.m, d.mode, x, d.ps, d.st)) elseif d.m isa AbstractICNF{<:AbstractFloat, <:AbstractArray, <:MatrixMode} first(Distributions._logpdf(d, hcat(x))) else @@ -202,14 +211,14 @@ function Distributions._logpdf(d::ICNFDist, A::AbstractMatrix{<:Real}) if d.m isa AbstractICNF{<:AbstractFloat, <:AbstractArray, <:VectorMode} broadcast(x -> Distributions._logpdf(d, x), eachcol(A)) elseif d.m isa AbstractICNF{<:AbstractFloat, <:AbstractArray, <:MatrixMode} - first(inference(d.m, TestMode(), A, d.ps, d.st)) + first(inference(d.m, d.mode, A, d.ps, d.st)) else error("Not Implemented") end end function Distributions._rand!(rng::AbstractRNG, d::ICNFDist, x::AbstractVector{<:Real}) if d.m isa AbstractICNF{<:AbstractFloat, <:AbstractArray, <:VectorMode} - x .= generate(d.m, TestMode(), d.ps, d.st; rng) + x .= generate(d.m, d.mode, d.ps, d.st; rng) elseif d.m isa AbstractICNF{<:AbstractFloat, <:AbstractArray, <:MatrixMode} x .= Distributions._rand!(rng, d, hcat(x)) else @@ -220,7 +229,7 @@ function Distributions._rand!(rng::AbstractRNG, d::ICNFDist, A::AbstractMatrix{< if d.m isa AbstractICNF{<:AbstractFloat, <:AbstractArray, <:VectorMode} A .= hcat(broadcast(x -> Distributions._rand!(rng, d, x), eachcol(A))...) elseif d.m isa AbstractICNF{<:AbstractFloat, <:AbstractArray, <:MatrixMode} - A .= generate(d.m, TestMode(), d.ps, d.st, size(A, 2); rng) + A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2); rng) else error("Not Implemented") end From bc97307b370d9a8f968c048ead8725b7b29b7c7a Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 27 Jun 2023 07:55:54 +0330 Subject: [PATCH 2/8] add it to tests --- test/benchmark_tests.jl | 29 ++++++++++++++------ test/call_tests.jl | 61 ++++++++++++++++++++++------------------- 2 files changed, 53 insertions(+), 37 deletions(-) diff --git a/test/benchmark_tests.jl b/test/benchmark_tests.jl index 44fe8351..d51d2bac 100644 --- a/test/benchmark_tests.jl +++ b/test/benchmark_tests.jl @@ -8,18 +8,29 @@ icnf = construct(RNODE, nn, nvars; compute_mode = ZygoteMatrixMode) ps, st = Lux.setup(rng, icnf) - diff_loss(x) = loss(icnf, r, x, st) - grad_diff_loss() = Zygote.gradient(diff_loss, ps) - t_loss() = loss(icnf, r, ps, st) + diff_loss_train(x) = loss(icnf, r, x, st; mode = TrainMode()) + diff_loss_test(x) = loss(icnf, r, x, st; mode = TestMode()) + grad_diff_loss_train() = Zygote.gradient(diff_loss_train, ps) + grad_diff_loss_test() = Zygote.gradient(diff_loss_test, ps) + t_loss_train() = loss(icnf, r, ps, st; mode = TrainMode()) + t_loss_test() = loss(icnf, r, ps, st; mode = TestMode()) - ben_1 = BenchmarkTools.@benchmark $t_loss() - ben_2 = BenchmarkTools.@benchmark $grad_diff_loss() + ben_loss_train = BenchmarkTools.@benchmark $t_loss_train() + ben_loss_test = BenchmarkTools.@benchmark $t_loss_test() + ben_grad_train = BenchmarkTools.@benchmark $grad_diff_loss_train() + ben_grad_test = BenchmarkTools.@benchmark $grad_diff_loss_test() - @info "t_loss" - display(ben_1) + @info "t_loss_train" + display(ben_loss_train) - @info "grad_diff_loss" - display(ben_2) + @info "t_loss_test" + display(ben_loss_test) + + @info "grad_diff_loss_train" + display(ben_grad_train) + + @info "grad_diff_loss_test" + display(ben_grad_test) @test true end diff --git a/test/call_tests.jl b/test/call_tests.jl index 0c2fefa4..827589ba 100644 --- a/test/call_tests.jl +++ b/test/call_tests.jl @@ -19,6 +19,7 @@ SDVecJacMatrixMode, SDJacVecMatrixMode, ] + omodes = Type{<:ContinuousNormalizingFlows.Mode}[TrainMode(), TestMode()] nvars_ = (1:2) adb_list = AbstractDifferentiation.AbstractBackend[ AbstractDifferentiation.ZygoteBackend(), @@ -34,6 +35,7 @@ tp in tps, adb_u in adb_list, nvars in nvars_, + omode in omodes, mt in mts adb_u isa AbstractDifferentiation.FiniteDifferencesBackend && continue @@ -60,14 +62,12 @@ ps, st = Lux.setup(rng, icnf) ps = ComponentArrays.ComponentArray(map(at{tp}, ps)) - @test !isnothing(inference(icnf, TestMode(), r, ps, st)) - @test !isnothing(inference(icnf, TrainMode(), r, ps, st)) - @test !isnothing(generate(icnf, TestMode(), ps, st)) - @test !isnothing(generate(icnf, TrainMode(), ps, st)) + @test !isnothing(inference(icnf, omode, r, ps, st)) + @test !isnothing(generate(icnf, omode, ps, st)) - @test !isnothing(loss(icnf, r, ps, st)) + @test !isnothing(loss(icnf, r, ps, st; mode = omode)) - diff_loss(x) = loss(icnf, r, x, st) + diff_loss(x) = loss(icnf, r, x, st; mode = omode) @testset "Using $(typeof(adb).name.name) For Loss" for adb in adb_list adb isa AbstractDifferentiation.TrackerBackend && continue @@ -111,7 +111,9 @@ @test_throws MethodError !isnothing(Calculus.gradient(diff_loss, ps)) # @test !isnothing(Calculus.hessian(diff_loss, ps)) - d = ICNFDist(icnf, ps, st) + d = ICNFDist(icnf, ps, st; mode = omode) + d2 = ICNFDist(mach; mode = omode) + @test d == d2 @test !isnothing(Distributions.logpdf(d, r)) @test !isnothing(Distributions.logpdf(d, r_arr)) @@ -124,6 +126,7 @@ tp in tps, cmode in cmodes, nvars in nvars_, + omode in omodes, mt in mts cmode <: SDJacVecMatrixMode && continue @@ -142,14 +145,12 @@ ps, st = Lux.setup(rng, icnf) ps = ComponentArrays.ComponentArray(map(at{tp}, ps)) - @test !isnothing(inference(icnf, TestMode(), r_arr, ps, st)) - @test !isnothing(inference(icnf, TrainMode(), r_arr, ps, st)) - @test !isnothing(generate(icnf, TestMode(), ps, st, 2)) - @test !isnothing(generate(icnf, TrainMode(), ps, st, 2)) + @test !isnothing(inference(icnf, omode, r_arr, ps, st)) + @test !isnothing(generate(icnf, omode, ps, st, 2)) - @test !isnothing(loss(icnf, r_arr, ps, st)) + @test !isnothing(loss(icnf, r_arr, ps, st; mode = omode)) - diff_loss(x) = loss(icnf, r_arr, x, st) + diff_loss(x) = loss(icnf, r_arr, x, st; mode = omode) @testset "Using $(typeof(adb).name.name) For Loss" for adb in adb_list adb isa AbstractDifferentiation.TrackerBackend && continue @@ -193,7 +194,9 @@ @test_throws MethodError !isnothing(Calculus.gradient(diff_loss, ps)) # @test !isnothing(Calculus.hessian(diff_loss, ps)) - d = ICNFDist(icnf, ps, st) + d = ICNFDist(icnf, ps, st; mode = omode) + d2 = ICNFDist(mach; mode = omode) + @test d == d2 @test !isnothing(Distributions.logpdf(d, r)) @test !isnothing(Distributions.logpdf(d, r_arr)) @@ -206,6 +209,7 @@ tp in tps, adb_u in adb_list, nvars in nvars_, + omode in omodes, mt in cmts adb_u isa AbstractDifferentiation.FiniteDifferencesBackend && continue @@ -236,14 +240,12 @@ ps, st = Lux.setup(rng, icnf) ps = ComponentArrays.ComponentArray(map(at{tp}, ps)) - @test !isnothing(inference(icnf, TestMode(), r, r2, ps, st)) - @test !isnothing(inference(icnf, TrainMode(), r, r2, ps, st)) - @test !isnothing(generate(icnf, TestMode(), r2, ps, st)) - @test !isnothing(generate(icnf, TrainMode(), r2, ps, st)) + @test !isnothing(inference(icnf, omode, r, r2, ps, st)) + @test !isnothing(generate(icnf, omode, r2, ps, st)) - @test !isnothing(loss(icnf, r, r2, ps, st)) + @test !isnothing(loss(icnf, r, r2, ps, st; mode = omode)) - diff_loss(x) = loss(icnf, r, r2, x, st) + diff_loss(x) = loss(icnf, r, r2, x, st; mode = omode) @testset "Using $(typeof(adb).name.name) For Loss" for adb in adb_list adb isa AbstractDifferentiation.TrackerBackend && continue @@ -287,7 +289,9 @@ @test_throws MethodError !isnothing(Calculus.gradient(diff_loss, ps)) # @test !isnothing(Calculus.hessian(diff_loss, ps)) - d = CondICNFDist(icnf, r2, ps, st) + d = CondICNFDist(icnf, r2, ps, st; mode = omode) + d2 = CondICNFDist(mach, r2; mode = omode) + @test d == d2 @test !isnothing(Distributions.logpdf(d, r)) @test !isnothing(Distributions.logpdf(d, r_arr)) @@ -300,6 +304,7 @@ tp in tps, cmode in cmodes, nvars in nvars_, + omode in omodes, mt in cmts cmode <: SDJacVecMatrixMode && continue @@ -321,14 +326,12 @@ ps, st = Lux.setup(rng, icnf) ps = ComponentArrays.ComponentArray(map(at{tp}, ps)) - @test !isnothing(inference(icnf, TestMode(), r_arr, r2_arr, ps, st)) - @test !isnothing(inference(icnf, TrainMode(), r_arr, r2_arr, ps, st)) - @test !isnothing(generate(icnf, TestMode(), r2_arr, ps, st, 2)) - @test !isnothing(generate(icnf, TrainMode(), r2_arr, ps, st, 2)) + @test !isnothing(inference(icnf, omode, r_arr, r2_arr, ps, st)) + @test !isnothing(generate(icnf, omode, r2_arr, ps, st, 2)) - @test !isnothing(loss(icnf, r_arr, r2_arr, ps, st)) + @test !isnothing(loss(icnf, r_arr, r2_arr, ps, st; mode = omode)) - diff_loss(x) = loss(icnf, r_arr, r2_arr, x, st) + diff_loss(x) = loss(icnf, r_arr, r2_arr, x, st; mode = omode) @testset "Using $(typeof(adb).name.name) For Loss" for adb in adb_list adb isa AbstractDifferentiation.TrackerBackend && continue @@ -372,7 +375,9 @@ @test_throws MethodError !isnothing(Calculus.gradient(diff_loss, ps)) # @test !isnothing(Calculus.hessian(diff_loss, ps)) - d = CondICNFDist(icnf, r2_arr, ps, st) + d = CondICNFDist(icnf, r2_arr, ps, st; mode = omode) + d2 = CondICNFDist(mach, r2_arr; mode = omode) + @test d == d2 @test !isnothing(Distributions.logpdf(d, r)) @test !isnothing(Distributions.logpdf(d, r_arr)) From 652561d73c18bbce6ec3ecdc9b05c0738b469e7f Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 27 Jun 2023 12:09:43 +0330 Subject: [PATCH 3/8] update testmode jacs --- src/utils.jl | 99 +++++++++++++--------------------------------------- 1 file changed, 24 insertions(+), 75 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 04fb2700..95f07414 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,57 +1,3 @@ -function jacobian_batched( - f, - xs::AbstractMatrix{<:AbstractFloat}, - T::Type{<:AbstractFloat}, - AT::Type{<:AbstractArray}, - CM::Type{<:ZygoteMatrixMode}, -)::Tuple - y, back = Zygote.pullback(f, xs) - z::AT = zeros(T, size(xs)) - res::AT = zeros(T, size(xs, 1), size(xs, 1), size(xs, 2)) - for i in axes(y, 1) - z[i, :] .= one(T) - res[i, :, :] .= only(back(z)) - z[i, :] .= zero(T) - end - y, res -end - -function jacobian_batched( - f, - xs::AbstractMatrix{<:AbstractFloat}, - T::Type{<:AbstractFloat}, - AT::Type{<:AbstractArray}, - CM::Type{<:SDVecJacMatrixMode}, -)::Tuple - y = f(xs) - z::AT = zeros(T, size(xs)) - res::AT = zeros(T, size(xs, 1), size(xs, 1), size(xs, 2)) - for i in axes(y, 1) - z[i, :] .= one(T) - res[i, :, :] .= reshape(auto_vecjac(f, xs, z), size(xs)) - z[i, :] .= zero(T) - end - y, res -end - -function jacobian_batched( - f, - xs::AbstractMatrix{<:AbstractFloat}, - T::Type{<:AbstractFloat}, - AT::Type{<:AbstractArray}, - CM::Type{<:SDJacVecMatrixMode}, -)::Tuple - y = f(xs) - z::AT = zeros(T, size(xs)) - res::AT = zeros(T, size(xs, 1), size(xs, 1), size(xs, 2)) - for i in axes(y, 1) - z[i, :] .= one(T) - res[i, :, :] .= reshape(auto_jacvec(f, xs, z), size(xs)) - z[i, :] .= zero(T) - end - y, res -end - function jacobian_batched( f, xs::AbstractMatrix{<:Real}, @@ -60,14 +6,15 @@ function jacobian_batched( CM::Type{<:ZygoteMatrixMode}, )::Tuple y, back = Zygote.pullback(f, xs) - z::AT = zeros(T, size(xs)) - res::AT{Real} = zeros(T, size(xs, 1), size(xs, 1), size(xs, 2)) - for i in axes(y, 1) - z[i, :] .= one(T) - res[i, :, :] .= only(back(z)) - z[i, :] .= zero(T) + z = Zygote.Buffer(xs) + z[:, :] = convert(AT, zeros(T, size(xs))) + res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) + for i in axes(xs, 1) + z[i, :] = one(T) + res[i, :, :] = only(back(z)) + z[i, :] = zero(T) end - y, res + y, copy(res) end function jacobian_batched( @@ -78,14 +25,15 @@ function jacobian_batched( CM::Type{<:SDVecJacMatrixMode}, )::Tuple y = f(xs) - z::AT = zeros(T, size(xs)) - res::AT{Real} = zeros(T, size(xs, 1), size(xs, 1), size(xs, 2)) - for i in axes(y, 1) - z[i, :] .= one(T) - res[i, :, :] .= reshape(auto_vecjac(f, xs, z), size(xs)) - z[i, :] .= zero(T) + z = Zygote.Buffer(xs) + z[:, :] = convert(AT, zeros(T, size(xs))) + res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) + for i in axes(xs, 1) + z[i, :] = one(T) + res[i, :, :] = reshape(auto_vecjac(f, xs, z), size(xs)) + z[i, :] = zero(T) end - y, res + y, copy(res) end function jacobian_batched( @@ -96,12 +44,13 @@ function jacobian_batched( CM::Type{<:SDJacVecMatrixMode}, )::Tuple y = f(xs) - z::AT = zeros(T, size(xs)) - res::AT{Real} = zeros(T, size(xs, 1), size(xs, 1), size(xs, 2)) - for i in axes(y, 1) - z[i, :] .= one(T) - res[i, :, :] .= reshape(auto_jacvec(f, xs, z), size(xs)) - z[i, :] .= zero(T) + z = Zygote.Buffer(xs) + z[:, :] = convert(AT, zeros(T, size(xs))) + res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) + for i in axes(xs, 1) + z[i, :] = one(T) + res[i, :, :] = reshape(auto_jacvec(f, xs, z), size(xs)) + z[i, :] = zero(T) end - y, res + y, copy(res) end From 9f5fec6d7c9b80ddff0aa87dee113adc9e5331a6 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 27 Jun 2023 14:06:18 +0330 Subject: [PATCH 4/8] change api to have `mode` after `icnf` --- README.md | 62 +++++++++++++++++++++---------------------- src/base_cond_icnf.jl | 4 +-- src/base_icnf.jl | 4 +-- src/cond_rnode.jl | 4 +-- src/core_cond_icnf.jl | 22 +++++---------- src/core_icnf.jl | 14 ++++------ src/precompile.jl | 17 ++++++------ src/rnode.jl | 4 +-- src/utils.jl | 21 +++++++-------- test/call_tests.jl | 34 +++++++++--------------- test/fit_tests.jl | 12 +++++++++ 11 files changed, 93 insertions(+), 105 deletions(-) diff --git a/README.md b/README.md index c3aad807..b4734e8e 100644 --- a/README.md +++ b/README.md @@ -1,34 +1,34 @@ -# ContinuousNormalizingFlows.jl - -[![deps](https://juliahub.com/docs/ContinuousNormalizingFlows/deps.svg)](https://juliahub.com/ui/Packages/ContinuousNormalizingFlows/iP1wo?t=2) -[![version](https://juliahub.com/docs/ContinuousNormalizingFlows/version.svg)](https://juliahub.com/ui/Packages/ContinuousNormalizingFlows/iP1wo) -[![pkgeval](https://juliahub.com/docs/ContinuousNormalizingFlows/pkgeval.svg)](https://juliahub.com/ui/Packages/ContinuousNormalizingFlows/iP1wo) -[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/stable) -[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/dev) -[![Build Status](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain) -[![Coverage](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl) -[![Coverage](https://coveralls.io/repos/github/impICNF/ContinuousNormalizingFlows.jl/badge.svg?branch=main)](https://coveralls.io/github/impICNF/ContinuousNormalizingFlows.jl?branch=main) -[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac) - -Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia - -## Citing - -See [`CITATION.bib`](CITATION.bib) for the relevant reference(s). - -## Usage - -To add this package, we can do it by - -```julia +# ContinuousNormalizingFlows.jl + +[![deps](https://juliahub.com/docs/ContinuousNormalizingFlows/deps.svg)](https://juliahub.com/ui/Packages/ContinuousNormalizingFlows/iP1wo?t=2) +[![version](https://juliahub.com/docs/ContinuousNormalizingFlows/version.svg)](https://juliahub.com/ui/Packages/ContinuousNormalizingFlows/iP1wo) +[![pkgeval](https://juliahub.com/docs/ContinuousNormalizingFlows/pkgeval.svg)](https://juliahub.com/ui/Packages/ContinuousNormalizingFlows/iP1wo) +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/stable) +[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/dev) +[![Build Status](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain) +[![Coverage](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl) +[![Coverage](https://coveralls.io/repos/github/impICNF/ContinuousNormalizingFlows.jl/badge.svg?branch=main)](https://coveralls.io/github/impICNF/ContinuousNormalizingFlows.jl?branch=main) +[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac) + +Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia + +## Citing + +See [`CITATION.bib`](CITATION.bib) for the relevant reference(s). + +## Usage + +To add this package, we can do it by + +```julia using Pkg Pkg.add("ContinuousNormalizingFlows") -``` - -To use this package, here is an example: - -```julia +``` + +To use this package, here is an example: + +```julia using ContinuousNormalizingFlows using Distributions, Lux # using Flux @@ -63,7 +63,7 @@ fit!(mach) ps, st = fitted_params(mach) # Use It -d = ICNFDist(icnf, ps, st) +d = ICNFDist(icnf, mach, TestMode()) actual_pdf = pdf.(data_dist, vec(r)) estimated_pdf = pdf(d, r) new_data = rand(d, n) @@ -79,4 +79,4 @@ using Plots p = plot(x -> pdf(data_dist, x), 0, 1; label = "actual") p = plot!(p, x -> pdf(d, convert.(Float32, vcat(x))), 0, 1; label = "estimated") savefig(p, "plot.png") -``` +``` diff --git a/src/base_cond_icnf.jl b/src/base_cond_icnf.jl index d38aa76a..758da907 100644 --- a/src/base_cond_icnf.jl +++ b/src/base_cond_icnf.jl @@ -111,12 +111,12 @@ end function loss( icnf::AbstractCondICNF{T, AT, <:VectorMode}, + mode::Mode, xs::AbstractVector{<:Real}, ys::AbstractVector{<:Real}, ps::Any, st::Any; differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend, - mode::Mode = TrainMode(), rng::AbstractRNG = Random.default_rng(), sol_args::Tuple = icnf.sol_args, sol_kwargs::Dict = icnf.sol_kwargs, @@ -138,12 +138,12 @@ end function loss( icnf::AbstractCondICNF{T, AT, <:MatrixMode}, + mode::Mode, xs::AbstractMatrix{<:Real}, ys::AbstractMatrix{<:Real}, ps::Any, st::Any; differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend, - mode::Mode = TrainMode(), rng::AbstractRNG = Random.default_rng(), sol_args::Tuple = icnf.sol_args, sol_kwargs::Dict = icnf.sol_kwargs, diff --git a/src/base_icnf.jl b/src/base_icnf.jl index e528073a..c81afade 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -107,11 +107,11 @@ end function loss( icnf::AbstractICNF{T, AT, <:VectorMode}, + mode::Mode, xs::AbstractVector{<:Real}, ps::Any, st::Any; differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend, - mode::Mode = TrainMode(), rng::AbstractRNG = Random.default_rng(), sol_args::Tuple = icnf.sol_args, sol_kwargs::Dict = icnf.sol_kwargs, @@ -132,11 +132,11 @@ end function loss( icnf::AbstractICNF{T, AT, <:MatrixMode}, + mode::Mode, xs::AbstractMatrix{<:Real}, ps::Any, st::Any; differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend, - mode::Mode = TrainMode(), rng::AbstractRNG = Random.default_rng(), sol_args::Tuple = icnf.sol_args, sol_kwargs::Dict = icnf.sol_kwargs, diff --git a/src/cond_rnode.jl b/src/cond_rnode.jl index 30599d8f..d9cf3dad 100644 --- a/src/cond_rnode.jl +++ b/src/cond_rnode.jl @@ -178,6 +178,7 @@ end function loss( icnf::CondRNODE{T, AT, <:VectorMode}, + mode::TrainMode, xs::AbstractVector{<:Real}, ys::AbstractVector{<:Real}, ps::Any, @@ -185,7 +186,6 @@ function loss( λ₁::T = convert(T, 1e-2), λ₂::T = convert(T, 1e-2); differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend, - mode::Mode = TrainMode(), rng::AbstractRNG = Random.default_rng(), sol_args::Tuple = icnf.sol_args, sol_kwargs::Dict = icnf.sol_kwargs, @@ -207,6 +207,7 @@ end function loss( icnf::CondRNODE{T, AT, <:MatrixMode}, + mode::TrainMode, xs::AbstractMatrix{<:Real}, ys::AbstractMatrix{<:Real}, ps::Any, @@ -214,7 +215,6 @@ function loss( λ₁::T = convert(T, 1e-2), λ₂::T = convert(T, 1e-2); differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend, - mode::Mode = TrainMode(), rng::AbstractRNG = Random.default_rng(), sol_args::Tuple = icnf.sol_args, sol_kwargs::Dict = icnf.sol_kwargs, diff --git a/src/core_cond_icnf.jl b/src/core_cond_icnf.jl index fb03b6ae..09dab244 100644 --- a/src/core_cond_icnf.jl +++ b/src/core_cond_icnf.jl @@ -4,7 +4,7 @@ export CondICNFModel, CondICNFDist function loss_f(icnf::AbstractCondICNF, loss::Function, st::Any)::Function function f(ps, θ, xs, ys) - loss(icnf, xs, ys, ps, st) + loss(icnf, TrainMode(), xs, ys, ps, st) end f end @@ -191,29 +191,19 @@ MLJBase.metadata_model( struct CondICNFDist <: ICNFDistribution m::AbstractCondICNF + mode::Mode ys::AbstractVecOrMat{<:Real} ps::Any st::Any - mode::Mode -end - -function CondICNFDist( - m::AbstractCondICNF, - ys::AbstractVecOrMat{<:Real}, - ps::Any, - st::Any; - mode::Mode = TestMode(), -) - CondICNFDist(m, ys, ps, st, mode) end function CondICNFDist( mach::Machine{<:CondICNFModel}, - ys::AbstractVecOrMat{<:Real}; - mode::Mode = TestMode(), + mode::Mode, + ys::AbstractVecOrMat{<:Real}, ) - (ps, st) = MLJBase.fitted_params(mach) - CondICNFDist(mach.model.m, ys, ps, st, mode) + (ps, st) = fitted_params(mach) + CondICNFDist(mach.model.m, mode, ys, ps, st) end Base.length(d::CondICNFDist) = d.m.nvars diff --git a/src/core_icnf.jl b/src/core_icnf.jl index ffdb180e..39ee026c 100644 --- a/src/core_icnf.jl +++ b/src/core_icnf.jl @@ -4,7 +4,7 @@ export ICNFModel, ICNFDist function loss_f(icnf::AbstractICNF, loss::Function, st::Any)::Function function f(ps, θ, xs) - loss(icnf, xs, ps, st) + loss(icnf, TrainMode(), xs, ps, st) end f end @@ -182,18 +182,14 @@ MLJBase.metadata_model( struct ICNFDist <: ICNFDistribution m::AbstractICNF + mode::Mode ps::Any st::Any - mode::Mode -end - -function ICNFDist(m::AbstractICNF, ps::Any, st::Any; mode::Mode = TestMode()) - ICNFDist(m, ps, st, mode) end -function ICNFDist(mach::Machine{<:ICNFModel}; mode::Mode = TestMode()) - (ps, st) = MLJBase.fitted_params(mach) - ICNFDist(mach.model.m, ps, st, mode) +function ICNFDist(mach::Machine{<:ICNFModel}, mode::Mode) + (ps, st) = fitted_params(mach) + ICNFDist(mach.model.m, mode, ps, st) end Base.length(d::ICNFDist) = d.m.nvars diff --git a/src/precompile.jl b/src/precompile.jl index f8b0d023..f884e66f 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -1,16 +1,17 @@ @setup_workload begin @compile_workload begin rng = Random.default_rng() - mts = Type{<:AbstractICNF}[RNODE, FFJORD, Planar] - cmts = Type{<:AbstractCondICNF}[CondRNODE, CondFFJORD, CondPlanar] + mts = Type{<:AbstractICNF}[RNODE] + cmts = Type{<:AbstractCondICNF}[CondRNODE] cmodes = Type{<:ComputeMode}[ADVectorMode, ZygoteMatrixMode, SDVecJacMatrixMode] + omodes = Mode[TrainMode(), TestMode()] nvars = 2 r = rand(Float32, nvars) r_arr = rand(Float32, nvars, 2) r2 = rand(Float32, nvars) r2_arr = rand(Float32, nvars, 2) - for cmode in cmodes, mt in mts + for cmode in cmodes, omode in omodes, mt in mts if mt <: Planar nn = PlanarLayer(nvars, tanh) else @@ -20,12 +21,12 @@ ps, st = Lux.setup(rng, icnf) ps = ComponentArray(ps) if cmode <: VectorMode - L = loss(icnf, r, ps, st) + L = loss(icnf, omode, r, ps, st) elseif cmode <: MatrixMode - L = loss(icnf, r_arr, ps, st) + L = loss(icnf, omode, r_arr, ps, st) end end - for cmode in cmodes, mt in cmts + for cmode in cmodes, omode in omodes, mt in cmts if mt <: CondPlanar nn = PlanarLayer(nvars, tanh; cond = true) else @@ -35,9 +36,9 @@ ps, st = Lux.setup(rng, icnf) ps = ComponentArray(ps) if cmode <: VectorMode - L = loss(icnf, r, r2, ps, st) + L = loss(icnf, omode, r, r2, ps, st) elseif cmode <: MatrixMode - L = loss(icnf, r_arr, r2_arr, ps, st) + L = loss(icnf, omode, r_arr, r2_arr, ps, st) end end end diff --git a/src/rnode.jl b/src/rnode.jl index 2a31ac7a..42fed346 100644 --- a/src/rnode.jl +++ b/src/rnode.jl @@ -167,13 +167,13 @@ end function loss( icnf::RNODE{T, AT, <:VectorMode}, + mode::TrainMode, xs::AbstractVector{<:Real}, ps::Any, st::Any, λ₁::T = convert(T, 1e-2), λ₂::T = convert(T, 1e-2); differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend, - mode::Mode = TrainMode(), rng::AbstractRNG = Random.default_rng(), sol_args::Tuple = icnf.sol_args, sol_kwargs::Dict = icnf.sol_kwargs, @@ -194,13 +194,13 @@ end function loss( icnf::RNODE{T, AT, <:MatrixMode}, + mode::TrainMode, xs::AbstractMatrix{<:Real}, ps::Any, st::Any, λ₁::T = convert(T, 1e-2), λ₂::T = convert(T, 1e-2); differentiation_backend::AbstractDifferentiation.AbstractBackend = icnf.differentiation_backend, - mode::Mode = TrainMode(), rng::AbstractRNG = Random.default_rng(), sol_args::Tuple = icnf.sol_args, sol_kwargs::Dict = icnf.sol_kwargs, diff --git a/src/utils.jl b/src/utils.jl index 95f07414..33dcb3c5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,13 +6,12 @@ function jacobian_batched( CM::Type{<:ZygoteMatrixMode}, )::Tuple y, back = Zygote.pullback(f, xs) - z = Zygote.Buffer(xs) - z[:, :] = convert(AT, zeros(T, size(xs))) + z::AT = zeros(T, size(xs)) res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) for i in axes(xs, 1) - z[i, :] = one(T) + z[i, :] .= one(T) res[i, :, :] = only(back(z)) - z[i, :] = zero(T) + z[i, :] .= zero(T) end y, copy(res) end @@ -25,13 +24,12 @@ function jacobian_batched( CM::Type{<:SDVecJacMatrixMode}, )::Tuple y = f(xs) - z = Zygote.Buffer(xs) - z[:, :] = convert(AT, zeros(T, size(xs))) + z::AT = zeros(T, size(xs)) res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) for i in axes(xs, 1) - z[i, :] = one(T) + z[i, :] .= one(T) res[i, :, :] = reshape(auto_vecjac(f, xs, z), size(xs)) - z[i, :] = zero(T) + z[i, :] .= zero(T) end y, copy(res) end @@ -44,13 +42,12 @@ function jacobian_batched( CM::Type{<:SDJacVecMatrixMode}, )::Tuple y = f(xs) - z = Zygote.Buffer(xs) - z[:, :] = convert(AT, zeros(T, size(xs))) + z::AT = zeros(T, size(xs)) res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) for i in axes(xs, 1) - z[i, :] = one(T) + z[i, :] .= one(T) res[i, :, :] = reshape(auto_jacvec(f, xs, z), size(xs)) - z[i, :] = zero(T) + z[i, :] .= zero(T) end y, copy(res) end diff --git a/test/call_tests.jl b/test/call_tests.jl index 827589ba..3c0e43bf 100644 --- a/test/call_tests.jl +++ b/test/call_tests.jl @@ -19,7 +19,7 @@ SDVecJacMatrixMode, SDJacVecMatrixMode, ] - omodes = Type{<:ContinuousNormalizingFlows.Mode}[TrainMode(), TestMode()] + omodes = ContinuousNormalizingFlows.Mode[TrainMode(), TestMode()] nvars_ = (1:2) adb_list = AbstractDifferentiation.AbstractBackend[ AbstractDifferentiation.ZygoteBackend(), @@ -65,9 +65,9 @@ @test !isnothing(inference(icnf, omode, r, ps, st)) @test !isnothing(generate(icnf, omode, ps, st)) - @test !isnothing(loss(icnf, r, ps, st; mode = omode)) + @test !isnothing(loss(icnf, omode, r, ps, st)) - diff_loss(x) = loss(icnf, r, x, st; mode = omode) + diff_loss(x) = loss(icnf, omode, r, x, st) @testset "Using $(typeof(adb).name.name) For Loss" for adb in adb_list adb isa AbstractDifferentiation.TrackerBackend && continue @@ -111,9 +111,7 @@ @test_throws MethodError !isnothing(Calculus.gradient(diff_loss, ps)) # @test !isnothing(Calculus.hessian(diff_loss, ps)) - d = ICNFDist(icnf, ps, st; mode = omode) - d2 = ICNFDist(mach; mode = omode) - @test d == d2 + d = ICNFDist(icnf, omode, ps, st) @test !isnothing(Distributions.logpdf(d, r)) @test !isnothing(Distributions.logpdf(d, r_arr)) @@ -148,9 +146,9 @@ @test !isnothing(inference(icnf, omode, r_arr, ps, st)) @test !isnothing(generate(icnf, omode, ps, st, 2)) - @test !isnothing(loss(icnf, r_arr, ps, st; mode = omode)) + @test !isnothing(loss(icnf, omode, r_arr, ps, st)) - diff_loss(x) = loss(icnf, r_arr, x, st; mode = omode) + diff_loss(x) = loss(icnf, omode, r_arr, x, st) @testset "Using $(typeof(adb).name.name) For Loss" for adb in adb_list adb isa AbstractDifferentiation.TrackerBackend && continue @@ -194,9 +192,7 @@ @test_throws MethodError !isnothing(Calculus.gradient(diff_loss, ps)) # @test !isnothing(Calculus.hessian(diff_loss, ps)) - d = ICNFDist(icnf, ps, st; mode = omode) - d2 = ICNFDist(mach; mode = omode) - @test d == d2 + d = ICNFDist(icnf, omode, ps, st) @test !isnothing(Distributions.logpdf(d, r)) @test !isnothing(Distributions.logpdf(d, r_arr)) @@ -243,9 +239,9 @@ @test !isnothing(inference(icnf, omode, r, r2, ps, st)) @test !isnothing(generate(icnf, omode, r2, ps, st)) - @test !isnothing(loss(icnf, r, r2, ps, st; mode = omode)) + @test !isnothing(loss(icnf, omode, r, r2, ps, st)) - diff_loss(x) = loss(icnf, r, r2, x, st; mode = omode) + diff_loss(x) = loss(icnf, omode, r, r2, x, st) @testset "Using $(typeof(adb).name.name) For Loss" for adb in adb_list adb isa AbstractDifferentiation.TrackerBackend && continue @@ -289,9 +285,7 @@ @test_throws MethodError !isnothing(Calculus.gradient(diff_loss, ps)) # @test !isnothing(Calculus.hessian(diff_loss, ps)) - d = CondICNFDist(icnf, r2, ps, st; mode = omode) - d2 = CondICNFDist(mach, r2; mode = omode) - @test d == d2 + d = CondICNFDist(icnf, omode, r2, ps, st) @test !isnothing(Distributions.logpdf(d, r)) @test !isnothing(Distributions.logpdf(d, r_arr)) @@ -329,9 +323,9 @@ @test !isnothing(inference(icnf, omode, r_arr, r2_arr, ps, st)) @test !isnothing(generate(icnf, omode, r2_arr, ps, st, 2)) - @test !isnothing(loss(icnf, r_arr, r2_arr, ps, st; mode = omode)) + @test !isnothing(loss(icnf, omode, r_arr, r2_arr, ps, st)) - diff_loss(x) = loss(icnf, r_arr, r2_arr, x, st; mode = omode) + diff_loss(x) = loss(icnf, omode, r_arr, r2_arr, x, st) @testset "Using $(typeof(adb).name.name) For Loss" for adb in adb_list adb isa AbstractDifferentiation.TrackerBackend && continue @@ -375,9 +369,7 @@ @test_throws MethodError !isnothing(Calculus.gradient(diff_loss, ps)) # @test !isnothing(Calculus.hessian(diff_loss, ps)) - d = CondICNFDist(icnf, r2_arr, ps, st; mode = omode) - d2 = CondICNFDist(mach, r2_arr; mode = omode) - @test d == d2 + d = CondICNFDist(icnf, omode, r2_arr, ps, st) @test !isnothing(Distributions.logpdf(d, r)) @test !isnothing(Distributions.logpdf(d, r_arr)) diff --git a/test/fit_tests.jl b/test/fit_tests.jl index 70f21328..f633edf7 100644 --- a/test/fit_tests.jl +++ b/test/fit_tests.jl @@ -79,6 +79,9 @@ @test !isnothing(MLJBase.fit!(mach)) @test !isnothing(MLJBase.transform(mach, df)) @test !isnothing(MLJBase.fitted_params(mach)) + + @test !isnothing(ICNFDist(mach, TrainMode())) + @test !isnothing(ICNFDist(mach, TestMode())) end @testset "$at | $tp | $cmode | $(typeof(go_ad).name.name) for fitting | $nvars Vars | $mt" for at in ats, @@ -114,6 +117,9 @@ @test !isnothing(MLJBase.fit!(mach)) @test !isnothing(MLJBase.transform(mach, df)) @test !isnothing(MLJBase.fitted_params(mach)) + + @test !isnothing(ICNFDist(mach, TrainMode())) + @test !isnothing(ICNFDist(mach, TestMode())) end @testset "$at | $tp | $(typeof(adb_u).name.name) for internal | $(typeof(go_ad).name.name) for fitting | $nvars Vars | $mt" for at in ats, @@ -161,6 +167,9 @@ @test !isnothing(MLJBase.fit!(mach)) @test !isnothing(MLJBase.transform(mach, (df, df2))) @test !isnothing(MLJBase.fitted_params(mach)) + + @test !isnothing(CondICNFDist(mach, TrainMode())) + @test !isnothing(CondICNFDist(mach, TestMode())) end @testset "$at | $tp | $cmode | $(typeof(go_ad).name.name) for fitting | $nvars Vars | $mt" for at in ats, @@ -199,5 +208,8 @@ @test !isnothing(MLJBase.fit!(mach)) @test !isnothing(MLJBase.transform(mach, (df, df2))) @test !isnothing(MLJBase.fitted_params(mach)) + + @test !isnothing(CondICNFDist(mach, TrainMode())) + @test !isnothing(CondICNFDist(mach, TestMode())) end end From 6638938e5ce4042be62736bd0d0940de1f0fdf96 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 27 Jun 2023 15:20:45 +0330 Subject: [PATCH 5/8] fix benchmark test --- test/benchmark_tests.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/benchmark_tests.jl b/test/benchmark_tests.jl index d51d2bac..8f53a1f7 100644 --- a/test/benchmark_tests.jl +++ b/test/benchmark_tests.jl @@ -8,12 +8,12 @@ icnf = construct(RNODE, nn, nvars; compute_mode = ZygoteMatrixMode) ps, st = Lux.setup(rng, icnf) - diff_loss_train(x) = loss(icnf, r, x, st; mode = TrainMode()) - diff_loss_test(x) = loss(icnf, r, x, st; mode = TestMode()) + diff_loss_train(x) = loss(icnf, TrainMode(), r, x, st) + diff_loss_test(x) = loss(icnf, TestMode(), r, x, st) grad_diff_loss_train() = Zygote.gradient(diff_loss_train, ps) grad_diff_loss_test() = Zygote.gradient(diff_loss_test, ps) - t_loss_train() = loss(icnf, r, ps, st; mode = TrainMode()) - t_loss_test() = loss(icnf, r, ps, st; mode = TestMode()) + t_loss_train() = loss(icnf, TrainMode(), r, ps, st) + t_loss_test() = loss(icnf, TestMode(), r, ps, st) ben_loss_train = BenchmarkTools.@benchmark $t_loss_train() ben_loss_test = BenchmarkTools.@benchmark $t_loss_test() From 683d1828d92a9aacb1c43e79b707f7201c2fa6e4 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 27 Jun 2023 15:22:20 +0330 Subject: [PATCH 6/8] Format .jl files (#240) Co-authored-by: prbzrg --- README.md | 60 +++++++++++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index b4734e8e..2fbf817c 100644 --- a/README.md +++ b/README.md @@ -1,34 +1,34 @@ -# ContinuousNormalizingFlows.jl - -[![deps](https://juliahub.com/docs/ContinuousNormalizingFlows/deps.svg)](https://juliahub.com/ui/Packages/ContinuousNormalizingFlows/iP1wo?t=2) -[![version](https://juliahub.com/docs/ContinuousNormalizingFlows/version.svg)](https://juliahub.com/ui/Packages/ContinuousNormalizingFlows/iP1wo) -[![pkgeval](https://juliahub.com/docs/ContinuousNormalizingFlows/pkgeval.svg)](https://juliahub.com/ui/Packages/ContinuousNormalizingFlows/iP1wo) -[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/stable) -[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/dev) -[![Build Status](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain) -[![Coverage](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl) -[![Coverage](https://coveralls.io/repos/github/impICNF/ContinuousNormalizingFlows.jl/badge.svg?branch=main)](https://coveralls.io/github/impICNF/ContinuousNormalizingFlows.jl?branch=main) -[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) -[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac) - -Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia - -## Citing - -See [`CITATION.bib`](CITATION.bib) for the relevant reference(s). - -## Usage - -To add this package, we can do it by - -```julia +# ContinuousNormalizingFlows.jl + +[![deps](https://juliahub.com/docs/ContinuousNormalizingFlows/deps.svg)](https://juliahub.com/ui/Packages/ContinuousNormalizingFlows/iP1wo?t=2) +[![version](https://juliahub.com/docs/ContinuousNormalizingFlows/version.svg)](https://juliahub.com/ui/Packages/ContinuousNormalizingFlows/iP1wo) +[![pkgeval](https://juliahub.com/docs/ContinuousNormalizingFlows/pkgeval.svg)](https://juliahub.com/ui/Packages/ContinuousNormalizingFlows/iP1wo) +[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/stable) +[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/dev) +[![Build Status](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain) +[![Coverage](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl) +[![Coverage](https://coveralls.io/repos/github/impICNF/ContinuousNormalizingFlows.jl/badge.svg?branch=main)](https://coveralls.io/github/impICNF/ContinuousNormalizingFlows.jl?branch=main) +[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac) + +Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia + +## Citing + +See [`CITATION.bib`](CITATION.bib) for the relevant reference(s). + +## Usage + +To add this package, we can do it by + +```julia using Pkg Pkg.add("ContinuousNormalizingFlows") -``` - -To use this package, here is an example: - -```julia +``` + +To use this package, here is an example: + +```julia using ContinuousNormalizingFlows using Distributions, Lux # using Flux @@ -79,4 +79,4 @@ using Plots p = plot(x -> pdf(data_dist, x), 0, 1; label = "actual") p = plot!(p, x -> pdf(d, convert.(Float32, vcat(x))), 0, 1; label = "estimated") savefig(p, "plot.png") -``` +``` From 09e1a2fdbd3c4416b3c7faa868cdc4da471d07cc Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 27 Jun 2023 15:51:06 +0330 Subject: [PATCH 7/8] fix a miss in tests --- test/fit_tests.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/fit_tests.jl b/test/fit_tests.jl index f633edf7..817914bf 100644 --- a/test/fit_tests.jl +++ b/test/fit_tests.jl @@ -168,8 +168,8 @@ @test !isnothing(MLJBase.transform(mach, (df, df2))) @test !isnothing(MLJBase.fitted_params(mach)) - @test !isnothing(CondICNFDist(mach, TrainMode())) - @test !isnothing(CondICNFDist(mach, TestMode())) + @test !isnothing(CondICNFDist(mach, TrainMode(), r2)) + @test !isnothing(CondICNFDist(mach, TestMode(), r2)) end @testset "$at | $tp | $cmode | $(typeof(go_ad).name.name) for fitting | $nvars Vars | $mt" for at in ats, @@ -209,7 +209,7 @@ @test !isnothing(MLJBase.transform(mach, (df, df2))) @test !isnothing(MLJBase.fitted_params(mach)) - @test !isnothing(CondICNFDist(mach, TrainMode())) - @test !isnothing(CondICNFDist(mach, TestMode())) + @test !isnothing(CondICNFDist(mach, TrainMode(), r2)) + @test !isnothing(CondICNFDist(mach, TestMode(), r2)) end end From 5e7d05128f7db238e7861871b25e0c0e81af6e40 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Tue, 27 Jun 2023 16:58:09 +0330 Subject: [PATCH 8/8] ignore mutation --- Project.toml | 2 ++ src/ContinuousNormalizingFlows.jl | 1 + src/utils.jl | 12 ++++++------ 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index b839ee5f..035d709a 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.6.0" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" @@ -44,6 +45,7 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" ADTypes = "0.1" AbstractDifferentiation = "0.5" CUDA = "4" +ChainRulesCore = "1" ComponentArrays = "0.13" ComputationalResources = "0.3" DataFrames = "1" diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index 218ad0b0..877ae7ad 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -3,6 +3,7 @@ module ContinuousNormalizingFlows using AbstractDifferentiation, ADTypes, Base.Iterators, + ChainRulesCore, ComponentArrays, ComputationalResources, CUDA, diff --git a/src/utils.jl b/src/utils.jl index 33dcb3c5..cedb1631 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,9 +9,9 @@ function jacobian_batched( z::AT = zeros(T, size(xs)) res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) for i in axes(xs, 1) - z[i, :] .= one(T) + @ignore_derivatives z[i, :] .= one(T) res[i, :, :] = only(back(z)) - z[i, :] .= zero(T) + @ignore_derivatives z[i, :] .= zero(T) end y, copy(res) end @@ -27,9 +27,9 @@ function jacobian_batched( z::AT = zeros(T, size(xs)) res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) for i in axes(xs, 1) - z[i, :] .= one(T) + @ignore_derivatives z[i, :] .= one(T) res[i, :, :] = reshape(auto_vecjac(f, xs, z), size(xs)) - z[i, :] .= zero(T) + @ignore_derivatives z[i, :] .= zero(T) end y, copy(res) end @@ -45,9 +45,9 @@ function jacobian_batched( z::AT = zeros(T, size(xs)) res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) for i in axes(xs, 1) - z[i, :] .= one(T) + @ignore_derivatives z[i, :] .= one(T) res[i, :, :] = reshape(auto_jacvec(f, xs, z), size(xs)) - z[i, :] .= zero(T) + @ignore_derivatives z[i, :] .= zero(T) end y, copy(res) end