Skip to content

Commit

Permalink
add size argument and check for NoDistribution
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Aug 17, 2021
1 parent e98ebd8 commit 04c6800
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
40 changes: 24 additions & 16 deletions pymc3/distributions/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,35 @@ class BARTRV(RandomVariable):
def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)

def __new__(cls, *args, **kwargs):
return super().__new__(cls)

@classmethod
def rng_fn(cls, rng=np.random.default_rng(), X_new=None, *args, **kwargs):
def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs):
size = kwargs.pop("size", None)
X_new = kwargs.pop("X_new", None)
all_trees = cls.all_trees
if all_trees:
# this should be rng.integers() but when sampling from the prior/posterior predictive
# I get 'numpy.random.mtrand.RandomState' object has no attribute 'integers'
# So I guess those functions need to be updated
idx = np.random.randint(len(all_trees))
trees = all_trees[idx]

if size is None:
size = ()
elif isinstance(size, int):
size = [size]

flatten_size = 1
for s in size:
flatten_size *= s

idx = rng.randint(len(all_trees), size=flatten_size)

if X_new is None:
pred = np.zeros(trees[0].num_observations)
for tree in trees:
pred += tree.predict_output()
pred = np.zeros((flatten_size, all_trees[0][0].num_observations))
for ind, p in enumerate(pred):
for tree in all_trees[idx[ind]]:
p += tree.predict_output()
else:
pred = np.zeros(X_new.shape[0])
for tree in trees:
pred += np.array([tree.predict_out_of_sample(x) for x in X_new])
return pred
pred = np.zeros((flatten_size, X_new.shape[0]))
for ind, p in enumerate(pred):
for tree in all_trees[idx[ind]]:
p += np.array([tree.predict_out_of_sample(x) for x in X_new])
return pred.reshape((*size, -1))
else:
return np.full_like(cls.Y, cls.Y.mean())

Expand Down
7 changes: 5 additions & 2 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from pymc3.backends.base import BaseTrace, MultiTrace
from pymc3.backends.ndarray import NDArray
from pymc3.blocking import DictToArrayBijection
from pymc3.distributions.bart import BARTRV
from pymc3.distributions import NoDistribution
from pymc3.exceptions import IncorrectArgumentsError, SamplingError
from pymc3.model import Model, Point, modelcontext
from pymc3.parallel_sampling import Draw, _cpu_count
Expand Down Expand Up @@ -240,7 +240,10 @@ def all_continuous(vars, model):

if any(
[
(var.dtype in discrete_types or isinstance(model.values_to_rvs[var].owner.op, BARTRV))
(
var.dtype in discrete_types
or isinstance(model.values_to_rvs[var].owner.op, NoDistribution)
)
for var in vars_
]
):
Expand Down

0 comments on commit 04c6800

Please sign in to comment.