-
Notifications
You must be signed in to change notification settings - Fork 421
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add NamedTupleVariate and ProductNamedTupleDistribution (#1803)
* Add NamedTupleVariate * Add ProductNamedTupleDistribution * Correctly implement eltype * Simplify insupport implementation * Overload std for ProductNamedTupleDistribution * Simplify rand for ProductNamedTupleDistribution * Reformat line * Add docstring to ProductNamedTupleDistribution * Add marginal API function * Add marginal for ProductDistribution * Rearrange marginal * Allow tuple indexing via marginal * Make logpdf type-stable * Add loglikelihood * Support extrema for multivariate distributions * Add tests * Improve type-inferrability * Remove extension * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Remove marginal * Add sampler for product namedtuple * Use ProductNamedTupleSampler for array rand calls * Add docs page for product distributions * Fix typo * Fix ProductNamedTuple docstring * Add deprecation warning to Product docstring * Move multivariate product distributions to own page * Document NamedTuple products * Add docs index * Document usage of ProductNamedTuple * Load Distributions for jldoctest * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Call method on NamedTuple * Revert to typejoin based eltype * Explicitly check eltype of dist matches that of draw * Correctly compute eltype for nested prod namedtuple distributions * Revert "Call method on NamedTuple" This reverts commit a86cac4. * Update test/namedtuple/productnamedtuple.jl Co-authored-by: David Widmann <[email protected]> * Support permutations of NamedTuple fields * Fix formatting * Support permutations of names in kldivergence --------- Co-authored-by: David Widmann <[email protected]>
- Loading branch information
Showing
11 changed files
with
450 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Product Distributions | ||
|
||
Product distributions are joint distributions of multiple independent distributions. | ||
It is recommended to use `product_distribution` to construct product distributions. | ||
Depending on the type of the argument, it may construct a different distribution type. | ||
|
||
## Multivariate products | ||
|
||
```@docs | ||
Distributions.product_distribution(::AbstractArray{<:Distribution{<:ArrayLikeVariate}}) | ||
Distributions.product_distribution(::AbstractVector{<:Normal}) | ||
Distributions.ProductDistribution | ||
Distributions.Product | ||
``` | ||
|
||
## NamedTuple-variate products | ||
|
||
```@docs | ||
Distributions.product_distribution(::NamedTuple{<:Any,<:Tuple{Distribution,Vararg{Distribution}}}) | ||
Distributions.ProductNamedTupleDistribution | ||
``` | ||
|
||
## Index | ||
|
||
```@index | ||
Pages = ["product.md"] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
""" | ||
ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <: | ||
Distribution{NamedTupleVariate{Tnames},S} | ||
A distribution of `NamedTuple`s, constructed from a `NamedTuple` of independent named | ||
distributions. | ||
Users should use [`product_distribution`](@ref) to construct a product distribution of | ||
independent distributions instead of constructing a `ProductNamedTupleDistribution` | ||
directly. | ||
# Examples | ||
```jldoctest ProductNamedTuple; setup = :(using Distributions, Random; Random.seed!(832)) | ||
julia> d = product_distribution((x=Normal(), y=Dirichlet([2, 4]))) | ||
ProductNamedTupleDistribution{(:x, :y)}( | ||
x: Normal{Float64}(μ=0.0, σ=1.0) | ||
y: Dirichlet{Int64, Vector{Int64}, Float64}(alpha=[2, 4]) | ||
) | ||
julia> nt = rand(d) | ||
(x = 1.5155385995160346, y = [0.533531876438439, 0.466468123561561]) | ||
julia> pdf(d, nt) | ||
0.13702825691074877 | ||
julia> pdf(d, reverse(nt)) # order of fields does not matter | ||
0.13702825691074877 | ||
julia> mode(d) # mode of marginals | ||
(x = 0.0, y = [0.25, 0.75]) | ||
julia> mean(d) # mean of marginals | ||
(x = 0.0, y = [0.3333333333333333, 0.6666666666666666]) | ||
julia> var(d) # var of marginals | ||
(x = 1.0, y = [0.031746031746031744, 0.031746031746031744]) | ||
``` | ||
""" | ||
struct ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <: | ||
Distribution{NamedTupleVariate{Tnames},S} | ||
dists::NamedTuple{Tnames,Tdists} | ||
end | ||
function ProductNamedTupleDistribution( | ||
dists::NamedTuple{K,V} | ||
) where {K,V<:Tuple{Distribution,Vararg{Distribution}}} | ||
vs = _product_valuesupport(values(dists)) | ||
eltypes = _product_namedtuple_eltype(values(dists)) | ||
return ProductNamedTupleDistribution{K,V,vs,eltypes}(dists) | ||
end | ||
|
||
_gentype(d::UnivariateDistribution) = eltype(d) | ||
_gentype(d::Distribution{<:ArrayLikeVariate{S}}) where {S} = Array{eltype(d),S} | ||
function _gentype(d::Distribution{CholeskyVariate}) | ||
T = eltype(d) | ||
return LinearAlgebra.Cholesky{T,Matrix{T}} | ||
end | ||
function _gentype(d::ProductNamedTupleDistribution{K}) where {K} | ||
return NamedTuple{K,Tuple{map(_gentype, values(d.dists))...}} | ||
end | ||
_gentype(::Distribution) = Any | ||
|
||
_product_namedtuple_eltype(dists) = typejoin(map(_gentype, dists)...) | ||
|
||
function Base.show(io::IO, d::ProductNamedTupleDistribution) | ||
return show_multline(io, d, collect(pairs(d.dists))) | ||
end | ||
|
||
function distrname(::ProductNamedTupleDistribution{K}) where {K} | ||
return "ProductNamedTupleDistribution{$K}" | ||
end | ||
|
||
""" | ||
product_distribution(dists::NamedTuple{K,Tuple{Vararg{Distribution}}}) where {K} | ||
Create a distribution of `NamedTuple`s as a product distribution of independent named | ||
distributions. | ||
The function falls back to constructing a [`ProductNamedTupleDistribution`](@ref) | ||
distribution but specialized methods can be defined. | ||
""" | ||
function product_distribution( | ||
dists::NamedTuple{<:Any,<:Tuple{Distribution,Vararg{Distribution}}} | ||
) | ||
return ProductNamedTupleDistribution(dists) | ||
end | ||
|
||
# Properties | ||
|
||
Base.eltype(::Type{<:ProductNamedTupleDistribution{<:Any,<:Any,<:Any,T}}) where {T} = T | ||
|
||
Base.minimum(d::ProductNamedTupleDistribution) = map(minimum, d.dists) | ||
|
||
Base.maximum(d::ProductNamedTupleDistribution) = map(maximum, d.dists) | ||
|
||
function _named_fields_match(x::NamedTuple{K}, y::NamedTuple) where {K} | ||
length(x) == length(y) || return false | ||
try | ||
NamedTuple{K}(y) | ||
return true | ||
catch | ||
return false | ||
end | ||
end | ||
|
||
function insupport(dist::ProductNamedTupleDistribution{K}, x::NamedTuple) where {K} | ||
return ( | ||
_named_fields_match(dist.dists, x) && | ||
all(map(insupport, dist.dists, NamedTuple{K}(x))) | ||
) | ||
end | ||
|
||
# Evaluation | ||
|
||
function pdf(dist::ProductNamedTupleDistribution, x::NamedTuple) | ||
return exp(logpdf(dist, x)) | ||
end | ||
|
||
function logpdf(dist::ProductNamedTupleDistribution{K}, x::NamedTuple) where {K} | ||
return sum(map(logpdf, dist.dists, NamedTuple{K}(x))) | ||
end | ||
|
||
function loglikelihood(dist::ProductNamedTupleDistribution, x::NamedTuple) | ||
return logpdf(dist, x) | ||
end | ||
|
||
function loglikelihood(dist::ProductNamedTupleDistribution, xs::AbstractArray{<:NamedTuple}) | ||
return sum(Base.Fix1(loglikelihood, dist), xs) | ||
end | ||
|
||
# Statistics | ||
|
||
mode(d::ProductNamedTupleDistribution) = map(mode, d.dists) | ||
|
||
mean(d::ProductNamedTupleDistribution) = map(mean, d.dists) | ||
|
||
var(d::ProductNamedTupleDistribution) = map(var, d.dists) | ||
|
||
std(d::ProductNamedTupleDistribution) = map(std, d.dists) | ||
|
||
entropy(d::ProductNamedTupleDistribution) = sum(entropy, values(d.dists)) | ||
|
||
function kldivergence( | ||
d1::ProductNamedTupleDistribution{K}, d2::ProductNamedTupleDistribution | ||
) where {K} | ||
_named_fields_match(d1.dists, d2.dists) || throw( | ||
ArgumentError( | ||
"Sets of named tuple fields are not the same: !issetequal($(keys(d1.dists)), $(keys(d2.dists)))", | ||
), | ||
) | ||
return sum(map(kldivergence, d1.dists, NamedTuple{K}(d2.dists))) | ||
end | ||
|
||
# Sampling | ||
|
||
function sampler(d::ProductNamedTupleDistribution{K,<:Any,S}) where {K,S} | ||
samplers = map(sampler, d.dists) | ||
Tsamplers = typeof(values(samplers)) | ||
return ProductNamedTupleSampler{K,Tsamplers,S}(samplers) | ||
end | ||
|
||
function Base.rand(rng::AbstractRNG, d::ProductNamedTupleDistribution{K}) where {K} | ||
return NamedTuple{K}(map(Base.Fix1(rand, rng), d.dists)) | ||
end | ||
function Base.rand( | ||
rng::AbstractRNG, d::ProductNamedTupleDistribution{K}, dims::Dims | ||
) where {K} | ||
return convert(AbstractArray{<:NamedTuple{K}}, _rand(rng, sampler(d), dims)) | ||
end | ||
|
||
function _rand!(rng::AbstractRNG, d::ProductNamedTupleDistribution, xs::AbstractArray) | ||
return _rand!(rng, sampler(d), xs) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
struct ProductNamedTupleSampler{Tnames,Tsamplers,S<:ValueSupport} <: | ||
Sampleable{NamedTupleVariate{Tnames},S} | ||
samplers::NamedTuple{Tnames,Tsamplers} | ||
end | ||
|
||
function Base.rand(rng::AbstractRNG, spl::ProductNamedTupleSampler{K}) where {K} | ||
return NamedTuple{K}(map(Base.Fix1(rand, rng), spl.samplers)) | ||
end | ||
|
||
function _rand(rng::AbstractRNG, spl::ProductNamedTupleSampler, dims::Dims) | ||
return map(CartesianIndices(dims)) do _ | ||
return rand(rng, spl) | ||
end | ||
end | ||
|
||
function _rand!( | ||
rng::AbstractRNG, spl::ProductNamedTupleSampler, xs::AbstractArray{<:NamedTuple{K}} | ||
) where {K} | ||
for i in eachindex(xs) | ||
xs[i] = NamedTuple{K}(rand(rng, spl)) | ||
end | ||
return xs | ||
end |
Oops, something went wrong.