Skip to content

Commit

Permalink
Merge pull request #433 from MrVPlusOne/fix-broadcasted-normal
Browse files Browse the repository at this point in the history
Fix `logpdf_grad` for BroadcastedNormal.
  • Loading branch information
bzinberg authored Aug 2, 2021
2 parents d79aded + 5d3f192 commit acd7005
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 23 deletions.
46 changes: 40 additions & 6 deletions src/modeling_library/distributions/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Samples an `Array{Float64, max(N1, N2)}` of shape
`Broadcast.broadcast_shapes(size(mu), size(std))` where each element is
independently normally distributed. This is equivalent to (a reshape of) a
multivariate normal with diagonal covariance matrix, but its implementation is
more efficient than that of the more general `mvnormal` for this case.
more efficient than that of the more general [`mvnormal`](@ref) for this case.
The shapes of `mu` and `std` must be broadcast-compatible.
Expand Down Expand Up @@ -65,8 +65,6 @@ function logpdf(::BroadcastedNormal,
assert_has_shape(x, broadcast_shapes_or_crash(mu, std);
msg="Shape of `x` does not agree with the sample space")
z = (x .- mu) ./ std
var = std .* std
diff = x .- mu
sum(- (abs2.(z) .+ log(2π)) / 2 .- log.(std))
end

Expand All @@ -85,10 +83,46 @@ function logpdf_grad(::BroadcastedNormal,
assert_has_shape(x, broadcast_shapes_or_crash(mu, std);
msg="Shape of `x` does not agree with the sample space")
z = (x .- mu) ./ std
deriv_x = sum(- z ./ std)
deriv_x = - z ./ std
deriv_mu = -deriv_x
deriv_std = sum(-1. ./ std .+ abs2.(z) ./ std)
(deriv_x, deriv_mu, deriv_std)
deriv_std = -1. ./ std .+ abs2.(z) ./ std
(_unbroadcast_like(x, deriv_x),
_unbroadcast_like(mu, deriv_mu),
_unbroadcast_like(std, deriv_std))
end

_unbroadcast_like(::Real, full_arr) = sum(full_arr)
_unbroadcast_like(::AbstractArray{<:Real, 0}, full_arr::Real) = fill(full_arr)
function _unbroadcast_like(a::AbstractArray{<:Real, N},
full_arr::AbstractArray{T}
)::AbstractArray{T, N} where {N,T}
if size(a) == size(full_arr)
return full_arr
end
return _unbroadcast_to_shape(size(a), full_arr)
end

"""
"Unbroadcasts" `full_arr` to have shape `target_shape` by:
* Summing over all dims that would be increased by a broadcast from shape
`target_shape` to shape `size(full_arr)`
* Then dropping trailing dims (which will all be 1's) as needed so that the
result has shape `target_shape`.
Requires that `size(full_arr)` is "strictly bigger" than `target_shape`, in the
sense that
Broadcast.broadcast_shapes(target_shape, size(full_arr)) == size(full_arr)
"""
function _unbroadcast_to_shape(target_shape::NTuple{target_ndims, Int},
full_arr::AbstractArray{T, full_ndims}
) where {T, target_ndims, full_ndims}
@assert full_ndims >= target_ndims
should_sum_dim(i) = (i > target_ndims) || (target_shape[i] == 1 &&
size(full_arr, i) > 1)
dropdims(sum(full_arr; dims=filter(should_sum_dim, 1:full_ndims));
dims=Dims(target_ndims + 1 : full_ndims))
end

random(::Normal, mu::Real, std::Real) = mu + std * randn()
Expand Down
46 changes: 30 additions & 16 deletions test/modeling_library/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,14 @@ end
x = broadcasted_normal(fill(0), fill(1))

# logpdf_grad
f = (x, mu, std) -> logpdf(broadcasted_normal, x, mu, std)
f(x, mu, std) = logpdf(broadcasted_normal, x, mu, std)
args = (fill(0.4), fill(0.2), fill(0.3))
actual = logpdf_grad(broadcasted_normal, args...)

@test actual[1] isa AbstractArray && size(actual[1]) == ()
@test actual[2] isa AbstractArray && size(actual[2]) == ()
@test actual[3] isa AbstractArray && size(actual[3]) == ()

@test isapprox(actual[1], finite_diff(f, args, 1, dx; broadcast=true))
@test isapprox(actual[2], finite_diff(f, args, 2, dx; broadcast=true))
@test isapprox(actual[3], finite_diff(f, args, 3, dx; broadcast=true))
Expand All @@ -144,27 +149,37 @@ end
broadcasted_normal(mu, std)

# logpdf_grad
f = (x, mu, std) -> logpdf(broadcasted_normal, x, mu, std)
args = (mu, std, x)
f(x_, mu_, std_) = logpdf(broadcasted_normal, x_, mu_, std_)
args = (x, mu, std)
actual = logpdf_grad(broadcasted_normal, args...)
@test isapprox(actual[1], finite_diff(f, args, 1, dx; broadcast=true))
@test isapprox(actual[2], finite_diff(f, args, 2, dx; broadcast=true))
@test isapprox(actual[3], finite_diff(f, args, 3, dx; broadcast=true))

@test actual[1] isa AbstractArray && size(actual[1]) == (2, 3)
@test actual[2] isa AbstractArray && size(actual[2]) == (2, 3)
@test actual[3] isa AbstractArray && size(actual[3]) == (2, 3)

@test isapprox(actual[1], finite_diff_arr_fullarg(f, args, 1, dx); rtol=1e-7)
@test isapprox(actual[2], finite_diff_arr_fullarg(f, args, 2, dx); rtol=1e-7)
@test isapprox(actual[3], finite_diff_arr_fullarg(f, args, 3, dx); rtol=1e-7)
end

@testset "broadcasted normal" begin

## Return shape of `broadcasted_normal`
@test size(broadcasted_normal([0. 0. 0.], 1.)) == (1, 3)
@test size(broadcasted_normal(zeros(1, 3, 4), ones(2, 1, 4))) == (2, 3, 4)
@test size(broadcasted_normal(zeros(1, 3), ones(2, 1, 1))) == (2, 3, 1)
@test_throws DimensionMismatch broadcasted_normal([0 0 0], [1 1])
# Numpy and Julia use different conventions for which direction the
# implicit 1-padding goes. In Julia, it's not `(1, 2, 3)` but rather
# `(2, 3, 1)` that is broadcast-compatible with the shape `(2, 3)`.
@test_throws DimensionMismatch broadcasted_normal(zeros(2, 3), ones(1, 2, 3))

## Return shape of `logpdf` and `logpdf_grad`
@test size(logpdf(broadcasted_normal,
ones(2, 4), ones(2, 1), ones(1, 4))) == ()
@test all(size(g) == ()
for g in logpdf_grad(
broadcasted_normal, ones(2, 4), ones(2, 1), ones(1, 4)))
@test [size(g) for g in logpdf_grad(
broadcasted_normal, ones(2, 4), ones(2, 1), ones(1, 4))
] == [(2, 4), (2, 1), (1, 4)]
# `x` has the wrong shape
@test_throws DimensionMismatch logpdf(broadcasted_normal,
ones(1, 2), ones(1,3), ones(2,1))
Expand All @@ -182,21 +197,20 @@ end
@test_throws DimensionMismatch logpdf_grad(broadcasted_normal,
ones(2, 1), ones(1,2), ones(1,3))

## Equivalence of broadcast to supplying bigger arrays for `mu` and `std`
## For `logpdf`, equivalence of broadcast to supplying bigger arrays for
## `mu` and `std`
compact = OrderedDict(:x => reshape([ 0.2 0.3 0.4 0.5 ;
0.5 0.4 0.3 0.2 ],
(2, 4)),
(2, 4, 1)),
:mu => reshape([0.7 0.7 0.8 0.6],
(1, 4)),
:std => reshape([0.2, 0.1],
(2, 1)))
(2, 1, 1)))
expanded = OrderedDict(:x => compact[:x],
:mu => repeat(compact[:mu], outer=(2, 1)),
:std => repeat(compact[:std], outer=(1, 4)))
:mu => repeat(compact[:mu], outer=(2, 1, 1)),
:std => repeat(compact[:std], outer=(1, 4, 1)))
@test (logpdf(broadcasted_normal, values(compact)...) ==
logpdf(broadcasted_normal, values(expanded)...))
@test (logpdf_grad(broadcasted_normal, values(compact)...) ==
logpdf_grad(broadcasted_normal, values(expanded)...))
end

