Skip to content

Commit

Permalink
Funsor based TraceEnum_ELBO implementation (#1512)
Browse files Browse the repository at this point in the history
* initial commit

* wip

* remove TraceMarkovEnum_ELBO

* pair coded

* clean up

* Make enum example work

* port tests from pyro to numpyro

* Add missing test file

* traceenum_elbo2

* pair coded

* pass more tests

* pass all tests

* organize

* lint

* fixes

* test_gradient

* fix TraceGraph_ELBO

* lint

* Revert tracegraph_elbo changes

* Address masked distribution

* revert changes at replay messenger

* refactor

* clean

* add validations

* fix comments

* lint

* fix validation

* fix

* fix enum_vars

* rm wordclouds.png

* address comments

Co-authored-by: Du Phan <[email protected]>
  • Loading branch information
ordabayevy and fehiepsi authored Jan 23, 2023
1 parent c46b0db commit 09a3e0b
Show file tree
Hide file tree
Showing 8 changed files with 2,745 additions and 7 deletions.
6 changes: 6 additions & 0 deletions numpyro/contrib/funsor/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,12 @@ def _get_batch_shape(cond_indep_stack):
def process_message(self, msg):
if msg["type"] in ["to_funsor", "to_data"]:
return super().process_message(msg)
if msg["type"] == "sample" and self.size != self.subsample_size:
plate_to_scale = msg.setdefault("plate_to_scale", {})
assert self.name not in plate_to_scale
plate_to_scale[self.name] = (
self.size / self.subsample_size if self.subsample_size else 1
)
return OrigPlateMessenger.process_message(self, msg)

def postprocess_message(self, msg):
Expand Down
9 changes: 9 additions & 0 deletions numpyro/distributions/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Normal,
Weibull,
)
from numpyro.distributions.discrete import CategoricalProbs
from numpyro.distributions.distribution import (
Delta,
Distribution,
Expand Down Expand Up @@ -146,6 +147,14 @@ def kl_divergence(p, q):
return t1 - t2 + t3


@dispatch(CategoricalProbs, CategoricalProbs)
def kl_divergence(p, q):
t = p.probs * (p.logits - q.logits)
t = jnp.where(q.probs == 0, jnp.inf, t)
t = jnp.where(p.probs == 0, 0.0, t)
return t.sum(-1)


@dispatch(Dirichlet, Dirichlet)
def kl_divergence(p, q):
# From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
Expand Down
7 changes: 7 additions & 0 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,13 @@ def process_message(self, msg):
msg["scale"] = (
self.scale if msg.get("scale") is None else self.scale * msg["scale"]
)
plate_to_scale = msg.setdefault("plate_to_scale", {})
scale = (
self.scale
if plate_to_scale.get(None) is None
else self.scale * plate_to_scale[None]
)
plate_to_scale[None] = scale


class scope(Messenger):
Expand Down
2 changes: 2 additions & 0 deletions numpyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ELBO,
RenyiELBO,
Trace_ELBO,
TraceEnum_ELBO,
TraceGraph_ELBO,
TraceMeanField_ELBO,
)
Expand Down Expand Up @@ -49,6 +50,7 @@
"SA",
"SVI",
"Trace_ELBO",
"TraceEnum_ELBO",
"TraceGraph_ELBO",
"TraceMeanField_ELBO",
]
303 changes: 300 additions & 3 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import defaultdict
from functools import partial
from collections import OrderedDict, defaultdict
from functools import partial, reduce
from operator import itemgetter
import warnings

Expand All @@ -11,10 +11,16 @@
import jax.numpy as jnp
from jax.scipy.special import logsumexp

from numpyro.distributions import ExpandedDistribution, MaskedDistribution
from numpyro.distributions.kl import kl_divergence
from numpyro.distributions.util import scale_and_mask
from numpyro.handlers import Messenger, replay, seed, substitute, trace
from numpyro.infer.util import get_importance_trace, log_density
from numpyro.infer.util import (
_without_rsample_stop_gradient,
get_importance_trace,
is_identically_one,
log_density,
)
from numpyro.ops.provenance import eval_provenance, get_provenance
from numpyro.util import _validate_model, check_model_guide_match, find_stack_level

Expand Down Expand Up @@ -710,3 +716,294 @@ def single_particle_elbo(rng_key):
else:
rng_keys = random.split(rng_key, self.num_particles)
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))


