Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix logpdf_grad for BroadcastedNormal. #433

Merged
merged 20 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
56cdef4
Update `logpdf_grad` for BroadcastedNormal.
MrVPlusOne Jul 14, 2021
76e7f98
Make broadcasted normal's gradient overflow-proof.
MrVPlusOne Jul 14, 2021
a7ff2d7
Update the test for zero-dimensional arguments.
MrVPlusOne Jul 14, 2021
8bc4d89
Update src/modeling_library/distributions/normal.jl
MrVPlusOne Jul 14, 2021
8d79f46
Add workaround for the zero-dimensional array issue in https://github…
bzinberg Jul 14, 2021
0d37660
Implement and use the right variant of `finite_diff` for this more ge…
bzinberg Jul 14, 2021
e712c71
fix incorrect unit tests :grimacing:
bzinberg Jul 14, 2021
b35f326
indentation
bzinberg Jul 17, 2021
6dc4492
when testing `BroadcastedNormal`, directly verify shapes of the retur…
bzinberg Jul 17, 2021
127e07c
clarify that the equivalence is for `logpdf`, not `logpdf_grad`
bzinberg Jul 17, 2021
677282a
add ref link in docstring
bzinberg Jul 17, 2021
36e8c1f
add another ref link in docstring
bzinberg Jul 17, 2021
8e50e66
rename `unbroadcast_for_arg` -> `_unbroadcast_for_arg`, connoting tha…
bzinberg Jul 17, 2021
f7cc6d9
indentation
bzinberg Jul 17, 2021
ab66a2b
Rename `unbroadcast` functions to say what they generically do, rathe…
bzinberg Jul 17, 2021
7159b11
ternary expression is long, use `if` instead
bzinberg Jul 17, 2021
f9ac944
tidy up the body of `_unbroadcast_to_shape` a bit
bzinberg Jul 17, 2021
e08dd07
The term "unbroadcast" is non-obvious enough that it deserves a docst…
bzinberg Jul 17, 2021
c4573fc
`arg` -> `a`
bzinberg Jul 17, 2021
5d3f192
Add tests and modify an existing test to exercise the case where the …
bzinberg Jul 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions src/modeling_library/distributions/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ function logpdf(::BroadcastedNormal,
assert_has_shape(x, broadcast_shapes_or_crash(mu, std);
msg="Shape of `x` does not agree with the sample space")
z = (x .- mu) ./ std
var = std .* std
diff = x .- mu
sum(- (abs2.(z) .+ log(2π)) / 2 .- log.(std))
end

Expand All @@ -85,10 +83,29 @@ function logpdf_grad(::BroadcastedNormal,
assert_has_shape(x, broadcast_shapes_or_crash(mu, std);
msg="Shape of `x` does not agree with the sample space")
z = (x .- mu) ./ std
deriv_x = sum(- z ./ std)
deriv_x = - z ./ std
deriv_mu = -deriv_x
deriv_std = sum(-1. ./ std .+ abs2.(z) ./ std)
(deriv_x, deriv_mu, deriv_std)
deriv_std = -1. ./ std .+ abs2(z) ./ std
MrVPlusOne marked this conversation as resolved.
Show resolved Hide resolved
(unbroadcast_for_arg(x, deriv_x),
unbroadcast_for_arg(mu, deriv_mu),
unbroadcast_for_arg(std, deriv_std))
bzinberg marked this conversation as resolved.
Show resolved Hide resolved
end

unbroadcast_for_arg(::Real, grad) = sum(grad)
unbroadcast_for_arg(::Array{Float64, 0}, grad::Real) = fill(grad)
function unbroadcast_for_arg(
arg::AbstractArray{<:Real, N}, grad::AbstractArray{T}
)::AbstractArray{T, N} where {N,T}
size(arg) == size(grad) ? grad : unbroadcast_grad(size(arg), grad)
end

function unbroadcast_grad(
old_shape::NTuple{l_old, Int}, grad::AbstractArray{T, l_new}
) where {T, l_old, l_new}
@assert l_new >= l_old
new_shape = size(grad)
dims=filter(i -> i > l_old || old_shape[i] == 1 && new_shape[i] > 1, 1:l_new)
dropdims(sum(grad; dims=dims); dims=tuple((l_old+1:l_new)...))::AbstractArray{T, l_old}
end

random(::Normal, mu::Real, std::Real) = mu + std * randn()
Expand Down
6 changes: 3 additions & 3 deletions test/modeling_library/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ end
f = (x, mu, std) -> logpdf(broadcasted_normal, x, mu, std)
args = (fill(0.4), fill(0.2), fill(0.3))
actual = logpdf_grad(broadcasted_normal, args...)
@test isapprox(actual[1], finite_diff(f, args, 1, dx; broadcast=true))
@test isapprox(actual[2], finite_diff(f, args, 2, dx; broadcast=true))
@test isapprox(actual[3], finite_diff(f, args, 3, dx; broadcast=true))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

@MrVPlusOne MrVPlusOne Jul 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually a little unsure how to fix these tests. It seems that the first argument of isapprox is a zero-dimensional array (which I believe is the expected behavior) but the second argument is a Float64, which causes Julia to complain that no method of isapprox matches the given argument types. Any suggestions on how to fix this?

Copy link
Contributor

@bzinberg bzinberg Jul 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After investigating this a bit, I think it boils down to JuliaLang/julia#28866, which I consider to be a bug in Julia. Yes -- in the version of the code before this PR, both the LHS and the RHS should be 0-dimensional arrays. For the RHS, the issue appears to be that these lines

Gen.jl/test/runtests.jl

Lines 22 to 23 in d79aded

pos_args[i] = copy(args[i]) .+ dx
neg_args[i] = copy(args[i]) .- dx

do ::Aray{T,0} .+ ::T and ::Aray{T,0} .- ::T, which should give an ::Array{T,0} but due to JuliaLang/julia#28866 gives a scalar. So the LHS of the ./ on the following line should have an Array{T,0} as the ith argument to f, thus (since broadcast = true) should output an Array{T,0}, and consequently the division operation should be ::Array{T,0} ./ ::T which should output an Array{T,0}, but that is not what happens.

I think the cleanest way to fix this is to add a workaround to the (test-only) function finite_diff so that it handles zero-dimensional array valued arguments correctly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'm currently drafting this.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, in addition to the above, the function f being supplied to finite_diff does not satisfy the requirements in its docstring: it takes array-valued arguments, but is not the broadcast of a function that takes scalar-valued arguments (i.e. logpdf_grad(vector-valued distribution parameters) is not a broadcast of a bunch of logpdf_grad(scalar-valued distribution parameters)'s). So I guess we should add a variant of finite_diff that has the semantics we need.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think to properly test this, we will need a more general finite_diff function or use an autodiff library implementation as a reference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you clearly trust autodiff more than I do 🙂

@test isapprox(actual[1][], finite_diff(f, args, 1, dx))
@test isapprox(actual[2][], finite_diff(f, args, 2, dx))
@test isapprox(actual[3][], finite_diff(f, args, 3, dx))
end

@testset "array normal (trivially broadcasted: all args have same shape)" begin
Expand Down