Skip to content

Commit

Permalink
revert test
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Aug 23, 2021
1 parent 924f393 commit 1ac00fc
Showing 1 changed file with 0 additions and 25 deletions.
25 changes: 0 additions & 25 deletions pymc3/tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,6 @@ def test_leaf_node():
assert leaf_node.get_idx_right_child() == 12


def test_model():
X = np.linspace(7, 15, 100)
Y = np.sin(np.random.normal(X, 0.2)) + 3
X = X[:, None]

with pm.Model() as model:
sigma = pm.HalfNormal("sigma", 1)
mu = pm.BART("mu", X, Y, m=50)
y = pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample(chains=4)
mean = idata.posterior["mu"].stack(samples=("chain", "draw")).mean("samples")

np.testing.assert_allclose(mean, Y, 0.5)

Y = np.repeat([0, 1], 50)
with pm.Model() as model:
mu_ = pm.BART("mu_", X, Y, m=50)
mu = pm.Deterministic("mu", pm.math.invlogit(mu_))
y = pm.Bernoulli("y", mu, observed=Y)
idata = pm.sample(chains=4)
mean = idata.posterior["mu"].stack(samples=("chain", "draw")).mean("samples")

np.testing.assert_allclose(mean, Y, atol=0.5)


def test_bart_vi():
X = np.random.normal(0, 1, size=(3, 250)).T
Y = np.random.normal(0, 1, size=250)
Expand Down

0 comments on commit 1ac00fc

Please sign in to comment.