Skip to content

Commit

Permalink
Add NamedTupleVariate and ProductNamedTupleDistribution (#1803)
Browse files Browse the repository at this point in the history
* Add NamedTupleVariate

* Add ProductNamedTupleDistribution

* Correctly implement eltype

* Simplify insupport implementation

* Overload std for ProductNamedTupleDistribution

* Simplify rand for ProductNamedTupleDistribution

* Reformat line

* Add docstring to ProductNamedTupleDistribution

* Add marginal API function

* Add marginal for ProductDistribution

* Rearrange marginal

* Allow tuple indexing via marginal

* Make logpdf type-stable

* Add loglikelihood

* Support extrema for multivariate distributions

* Add tests

* Improve type-inferrability

* Remove extension

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Remove marginal

* Add sampler for product namedtuple

* Use ProductNamedTupleSampler for array rand calls

* Add docs page for product distributions

* Fix typo

* Fix ProductNamedTuple docstring

* Add deprecation warning to Product docstring

* Move multivariate product distributions to own page

* Document NamedTuple products

* Add docs index

* Document usage of ProductNamedTuple

* Load Distributions for jldoctest

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* Call method on NamedTuple

* Revert to typejoin based eltype

* Explicitly check eltype of dist matches that of draw

* Correctly compute eltype for nested prod namedtuple distributions

* Revert "Call method on NamedTuple"

This reverts commit a86cac4.

* Update test/namedtuple/productnamedtuple.jl

Co-authored-by: David Widmann <[email protected]>

* Support permutations of NamedTuple fields

* Fix formatting

* Support permutations of names in kldivergence

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
sethaxen and devmotion authored Jan 17, 2025
1 parent 957f0c0 commit 07d1e78
Show file tree
Hide file tree
Showing 11 changed files with 450 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ makedocs(
"reshape.md",
"cholesky.md",
"mixture.md",
"product.md",
"order_statistics.md",
"convolution.md",
"fit.md",
Expand Down
10 changes: 0 additions & 10 deletions docs/src/multivariate.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ MvNormalCanon
MvLogitNormal
MvLogNormal
Dirichlet
Product
```

## Addition Methods
Expand Down Expand Up @@ -105,15 +104,6 @@ params{D<:Distributions.AbstractMvLogNormal}(::Type{D},m::AbstractVector,S::Abst
Distributions._logpdf(d::MultivariateDistribution, x::AbstractArray)
```

## Product distributions

```@docs
Distributions.product_distribution
```

Using `product_distribution` is advised to construct product distributions.
For some distributions, it constructs a special multivariate type.

## Index

```@index
Expand Down
27 changes: 27 additions & 0 deletions docs/src/product.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Product Distributions

Product distributions are joint distributions of multiple independent distributions.
It is recommended to use `product_distribution` to construct product distributions.
Depending on the type of the argument, it may construct a different distribution type.

## Multivariate products

```@docs
Distributions.product_distribution(::AbstractArray{<:Distribution{<:ArrayLikeVariate}})
Distributions.product_distribution(::AbstractVector{<:Normal})
Distributions.ProductDistribution
Distributions.Product
```

## NamedTuple-variate products

```@docs
Distributions.product_distribution(::NamedTuple{<:Any,<:Tuple{Distribution,Vararg{Distribution}}})
Distributions.ProductNamedTupleDistribution
```

## Index

```@index
Pages = ["product.md"]
```
2 changes: 2 additions & 0 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export
Multivariate,
Matrixvariate,
CholeskyVariate,
NamedTupleVariate,
Discrete,
Continuous,
Sampleable,
Expand Down Expand Up @@ -296,6 +297,7 @@ include("univariates.jl")
include("edgeworth.jl")
include("multivariates.jl")
include("matrixvariates.jl")
include("namedtuple/productnamedtuple.jl")
include("cholesky/lkjcholesky.jl")
include("samplers.jl")

Expand Down
6 changes: 6 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ const Univariate = ArrayLikeVariate{0}
const Multivariate = ArrayLikeVariate{1}
const Matrixvariate = ArrayLikeVariate{2}

"""
`F <: NamedTupleVariate{K}` specifies that the variate or a sample is of type
`NamedTuple{K}`.
"""
struct NamedTupleVariate{K} <: VariateForm end

"""
`F <: CholeskyVariate` specifies that the variate or a sample is of type
`LinearAlgebra.Cholesky`.
Expand Down
4 changes: 4 additions & 0 deletions src/multivariate/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ An N dimensional `MultivariateDistribution` constructed from a vector of N indep
```julia
Product(Uniform.(rand(10), 1)) # A 10-dimensional Product from 10 independent `Uniform` distributions.
```
!!! note
`Product` is deprecated and will be removed in the next breaking release.
Use [`product_distribution`](@ref) instead.
"""
struct Product{
S<:ValueSupport,
Expand Down
174 changes: 174 additions & 0 deletions src/namedtuple/productnamedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""
ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <:
Distribution{NamedTupleVariate{Tnames},S}
A distribution of `NamedTuple`s, constructed from a `NamedTuple` of independent named
distributions.
Users should use [`product_distribution`](@ref) to construct a product distribution of
independent distributions instead of constructing a `ProductNamedTupleDistribution`
directly.
# Examples
```jldoctest ProductNamedTuple; setup = :(using Distributions, Random; Random.seed!(832))
julia> d = product_distribution((x=Normal(), y=Dirichlet([2, 4])))
ProductNamedTupleDistribution{(:x, :y)}(
x: Normal{Float64}(μ=0.0, σ=1.0)
y: Dirichlet{Int64, Vector{Int64}, Float64}(alpha=[2, 4])
)
julia> nt = rand(d)
(x = 1.5155385995160346, y = [0.533531876438439, 0.466468123561561])
julia> pdf(d, nt)
0.13702825691074877
julia> pdf(d, reverse(nt)) # order of fields does not matter
0.13702825691074877
julia> mode(d) # mode of marginals
(x = 0.0, y = [0.25, 0.75])
julia> mean(d) # mean of marginals
(x = 0.0, y = [0.3333333333333333, 0.6666666666666666])
julia> var(d) # var of marginals
(x = 1.0, y = [0.031746031746031744, 0.031746031746031744])
```
"""
struct ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <:
Distribution{NamedTupleVariate{Tnames},S}
dists::NamedTuple{Tnames,Tdists}
end
function ProductNamedTupleDistribution(
dists::NamedTuple{K,V}
) where {K,V<:Tuple{Distribution,Vararg{Distribution}}}
vs = _product_valuesupport(values(dists))
eltypes = _product_namedtuple_eltype(values(dists))
return ProductNamedTupleDistribution{K,V,vs,eltypes}(dists)
end

_gentype(d::UnivariateDistribution) = eltype(d)
_gentype(d::Distribution{<:ArrayLikeVariate{S}}) where {S} = Array{eltype(d),S}
function _gentype(d::Distribution{CholeskyVariate})
T = eltype(d)
return LinearAlgebra.Cholesky{T,Matrix{T}}
end
function _gentype(d::ProductNamedTupleDistribution{K}) where {K}
return NamedTuple{K,Tuple{map(_gentype, values(d.dists))...}}
end
_gentype(::Distribution) = Any

_product_namedtuple_eltype(dists) = typejoin(map(_gentype, dists)...)

function Base.show(io::IO, d::ProductNamedTupleDistribution)
return show_multline(io, d, collect(pairs(d.dists)))
end

function distrname(::ProductNamedTupleDistribution{K}) where {K}
return "ProductNamedTupleDistribution{$K}"
end

"""
product_distribution(dists::NamedTuple{K,Tuple{Vararg{Distribution}}}) where {K}
Create a distribution of `NamedTuple`s as a product distribution of independent named
distributions.
The function falls back to constructing a [`ProductNamedTupleDistribution`](@ref)
distribution but specialized methods can be defined.
"""
function product_distribution(
dists::NamedTuple{<:Any,<:Tuple{Distribution,Vararg{Distribution}}}
)
return ProductNamedTupleDistribution(dists)
end

# Properties

Base.eltype(::Type{<:ProductNamedTupleDistribution{<:Any,<:Any,<:Any,T}}) where {T} = T

Base.minimum(d::ProductNamedTupleDistribution) = map(minimum, d.dists)

Base.maximum(d::ProductNamedTupleDistribution) = map(maximum, d.dists)

function _named_fields_match(x::NamedTuple{K}, y::NamedTuple) where {K}
length(x) == length(y) || return false
try
NamedTuple{K}(y)
return true
catch
return false
end
end

function insupport(dist::ProductNamedTupleDistribution{K}, x::NamedTuple) where {K}
return (
_named_fields_match(dist.dists, x) &&
all(map(insupport, dist.dists, NamedTuple{K}(x)))
)
end

# Evaluation

function pdf(dist::ProductNamedTupleDistribution, x::NamedTuple)
return exp(logpdf(dist, x))
end

function logpdf(dist::ProductNamedTupleDistribution{K}, x::NamedTuple) where {K}
return sum(map(logpdf, dist.dists, NamedTuple{K}(x)))
end

function loglikelihood(dist::ProductNamedTupleDistribution, x::NamedTuple)
return logpdf(dist, x)
end

function loglikelihood(dist::ProductNamedTupleDistribution, xs::AbstractArray{<:NamedTuple})
return sum(Base.Fix1(loglikelihood, dist), xs)
end

# Statistics

mode(d::ProductNamedTupleDistribution) = map(mode, d.dists)

mean(d::ProductNamedTupleDistribution) = map(mean, d.dists)

var(d::ProductNamedTupleDistribution) = map(var, d.dists)

std(d::ProductNamedTupleDistribution) = map(std, d.dists)

entropy(d::ProductNamedTupleDistribution) = sum(entropy, values(d.dists))

function kldivergence(
d1::ProductNamedTupleDistribution{K}, d2::ProductNamedTupleDistribution
) where {K}
_named_fields_match(d1.dists, d2.dists) || throw(
ArgumentError(
"Sets of named tuple fields are not the same: !issetequal($(keys(d1.dists)), $(keys(d2.dists)))",
),
)
return sum(map(kldivergence, d1.dists, NamedTuple{K}(d2.dists)))
end

# Sampling

function sampler(d::ProductNamedTupleDistribution{K,<:Any,S}) where {K,S}
samplers = map(sampler, d.dists)
Tsamplers = typeof(values(samplers))
return ProductNamedTupleSampler{K,Tsamplers,S}(samplers)
end

function Base.rand(rng::AbstractRNG, d::ProductNamedTupleDistribution{K}) where {K}
return NamedTuple{K}(map(Base.Fix1(rand, rng), d.dists))
end
function Base.rand(
rng::AbstractRNG, d::ProductNamedTupleDistribution{K}, dims::Dims
) where {K}
return convert(AbstractArray{<:NamedTuple{K}}, _rand(rng, sampler(d), dims))
end

function _rand!(rng::AbstractRNG, d::ProductNamedTupleDistribution, xs::AbstractArray)
return _rand!(rng, sampler(d), xs)
end
4 changes: 3 additions & 1 deletion src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ for fname in ["aliastable.jl",
"vonmises.jl",
"vonmisesfisher.jl",
"discretenonparametric.jl",
"categorical.jl"]
"categorical.jl",
"productnamedtuple.jl",
]

include(joinpath("samplers", fname))
end
23 changes: 23 additions & 0 deletions src/samplers/productnamedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
struct ProductNamedTupleSampler{Tnames,Tsamplers,S<:ValueSupport} <:
Sampleable{NamedTupleVariate{Tnames},S}
samplers::NamedTuple{Tnames,Tsamplers}
end

function Base.rand(rng::AbstractRNG, spl::ProductNamedTupleSampler{K}) where {K}
return NamedTuple{K}(map(Base.Fix1(rand, rng), spl.samplers))
end

function _rand(rng::AbstractRNG, spl::ProductNamedTupleSampler, dims::Dims)
return map(CartesianIndices(dims)) do _
return rand(rng, spl)
end
end

function _rand!(
rng::AbstractRNG, spl::ProductNamedTupleSampler, xs::AbstractArray{<:NamedTuple{K}}
) where {K}
for i in eachindex(xs)
xs[i] = NamedTuple{K}(rand(rng, spl))
end
return xs
end
Loading

0 comments on commit 07d1e78

Please sign in to comment.