Skip to content

Commit

Permalink
Add moment for Simulator distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and twiecki committed Nov 19, 2021
1 parent ac5126b commit 64d8396
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
12 changes: 11 additions & 1 deletion pymc/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from scipy.spatial import cKDTree

from pymc.aesaraf import floatX
from pymc.distributions.distribution import NoDistribution
from pymc.distributions.distribution import NoDistribution, _get_moment

__all__ = ["Simulator"]

Expand Down Expand Up @@ -223,13 +223,23 @@ def logp(op, value_var_list, *dist_params, **kwargs):
value_var = value_var_list[0]
return cls.logp(value_var, op, dist_params)

@_get_moment.register(SimulatorRV)
def get_moment(op, rv, size, *rv_inputs):
return cls.get_moment(rv, size, *rv_inputs)

cls.rv_op = sim_op
return super().__new__(cls, name, *params, **kwargs)

@classmethod
def dist(cls, *params, **kwargs):
return super().dist(params, **kwargs)

@classmethod
def get_moment(cls, rv, size, *sim_inputs):
# Take the mean of 10 draws
multiple_sim = rv.owner.op(*sim_inputs, size=at.concatenate([[10], rv.shape]))
return at.mean(multiple_sim, axis=0)

@classmethod
def logp(cls, value, sim_op, sim_inputs):
# Use a new rng to avoid non-randomness in parallel sampling
Expand Down
40 changes: 40 additions & 0 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import aesara
import numpy as np
import pytest
import scipy.stats as st

from aesara import tensor as at
from scipy import special

import pymc as pm

from pymc import Simulator
from pymc.distributions import (
AsymmetricLaplace,
Bernoulli,
Expand Down Expand Up @@ -1074,3 +1076,41 @@ def test_zero_inflated_negative_binomial_moment(psi, mu, alpha, size, expected):
with Model() as model:
ZeroInflatedNegativeBinomial("x", psi=psi, mu=mu, alpha=alpha, size=size)
assert_moment_is_expected(model, expected)


@pytest.mark.parametrize("mu", [0, np.arange(3)], ids=str)
@pytest.mark.parametrize("sigma", [1, np.array([1, 2, 5])], ids=str)
@pytest.mark.parametrize("size", [None, 3, (5, 3)], ids=str)
def test_simulator_moment(mu, sigma, size):
def normal_sim(rng, mu, sigma, size):
return rng.normal(mu, sigma, size=size)

with Model() as model:
x = Simulator("x", normal_sim, mu, sigma, size=size)

fn = make_initial_point_fn(
model=model,
return_transformed=False,
default_strategy="moment",
)

random_draw = model["x"].eval()
result = fn(0)["x"]
assert result.shape == random_draw.shape

# We perform a z-test between the moment and expected mean from a sample of 10 draws
# This test fails if the number of samples averaged in get_moment(Simulator)
# is much smaller than 10, but would not catch the case where the number of samples
# is higher than the expected 10

n = 10 # samples
expected_sample_mean = mu
expected_sample_mean_std = np.sqrt(sigma ** 2 / n)

# Multiple test adjustment for z-test to maintain alpha=0.01
alpha = 0.01
alpha /= 2 * 2 * 3 # Correct for number of test permutations
alpha /= random_draw.size # Correct for distribution size
cutoff = st.norm().ppf(1 - (alpha / 2))

assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff)

0 comments on commit 64d8396

Please sign in to comment.