@testset "multivariate normal" begin
Expand Down
28 changes: 27 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ function finite_diff(f::Function, args::Tuple, i::Int, dx::Float64;
if broadcast
pos_args[i] = copy(args[i]) .+ dx
neg_args[i] = copy(args[i]) .- dx
return (f(pos_args...) - f(neg_args...)) ./ (2. * dx)
ans = (f(pos_args...) - f(neg_args...)) ./ (2. * dx)
# Workaround for
# https://github.com/probcomp/Gen.jl/pull/433#discussion_r669958584
if args[i] isa AbstractArray && ndims(args[i]) == 0
return fill(ans)
end
return ans
else
pos_args[i] += dx
neg_args[i] -= dx
Expand Down Expand Up @@ -74,6 +80,26 @@ function finite_diff_arr(f::Function, args::Tuple, i::Int, idx, dx::Float64)
return (f(pos_args...) - f(neg_args...)) / (2. * dx)
end

"""
Returns the partial derivatives of `f` with respect to all entries of
`args[i]`.
That is, returns an array of the same shape as `args[i]`, each entry of which
is [`finite_diff_arr`](@ref) applied to the corresponding entry of `args[i]`.
Requires that `args[i]` have nonzero rank. Due to [1], handling
zero-dimensional arrays properly in this function is not feasible; the caller
should handle that case on their own.
[1] https://github.com/JuliaLang/julia/issues/28866
"""
function finite_diff_arr_fullarg(f::Function, args::Tuple, i::Int, dx::Float64)
@assert args[i] isa AbstractArray
@assert ndims(args[i]) > 0
return [finite_diff_arr(f, args, i, idx, dx)
for idx in keys(args[i])]
end

const dx = 1e-6

include("autodiff.jl")
Expand Down

0 comments on commit acd7005

Please sign in to comment.