Skip to content

Commit

Permalink
remove type promote (#633)
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg authored Oct 2, 2021
1 parent 60361b4 commit 73817a5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
3 changes: 1 addition & 2 deletions src/ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,7 @@ function forward_ffjord(n::FFJORD, x, p=n.p, e=randn(eltype(x), size(x));
λ₁ = λ₂ = _z[1, :]
end

# logpdf promotes the type to Float64 by default
logpz = eltype(x).(reshape(logpdf(pz, z), 1, size(x, 2)))
logpz = reshape(logpdf(pz, z), 1, size(x, 2))
logpx = logpz .- delta_logp

logpx, λ₁, λ₂
Expand Down
6 changes: 3 additions & 3 deletions test/cnf_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ end
regularize = false
monte_carlo = false

@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; cb, maxiters=10))
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; cb, maxiters=10))
end
@testset "regularize=false & monte_carlo=true" begin
regularize = false
monte_carlo = true

@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; cb, maxiters=10))
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; cb, maxiters=10))
end
@testset "regularize=true & monte_carlo=false" begin
regularize = true
Expand All @@ -58,7 +58,7 @@ end
regularize = true
monte_carlo = true

@test_broken !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; cb, maxiters=10))
@test !isnothing(DiffEqFlux.sciml_train-> loss(θ; regularize, monte_carlo), ffjord_mdl.p, ADAM(0.1), adtype; cb, maxiters=10))
end
end
@testset "AutoReverseDiff as adtype" begin
Expand Down

0 comments on commit 73817a5

Please sign in to comment.