-
Notifications
You must be signed in to change notification settings - Fork 162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix logpdf_grad
for BroadcastedNormal.
#433
Conversation
Now it should correctly handle non-scalar arguments.
Thanks for making this PR! I'm not sure why the Travis build errored, but separate of that, it'd great if you could adjust the implementation to be overflow-safe in line with #321! |
Fix keyword syntax for earlier julia versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @MrVPlusOne! A couple of initial comments now; I'll come back later to suggest unit / regression tests.
@test isapprox(actual[1], finite_diff(f, args, 1, dx; broadcast=true)) | ||
@test isapprox(actual[2], finite_diff(f, args, 2, dx; broadcast=true)) | ||
@test isapprox(actual[3], finite_diff(f, args, 3, dx; broadcast=true)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually a little unsure how to fix these tests. It seems that the first argument of isapprox
is a zero-dimensional array (which I believe is the expected behavior) but the second argument is a Float64
, which causes Julia to complain that no method of isapprox
matches the given argument types. Any suggestions on how to fix this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After investigating this a bit, I think it boils down to JuliaLang/julia#28866, which I consider to be a bug in Julia. Yes -- in the version of the code before this PR, both the LHS and the RHS should be 0-dimensional arrays. For the RHS, the issue appears to be that these lines
Lines 22 to 23 in d79aded
pos_args[i] = copy(args[i]) .+ dx | |
neg_args[i] = copy(args[i]) .- dx |
do
::Aray{T,0} .+ ::T
and ::Aray{T,0} .- ::T
, which should give an ::Array{T,0}
but due to JuliaLang/julia#28866 gives a scalar. So the LHS of the ./
on the following line should have an Array{T,0}
as the i
th argument to f
, thus (since broadcast = true
) should output an Array{T,0}
, and consequently the division operation should be ::Array{T,0} ./ ::T
which should output an Array{T,0}
, but that is not what happens.
I think the cleanest way to fix this is to add a workaround to the (test-only) function finite_diff
so that it handles zero-dimensional array valued arguments correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I'm currently drafting this.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha, in addition to the above, the function f
being supplied to finite_diff
does not satisfy the requirements in its docstring: it takes array-valued arguments, but is not the broadcast of a function that takes scalar-valued arguments (i.e. logpdf_grad(vector-valued distribution parameters)
is not a broadcast of a bunch of logpdf_grad(scalar-valued distribution parameters)
's). So I guess we should add a variant of finite_diff
that has the semantics we need.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think to properly test this, we will need a more general finite_diff
function or use an autodiff library implementation as a reference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you clearly trust autodiff more than I do 🙂
Co-authored-by: bzinberg <[email protected]>
logpdf_grad
for BroadcastedNormal.logpdf_grad
for BroadcastedNormal.
@MrVPlusOne, I've drafted a fix to the unit test breakage. Take a look and LMKWYT? |
@bzinberg Thanks! The fix looks good to me! |
…n values of `logpdf_grad`
…t it's an implementation detail
…r than being coupled to a specific usage where the target shape is an "arg" and the full array is a "grad"
@MrVPlusOne, I've made a few more changes that are mostly by way of tidying.
The one functional change I made was to add to a couple of the unit tests an explicit check on the shape of the returned arrays. I think regression testing is covered by the shape checks that were fixed in e712c71. They passed before, but that is because the unit tests were incorrect. @MrVPlusOne, any comments you'd like to make on the above? If looks good to you, I think this is ready to merge. |
Thanks for the cleaning up; they all look good to me! I think one more nice thing to have would be to add a unit test for the case where the two arguments of |
…parameters `broadcasted_normal` have different ranks See probcomp#433 (comment)
Great idea @MrVPlusOne, how's this? |
Thanks, looks good to me! |
Thanks @MrVPlusOne for using and improving Gen! |
As requested by #432, this PR fixes the gradient of BroadcastedNormal for the non-scalar cases.