From 7b0d39048f043dcd30323eb84127137e95a1b912 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 13 Sep 2022 06:40:12 -0700 Subject: [PATCH] factor out exponential-families and tweedie --- Project.toml | 2 - src/MeasureTheory.jl | 1 - src/combinators/exponential-families.jl | 119 ------------------ src/combinators/tweedie.jl | 159 ------------------------ 4 files changed, 281 deletions(-) delete mode 100644 src/combinators/exponential-families.jl delete mode 100644 src/combinators/tweedie.jl diff --git a/Project.toml b/Project.toml index 26a3a666..d26cc4d2 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,6 @@ Infinities = "e1ba4f0e-776d-440f-acd9-e1d2e9742647" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" KeywordCalls = "4d827475-d3e4-43d6-abe3-9688362ede9f" -LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" @@ -55,7 +54,6 @@ IfElse = "0.1" Infinities = "0.1" InverseFunctions = "0.1" KeywordCalls = "0.2" -LazyArrays = "0.22" LogExpFunctions = "0.3.3" MLStyle = "0.4" MacroTools = "0.5" diff --git a/src/MeasureTheory.jl b/src/MeasureTheory.jl index e43642d7..86108fbc 100644 --- a/src/MeasureTheory.jl +++ b/src/MeasureTheory.jl @@ -130,7 +130,6 @@ include("combinators/affine.jl") include("combinators/weighted.jl") include("combinators/product.jl") include("combinators/transforms.jl") -include("combinators/exponential-families.jl") include("resettable-rng.jl") include("realized.jl") include("combinators/chain.jl") diff --git a/src/combinators/exponential-families.jl b/src/combinators/exponential-families.jl deleted file mode 100644 index e3e169a7..00000000 --- a/src/combinators/exponential-families.jl +++ /dev/null @@ -1,119 +0,0 @@ -export ExponentialFamily -using LazyArrays - -@concrete terse struct ExponentialFamily <: AbstractTransitionKernel - support_contains - base - mdim - pdim - t - x - a -end - -MeasureBase.insupport(fam::ExponentialFamily, x) = fam.support_contains(x) - -function ExponentialFamily(support_contains, base, mdim, pdim, t, a) - return ExponentialFamily(support_contains, base, mdim, pdim, t, I, a) -end - -function MeasureBase.powermeasure(fam::ExponentialFamily, dims::NTuple) - support_contains(x) = all(xj -> fam.support_contains(xj), x) - t = Tuple((y -> f.(y) for f in fam.t)) - a(η) = LazyArrays.BroadcastArray(fam.a, η) - p = prod(dims) - ExponentialFamily( - support_contains, - fam.base^dims, - fam.mdim * p, - fam.pdim * p, - t, - fam.x, - a, - ) -end - -powermeasure(fam::ExponentialFamily, ::Tuple{}) = fam - -@concrete terse struct ExpFamMeasure <: AbstractMeasure - fam - η # instantiated to a value - a # instantiated to a value -end - -MeasureBase.insupport(μ::ExpFamMeasure, x) = μ.fam.support_contains(x) - -@inline function (fam::ExponentialFamily)(β) - η = fam.x * β - a = fam.a(η) - ExpFamMeasure(fam, η, a) -end - -MeasureBase.basemeasure(d::ExpFamMeasure) = d.fam.base - -tracedot(a::AbstractVector, b::AbstractVector) = dot(a, b) - -tracedot(a::AbstractVector, x, b::AbstractVector) = dot(a, x, b) - -tracedot(a, b) = sum((dot(view(a, :, j), view(b, :, j)) for j in 1:size(a, 2))) - -tracedot(a, x, b) = - sum(1:size(a, 2)) do j - dot(view(a, :, j), x, view(b, :, j)) - end - -# @inline function tracedot(a::BlockDiag, b::BlockDiag) -# numblocks = length(a.blocks) -# sum(tracedot(a.blocks[j], b.blocks[j]) for j in 1:length(a.blocks)) -# end - -# @inline function tracedot(a::BlockDiag, x::BlockDiag, b::BlockDiag) -# numblocks = length(x.blocks) -# sum(tracedot(a.blocks[j], x.blocks[j], b.blocks[j]) for j in 1:length(x.blocks)) -# end - -function logdensity_def(d::ExpFamMeasure, y) - t = ApplyArray(vcat, (f.(y) for f in d.fam.t)...) - η = d.η - dot(t, η) -end - -function withX(fam::ExponentialFamily, x) - @inline t(y) = fam.t.(y) - newx = ApplyArray(kron, x, fam.x) - η(β) = fam.η.(β) - a(β) = sum(fam.a, β) - ExponentialFamily(fam.base^size(x, 1), t, x, η, a) -end - -@concrete terse struct ExpFamLikelihood <: AbstractLikelihood - fam - y - tᵀx - c -end - -# function regression(fam, uᵀ, vᵀ) - -# end - -function MeasureBase.likelihoodof(fam::ExponentialFamily, y) - c = logdensityof(fam.base, y) - t = ApplyArray(vcat, (f.(y) for f in fam.t)...) - tᵀx = t' * fam.x - ExpFamLikelihood(fam, y, tᵀx, c) -end - -@inline function logdensityof(ℓ::ExpFamLikelihood, β) - xβ = ApplyArray(*, ℓ.fam.x, β) - a = sum(ℓ.fam.a(xβ)) - # a = sum(ℓ.fam.a, ApplyArray(*, ℓ.fam.uᵀ', ℓ.fam.vᵀ, β)) - ℓ.c + dot(ℓ.tᵀx, β) - a -end - -basemeasure(fam::ExponentialFamily) = fam.base - -# function stack_functions(funs, inds) -# function(x::AbstractArray{T,N}) where {T,N} -# ApplyArray(cat, ) -# end diff --git a/src/combinators/tweedie.jl b/src/combinators/tweedie.jl deleted file mode 100644 index b18a3fb8..00000000 --- a/src/combinators/tweedie.jl +++ /dev/null @@ -1,159 +0,0 @@ -export Tweedie - -abstract type AbstractEDM <: AbstractTransitionKernel end - -""" -From https://en.wikipedia.org/wiki/Tweedie_distribution: - -> The Tweedie distributions include a number of familiar distributions as well as -some unusual ones, each being specified by the domain of the index parameter. We -have the - - extreme stable distribution, p < 0, - normal distribution, p = 0, - Poisson distribution, p = 1, - compound Poisson-gamma distribution, 1 < p < 2, - gamma distribution, p = 2, - positive stable distributions, 2 < p < 3, - Inverse Gaussian distribution, p = 3, - positive stable distributions, p > 3, and - extreme stable distributions, p = ∞. - -For 0 < p < 1 no Tweedie model exists. Note that all stable distributions mean -actually generated by stable distributions. -""" -struct Tweedie{S,B,D,P} <: AbstractEDM - support_contains::S - base::B - dim::D - p::P -end - -struct TweedieMeasure{B,Θ,P,S,C} <: AbstractMeasure - fam::Tweedie{B,Θ,P} - θ::Θ - σ::S - cumulant::C -end - -mean(d::TweedieMeasure) = tweedie_mean(d.fam.p, d.θ) - -var(d::TweedieMeasure) = d.σ^2 * mean(d)^d.fam.p - -############################################################################### -# Tweedie cumulants - -@inline function tweedie_cumulant(p::P, θ) where {P} - if p == zero(P) - return 0.5 * θ^2 - elseif p == one(P) - return exp(θ) - elseif p == 2 - return -log(-θ) - else - α = (p - 2) / (p - 1) - coeff = (α - 1) / α - return coeff * (θ / (α - 1))^α - end -end - -@inline function tweedie_cumulant(::StaticFloat64{0.0}, θ) - return 0.5 * θ^2 -end - -@inline function tweedie_cumulant(::StaticFloat64{1.0}, θ) - return exp(θ) -end - -@inline function tweedie_cumulant(::StaticFloat64{2.0}, θ) - return -log(-θ) -end - -@generated function tweedie_cumulant(::StaticFloat64{p}, θ) where {p} - α = (p - 2) / (p - 1) - coeff = (α - 1) / α - - quote - $(Expr(:meta, :inline)) - coeff * (θ / (α - 1))^α - end -end - -@inline function (fam::Tweedie)(par) - base = fam.base(par.σ) - θ = fam.θ(par) - η = fam.η(θ) - t = fam.t - a = fam.a(θ) - TweedieMeasure(base, θ, p, σ) -end - -############################################################################### -# Tweedie mean function - -@inline function tweedie_mean(p::P, θ) where {P} - if p == zero(P) - return θ - elseif p == one(P) - return exp(θ) - elseif p == 2 - return inv(log(-θ)) - else - α_minus_1 = (p - 2) / (p - 1) - 1 - return (θ / α_minus_1)^α_minus_1 - end -end - -@inline function tweedie_mean(::StaticFloat64{0.0}, θ) - return θ -end - -@inline function tweedie_mean(::StaticFloat64{1.0}, θ) - return exp(θ) -end - -@inline function tweedie_mean(::StaticFloat64{2.0}, θ) - return inv(log(-θ)) -end - -@generated function tweedie_mean(::StaticFloat64{p}, θ) where {p} - α_minus_1 = (p - 2) / (p - 1) - 1 - - quote - $(Expr(:meta, :inline)) - (θ / α_minus_1)^α_minus_1 - end -end - -basemeasure(d::TweedieMeasure) = d.base - -function logdensity_def(d::TweedieMeasure, x) - mydot(x, d.θ) - d.cumulant -end - -function MeasureBase.powermeasure(fam::Tweedie, dims::NTuple{N,I}) where {N,I} - base(σ) = fam.base(σ)^dims - a = AffineTransform((σ = prod(dims),)) ∘ fam.a - Tweedie(fam.base^dims, fam.θ, fam.η, t, a) -end - -struct TweedieLikelihood{C,Θ,H,T,A} <: AbstractLikelihood - c::C - θ::Θ - η::H - t::T - a::A -end - -function likelihoodof(fam::Tweedie, x) - c = logdensityof(fam.base, x) - t = fam.t(x) - TweedieLikelihood(c, fam.θ, fam.η, t, fam.a) -end - -@inline function logdensity_def(ℓ::TweedieLikelihood, par) - θ = ℓ.θ(par) - mydot(θ, ℓ.t) - ℓ.a(θ) + ℓ.c -end - -basemeasure(fam::Tweedie) = fam.base