From 73817a5c8c2a8e3d9bc3849b128e6e92bc5fc508 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Sat, 2 Oct 2021 18:04:33 +0330 Subject: [PATCH] remove type promote (#633) --- src/ffjord.jl | 3 +-- test/cnf_test.jl | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/ffjord.jl b/src/ffjord.jl index e75dc21a34..b16841b9f6 100644 --- a/src/ffjord.jl +++ b/src/ffjord.jl @@ -221,8 +221,7 @@ function forward_ffjord(n::FFJORD, x, p=n.p, e=randn(eltype(x), size(x)); λ₁ = λ₂ = _z[1, :] end - # logpdf promotes the type to Float64 by default - logpz = eltype(x).(reshape(logpdf(pz, z), 1, size(x, 2))) + logpz = reshape(logpdf(pz, z), 1, size(x, 2)) logpx = logpz .- delta_logp logpx, λ₁, λ₂ diff --git a/test/cnf_test.jl b/test/cnf_test.jl index 5c3f559aac..c4fbe9c752 100644 --- a/test/cnf_test.jl +++ b/test/cnf_test.jl @@ -40,13 +40,13 @@ end regularize = false monte_carlo = false - @test_broken !isnothing(DiffEqFlux.sciml_train(θ -> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; cb, maxiters=10)) + @test !isnothing(DiffEqFlux.sciml_train(θ -> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; cb, maxiters=10)) end @testset "regularize=false & monte_carlo=true" begin regularize = false monte_carlo = true - @test_broken !isnothing(DiffEqFlux.sciml_train(θ -> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; cb, maxiters=10)) + @test !isnothing(DiffEqFlux.sciml_train(θ -> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; cb, maxiters=10)) end @testset "regularize=true & monte_carlo=false" begin regularize = true @@ -58,7 +58,7 @@ end regularize = true monte_carlo = true - @test_broken !isnothing(DiffEqFlux.sciml_train(θ -> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; cb, maxiters=10)) + @test !isnothing(DiffEqFlux.sciml_train(θ -> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; cb, maxiters=10)) end end @testset "AutoReverseDiff as adtype" begin