diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py index 6d4256c323..03e6d587bd 100644 --- a/pymc3/distributions/bart.py +++ b/pymc3/distributions/bart.py @@ -14,271 +14,127 @@ import numpy as np -from pandas import DataFrame, Series +from aesara.tensor.random.op import RandomVariable, default_shape_from_params from pymc3.distributions.distribution import NoDistribution -from pymc3.distributions.tree import LeafNode, SplitNode, Tree __all__ = ["BART"] -class BaseBART(NoDistribution): - def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs): - - self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y) - - super().__init__(shape=X.shape[0], dtype="float64", initval=0, *args, **kwargs) - - if self.X.ndim != 2: - raise ValueError("The design matrix X must have two dimensions") - - if self.Y.ndim != 1: - raise ValueError("The response matrix Y must have one dimension") - if self.X.shape[0] != self.Y.shape[0]: - raise ValueError( - "The design matrix X and the response matrix Y must have the same number of elements" - ) - if not isinstance(m, int): - raise ValueError("The number of trees m type must be int") - if m < 1: - raise ValueError("The number of trees m must be greater than zero") - - if alpha <= 0 or 1 <= alpha: - raise ValueError( - "The value for the alpha parameter for the tree structure " - "must be in the interval (0, 1)" - ) - - self.num_observations = X.shape[0] - self.num_variates = X.shape[1] - self.available_predictors = list(range(self.num_variates)) - self.ssv = SampleSplittingVariable(split_prior, self.num_variates) - self.m = m - self.alpha = alpha - self.trees = self.init_list_of_trees() - self.all_trees = [] - self.mean = fast_mean() - self.prior_prob_leaf_node = compute_prior_probability(alpha) - - def preprocess_XY(self, X, Y): - if isinstance(Y, (Series, DataFrame)): - Y = Y.to_numpy() - if isinstance(X, (Series, DataFrame)): - X = X.to_numpy() - missing_data = np.any(np.isnan(X)) - X = np.random.normal(X, np.std(X, 0) / 100) - return X, Y, missing_data - - def init_list_of_trees(self): - initial_value_leaf_nodes = self.Y.mean() / self.m - initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32") - list_of_trees = [] - for i in range(self.m): - new_tree = Tree.init_tree( - tree_id=i, - leaf_node_value=initial_value_leaf_nodes, - idx_data_points=initial_idx_data_points_leaf_nodes, - ) - list_of_trees.append(new_tree) - # Diff trick to speed computation of residuals. From Section 3.1 of Kapelner, A and Bleich, J. - # bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013 - # The sum_trees_output will contain the sum of the predicted output for all trees. - # When R_j is needed we subtract the current predicted output for tree T_j. - self.sum_trees_output = np.full_like(self.Y, self.Y.mean()) - - return list_of_trees - - def __iter__(self): - return iter(self.trees) - - def __repr_latex(self): - raise NotImplementedError - - def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable): - x_j = self.X[idx_data_points_split_node, idx_split_variable] - if self.missing_data: - x_j = x_j[~np.isnan(x_j)] - values = np.unique(x_j) - # The last value is never available as it would leave the right subtree empty. - return values[:-1] - - def grow_tree(self, tree, index_leaf_node): - current_node = tree.get_node(index_leaf_node) - - index_selected_predictor = self.ssv.rvs() - selected_predictor = self.available_predictors[index_selected_predictor] - available_splitting_rules = self.get_available_splitting_rules( - current_node.idx_data_points, selected_predictor - ) - # This can be unsuccessful when there are not available splitting rules - if available_splitting_rules.size == 0: - return False, None - - index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules)) - selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule] - new_split_node = SplitNode( - index=index_leaf_node, - idx_split_variable=selected_predictor, - split_value=selected_splitting_rule, - ) - - left_node_idx_data_points, right_node_idx_data_points = self.get_new_idx_data_points( - new_split_node, current_node.idx_data_points - ) - - left_node_value = self.draw_leaf_value(left_node_idx_data_points) - right_node_value = self.draw_leaf_value(right_node_idx_data_points) - - new_left_node = LeafNode( - index=current_node.get_idx_left_child(), - value=left_node_value, - idx_data_points=left_node_idx_data_points, - ) - new_right_node = LeafNode( - index=current_node.get_idx_right_child(), - value=right_node_value, - idx_data_points=right_node_idx_data_points, - ) - tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node) - - return True, index_selected_predictor - - def get_new_idx_data_points(self, current_split_node, idx_data_points): - idx_split_variable = current_split_node.idx_split_variable - split_value = current_split_node.split_value - - left_idx = self.X[idx_data_points, idx_split_variable] <= split_value - left_node_idx_data_points = idx_data_points[left_idx] - right_node_idx_data_points = idx_data_points[~left_idx] - - return left_node_idx_data_points, right_node_idx_data_points - - def get_residuals(self): - """Compute the residuals.""" - R_j = self.Y - self.sum_trees_output - return R_j - - def get_residuals_loo(self, tree): - """Compute the residuals without leaving the passed tree out.""" - R_j = self.Y - (self.sum_trees_output - tree.predict_output(self.num_observations)) - return R_j - - def draw_leaf_value(self, idx_data_points): - """Draw the residual mean.""" - R_j = self.get_residuals()[idx_data_points] - draw = self.mean(R_j) - return draw - - def predict(self, X_new): - """Compute out of sample predictions evaluated at X_new""" - trees = self.all_trees - num_observations = X_new.shape[0] - pred = np.zeros((len(trees), num_observations)) - np.random.randint(len(trees)) - for draw, trees_to_sum in enumerate(trees): - new_Y = np.zeros(num_observations) - for tree in trees_to_sum: - new_Y += [tree.predict_out_of_sample(x) for x in X_new] - pred[draw] = new_Y - return pred - - -def compute_prior_probability(alpha): +class BARTRV(RandomVariable): """ - Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)). - Taken from equation 19 in [Rockova2018]. - - Parameters - ---------- - alpha : float - - Returns - ------- - list with probabilities for leaf nodes - - References - ---------- - .. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART. - arXiv, `link `__ + Base class for BART """ - prior_leaf_prob = [0] - depth = 1 - while prior_leaf_prob[-1] < 1: - prior_leaf_prob.append(1 - alpha ** depth) - depth += 1 - return prior_leaf_prob - - -def fast_mean(): - """If available use Numba to speed up the computation of the mean.""" - try: - from numba import jit - except ImportError: - return np.mean - - @jit - def mean(a): - count = a.shape[0] - suma = 0 - for i in range(count): - suma += a[i] - return suma / count - - return mean + name = "BART" + ndim_supp = 1 + ndims_params = [2, 1, 0, 0, 0, 1] + dtype = "floatX" + _print_name = ("BART", "\\operatorname{BART}") + all_trees = None + + def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): + return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes) + + @classmethod + def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs): + size = kwargs.pop("size", None) + X_new = kwargs.pop("X_new", None) + all_trees = cls.all_trees + if all_trees: + + if size is None: + size = () + elif isinstance(size, int): + size = [size] + + flatten_size = 1 + for s in size: + flatten_size *= s + + idx = rng.randint(len(all_trees), size=flatten_size) + + if X_new is None: + pred = np.zeros((flatten_size, all_trees[0][0].num_observations)) + for ind, p in enumerate(pred): + for tree in all_trees[idx[ind]]: + p += tree.predict_output() + else: + pred = np.zeros((flatten_size, X_new.shape[0])) + for ind, p in enumerate(pred): + for tree in all_trees[idx[ind]]: + p += np.array([tree.predict_out_of_sample(x) for x in X_new]) + return pred.reshape((*size, -1)) + else: + return np.full_like(cls.Y, cls.Y.mean()) -def discrete_uniform_sampler(upper_value): - """Draw from the uniform distribution with bounds [0, upper_value).""" - return int(np.random.random() * upper_value) - - -class SampleSplittingVariable: - def __init__(self, prior, num_variates): - self.prior = prior - self.num_variates = num_variates - - if self.prior is not None: - self.prior = np.asarray(self.prior) - self.prior = self.prior / self.prior.sum() - if self.prior.size != self.num_variates: - raise ValueError( - f"The size of split_prior ({self.prior.size}) should be the " - f"same as the number of covariates ({self.num_variates})" - ) - self.enu = list(enumerate(np.cumsum(self.prior))) - def rvs(self): - if self.prior is None: - return int(np.random.random() * self.num_variates) - else: - r = np.random.random() - for i, v in self.enu: - if r <= v: - return i +bart = BARTRV() -class BART(BaseBART): +class BART(NoDistribution): """ - BART distribution. + Bayesian Additive Regression Tree distribution. Distribution representing a sum over trees Parameters ---------- X : array-like - The design matrix. + The covariate matrix. Y : array-like The response vector. m : int Number of trees alpha : float - Control the prior probability over the depth of the trees. Must be in the interval (0, 1), - altought it is recomenned to be in the interval (0, 0.5]. + Control the prior probability over the depth of the trees. Even when it can takes values in + the interval (0, 1), it is recommended to be in the interval (0, 0.5]. + k : float + Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1 + and 3. split_prior : array-like - Each element of split_prior should be in the [0, 1] interval and the elements should sum - to 1. Otherwise they will be normalized. - Defaults to None, all variable have the same a prior probability + Each element of split_prior should be in the [0, 1] interval and the elements should sum to + 1. Otherwise they will be normalized. + Defaults to None, i.e. all covariates have the same prior probability to be selected. """ - def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None): - super().__init__(X, Y, m, alpha, split_prior) + def __new__( + cls, + name, + X, + Y, + m=50, + alpha=0.25, + k=2, + split_prior=None, + **kwargs, + ): + + cls.all_trees = [] + + bart_op = type( + f"BART_{name}", + (BARTRV,), + dict( + name="BART", + all_trees=cls.all_trees, + inplace=False, + initval=Y.mean(), + X=X, + Y=Y, + m=m, + alpha=alpha, + k=k, + split_prior=split_prior, + ), + )() + + NoDistribution.register(BARTRV) + + cls.rv_op = bart_op + params = [X, Y, m, alpha, k] + return super().__new__(cls, name, *params, **kwargs) + + @classmethod + def dist(cls, *params, **kwargs): + return super().dist(params, **kwargs) diff --git a/pymc3/distributions/tree.py b/pymc3/distributions/tree.py index 8e84bd9a7c..31ed47b530 100644 --- a/pymc3/distributions/tree.py +++ b/pymc3/distributions/tree.py @@ -16,41 +16,50 @@ from copy import deepcopy +import aesara import numpy as np class Tree: """Full binary tree + A full binary tree is a tree where each node has exactly zero or two children. This structure is used as the basic component of the Bayesian Additive Regression Tree (BART) + Attributes ---------- tree_structure : dict - A dictionary that represents the nodes stored in breadth-first order, based in the array method - for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays). + A dictionary that represents the nodes stored in breadth-first order, based in the array + method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays). The dictionary's keys are integers that represent the nodes position. - The dictionary's values are objects of type SplitNode or LeafNode that represent the nodes of the tree itself. + The dictionary's values are objects of type SplitNode or LeafNode that represent the nodes + of the tree itself. num_nodes : int Total number of nodes. idx_leaf_nodes : list List with the index of the leaf nodes of the tree. idx_prunable_split_nodes : list - List with the index of the prunable splitting nodes of the tree. A splitting node is prunable if both - its children are leaf nodes. + List with the index of the prunable splitting nodes of the tree. A splitting node is + prunable if both its children are leaf nodes. tree_id : int Identifier used to get the previous tree in the ParticleGibbs algorithm used in BART. + num_observations : int + Number of observations used to fit BART. + Parameters ---------- tree_id : int, optional + num_observations : int, optional """ - def __init__(self, tree_id=0): + def __init__(self, tree_id=0, num_observations=0): self.tree_structure = {} self.num_nodes = 0 self.idx_leaf_nodes = [] self.idx_prunable_split_nodes = [] self.tree_id = tree_id + self.num_observations = num_observations def __getitem__(self, index): return self.get_node(index) @@ -77,12 +86,13 @@ def delete_node(self, index): del self.tree_structure[index] self.num_nodes -= 1 - def predict_output(self, num_observations): - output = np.zeros(num_observations) + def predict_output(self): + output = np.zeros(self.num_observations) for node_index in self.idx_leaf_nodes: current_node = self.get_node(node_index) output[current_node.idx_data_points] = current_node.value - return output + + return output.astype(aesara.config.floatX) def predict_out_of_sample(self, x): """ @@ -163,7 +173,7 @@ def init_tree(tree_id, leaf_node_value, idx_data_points): ------- """ - new_tree = Tree(tree_id) + new_tree = Tree(tree_id, len(idx_data_points)) new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points) return new_tree diff --git a/pymc3/sampling.py b/pymc3/sampling.py index f032ad1fda..e8131ce642 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -42,6 +42,7 @@ from pymc3.backends.base import BaseTrace, MultiTrace from pymc3.backends.ndarray import NDArray from pymc3.blocking import DictToArrayBijection +from pymc3.distributions import NoDistribution from pymc3.exceptions import IncorrectArgumentsError, SamplingError from pymc3.model import Model, Point, modelcontext from pymc3.parallel_sampling import Draw, _cpu_count @@ -232,13 +233,17 @@ def _print_step_hierarchy(s: Step, level=0) -> None: _log.info(">" * level + f"{s.__class__.__name__}: [{varnames}]") -def all_continuous(vars): +def all_continuous(vars, model): """Check that vars not include discrete variables or BART variables, excepting observed RVs.""" vars_ = [var for var in vars if not (var.owner and hasattr(var.tag, "observations"))] + if any( [ - (var.dtype in discrete_types or (var.owner and isinstance(var.owner.op, pm.BART))) + ( + var.dtype in discrete_types + or isinstance(model.values_to_rvs[var].owner.op, NoDistribution) + ) for var in vars_ ] ): @@ -499,7 +504,7 @@ def sample( draws += tune - if step is None and init is not None and all_continuous(model.value_vars): + if step is None and init is not None and all_continuous(model.value_vars, model): try: # By default, try to use NUTS _log.info("Auto-assigning NUTS sampler...") @@ -635,8 +640,13 @@ def sample( trace.report._t_sampling = t_sampling if "variable_inclusion" in trace.stat_names: - variable_inclusion = np.stack(trace.get_sampler_stats("variable_inclusion")).mean(0) - trace.report.variable_importance = variable_inclusion / variable_inclusion.sum() + for strace in trace._straces.values(): + for stat in strace._stats: + if "variable_inclusion" in stat: + if trace.nchains > 1: + stat["variable_inclusion"] = np.vstack(stat["variable_inclusion"]) + else: + stat["variable_inclusion"] = [np.vstack(stat["variable_inclusion"])] n_chains = len(trace.chains) _log.info( @@ -2128,7 +2138,7 @@ def init_nuts( vars = kwargs.get("vars", model.value_vars) if set(vars) != set(model.value_vars): raise ValueError("Must use init_nuts on all variables of a model.") - if not all_continuous(vars): + if not all_continuous(vars, model): raise ValueError("init_nuts can only be used for models with only " "continuous variables.") if not isinstance(init, str): diff --git a/pymc3/step_methods/hmc/base_hmc.py b/pymc3/step_methods/hmc/base_hmc.py index c5e9603a90..aaaaa9f4b2 100644 --- a/pymc3/step_methods/hmc/base_hmc.py +++ b/pymc3/step_methods/hmc/base_hmc.py @@ -101,6 +101,7 @@ def __init__( # XXX: If the dimensions of these terms change, the step size # dimension-scaling should change as well, no? test_point = self._model.initial_point + nuts_vars = [test_point[v.name] for v in vars] size = sum(v.size for v in nuts_vars) diff --git a/pymc3/step_methods/hmc/nuts.py b/pymc3/step_methods/hmc/nuts.py index dca02a6c9b..267c20659f 100644 --- a/pymc3/step_methods/hmc/nuts.py +++ b/pymc3/step_methods/hmc/nuts.py @@ -18,7 +18,7 @@ from pymc3.aesaraf import floatX from pymc3.backends.report import SamplerWarning, WarningType -from pymc3.distributions import BART +from pymc3.distributions.bart import BARTRV from pymc3.math import logbern, logdiffexp_numpy from pymc3.step_methods.arraystep import Competence from pymc3.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData @@ -198,8 +198,9 @@ def _hamiltonian_step(self, start, p0, step_size): @staticmethod def competence(var, has_grad): """Check how appropriate this class is for sampling a random variable.""" + dist = getattr(var.owner, "op", None) - if var.dtype in continuous_types and has_grad and not isinstance(dist, BART): + if var.dtype in continuous_types and has_grad and not isinstance(dist, BARTRV): return Competence.IDEAL return Competence.INCOMPATIBLE diff --git a/pymc3/step_methods/pgbart.py b/pymc3/step_methods/pgbart.py index b3b00bfa52..351f1ae8a2 100644 --- a/pymc3/step_methods/pgbart.py +++ b/pymc3/step_methods/pgbart.py @@ -14,13 +14,18 @@ import logging +from typing import Any, Dict, List, Tuple + +import aesara import numpy as np from aesara import function as aesara_function +from pandas import DataFrame, Series from pymc3.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements -from pymc3.distributions import BART -from pymc3.distributions.tree import Tree +from pymc3.blocking import RaveledVars +from pymc3.distributions.bart import BARTRV +from pymc3.distributions.tree import LeafNode, SplitNode, Tree from pymc3.model import modelcontext from pymc3.step_methods.arraystep import ArrayStepShared, Competence @@ -56,12 +61,44 @@ class PGBART(ArrayStepShared): generates_stats = True stats_dtypes = [{"variable_inclusion": np.ndarray}] - def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", model=None): - _log.warning("The BART model is experimental. Use with caution.") + def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", model=None): + _log.warning("BART is experimental. Use with caution.") model = modelcontext(model) initial_values = model.initial_point - vars = inputvars(vars) - self.bart = vars[0].distribution + value_bart = inputvars(vars)[0] + self.bart = model.values_to_rvs[value_bart].owner.op + + self.X, self.Y, self.missing_data = preprocess_XY(self.bart.X, self.bart.Y) + self.m = self.bart.m + self.alpha = self.bart.alpha + self.k = self.bart.k + self.split_prior = self.bart.split_prior + if self.split_prior is None: + self.split_prior = np.ones(self.X.shape[1]) + + self.init_mean = self.Y.mean() + # if data is binary + Y_unique = np.unique(self.Y) + if Y_unique.size == 2 and np.all(Y_unique == [0, 1]): + self.mu_std = 6 / (self.k * self.m ** 0.5) + # maybe we need to check for count data + else: + self.mu_std = self.Y.std() / (self.k * self.m ** 0.5) + + self.num_observations = self.X.shape[0] + self.num_variates = self.X.shape[1] + self.available_predictors = list(range(self.num_variates)) + + sum_trees_output = np.full_like(self.Y, self.init_mean).astype(aesara.config.floatX) + self.a_tree = Tree.init_tree( + tree_id=0, + leaf_node_value=self.init_mean / self.m, + idx_data_points=np.arange(self.num_observations, dtype="int32"), + ) + self.mean = fast_mean() + self.normal = NormalSampler() + self.prior_prob_leaf_node = compute_prior_probability(self.alpha) + self.ssv = SampleSplittingVariable(self.split_prior) self.tune = True self.idx = 0 @@ -69,63 +106,78 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m self.sum_trees = [] self.chunk = chunk - if chunk == "auto": - self.chunk = max(1, int(self.bart.m * 0.1)) - self.bart.chunk = self.chunk + if self.chunk == "auto": + self.chunk = max(1, int(self.m * 0.1)) self.num_particles = num_particles self.log_num_particles = np.log(num_particles) self.indices = list(range(1, num_particles)) self.max_stages = max_stages - self.old_trees_particles_list = [] - for i in range(self.bart.m): - p = ParticleTree(self.bart.trees[i], self.bart.prior_prob_leaf_node) - self.old_trees_particles_list.append(p) shared = make_shared_replacements(initial_values, vars, model) self.likelihood_logp = logp(initial_values, [model.datalogpt], vars, shared) + self.init_likelihood = self.likelihood_logp(sum_trees_output) + self.init_log_weight = self.init_likelihood - self.log_num_particles + self.all_particles = [] + for i in range(self.m): + self.a_tree.tree_id = i + p = ParticleTree( + self.a_tree, + self.init_log_weight, + self.init_likelihood, + ) + self.all_particles.append(p) super().__init__(vars, shared) - def astep(self, _): - bart = self.bart - num_observations = bart.num_observations - variable_inclusion = np.zeros(bart.num_variates, dtype="int") + def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: + point_map_info = q.point_map_info + sum_trees_output = q.data - # For the tunning phase we restrict max_stages to a low number, otherwise it is almost sure - # we will reach max_stages given that our first set of m trees is not good at all. - # Can set max_stages as a function of the number of variables/dimensions? - if self.tune: - max_stages = 5 - else: - max_stages = self.max_stages + variable_inclusion = np.zeros(self.num_variates, dtype="int") - if self.idx == bart.m: + if self.idx == self.m: self.idx = 0 for idx in range(self.idx, self.idx + self.chunk): - if idx >= bart.m: + if idx >= self.m: break + tree = self.all_particles[idx].tree + sum_trees_output_noi = sum_trees_output - tree.predict_output() self.idx += 1 - tree = bart.trees[idx] - R_j = bart.get_residuals_loo(tree) # Generate an initial set of SMC particles # at the end of the algorithm we return one of these particles as the new tree - particles = self.init_particles(tree.tree_id, R_j, num_observations) + particles = self.init_particles(tree.tree_id) - for t in range(1, max_stages): + for t in range(self.max_stages): # Get old particle at stage t - particles[0] = self.get_old_tree_particle(tree.tree_id, t) + if t > 0: + particles[0] = self.get_old_tree_particle(tree.tree_id, t) # sample each particle (try to grow each tree) - for c in range(1, self.num_particles): - particles[c].sample_tree_sequential(bart) + compute_logp = [True] + for p in particles[1:]: + clp = p.sample_tree_sequential( + self.ssv, + self.available_predictors, + self.prior_prob_leaf_node, + self.X, + self.missing_data, + sum_trees_output, + self.mean, + self.m, + self.normal, + self.mu_std, + ) + compute_logp.append(clp) # Update weights. Since the prior is used as the proposal,the weights # are updated additively as the ratio of the new and old log_likelihoods - for p_idx, p in enumerate(particles): - new_likelihood = self.likelihood_logp(p.tree.predict_output(num_observations)) - p.log_weight += new_likelihood - p.old_likelihood_logp - p.old_likelihood_logp = new_likelihood - + for clp, p in zip(compute_logp, particles): + if clp: # Compute the likelihood when p has changed from the previous iteration + new_likelihood = self.likelihood_logp( + sum_trees_output_noi + p.tree.predict_output() + ) + p.log_weight += new_likelihood - p.old_likelihood_logp + p.old_likelihood_logp = new_likelihood # Normalize weights - W, normalized_weights = self.normalize(particles) + W_t, normalized_weights = self.normalize(particles) # Resample all but first particle re_n_w = normalized_weights[1:] / normalized_weights[1:].sum() @@ -133,37 +185,42 @@ def astep(self, _): particles[1:] = particles[new_indices] # Set the new weights - w_t = W - self.log_num_particles for p in particles: - p.log_weight = w_t + p.log_weight = W_t # Check if particles can keep growing, otherwise stop iterating - non_available_nodes_for_expansion = np.ones(self.num_particles - 1) - for c in range(1, self.num_particles): - if len(particles[c].expansion_nodes) != 0: - non_available_nodes_for_expansion[c - 1] = 0 - if np.all(non_available_nodes_for_expansion): + non_available_nodes_for_expansion = [] + for p in particles[1:]: + if p.expansion_nodes: + non_available_nodes_for_expansion.append(0) + if all(non_available_nodes_for_expansion): break # Get the new tree and update - new_tree = np.random.choice(particles, p=normalized_weights) - self.old_trees_particles_list[tree.tree_id] = new_tree - bart.trees[idx] = new_tree.tree - new_prediction = new_tree.tree.predict_output(num_observations) - bart.sum_trees_output = bart.Y - R_j + new_prediction - - if not self.tune: + new_particle = np.random.choice(particles, p=normalized_weights) + new_tree = new_particle.tree + new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles + self.all_particles[tree.tree_id] = new_particle + sum_trees_output = sum_trees_output_noi + new_tree.predict_output() + + if self.tune: + for index in new_particle.used_variates: + self.split_prior[index] += 1 + self.ssv = SampleSplittingVariable(self.split_prior) + else: self.iter += 1 - self.sum_trees.append(new_tree.tree) - if not self.iter % bart.m: - bart.all_trees.append(self.sum_trees) + self.sum_trees.append(new_tree) + if not self.iter % self.m: + # XXX update the all_trees variable in BARTRV to be used in the rng_fn method + # this fails for chains > 1 as the variable is not shared between proccesses + self.bart.all_trees.append(self.sum_trees) self.sum_trees = [] - for index in new_tree.used_variates: + for index in new_particle.used_variates: variable_inclusion[index] += 1 stats = {"variable_inclusion": variable_inclusion} - - return bart.sum_trees_output, [stats] + sum_trees_output = RaveledVars(sum_trees_output, point_map_info) + return sum_trees_output, [stats] @staticmethod def competence(var, has_grad): @@ -171,108 +228,293 @@ def competence(var, has_grad): PGBART is only suitable for BART distributions """ dist = getattr(var.owner, "op", None) - if isinstance(dist, BART): + if isinstance(dist, BARTRV): return Competence.IDEAL return Competence.INCOMPATIBLE def normalize(self, particles): """ - Use logsumexp trick to get W and softmax to get normalized_weights + Use logsumexp trick to get W_t and softmax to get normalized_weights """ log_w = np.array([p.log_weight for p in particles]) log_w_max = log_w.max() log_w_ = log_w - log_w_max w_ = np.exp(log_w_) w_sum = w_.sum() - W = log_w_max + np.log(w_sum) + W_t = log_w_max + np.log(w_sum) - self.log_num_particles normalized_weights = w_ / w_sum # stabilize weights to avoid assigning exactly zero probability to a particle normalized_weights += 1e-12 - return W, normalized_weights + return W_t, normalized_weights def get_old_tree_particle(self, tree_id, t): - old_tree_particle = self.old_trees_particles_list[tree_id] + old_tree_particle = self.all_particles[tree_id] old_tree_particle.set_particle_to_step(t) return old_tree_particle - def init_particles(self, tree_id, R_j, num_observations): + def init_particles(self, tree_id): """ Initialize particles """ - # The first particle is from the tree we are trying to replace - prev_tree = self.get_old_tree_particle(tree_id, 0) - likelihood = self.likelihood_logp(prev_tree.tree.predict_output(num_observations)) - prev_tree.old_likelihood_logp = likelihood - prev_tree.log_weight = likelihood - self.log_num_particles - particles = [prev_tree] - - # The rest of the particles are identically initialized - initial_value_leaf_nodes = R_j.mean() - initial_idx_data_points_leaf_nodes = np.arange(num_observations, dtype="int32") - new_tree = Tree.init_tree( - tree_id=tree_id, - leaf_node_value=initial_value_leaf_nodes, - idx_data_points=initial_idx_data_points_leaf_nodes, - ) - likelihood_logp = self.likelihood_logp(new_tree.predict_output(num_observations)) - log_weight = likelihood_logp - self.log_num_particles - for i in range(1, self.num_particles): + p = self.get_old_tree_particle(tree_id, 0) + p.log_weight = self.init_log_weight + p.old_likelihood_logp = self.init_likelihood + particles = [p] + + for _ in self.indices: + self.a_tree.tree_id = tree_id particles.append( - ParticleTree(new_tree, self.bart.prior_prob_leaf_node, log_weight, likelihood_logp) + ParticleTree( + self.a_tree, + self.init_log_weight, + self.init_likelihood, + ) ) return np.array(particles) - def resample(self, particles, weights): - """ - resample a set of particles given its weights - """ - particles = np.random.choice(particles, size=len(particles), p=weights) - return particles - class ParticleTree: """ Particle tree """ - def __init__(self, tree, prior_prob_leaf_node, log_weight=0, likelihood=0): + def __init__(self, tree, log_weight, likelihood): self.tree = tree.copy() # keeps the tree that we care at the moment - self.expansion_nodes = tree.idx_leaf_nodes.copy() # This should be the array [0] + self.expansion_nodes = [0] self.tree_history = [self.tree] self.expansion_nodes_history = [self.expansion_nodes] - self.log_weight = 0 - self.prior_prob_leaf_node = prior_prob_leaf_node + self.log_weight = log_weight self.old_likelihood_logp = likelihood self.used_variates = [] - def sample_tree_sequential(self, bart): + def sample_tree_sequential( + self, + ssv, + available_predictors, + prior_prob_leaf_node, + X, + missing_data, + sum_trees_output, + mean, + m, + normal, + mu_std, + ): + clp = False if self.expansion_nodes: index_leaf_node = self.expansion_nodes.pop(0) # Probability that this node will remain a leaf node - prob_leaf = self.prior_prob_leaf_node[self.tree[index_leaf_node].depth] + prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth] if prob_leaf < np.random.random(): - grow_successful, index_selected_predictor = bart.grow_tree( - self.tree, index_leaf_node + clp, index_selected_predictor = grow_tree( + self.tree, + index_leaf_node, + ssv, + available_predictors, + X, + missing_data, + sum_trees_output, + mean, + m, + normal, + mu_std, ) - if grow_successful: - # Add new leaf nodes indexes + if clp: new_indexes = self.tree.idx_leaf_nodes[-2:] self.expansion_nodes.extend(new_indexes) self.used_variates.append(index_selected_predictor) self.tree_history.append(self.tree) self.expansion_nodes_history.append(self.expansion_nodes) + return clp def set_particle_to_step(self, t): if len(self.tree_history) <= t: - self.tree = self.tree_history[-1] - self.expansion_nodes = self.expansion_nodes_history[-1] - else: - self.tree = self.tree_history[t] - self.expansion_nodes = self.expansion_nodes_history[t] + t = -1 + self.tree = self.tree_history[t] + self.expansion_nodes = self.expansion_nodes_history[t] + + +def preprocess_XY(X, Y): + if isinstance(Y, (Series, DataFrame)): + Y = Y.to_numpy() + if isinstance(X, (Series, DataFrame)): + X = X.to_numpy() + missing_data = np.any(np.isnan(X)) + Y = Y.astype(float) + return X, Y, missing_data + + +class SampleSplittingVariable: + def __init__(self, alpha_prior): + """ + Sample splitting variables proportional to `alpha_prior`. + + This is equivalent as sampling weights from a Dirichlet distribution with `alpha_prior` + parameter and then using those weights to sample from the available spliting variables. + This enforce sparsity. + """ + self.enu = list(enumerate(np.cumsum(alpha_prior / alpha_prior.sum()))) + + def rvs(self): + r = np.random.random() + for i, v in self.enu: + if r <= v: + return i + + +def compute_prior_probability(alpha): + """ + Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)). + Taken from equation 19 in [Rockova2018]. + + Parameters + ---------- + alpha : float + + Returns + ------- + list with probabilities for leaf nodes + + References + ---------- + .. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART. + arXiv, `link `__ + """ + prior_leaf_prob = [0] + depth = 1 + while prior_leaf_prob[-1] < 1: + prior_leaf_prob.append(1 - alpha ** depth) + depth += 1 + return prior_leaf_prob + + +def grow_tree( + tree, + index_leaf_node, + ssv, + available_predictors, + X, + missing_data, + sum_trees_output, + mean, + m, + normal, + mu_std, +): + current_node = tree.get_node(index_leaf_node) + + index_selected_predictor = ssv.rvs() + selected_predictor = available_predictors[index_selected_predictor] + available_splitting_values = X[current_node.idx_data_points, selected_predictor] + if missing_data: + available_splitting_values = available_splitting_values[ + ~np.isnan(available_splitting_values) + ] + + if available_splitting_values.size == 0: + return False, None + + idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values)) + selected_splitting_rule = available_splitting_values[idx_selected_splitting_values] + new_split_node = SplitNode( + index=index_leaf_node, + idx_split_variable=selected_predictor, + split_value=selected_splitting_rule, + ) + + left_node_idx_data_points, right_node_idx_data_points = get_new_idx_data_points( + new_split_node, current_node.idx_data_points, X + ) + + left_node_value = draw_leaf_value( + sum_trees_output[left_node_idx_data_points], mean, m, normal, mu_std + ) + right_node_value = draw_leaf_value( + sum_trees_output[right_node_idx_data_points], mean, m, normal, mu_std + ) + + new_left_node = LeafNode( + index=current_node.get_idx_left_child(), + value=left_node_value, + idx_data_points=left_node_idx_data_points, + ) + new_right_node = LeafNode( + index=current_node.get_idx_right_child(), + value=right_node_value, + idx_data_points=right_node_idx_data_points, + ) + tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node) + + return True, index_selected_predictor + + +def get_new_idx_data_points(current_split_node, idx_data_points, X): + idx_split_variable = current_split_node.idx_split_variable + split_value = current_split_node.split_value + + left_idx = X[idx_data_points, idx_split_variable] <= split_value + left_node_idx_data_points = idx_data_points[left_idx] + right_node_idx_data_points = idx_data_points[~left_idx] + + return left_node_idx_data_points, right_node_idx_data_points + + +def draw_leaf_value(sum_trees_output_idx, mean, m, normal, mu_std): + """Draw Gaussian distributed leaf values""" + if sum_trees_output_idx.size == 0: + return 0 + else: + mu_mean = mean(sum_trees_output_idx) / m + draw = normal.random() * mu_std + mu_mean + return draw + + +def fast_mean(): + """If available use Numba to speed up the computation of the mean.""" + try: + from numba import jit + except ImportError: + return np.mean + + @jit + def mean(a): + count = a.shape[0] + suma = 0 + for i in range(count): + suma += a[i] + return suma / count + + return mean + + +def discrete_uniform_sampler(upper_value): + """Draw from the uniform distribution with bounds [0, upper_value). + + This is the same and np.random.randit(upper_value) but faster. + """ + return int(np.random.random() * upper_value) + + +class NormalSampler: + """ + Cache samples from a standard normal distribution + """ + + def __init__(self): + self.size = 1000 + self.cache = [] + + def random(self): + if not self.cache: + self.update() + return self.cache.pop() + + def update(self): + self.cache = np.random.normal(loc=0.0, scale=1, size=self.size).tolist() def logp(point, out_vars, vars, shared): diff --git a/pymc3/tests/test_bart.py b/pymc3/tests/test_bart.py new file mode 100644 index 0000000000..5d221633a4 --- /dev/null +++ b/pymc3/tests/test_bart.py @@ -0,0 +1,79 @@ +import numpy as np + +from numpy.random import RandomState +from numpy.testing import assert_almost_equal + +import pymc3 as pm + + +def test_split_node(): + split_node = pm.distributions.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0) + assert split_node.index == 5 + assert split_node.idx_split_variable == 2 + assert split_node.split_value == 3.0 + assert split_node.depth == 2 + assert split_node.get_idx_parent_node() == 2 + assert split_node.get_idx_left_child() == 11 + assert split_node.get_idx_right_child() == 12 + + +def test_leaf_node(): + leaf_node = pm.distributions.tree.LeafNode(index=5, value=3.14, idx_data_points=[1, 2, 3]) + assert leaf_node.index == 5 + assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3]) + assert leaf_node.value == 3.14 + assert leaf_node.get_idx_parent_node() == 2 + assert leaf_node.get_idx_left_child() == 11 + assert leaf_node.get_idx_right_child() == 12 + + +def test_bart_vi(): + X = np.random.normal(0, 1, size=(3, 250)).T + Y = np.random.normal(0, 1, size=250) + X[:, 0] = np.random.normal(Y, 0.1) + + with pm.Model() as model: + mu = pm.BART("mu", X, Y, m=10) + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu, sigma, observed=Y) + idata = pm.sample(random_seed=3415) + var_imp = ( + idata.sample_stats["variable_inclusion"] + .stack(samples=("chain", "draw")) + .mean("samples") + ) + var_imp /= var_imp.sum() + assert var_imp[0] > var_imp[1:].sum() + np.testing.assert_almost_equal(var_imp.sum(), 1) + + +def test_bart_random(): + X = np.random.normal(0, 1, size=(2, 50)).T + Y = np.random.normal(0, 1, size=50) + + with pm.Model() as model: + mu = pm.BART("mu", X, Y, m=10) + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu, sigma, observed=Y) + idata = pm.sample(random_seed=3415, chains=1) + + rng = RandomState(12345) + pred_all = mu.owner.op.rng_fn(rng, size=2) + rng = RandomState(12345) + pred_first = mu.owner.op.rng_fn(rng, X_new=X[:10]) + + assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) + assert pred_all.shape == (2, 50) + assert pred_first.shape == (10,) + + +def test_missing_data(): + X = np.random.normal(0, 1, size=(2, 50)).T + Y = np.random.normal(0, 1, size=50) + X[10:20, 0] = np.nan + + with pm.Model() as model: + mu = pm.BART("mu", X, Y, m=10) + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu, sigma, observed=Y) + idata = pm.sample(random_seed=3415) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index e09ce70b90..ffa20bfd68 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -174,21 +174,6 @@ def test_trace_report(self, step_cls, discard): assert trace.report.n_draws == 100 assert isinstance(trace.report.t_sampling, float) - @pytest.mark.xfail(reason="BART not refactored for v4") - def test_trace_report_bart(self): - X = np.random.normal(0, 1, size=(3, 250)).T - Y = np.random.normal(0, 1, size=250) - X[:, 0] = np.random.normal(Y, 0.1) - - with pm.Model() as model: - mu = pm.BART("mu", X, Y, m=20) - sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - trace = pm.sample(500, tune=100, random_seed=3415, return_inferencedata=False) - var_imp = trace.report.variable_importance - assert var_imp[0] > var_imp[1:].sum() - npt.assert_almost_equal(var_imp.sum(), 1) - def test_return_inferencedata(self): with self.model: kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=pm.Metropolis())