Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.19 #240

Merged
merged 84 commits into from
Aug 22, 2023
Merged

0.19 #240

Show file tree
Hide file tree
Changes from 71 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
c78c166
First steps toward factoring out TransformVariables
cscherrer Sep 12, 2022
7b0d390
factor out exponential-families and tweedie
cscherrer Sep 13, 2022
3df9cfe
Merge branch 'factor-out-measurefamilies' into dev
cscherrer Sep 13, 2022
f168d89
AffinePushfwd <: AbstractPushforward
cscherrer Sep 13, 2022
0e2d0b7
don't export `mydot`
cscherrer Sep 20, 2022
d662ddb
moving things around
cscherrer Sep 20, 2022
8bc7361
drop TV functionality
cscherrer Sep 21, 2022
64e9ec6
bugfix
cscherrer Sep 23, 2022
a5a14bb
bugfix
cscherrer Sep 23, 2022
d12af8b
format
cscherrer Sep 28, 2022
f82c121
factor out Chain (let's move it to a new package)
cscherrer Sep 28, 2022
ce30ebe
fix type instability
cscherrer Sep 28, 2022
d253c2e
update Bernoulli
cscherrer Oct 12, 2022
4d35766
import massof
cscherrer Oct 14, 2022
df78a90
bugfix
cscherrer Oct 14, 2022
8a7aec7
add some methods
cscherrer Oct 14, 2022
38fc6c5
Revert "First steps toward factoring out TransformVariables"
cscherrer Oct 14, 2022
dadd081
smfinv => invsmf
cscherrer Oct 14, 2022
6af4782
update MeasureBase version
cscherrer Oct 17, 2022
c35ab1c
add some smfs
cscherrer Oct 17, 2022
464a2f1
Merge branch 'dev' of https://github.com/cscherrer/MeasureTheory.jl i…
cscherrer Oct 17, 2022
db8411f
edits
cscherrer Oct 17, 2022
f498cf3
Merge branch 'dev' of https://github.com/cscherrer/MeasureTheory.jl i…
cscherrer Oct 17, 2022
620439e
fix typo
cscherrer Oct 17, 2022
770cedf
fix laplace
cscherrer Oct 17, 2022
54ca9a4
fix bernoulli
cscherrer Oct 17, 2022
f94a629
fix beta
cscherrer Oct 17, 2022
6fc3132
fix StudentT
cscherrer Oct 17, 2022
c8db4c1
Fix Uniform
cscherrer Oct 17, 2022
f54ab64
fix Gumbel
cscherrer Oct 17, 2022
49d8693
fix Dirichlet
cscherrer Oct 17, 2022
8f98e9e
minor Dirichlet update
cscherrer Oct 17, 2022
b8d175e
more Dirichlet
cscherrer Oct 17, 2022
2ce5203
drop as(m::DistributionMeasures.DistributionMeasure) (for now)
cscherrer Oct 17, 2022
01a843d
drop tests for Chain
cscherrer Oct 17, 2022
6b44634
drop DistributionMeasures dependency (for now)
cscherrer Oct 17, 2022
2d06b4d
fix cauchy
cscherrer Oct 17, 2022
12243eb
fixing Binomial
cscherrer Oct 17, 2022
d0b2fbf
small updates
cscherrer Oct 17, 2022
ca7ec79
Merge branch 'dev' of https://github.com/cscherrer/MeasureTheory.jl i…
cscherrer Oct 17, 2022
e74d29c
update
cscherrer Oct 18, 2022
afc7515
updates
cscherrer Oct 18, 2022
0a1c7cb
drop brokwn measures
cscherrer Oct 18, 2022
491d081
Merge branch 'dev' of https://github.com/cscherrer/MeasureTheory.jl i…
cscherrer Oct 18, 2022
9bbfd41
fix normal
cscherrer Oct 18, 2022
4249638
asparams methods
cscherrer Oct 18, 2022
c929db6
update tests
cscherrer Oct 18, 2022
ff27404
smfAD macro
cscherrer Oct 19, 2022
db44529
update tests
cscherrer Oct 19, 2022
365247a
smf(::Normal)
cscherrer Oct 19, 2022
44f791a
move import
cscherrer Oct 19, 2022
eae4be3
add ForwardDiff
cscherrer Oct 19, 2022
a37900b
Merge branch 'dev' of https://github.com/cscherrer/MeasureTheory.jl i…
cscherrer Oct 19, 2022
ab33e86
dep constraint
cscherrer Oct 19, 2022
5972297
fix tests
cscherrer Oct 19, 2022
dfb420c
smf for affine pushfwd
cscherrer Oct 19, 2022
e7ad315
update tests
cscherrer Oct 19, 2022
f540634
update test name
cscherrer Oct 19, 2022
1db8600
tests
cscherrer Oct 19, 2022
3b4b58b
tests
cscherrer Oct 19, 2022
eec1544
update StudentT
cscherrer Oct 19, 2022
8f1c106
tests
cscherrer Oct 19, 2022
d40fc32
Some `For` updates
cscherrer Oct 20, 2022
ce53774
optimize `Normal`
cscherrer Oct 20, 2022
75656cc
update insupport for Normal
cscherrer Oct 21, 2022
81b9fa3
update basemeasure for For measures
cscherrer Oct 21, 2022
90a8dca
some affine updates
cscherrer Oct 21, 2022
8fcd05b
some fixes
cscherrer Oct 24, 2022
2544517
bump version
cscherrer Oct 27, 2022
516e237
formatting
cscherrer Oct 31, 2022
bfbe8c5
oops didn't mean to add that
cscherrer Oct 31, 2022
5a25a51
format
cscherrer Nov 3, 2022
2f0fa26
MeasureBase 0.14 + other updates
cscherrer Aug 21, 2023
038fabb
compatibility fixes
cscherrer Aug 21, 2023
eae6f85
updates
cscherrer Aug 21, 2023
86f58f9
more fixes
cscherrer Aug 21, 2023
8c1c1c0
fixes
cscherrer Aug 21, 2023
203d8c4
more fixes
cscherrer Aug 21, 2023
021d429
tests passing
cscherrer Aug 21, 2023
b307372
bump version
cscherrer Aug 21, 2023
07074d6
drop breakage CI for now
cscherrer Aug 22, 2023
744c4a4
Merge branch 'cs-mb-0.14' into dev
cscherrer Aug 22, 2023
f86085e
update Julia versions for CI
cscherrer Aug 22, 2023
9a3e8c7
Drop commented-out code
cscherrer Aug 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MeasureTheory"
uuid = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
authors = ["Chad Scherrer <[email protected]> and contributors"]
version = "0.18.1"
version = "0.19.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -10,16 +10,15 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
DistributionMeasures = "35643b39-bfd4-4670-843f-16596ca89bf3"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicIterators = "6c76993d-992e-5bf1-9e63-34920a5a5a38"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
Infinities = "e1ba4f0e-776d-440f-acd9-e1d2e9742647"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
KeywordCalls = "4d827475-d3e4-43d6-abe3-9688362ede9f"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Expand Down Expand Up @@ -47,20 +46,18 @@ Compat = "3.42, 4"
ConcreteStructs = "0.2"
ConstructionBase = "1.3"
DensityInterface = "0.4"
DistributionMeasures = "0.2"
Distributions = "0.25"
DynamicIterators = "0.4"
FillArrays = "0.12, 0.13"
ForwardDiff = "0.10"
IfElse = "0.1"
Infinities = "0.1"
InverseFunctions = "0.1"
KeywordCalls = "0.2"
LazyArrays = "0.22"
LogExpFunctions = "0.3.3"
MLStyle = "0.4"
MacroTools = "0.5"
MappedArrays = "0.4"
MeasureBase = "0.13"
MeasureBase = "0.14"
NamedTupleTools = "0.13, 0.14"
PositiveFactorizations = "0.2"
PrettyPrinting = "0.3, 0.4"
Expand Down
28 changes: 16 additions & 12 deletions src/MeasureTheory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ using MLStyle
import TransformVariables
const TV = TransformVariables

using DistributionMeasures
# using DistributionMeasures
using TransformVariables: asℝ₊, as𝕀, asℝ, transform

import Base
import Distributions
const Dists = Distributions
# import Distributions
# const Dists = Distributions
cscherrer marked this conversation as resolved.
Show resolved Hide resolved

export TV
export transform
Expand All @@ -27,14 +27,12 @@ export Lebesgue
export ℝ, ℝ₊, 𝕀
export ⊙
export SpikeMixture
export CountingMeasure
export TrivialMeasure
export Likelihood
export testvalue
export basekernel

using Infinities
using DynamicIterators
using KeywordCalls
using ConstructionBase
using Accessors
Expand All @@ -58,9 +56,12 @@ import MeasureBase:
paramnames,
∫,
𝒹,
∫exp
∫exp,
smf,
invsmf,
massof
import MeasureBase: ≪
using MeasureBase: BoundedInts, BoundedReals, CountingMeasure, IntegerDomain, IntegerNumbers
using MeasureBase: BoundedInts, BoundedReals, CountingBase, IntegerDomain, IntegerNumbers
using MeasureBase: weightedmeasure, restrict
using MeasureBase: AbstractTransitionKernel

Expand Down Expand Up @@ -97,13 +98,16 @@ using MeasureBase: kernel
using MeasureBase: Returns
import MeasureBase: proxy, @useproxy
import MeasureBase: basemeasure_depth
using MeasureBase: LebesgueMeasure
using MeasureBase: LebesgueBase

import DensityInterface: logdensityof
import DensityInterface: densityof
import DensityInterface: DensityKind
using DensityInterface

using ForwardDiff
using ForwardDiff: Dual

gentype(μ::AbstractMeasure) = typeof(testvalue(μ))

# gentype(μ::AbstractMeasure) = gentype(basemeasure(μ))
Expand All @@ -117,20 +121,18 @@ xlogy(x, y) = x * log(y)
xlog1py(x::Number, y::Number) = LogExpFunctions.xlog1py(x, y)
xlog1py(x, y) = x * log(1 + y)

as(args...; kwargs...) = TV.as(args...; kwargs...)
using MeasureBase: Φ, Φinv

include("utils.jl")
include("const.jl")
include("combinators/for.jl")
# include("traits.jl")
include("parameterized.jl")

include("macros.jl")
include("combinators/affine.jl")
include("combinators/weighted.jl")
include("combinators/product.jl")
include("combinators/transforms.jl")
include("combinators/exponential-families.jl")
# include("combinators/exponential-families.jl")
include("resettable-rng.jl")
include("realized.jl")
include("combinators/chain.jl")
Expand Down Expand Up @@ -165,5 +167,7 @@ include("combinators/ifelse.jl")
include("transforms/corrcholesky.jl")
include("transforms/ordered.jl")

include("parameterized.jl")

include("distproxy.jl")
end # module
91 changes: 73 additions & 18 deletions src/combinators/affine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ logjac(f::AffineTransform{(:μ, :σ)}) = logjac(f.σ)
logjac(f::AffineTransform{(:μ, :λ)}) = -logjac(f.λ)
logjac(f::AffineTransform{(:σ,)}) = logjac(f.σ)
logjac(f::AffineTransform{(:λ,)}) = -logjac(f.λ)
logjac(f::AffineTransform{(:μ,)}) = 0.0
logjac(f::AffineTransform{(:μ,)}) = static(0.0)

###############################################################################

Expand Down Expand Up @@ -161,7 +161,7 @@ basemeasure(d::OrthoLebesgue) = d

logdensity_def(::OrthoLebesgue, x) = static(0.0)

struct AffinePushfwd{N,M,T} <: AbstractMeasure
struct AffinePushfwd{N,M,T} <: MeasureBase.AbstractPushforward
f::AffineTransform{N,T}
parent::M
end
Expand All @@ -172,6 +172,10 @@ function Pretty.tile(d::AffinePushfwd)
Pretty.list_layout([pars, Pretty.tile(d.parent)]; prefix = :AffinePushfwd)
end

@inline MeasureBase.transport_origin(d::AffinePushfwd) = d.parent
@inline MeasureBase.to_origin(d::AffinePushfwd, y) = inverse(getfield(d, :f))(y)
@inline MeasureBase.from_origin(d::AffinePushfwd, x) = getfield(d, :f)(x)

@inline function testvalue(d::AffinePushfwd)
f = getfield(d, :f)
z = testvalue(parent(d))
Expand Down Expand Up @@ -282,7 +286,7 @@ end
@inline function basemeasure(d::AffinePushfwd{N,L}) where {N,L<:Lebesgue}
weightedmeasure(-logjac(d), d.parent)
end
@inline function basemeasure(d::AffinePushfwd{N,L}) where {N,L<:LebesgueMeasure}
@inline function basemeasure(d::AffinePushfwd{N,L}) where {N,L<:LebesgueBase}
weightedmeasure(-logjac(d), d.parent)
end

Expand Down Expand Up @@ -313,19 +317,6 @@ supportdim(nt::NamedTuple{(:σ,),T}) where {T} = colsize(nt.σ)
supportdim(nt::NamedTuple{(:λ,),T}) where {T} = rowsize(nt.λ)
supportdim(nt::NamedTuple{(:μ,),T}) where {T} = size(nt.μ)

asparams(::AffinePushfwd, ::StaticSymbol{:μ}) = asℝ
asparams(::AffinePushfwd, ::StaticSymbol{:σ}) = asℝ₊
asparams(::Type{A}, ::StaticSymbol{:μ}) where {A<:AffinePushfwd} = asℝ
asparams(::Type{A}, ::StaticSymbol{:σ}) where {A<:AffinePushfwd} = asℝ₊

function asparams(d::AffinePushfwd{N,M,T}, ::StaticSymbol{:μ}) where {N,M,T<:AbstractArray}
as(Array, asℝ, size(d.μ))
end

function asparams(d::AffinePushfwd{N,M,T}, ::StaticSymbol{:σ}) where {N,M,T<:AbstractArray}
as(Array, asℝ, size(d.σ))
end

function Base.rand(rng::Random.AbstractRNG, ::Type{T}, d::AffinePushfwd) where {T}
z = rand(rng, T, parent(d))
f = getfield(d, :f)
Expand All @@ -336,8 +327,8 @@ end
insupport(d.parent, inverse(d.f)(x))
end

@inline function Distributions.cdf(d::AffinePushfwd, x)
cdf(parent(d), inverse(d.f)(x))
@inline function MeasureBase.smf(d::AffinePushfwd, x)
smf(parent(d), inverse(d.f)(x))
end

@inline function mean(d::AffinePushfwd)
Expand All @@ -364,3 +355,67 @@ end
@inline function std(d::AffinePushfwd{(:μ, :λ)})
std(parent(d)) / d.λ
end

###############################################################################
# smf

@inline function smf(d::AffinePushfwd{(:μ,)}, x)
f = getfield(d, :f)
smf(parent(d), inverse(f)(x))
end

@inline function smf(d::AffinePushfwd{(:μ, :σ)}, x)
f = getfield(d, :f)
p = smf(parent(d), inverse(f)(x))
d.σ > 0 ? p : one(p) - p
end

@inline function smf(d::AffinePushfwd{(:σ,)}, x)
f = getfield(d, :f)
p = smf(parent(d), inverse(f)(x))
d.σ > 0 ? p : one(p) - p
end

@inline function smf(d::AffinePushfwd{(:λ,)}, x)
f = getfield(d, :f)
p = smf(parent(d), inverse(f)(x))
d.λ > 0 ? p : one(p) - p
end

@inline function smf(d::AffinePushfwd{(:μ, :λ)}, x)
f = getfield(d, :f)
p = smf(parent(d), inverse(f)(x))
d.λ > 0 ? p : one(p) - p
end

###############################################################################
# invsmf

@inline function invsmf(d::AffinePushfwd{(:μ,)}, p)
f = getfield(d, :f)
f(invsmf(parent(d), p))
end

@inline function invsmf(d::AffinePushfwd{(:μ, :σ)}, p)
p = d.σ > 0 ? p : one(p) - p
f = getfield(d, :f)
f(invsmf(parent(d), p))
end

@inline function invsmf(d::AffinePushfwd{(:σ,)}, p)
p = d.σ > 0 ? p : one(p) - p
f = getfield(d, :f)
f(invsmf(parent(d), p))
end

@inline function invsmf(d::AffinePushfwd{(:λ,)}, p)
p = d.λ > 0 ? p : one(p) - p
f = getfield(d, :f)
f(invsmf(parent(d), p))
end

@inline function invsmf(d::AffinePushfwd{(:μ, :λ)}, p)
p = d.λ > 0 ? p : one(p) - p
f = getfield(d, :f)
f(invsmf(parent(d), p))
end
Loading