def get_importance_trace_enum(model, guide, args, kwargs, params, max_plate_nesting):
"""
(EXPERIMENTAL) Returns traces from the enumerated guide and the enumerated model that is run against it.
The returned traces also store the log probability at each site and the log measure for measure vars.
"""
import funsor
from numpyro.contrib.funsor import (
enum,
plate_to_enum_plate,
to_funsor,
trace as _trace,
)

with plate_to_enum_plate(), enum(
first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None
):
guide = substitute(guide, data=params)
with _without_rsample_stop_gradient():
guide_trace = _trace(guide).get_trace(*args, **kwargs)
model = substitute(replay(model, guide_trace), data=params)
model_trace = _trace(model).get_trace(*args, **kwargs)
guide_trace = {
name: site for name, site in guide_trace.items() if site["type"] == "sample"
}
model_trace = {
name: site for name, site in model_trace.items() if site["type"] == "sample"
}
for is_model, tr in zip((False, True), (guide_trace, model_trace)):
for name, site in tr.items():
if is_model and (site["is_observed"] or (site["name"] in guide_trace)):
site["is_measure"] = False
if "log_prob" not in site:
value = site["value"]
intermediates = site["intermediates"]
if intermediates:
log_prob = site["fn"].log_prob(value, intermediates)
else:
log_prob = site["fn"].log_prob(value)

dim_to_name = site["infer"]["dim_to_name"]
site["log_prob"] = to_funsor(
log_prob, output=funsor.Real, dim_to_name=dim_to_name
)
if site.get("is_measure", True):
# get rid off masking
base_fn = site["fn"]
batch_shape = base_fn.batch_shape
while isinstance(
base_fn, (MaskedDistribution, ExpandedDistribution)
):
base_fn = base_fn.base_dist
base_fn = base_fn.expand(batch_shape)
if intermediates:
log_measure = base_fn.log_prob(value, intermediates)
else:
log_measure = base_fn.log_prob(value)
# dice factor
if not site["infer"].get("enumerate") == "parallel":
log_measure = log_measure - funsor.ops.detach(log_measure)
site["log_measure"] = to_funsor(
log_measure, output=funsor.Real, dim_to_name=dim_to_name
)
return model_trace, guide_trace


def _partition(model_sum_deps, sum_vars):
# Construct a bipartite graph between model_sum_deps and the sum_vars
neighbors = OrderedDict([(t, []) for t in model_sum_deps.keys()])
for key, deps in model_sum_deps.items():
for dim in deps:
if dim in sum_vars:
neighbors[key].append(dim)
neighbors.setdefault(dim, []).append(key)

# Partition the bipartite graph into connected components for contraction.
components = []
while neighbors:
v, pending = neighbors.popitem()
component = OrderedDict([(v, None)]) # used as an OrderedSet
for v in pending:
component[v] = None
while pending:
v = pending.pop()
for v in neighbors.pop(v):
if v not in component:
component[v] = None
pending.append(v)

# Split this connected component into factors and measures.
# Append only if component_factors is non-empty
component_factors = frozenset(v for v in component if v not in sum_vars)
if component_factors:
component_measures = frozenset(v for v in component if v in sum_vars)
components.append((component_factors, component_measures))
return components


class TraceEnum_ELBO(ELBO):
"""
A TraceEnum implementation of ELBO-based SVI. The gradient estimator
is constructed along the lines of reference [1] specialized to the case
of the ELBO. It supports arbitrary dependency structure for the model
and guide.
Fine-grained conditional dependency information as recorded in the
trace is used to reduce the variance of the gradient estimator.
In particular provenance tracking [2] is used to find the ``cost`` terms
that depend on each non-reparameterizable sample site.
Enumerated variables are eliminated using the TVE algorithm for plated
factor graphs [3].
References
[1] `Storchastic: A Framework for General Stochastic Automatic Differentiation`,
Emile van Kriekenc, Jakub M. Tomczak, Annette ten Teije
[2] `Nonstandard Interpretations of Probabilistic Programs for Efficient Inference`,
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind
[3] `Tensor Variable Elimination for Plated Factor Graphs`,
Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu,
Neeraj Pradhan, Alexander M. Rush, Noah Goodman
"""

can_infer_discrete = True

def __init__(self, num_particles=1, max_plate_nesting=float("inf")):
if max_plate_nesting == float("inf"):
raise ValueError(
"Currently, we require `max_plate_nesting` to be a non-positive integer."
)
self.max_plate_nesting = max_plate_nesting
super().__init__(num_particles=num_particles)

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
def single_particle_elbo(rng_key):
import funsor
from numpyro.contrib.funsor import to_data, to_funsor

model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)

model_trace, guide_trace = get_importance_trace_enum(
seeded_model,
seeded_guide,
args,
kwargs,
param_map,
self.max_plate_nesting,
)
check_model_guide_match(model_trace, guide_trace)
_validate_model(model_trace, plate_warning="strict")

# Find dependencies on non-reparameterizable sample sites for
# each cost term in the model and the guide.
model_deps, guide_deps = get_provenance(
eval_provenance(
partial(
track_nonreparam(get_importance_log_probs),
seeded_model,
seeded_guide,
args,
kwargs,
param_map,
)
)
)

sum_vars = frozenset(
[
name
for name, site in model_trace.items()
if site.get("is_measure", True)
]
)
model_sum_deps = {
k: v & sum_vars for k, v in model_deps.items() if k not in sum_vars
}
model_deps = {
k: v - sum_vars for k, v in model_deps.items() if k not in sum_vars
}

elbo = 0.0
for group_names, group_sum_vars in _partition(model_sum_deps, sum_vars):
if not group_sum_vars:
# uncontracted logp cost term
assert len(group_names) == 1
name = next(iter(group_names))
cost = model_trace[name]["log_prob"]
scale = model_trace[name]["scale"]
deps = model_deps[name]
dice_factors = [guide_trace[key]["log_measure"] for key in deps]
else:
# compute contracted cost term
group_factors = tuple(
model_trace[name]["log_prob"] for name in group_names
)
group_factors += tuple(
model_trace[var]["log_measure"] for var in group_sum_vars
)
group_factor_vars = frozenset().union(
*[f.inputs for f in group_factors]
)
group_plates = group_factor_vars - frozenset(model_trace.keys())
outermost_plates = frozenset.intersection(
*(frozenset(f.inputs) & group_plates for f in group_factors)
)
elim_plates = group_plates - outermost_plates
cost = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
group_factors,
plates=group_plates,
eliminate=group_sum_vars | elim_plates,
)
# incorporate the effects of subsampling and handlers.scale through a common scale factor
group_scales = {}
for name in group_names:
for plate, value in (
model_trace[name].get("plate_to_scale", {}).items()
):
if plate in group_scales:
if value != group_scales[plate]:
raise ValueError(
"Expected all enumerated sample sites to share a common scale factor, "
f"but found different scales at plate('{plate}')."
)
else:
group_scales[plate] = value
scale = (
reduce(lambda a, b: a * b, group_scales.values())
if group_scales
else None
)
# combine deps
deps = frozenset().union(
*[model_deps[name] for name in group_names]
)
# check model guide enumeration constraint
for key in deps:
site = guide_trace[key]
if site["infer"].get("enumerate") == "parallel":
for plate in (
frozenset(site["log_measure"].inputs) & elim_plates
):
raise ValueError(
"Expected model enumeration to be no more global than guide enumeration, but found "
f"model enumeration sites upstream of guide site '{key}' in plate('{plate}')."
"Try converting some model enumeration sites to guide enumeration sites."
)
# combine dice factors
dice_factors = [
guide_trace[key]["log_measure"].reduce(
funsor.ops.add,
frozenset(guide_trace[key]["log_measure"].inputs)
& elim_plates,
)
for key in deps
]

if dice_factors:
dice_factor = reduce(lambda a, b: a + b, dice_factors)
cost = cost * funsor.ops.exp(dice_factor)
if (scale is not None) and (not is_identically_one(scale)):
cost = cost * to_funsor(scale)

elbo = elbo + cost.reduce(funsor.ops.add)

for name, deps in guide_deps.items():
# -logq cost term
cost = -guide_trace[name]["log_prob"]
scale = guide_trace[name]["scale"]
dice_factors = [guide_trace[key]["log_measure"] for key in deps]
if dice_factors:
dice_factor = reduce(lambda a, b: a + b, dice_factors)
cost = cost * funsor.ops.exp(dice_factor)
if (scale is not None) and (not is_identically_one(scale)):
cost = cost * to_funsor(scale)
elbo = elbo + cost.reduce(funsor.ops.add)

return to_data(elbo)

# Return (-elbo) since by convention we do gradient descent on a loss and
# the ELBO is a lower bound that needs to be maximized.
if self.num_particles == 1:
return -single_particle_elbo(rng_key)
else:
rng_keys = random.split(rng_key, self.num_particles)
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
Loading

0 comments on commit 09a3e0b

Please sign in to comment.