diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 76ed533fb8..8a2557501a 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -63,7 +63,6 @@ jobs: --ignore=pymc/tests/test_distributions_random.py --ignore=pymc/tests/test_idata_conversion.py --ignore=pymc/tests/test_smc.py - --ignore=pymc/tests/test_bart.py --ignore=pymc/tests/test_missing.py - | @@ -77,7 +76,6 @@ jobs: pymc/tests/test_updates.py pymc/tests/test_transforms.py pymc/tests/test_smc.py - pymc/tests/test_bart.py pymc/tests/test_mixture.py - | diff --git a/docs/source/api.rst b/docs/source/api.rst index 7568a9a2f1..34a0ff6c8f 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -15,7 +15,6 @@ API Reference api/smc api/backends api/data - api/bart api/ode api/tuning api/math diff --git a/docs/source/api/bart.rst b/docs/source/api/bart.rst deleted file mode 100644 index 4ee4325678..0000000000 --- a/docs/source/api/bart.rst +++ /dev/null @@ -1,12 +0,0 @@ -Bayesian Additive Regression Trees (BART) -***************************************** - -.. currentmodule:: pymc - -.. autosummary:: - :toctree: generated/ - - BART - PGBART - bart.plot_dependence - bart.predict diff --git a/pymc/__init__.py b/pymc/__init__.py index 82de5b4ee6..f28d137e12 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -52,7 +52,6 @@ def __set_compiler_flags(): from pymc import gp, ode, sampling from pymc.aesaraf import * from pymc.backends import * -from pymc.bart import * from pymc.blocking import * from pymc.data import * from pymc.distributions import * diff --git a/pymc/bart/__init__.py b/pymc/bart/__init__.py deleted file mode 100644 index b244c69cf6..0000000000 --- a/pymc/bart/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2020 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from pymc.bart.bart import BART -from pymc.bart.pgbart import PGBART -from pymc.bart.utils import plot_dependence, predict - -__all__ = ["BART", "PGBART"] diff --git a/pymc/bart/bart.py b/pymc/bart/bart.py deleted file mode 100644 index fd00cc32a5..0000000000 --- a/pymc/bart/bart.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2020 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import aesara.tensor as at -import numpy as np - -from aeppl.logprob import _logprob -from aesara.tensor.random.op import RandomVariable, default_shape_from_params -from pandas import DataFrame, Series - -from pymc.distributions.distribution import NoDistribution, _get_moment - -__all__ = ["BART"] - - -class BARTRV(RandomVariable): - """ - Base class for BART - """ - - name = "BART" - ndim_supp = 1 - ndims_params = [2, 1, 0, 0, 0, 1] - dtype = "floatX" - _print_name = ("BART", "\\operatorname{BART}") - all_trees = None - - def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): - return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes) - - @classmethod - def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs): - return np.full_like(cls.Y, cls.Y.mean()) - - -bart = BARTRV() - - -class BART(NoDistribution): - """ - Bayesian Additive Regression Tree distribution. - - Distribution representing a sum over trees - - Parameters - ---------- - X : array-like - The covariate matrix. - Y : array-like - The response vector. - m : int - Number of trees - alpha : float - Control the prior probability over the depth of the trees. Defaults to 0.25. - It is recommended to be in the interval (0, 0.5]. - k : float - Scale parameter for the values of the leaf nodes. Defaults to 1. - split_prior : array-like - Each element of split_prior should be in the [0, 1] interval and the elements should sum to - 1. Otherwise they will be normalized. - Defaults to None, i.e. all covariates have the same prior probability to be selected. - """ - - def __new__( - cls, - name, - X, - Y, - m=50, - alpha=0.25, - k=2, - split_prior=None, - **kwargs, - ): - - X, Y = preprocess_XY(X, Y) - - bart_op = type( - f"BART_{name}", - (BARTRV,), - dict( - name="BART", - inplace=False, - initval=Y.mean(), - X=X, - Y=Y, - m=m, - alpha=alpha, - k=k, - split_prior=split_prior, - ), - )() - - NoDistribution.register(BARTRV) - - @_get_moment.register(BARTRV) - def get_moment(rv, size, *rv_inputs): - return cls.get_moment(rv, size, *rv_inputs) - - cls.rv_op = bart_op - params = [X, Y, m, alpha, k] - return super().__new__(cls, name, *params, **kwargs) - - @classmethod - def dist(cls, *params, **kwargs): - return super().dist(params, **kwargs) - - def logp(x, *inputs): - """Calculate log probability. - - Parameters - ---------- - x: numeric, TensorVariable - Value for which log-probability is calculated. - - Returns - ------- - TensorVariable - """ - return at.zeros_like(x) - - @classmethod - def get_moment(cls, rv, size, *rv_inputs): - mean = at.fill(size, rv.Y.mean()) - return mean - - -def preprocess_XY(X, Y): - if isinstance(Y, (Series, DataFrame)): - Y = Y.to_numpy() - if isinstance(X, (Series, DataFrame)): - X = X.to_numpy() - # X = np.random.normal(X, X.std(0)/100) - Y = Y.astype(float) - X = X.astype(float) - return X, Y - - -@_logprob.register(BARTRV) -def logp(op, value_var, *dist_params, **kwargs): - _dist_params = dist_params[3:] - value_var = value_var[0] - return BART.logp(value_var, *_dist_params) diff --git a/pymc/bart/pgbart.py b/pymc/bart/pgbart.py deleted file mode 100644 index 50a20a2844..0000000000 --- a/pymc/bart/pgbart.py +++ /dev/null @@ -1,542 +0,0 @@ -# Copyright 2020 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from copy import copy, deepcopy - -import aesara -import numpy as np - -from aesara import function as aesara_function - -from pymc.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements -from pymc.bart.bart import BARTRV -from pymc.bart.tree import LeafNode, SplitNode, Tree -from pymc.model import modelcontext -from pymc.step_methods.arraystep import ArrayStepShared, Competence - -_log = logging.getLogger("pymc") - - -class PGBART(ArrayStepShared): - """ - Particle Gibss BART sampling step - - Parameters - ---------- - vars: list - List of value variables for sampler - num_particles : int - Number of particles. Defaults to 40 - max_stages : int - Maximum number of iterations. Defaults to 100. - batch : int or tuple - Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees - during and after tuning. If a tuple is passed the first element is the batch size - during tuning and the second the batch size after tuning. - model: PyMC Model - Optional model for sampling step. Defaults to None (taken from context). - """ - - name = "pgbart" - default_blocked = False - generates_stats = True - stats_dtypes = [{"variable_inclusion": np.ndarray, "bart_trees": np.ndarray}] - - def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", model=None): - _log.warning("BART is experimental. Use with caution.") - model = modelcontext(model) - initial_values = model.compute_initial_point() - if vars is None: - vars = model.value_vars - else: - vars = [model.rvs_to_values.get(var, var) for var in vars] - vars = inputvars(vars) - value_bart = vars[0] - self.bart = model.values_to_rvs[value_bart].owner.op - - self.X = self.bart.X - self.Y = self.bart.Y - self.missing_data = np.any(np.isnan(self.X)) - self.m = self.bart.m - self.alpha = self.bart.alpha - self.k = self.bart.k - self.alpha_vec = self.bart.split_prior - if self.alpha_vec is None: - self.alpha_vec = np.ones(self.X.shape[1]) - - self.init_mean = self.Y.mean() - # if data is binary - Y_unique = np.unique(self.Y) - if Y_unique.size == 2 and np.all(Y_unique == [0, 1]): - self.mu_std = 3 / (self.k * self.m**0.5) - # maybe we need to check for count data - else: - self.mu_std = self.Y.std() / (self.k * self.m**0.5) - - self.num_observations = self.X.shape[0] - self.num_variates = self.X.shape[1] - self.available_predictors = list(range(self.num_variates)) - - self.sum_trees = np.full_like(self.Y, self.init_mean).astype(aesara.config.floatX) - self.a_tree = Tree.init_tree( - leaf_node_value=self.init_mean / self.m, - idx_data_points=np.arange(self.num_observations, dtype="int32"), - ) - self.mean = fast_mean() - - self.normal = NormalSampler() - self.prior_prob_leaf_node = compute_prior_probability(self.alpha) - self.ssv = SampleSplittingVariable(self.alpha_vec) - - self.tune = True - - if batch == "auto": - batch = max(1, int(self.m * 0.1)) - self.batch = (batch, batch) - else: - if isinstance(batch, (tuple, list)): - self.batch = batch - else: - self.batch = (batch, batch) - - self.log_num_particles = np.log(num_particles) - self.indices = list(range(2, num_particles)) - self.len_indices = len(self.indices) - self.max_stages = max_stages - - shared = make_shared_replacements(initial_values, vars, model) - self.likelihood_logp = logp(initial_values, [model.datalogpt], vars, shared) - self.all_particles = [] - for i in range(self.m): - self.a_tree.leaf_node_value = self.init_mean / self.m - p = ParticleTree(self.a_tree) - self.all_particles.append(p) - self.all_trees = np.array([p.tree for p in self.all_particles]) - super().__init__(vars, shared) - - def astep(self, _): - variable_inclusion = np.zeros(self.num_variates, dtype="int") - - tree_ids = np.random.choice(range(self.m), replace=False, size=self.batch[~self.tune]) - for tree_id in tree_ids: - # Generate an initial set of SMC particles - # at the end of the algorithm we return one of these particles as the new tree - particles = self.init_particles(tree_id) - # Compute the sum of trees without the old tree, that we are attempting to replace - self.sum_trees_noi = self.sum_trees - particles[0].tree.predict_output() - # Resample leaf values for particle 1 which is a copy of the old tree - particles[1].sample_leafs( - self.sum_trees, - self.X, - self.mean, - self.m, - self.normal, - self.mu_std, - ) - - # The old tree and the one with new leafs do not grow so we update the weights only once - self.update_weight(particles[0], old=True) - self.update_weight(particles[1], old=True) - for _ in range(self.max_stages): - # Sample each particle (try to grow each tree), except for the first two - stop_growing = True - for p in particles[2:]: - tree_grew = p.sample_tree( - self.ssv, - self.available_predictors, - self.prior_prob_leaf_node, - self.X, - self.missing_data, - self.sum_trees, - self.mean, - self.m, - self.normal, - self.mu_std, - ) - if tree_grew: - self.update_weight(p) - if p.expansion_nodes: - stop_growing = False - if stop_growing: - break - # Normalize weights - W_t, normalized_weights = self.normalize(particles[2:]) - - # Resample all but first two particles - new_indices = np.random.choice( - self.indices, size=self.len_indices, p=normalized_weights - ) - particles[2:] = particles[new_indices] - - # Set the new weights - for p in particles[2:]: - p.log_weight = W_t - - for p in particles[2:]: - p.log_weight = p.old_likelihood_logp - - _, normalized_weights = self.normalize(particles) - # Get the new tree and update - new_particle = np.random.choice(particles, p=normalized_weights) - new_tree = new_particle.tree - self.all_trees[tree_id] = new_tree - new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles - self.all_particles[tree_id] = new_particle - self.sum_trees = self.sum_trees_noi + new_tree.predict_output() - - if self.tune: - self.ssv = SampleSplittingVariable(self.alpha_vec) - for index in new_particle.used_variates: - self.alpha_vec[index] += 1 - else: - for index in new_particle.used_variates: - variable_inclusion[index] += 1 - - stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)} - return self.sum_trees, [stats] - - def normalize(self, particles): - """ - Use logsumexp trick to get W_t and softmax to get normalized_weights - """ - log_w = np.array([p.log_weight for p in particles]) - log_w_max = log_w.max() - log_w_ = log_w - log_w_max - w_ = np.exp(log_w_) - w_sum = w_.sum() - W_t = log_w_max + np.log(w_sum) - self.log_num_particles - normalized_weights = w_ / w_sum - # stabilize weights to avoid assigning exactly zero probability to a particle - normalized_weights += 1e-12 - - return W_t, normalized_weights - - def init_particles(self, tree_id: int) -> np.ndarray: - """ - Initialize particles - """ - p = self.all_particles[tree_id] - particles = [p, p.copy()] - - for _ in self.indices: - particles.append(ParticleTree(self.a_tree)) - - return np.array(particles) - - def update_weight(self, particle, old=False): - """ - Update the weight of a particle - - Since the prior is used as the proposal,the weights are updated additively as the ratio of - the new and old log-likelihoods. - """ - new_likelihood = self.likelihood_logp(self.sum_trees_noi + particle.tree.predict_output()) - if old: - particle.log_weight = new_likelihood - particle.old_likelihood_logp = new_likelihood - else: - particle.log_weight += new_likelihood - particle.old_likelihood_logp - particle.old_likelihood_logp = new_likelihood - - @staticmethod - def competence(var, has_grad): - """ - PGBART is only suitable for BART distributions - """ - dist = getattr(var.owner, "op", None) - if isinstance(dist, BARTRV): - return Competence.IDEAL - return Competence.INCOMPATIBLE - - -class ParticleTree: - """ - Particle tree - """ - - def __init__(self, tree): - self.tree = tree.copy() # keeps the tree that we care at the moment - self.expansion_nodes = [0] - self.log_weight = 0 - self.old_likelihood_logp = 0 - self.used_variates = [] - - def copy(self): - return deepcopy(self) - - def sample_tree( - self, - ssv, - available_predictors, - prior_prob_leaf_node, - X, - missing_data, - sum_trees, - mean, - m, - normal, - mu_std, - ): - tree_grew = False - if self.expansion_nodes: - index_leaf_node = self.expansion_nodes.pop(0) - # Probability that this node will remain a leaf node - prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth] - - if prob_leaf < np.random.random(): - index_selected_predictor = grow_tree( - self.tree, - index_leaf_node, - ssv, - available_predictors, - X, - missing_data, - sum_trees, - mean, - m, - normal, - mu_std, - ) - if index_selected_predictor is not None: - new_indexes = self.tree.idx_leaf_nodes[-2:] - self.expansion_nodes.extend(new_indexes) - self.used_variates.append(index_selected_predictor) - tree_grew = True - - return tree_grew - - def sample_leafs(self, sum_trees, X, mean, m, normal, mu_std): - - sample_leaf_values(self.tree, sum_trees, X, mean, m, normal, mu_std) - - -class SampleSplittingVariable: - def __init__(self, alpha_vec): - """ - Sample splitting variables proportional to `alpha_vec`. - - This is equivalent to compute the posterior mean of a Dirichlet-Multinomial model. - This enforce sparsity. - """ - self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum()))) - - def rvs(self): - r = np.random.random() - for i, v in self.enu: - if r <= v: - return i - - -def compute_prior_probability(alpha): - """ - Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)). - Taken from equation 19 in [Rockova2018]. - - Parameters - ---------- - alpha : float - - Returns - ------- - list with probabilities for leaf nodes - - References - ---------- - .. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART. - arXiv, `link `__ - """ - prior_leaf_prob = [0] - depth = 1 - while prior_leaf_prob[-1] < 1: - prior_leaf_prob.append(1 - alpha**depth) - depth += 1 - return prior_leaf_prob - - -def grow_tree( - tree, - index_leaf_node, - ssv, - available_predictors, - X, - missing_data, - sum_trees, - mean, - m, - normal, - mu_std, -): - current_node = tree.get_node(index_leaf_node) - idx_data_points = current_node.idx_data_points - - index_selected_predictor = ssv.rvs() - selected_predictor = available_predictors[index_selected_predictor] - available_splitting_values = X[idx_data_points, selected_predictor] - if missing_data: - idx_data_points = idx_data_points[~np.isnan(available_splitting_values)] - available_splitting_values = available_splitting_values[ - ~np.isnan(available_splitting_values) - ] - - if available_splitting_values.size > 0: - idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values)) - split_value = available_splitting_values[idx_selected_splitting_values] - - new_idx_data_points = get_new_idx_data_points( - split_value, idx_data_points, selected_predictor, X - ) - current_node_children = ( - current_node.get_idx_left_child(), - current_node.get_idx_right_child(), - ) - - new_nodes = [] - for idx in range(2): - idx_data_point = new_idx_data_points[idx] - node_value = draw_leaf_value( - sum_trees[idx_data_point], - X[idx_data_point, selected_predictor], - mean, - m, - normal, - mu_std, - ) - - new_node = LeafNode( - index=current_node_children[idx], - value=node_value, - idx_data_points=idx_data_point, - ) - new_nodes.append(new_node) - - new_split_node = SplitNode( - index=index_leaf_node, - idx_split_variable=selected_predictor, - split_value=split_value, - ) - - # update tree nodes and indexes - tree.delete_node(index_leaf_node) - tree.set_node(index_leaf_node, new_split_node) - tree.set_node(new_nodes[0].index, new_nodes[0]) - tree.set_node(new_nodes[1].index, new_nodes[1]) - - return index_selected_predictor - - -def sample_leaf_values(tree, sum_trees, X, mean, m, normal, mu_std): - - for idx in tree.idx_leaf_nodes: - if idx > 0: - leaf = tree[idx] - idx_data_points = leaf.idx_data_points - parent_node = tree[leaf.get_idx_parent_node()] - selected_predictor = parent_node.idx_split_variable - node_value = draw_leaf_value( - sum_trees[idx_data_points], - X[idx_data_points, selected_predictor], - mean, - m, - normal, - mu_std, - ) - leaf.value = node_value - - -def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X): - - left_idx = X[idx_data_points, selected_predictor] <= split_value - left_node_idx_data_points = idx_data_points[left_idx] - right_node_idx_data_points = idx_data_points[~left_idx] - - return left_node_idx_data_points, right_node_idx_data_points - - -def draw_leaf_value(Y_mu_pred, X_mu, mean, m, normal, mu_std): - """Draw Gaussian distributed leaf values""" - if Y_mu_pred.size == 0: - return 0 - else: - norm = normal.random() * mu_std - if Y_mu_pred.size == 1: - mu_mean = Y_mu_pred.item() / m - else: - mu_mean = mean(Y_mu_pred) / m - - draw = norm + mu_mean - return draw - - -def fast_mean(): - """If available use Numba to speed up the computation of the mean.""" - try: - from numba import jit - except ImportError: - return np.mean - - @jit - def mean(a): - count = a.shape[0] - suma = 0 - for i in range(count): - suma += a[i] - return suma / count - - return mean - - -def discrete_uniform_sampler(upper_value): - """Draw from the uniform distribution with bounds [0, upper_value). - - This is the same and np.random.randit(upper_value) but faster. - """ - return int(np.random.random() * upper_value) - - -class NormalSampler: - """ - Cache samples from a standard normal distribution - """ - - def __init__(self): - self.size = 1000 - self.cache = [] - - def random(self): - if not self.cache: - self.update() - return self.cache.pop() - - def update(self): - self.cache = np.random.normal(loc=0.0, scale=1, size=self.size).tolist() - - -def logp(point, out_vars, vars, shared): - """Compile Aesara function of the model and the input and output variables. - - Parameters - ---------- - out_vars: List - containing :class:`pymc.Distribution` for the output variables - vars: List - containing :class:`pymc.Distribution` for the input variables - shared: List - containing :class:`aesara.tensor.Tensor` for depended shared data - """ - out_list, inarray0 = join_nonshared_inputs(point, out_vars, vars, shared) - f = aesara_function([inarray0], out_list[0]) - f.trust_input = True - return f diff --git a/pymc/bart/tree.py b/pymc/bart/tree.py deleted file mode 100644 index 4705690d2d..0000000000 --- a/pymc/bart/tree.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright 2020 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -from copy import deepcopy - -import aesara -import numpy as np - - -class Tree: - """Full binary tree - - A full binary tree is a tree where each node has exactly zero or two children. - This structure is used as the basic component of the Bayesian Additive Regression Tree (BART) - - Attributes - ---------- - tree_structure : dict - A dictionary that represents the nodes stored in breadth-first order, based in the array - method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays). - The dictionary's keys are integers that represent the nodes position. - The dictionary's values are objects of type SplitNode or LeafNode that represent the nodes - of the tree itself. - idx_leaf_nodes : list - List with the index of the leaf nodes of the tree. - num_observations : int - Number of observations used to fit BART. - m : int - Number of trees - - Parameters - ---------- - num_observations : int, optional - """ - - def __init__(self, num_observations=0): - self.tree_structure = {} - self.idx_leaf_nodes = [] - self.num_observations = num_observations - - def __getitem__(self, index): - return self.get_node(index) - - def __setitem__(self, index, node): - self.set_node(index, node) - - def copy(self): - return deepcopy(self) - - def get_node(self, index): - return self.tree_structure[index] - - def set_node(self, index, node): - self.tree_structure[index] = node - if isinstance(node, LeafNode): - self.idx_leaf_nodes.append(index) - - def delete_node(self, index): - current_node = self.get_node(index) - if isinstance(current_node, LeafNode): - self.idx_leaf_nodes.remove(index) - del self.tree_structure[index] - - def predict_output(self): - output = np.zeros(self.num_observations) - for node_index in self.idx_leaf_nodes: - current_node = self.get_node(node_index) - output[current_node.idx_data_points] = current_node.value - - return output.astype(aesara.config.floatX) - - def predict_out_of_sample(self, X): - """ - Predict output of tree for an unobserved point x. - - Parameters - ---------- - X : numpy array - Unobserved point - - Returns - ------- - float - Value of the leaf value where the unobserved point lies. - """ - leaf_node = self._traverse_tree(X, node_index=0) - return leaf_node.value - - def _traverse_tree(self, x, node_index=0): - """ - Traverse the tree starting from a particular node given an unobserved point. - - Parameters - ---------- - x : np.ndarray - node_index : int - - Returns - ------- - LeafNode - """ - current_node = self.get_node(node_index) - if isinstance(current_node, SplitNode): - if x[current_node.idx_split_variable] <= current_node.split_value: - left_child = current_node.get_idx_left_child() - current_node = self._traverse_tree(x, left_child) - else: - right_child = current_node.get_idx_right_child() - current_node = self._traverse_tree(x, right_child) - return current_node - - @staticmethod - def init_tree(leaf_node_value, idx_data_points): - """ - - Parameters - ---------- - leaf_node_value - idx_data_points - m : int - number of trees in BART - - Returns - ------- - - """ - new_tree = Tree(len(idx_data_points)) - new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points) - return new_tree - - -class BaseNode: - def __init__(self, index): - self.index = index - self.depth = int(math.floor(math.log(index + 1, 2))) - - def get_idx_parent_node(self): - return (self.index - 1) // 2 - - def get_idx_left_child(self): - return self.index * 2 + 1 - - def get_idx_right_child(self): - return self.get_idx_left_child() + 1 - - -class SplitNode(BaseNode): - def __init__(self, index, idx_split_variable, split_value): - super().__init__(index) - - self.idx_split_variable = idx_split_variable - self.split_value = split_value - - -class LeafNode(BaseNode): - def __init__(self, index, value, idx_data_points): - super().__init__(index) - self.value = value - self.idx_data_points = idx_data_points diff --git a/pymc/bart/utils.py b/pymc/bart/utils.py deleted file mode 100644 index ea77f79393..0000000000 --- a/pymc/bart/utils.py +++ /dev/null @@ -1,297 +0,0 @@ -import arviz as az -import matplotlib.pyplot as plt -import numpy as np - -from numpy.random import RandomState -from scipy.interpolate import griddata -from scipy.signal import savgol_filter - - -def predict(idata, rng, X_new=None, size=None): - """ - Generate samples from the BART-posterior - - Parameters - ---------- - idata: InferenceData - InferenceData containing a collection of BART_trees in sample_stats group - rng: NumPy random generator - X_new : array-like - A new covariate matrix. Use it to obtain out-of-sample predictions - size: int or tuple - Number of samples. - """ - bart_trees = idata.sample_stats.bart_trees - stacked_trees = bart_trees.stack(trees=["chain", "draw"]) - if size is None: - size = () - elif isinstance(size, int): - size = [size] - - flatten_size = 1 - for s in size: - flatten_size *= s - - idx = rng.randint(len(stacked_trees.trees), size=flatten_size) - - if X_new is None: - pred = np.zeros((flatten_size, stacked_trees[0, 0].item().num_observations)) - for ind, p in enumerate(pred): - for tree in stacked_trees.isel(trees=idx[ind]).values: - p += tree.predict_output() - else: - pred = np.zeros((flatten_size, X_new.shape[0])) - for ind, p in enumerate(pred): - for tree in stacked_trees.isel(trees=idx[ind]).values: - p += np.array([tree.predict_out_of_sample(x) for x in X_new]) - return pred.reshape((*size, -1)) - - -def plot_dependence( - idata, - X=None, - Y=None, - kind="pdp", - xs_interval="linear", - xs_values=None, - var_idx=None, - var_discrete=None, - samples=50, - instances=10, - random_seed=None, - sharey=True, - rug=True, - smooth=True, - indices=None, - grid="long", - color="C0", - color_mean="C0", - alpha=0.1, - figsize=None, - smooth_kwargs=None, - ax=None, -): - """ - Partial dependence or individual conditional expectation plot - - Parameters - ---------- - idata: InferenceData - InferenceData containing a collection of BART_trees in sample_stats group - X : array-like - The covariate matrix. - Y : array-like - The response vector. - kind : str - Whether to plor a partial dependence plot ("pdp") or an individual conditional expectation - plot ("ice"). Defaults to pdp. - xs_interval : str - Method used to compute the values X used to evaluate the predicted function. "linear", - evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified - quantiles of X. "insample", the evaluation is done at the values of X. - For discrete variables these options are ommited. - xs_values : int or list - Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of - points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of - quantiles to compute, which must be between 0 and 1 inclusive. - Ignored when ``xs_interval="insample"``. - var_idx : list - List of the indices of the covariate for which to compute the pdp or ice. - var_discrete : list - List of the indices of the covariate treated as discrete. - samples : int - Number of posterior samples used in the predictions. Defaults to 50 - instances : int - Number of instances of X to plot. Only relevant if ice ``kind="ice"`` plots. - random_seed : int - random_seed used to sample from the posterior. Defaults to None. - sharey : bool - Controls sharing of properties among y-axes. Defaults to True. - rug : bool - Whether to include a rugplot. Defaults to True. - smooth=True, - If True the result will be smoothed by first computing a linear interpolation of the data - over a regular grid and then applying the Savitzky-Golay filter to the interpolated data. - Defaults to True. - grid : str or tuple - How to arrange the subplots. Defaults to "long", one subplot below the other. - Other options are "wide", one subplot next to eachother or a tuple indicating the number of - rows and columns. - color : matplotlib valid color - Color used to plot the pdp or ice. Defaults to "C0" - color_mean : matplotlib valid color - Color used to plot the mean pdp or ice. Defaults to "C0", - alpha : float - Transparency level, should in the interval [0, 1]. - figsize : tuple - Figure size. If None it will be defined automatically. - smooth_kwargs : dict - Additional keywords modifying the Savitzky-Golay filter. - See scipy.signal.savgol_filter() for details. - ax : axes - Matplotlib axes. - - Returns - ------- - axes: matplotlib axes - """ - if kind not in ["pdp", "ice"]: - raise ValueError(f"kind={kind} is not suported. Available option are 'pdp' or 'ice'") - - if xs_interval not in ["insample", "linear", "quantiles"]: - raise ValueError( - f"""{xs_interval} is not suported. - Available option are 'insample', 'linear' or 'quantiles'""" - ) - - rng = RandomState(seed=random_seed) - - if hasattr(X, "columns") and hasattr(X, "values"): - X_names = list(X.columns) - X = X.values - else: - X_names = [] - - if hasattr(Y, "name"): - Y_label = f"Predicted {Y.name}" - else: - Y_label = "Predicted Y" - - num_observations = X.shape[0] - num_covariates = X.shape[1] - - indices = list(range(num_covariates)) - - if var_idx is None: - var_idx = indices - if var_discrete is None: - var_discrete = [] - - if X_names: - X_labels = [X_names[idx] for idx in var_idx] - else: - X_labels = [f"X_{idx}" for idx in var_idx] - - if xs_interval == "linear" and xs_values is None: - xs_values = 10 - - if xs_interval == "quantiles" and xs_values is None: - xs_values = [0.05, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.95] - - if kind == "ice": - instances = np.random.choice(range(X.shape[0]), replace=False, size=instances) - - new_Y = [] - new_X_target = [] - y_mins = [] - - new_X = np.zeros_like(X) - idx_s = list(range(X.shape[0])) - for i in var_idx: - indices_mi = indices[:] - indices_mi.pop(i) - y_pred = [] - if kind == "pdp": - if i in var_discrete: - new_X_i = np.unique(X[:, i]) - else: - if xs_interval == "linear": - new_X_i = np.linspace(np.nanmin(X[:, i]), np.nanmax(X[:, i]), xs_values) - elif xs_interval == "quantiles": - new_X_i = np.quantile(X[:, i], q=xs_values) - elif xs_interval == "insample": - new_X_i = X[:, i] - - for x_i in new_X_i: - new_X[:, indices_mi] = X[:, indices_mi] - new_X[:, i] = x_i - y_pred.append(np.mean(predict(idata, rng, X_new=new_X, size=samples), 1)) - new_X_target.append(new_X_i) - else: - for instance in instances: - new_X = X[idx_s] - new_X[:, indices_mi] = X[:, indices_mi][instance] - y_pred.append(np.mean(predict(idata, rng, X_new=new_X, size=samples), 0)) - new_X_target.append(new_X[:, i]) - y_mins.append(np.min(y_pred)) - new_Y.append(np.array(y_pred).T) - - if ax is None: - if grid == "long": - fig, axes = plt.subplots(len(var_idx), sharey=sharey, figsize=figsize) - elif grid == "wide": - fig, axes = plt.subplots(1, len(var_idx), sharey=sharey, figsize=figsize) - elif isinstance(grid, tuple): - fig, axes = plt.subplots(grid[0], grid[1], sharey=sharey, figsize=figsize) - axes = np.ravel(axes) - else: - axes = [ax] - fig = ax.get_figure() - - for i, ax in enumerate(axes): - if i >= len(var_idx): - ax.set_axis_off() - fig.delaxes(ax) - else: - var = var_idx[i] - if var in var_discrete: - if kind == "pdp": - y_means = new_Y[i].mean(0) - hdi = az.hdi(new_Y[i]) - ax.errorbar( - new_X_target[i], - y_means, - (y_means - hdi[:, 0], hdi[:, 1] - y_means), - fmt=".", - color=color, - ) - else: - ax.plot(new_X_target[i], new_Y[i], ".", color=color, alpha=alpha) - ax.plot(new_X_target[i], new_Y[i].mean(1), "o", color=color_mean) - ax.set_xticks(new_X_target[i]) - elif smooth: - if smooth_kwargs is None: - smooth_kwargs = {} - smooth_kwargs.setdefault("window_length", 55) - smooth_kwargs.setdefault("polyorder", 2) - x_data = np.linspace(np.nanmin(new_X_target[i]), np.nanmax(new_X_target[i]), 200) - x_data[0] = (x_data[0] + x_data[1]) / 2 - if kind == "pdp": - interp = griddata(new_X_target[i], new_Y[i].mean(0), x_data) - else: - interp = griddata(new_X_target[i], new_Y[i], x_data) - - y_data = savgol_filter(interp, axis=0, **smooth_kwargs) - - if kind == "pdp": - az.plot_hdi( - new_X_target[i], new_Y[i], color=color, fill_kwargs={"alpha": alpha}, ax=ax - ) - ax.plot(x_data, y_data, color=color_mean) - else: - ax.plot(x_data, y_data.mean(1), color=color_mean) - ax.plot(x_data, y_data, color=color, alpha=alpha) - - else: - idx = np.argsort(new_X_target[i]) - if kind == "pdp": - az.plot_hdi( - new_X_target[i], - new_Y[i], - smooth=smooth, - fill_kwargs={"alpha": alpha}, - ax=ax, - ) - ax.plot(new_X_target[i][idx], new_Y[i][idx].mean(0), color=color) - else: - ax.plot(new_X_target[i][idx], new_Y[i][idx], color=color, alpha=alpha) - ax.plot(new_X_target[i][idx], new_Y[i][idx].mean(1), color=color_mean) - - if rug: - lb = np.min(y_mins) - ax.plot(X[:, var], np.full_like(X[:, var], lb), "k|") - - ax.set_xlabel(X_labels[i]) - - fig.text(-0.05, 0.5, Y_label, va="center", rotation="vertical", fontsize=15) - return axes diff --git a/pymc/sampling.py b/pymc/sampling.py index cda15813f0..0ac24fec4b 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -54,9 +54,7 @@ from pymc.backends.arviz import _DefaultTrace from pymc.backends.base import BaseTrace, MultiTrace from pymc.backends.ndarray import NDArray -from pymc.bart.pgbart import PGBART from pymc.blocking import DictToArrayBijection -from pymc.distributions import NoDistribution from pymc.exceptions import IncorrectArgumentsError, SamplingError from pymc.initial_point import ( PointType, @@ -66,17 +64,7 @@ ) from pymc.model import Model, modelcontext from pymc.parallel_sampling import Draw, _cpu_count -from pymc.step_methods import ( - NUTS, - BinaryGibbsMetropolis, - BinaryMetropolis, - CategoricalGibbsMetropolis, - CompoundStep, - DEMetropolis, - HamiltonianMC, - Metropolis, - Slice, -) +from pymc.step_methods import NUTS, CompoundStep, DEMetropolis from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential from pymc.util import ( @@ -100,16 +88,6 @@ "draw", ] -STEP_METHODS = ( - NUTS, - HamiltonianMC, - Metropolis, - BinaryMetropolis, - BinaryGibbsMetropolis, - Slice, - CategoricalGibbsMetropolis, - PGBART, -) Step: TypeAlias = Union[BlockedStep, CompoundStep] ArrayLike: TypeAlias = Union[np.ndarray, List[float]] @@ -167,7 +145,7 @@ def instantiate_steppers( return steps -def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None): +def assign_step_methods(model, step=None, methods=None, step_kwargs=None): """Assign model variables to appropriate step methods. Passing a specified model will auto-assign its constituent stochastic @@ -200,6 +178,9 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None steps = [] assigned_vars = set() + if methods is None: + methods = pm.STEP_METHODS + if step is not None: try: steps += list(step) @@ -212,6 +193,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None # variables selected_steps = defaultdict(list) model_logpt = model.logpt() + for var in model.value_vars: if var not in assigned_vars: # determine if a gradient can be computed @@ -221,6 +203,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None tg.grad(model_logpt, var) except (NotImplementedError, tg.NullTypeGradError): has_gradient = False + # select the best method rv_var = model.values_to_rvs[var] selected = max( @@ -249,20 +232,12 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None: _log.info(">" * level + f"{s.__class__.__name__}: [{varnames}]") -def all_continuous(vars, model): - """Check that vars not include discrete variables or BART variables, excepting observed RVs.""" +def all_continuous(vars): + """Check that vars not include discrete variables, excepting observed RVs.""" - vars_ = [var for var in vars if not (var.owner and hasattr(var.tag, "observations"))] + vars_ = [var for var in vars if not hasattr(var.tag, "observations")] - if any( - [ - ( - var.dtype in discrete_types - or isinstance(model.values_to_rvs[var].owner.op, NoDistribution) - ) - for var in vars_ - ] - ): + if any([(var.dtype in discrete_types) for var in vars_]): return False else: return True @@ -403,7 +378,7 @@ def sample( ``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``, ``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``, - ``DEMetropolis``, ``DEMetropolisZ``, ``slice``, ``pgbart`` + ``DEMetropolis``, ``DEMetropolisZ``, ``slice`` B. If you manually declare the ``step_method``\ s, within the ``step`` kwarg, then you can address the ``step_method`` kwargs directly. @@ -490,29 +465,7 @@ def sample( draws += tune initial_points = None - if step is None and init is not None and all_continuous(model.value_vars, model): - try: - # By default, try to use NUTS - _log.info("Auto-assigning NUTS sampler...") - initial_points, step = init_nuts( - init=init, - chains=chains, - n_init=n_init, - model=model, - seeds=random_seed, - progressbar=progressbar, - jitter_max_retries=jitter_max_retries, - tune=tune, - initvals=initvals, - **kwargs, - ) - except (AttributeError, NotImplementedError, tg.NullTypeGradError): - # gradient computation failed - _log.info("Initializing NUTS failed. Falling back to elementwise auto-assignment.") - _log.debug("Exception in init nuts", exc_info=True) - step = assign_step_methods(model, step, step_kwargs=kwargs) - else: - step = assign_step_methods(model, step, step_kwargs=kwargs) + step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs) if isinstance(step, list): step = CompoundStep(step) @@ -634,24 +587,6 @@ def sample( mtrace.report._n_draws = n_draws mtrace.report._t_sampling = t_sampling - if "variable_inclusion" in mtrace.stat_names: - for strace in mtrace._straces.values(): - for stat in strace._stats: - if "variable_inclusion" in stat: - if mtrace.nchains > 1: - stat["variable_inclusion"] = np.vstack(stat["variable_inclusion"]) - else: - stat["variable_inclusion"] = [np.vstack(stat["variable_inclusion"])] - - if "bart_trees" in mtrace.stat_names: - for strace in mtrace._straces.values(): - for stat in strace._stats: - if "bart_trees" in stat: - if mtrace.nchains > 1: - stat["bart_trees"] = np.vstack(stat["bart_trees"]) - else: - stat["bart_trees"] = [np.vstack(stat["bart_trees"])] - n_chains = len(mtrace.chains) _log.info( f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations ' @@ -2309,8 +2244,8 @@ def init_nuts( vars = kwargs.get("vars", model.value_vars) if set(vars) != set(model.value_vars): raise ValueError("Must use init_nuts on all variables of a model.") - if not all_continuous(vars, model): - raise ValueError("init_nuts can only be used for models with only " "continuous variables.") + if not all_continuous(vars): + raise ValueError("init_nuts can only be used for models with continuous variables.") if not isinstance(init, str): raise TypeError("init must be a string.") diff --git a/pymc/step_methods/__init__.py b/pymc/step_methods/__init__.py index 2b419feecc..12f404850c 100644 --- a/pymc/step_methods/__init__.py +++ b/pymc/step_methods/__init__.py @@ -36,3 +36,13 @@ RecursiveDAProposal, ) from pymc.step_methods.slicer import Slice + +STEP_METHODS = ( + NUTS, + HamiltonianMC, + Metropolis, + BinaryMetropolis, + BinaryGibbsMetropolis, + Slice, + CategoricalGibbsMetropolis, +) diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index b55650eb80..39ca4a6fd1 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -18,7 +18,6 @@ from pymc.aesaraf import floatX from pymc.backends.report import SamplerWarning, WarningType -from pymc.bart.bart import BARTRV from pymc.math import logbern, logdiffexp_numpy from pymc.step_methods.arraystep import Competence from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData @@ -200,8 +199,8 @@ def competence(var, has_grad): """Check how appropriate this class is for sampling a random variable.""" dist = getattr(var.owner, "op", None) - if var.dtype in continuous_types and has_grad and not isinstance(dist, BARTRV): - return Competence.IDEAL + if var.dtype in continuous_types and has_grad: + return Competence.PREFERRED return Competence.INCOMPATIBLE def warnings(self): diff --git a/pymc/tests/test_bart.py b/pymc/tests/test_bart.py deleted file mode 100644 index 131ef9c858..0000000000 --- a/pymc/tests/test_bart.py +++ /dev/null @@ -1,130 +0,0 @@ -import numpy as np -import pytest - -from numpy.random import RandomState -from numpy.testing import assert_almost_equal, assert_array_equal - -import pymc as pm - -from pymc.tests.test_distributions_moments import assert_moment_is_expected - - -def test_split_node(): - split_node = pm.bart.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0) - assert split_node.index == 5 - assert split_node.idx_split_variable == 2 - assert split_node.split_value == 3.0 - assert split_node.depth == 2 - assert split_node.get_idx_parent_node() == 2 - assert split_node.get_idx_left_child() == 11 - assert split_node.get_idx_right_child() == 12 - - -def test_leaf_node(): - leaf_node = pm.bart.tree.LeafNode(index=5, value=3.14, idx_data_points=[1, 2, 3]) - assert leaf_node.index == 5 - assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3]) - assert leaf_node.value == 3.14 - assert leaf_node.get_idx_parent_node() == 2 - assert leaf_node.get_idx_left_child() == 11 - assert leaf_node.get_idx_right_child() == 12 - - -def test_bart_vi(): - X = np.random.normal(0, 1, size=(250, 3)) - Y = np.random.normal(0, 1, size=250) - X[:, 0] = np.random.normal(Y, 0.1) - - with pm.Model() as model: - mu = pm.BART("mu", X, Y, m=10) - sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415) - var_imp = ( - idata.sample_stats["variable_inclusion"] - .stack(samples=("chain", "draw")) - .mean("samples") - ) - var_imp /= var_imp.sum() - assert var_imp[0] > var_imp[1:].sum() - assert_almost_equal(var_imp.sum(), 1) - - -def test_missing_data(): - X = np.random.normal(0, 1, size=(50, 2)) - Y = np.random.normal(0, 1, size=50) - X[10:20, 0] = np.nan - - with pm.Model() as model: - mu = pm.BART("mu", X, Y, m=10) - sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(tune=10, draws=10, chains=1, random_seed=3415) - - -class TestUtils: - X_norm = np.random.normal(0, 1, size=(50, 2)) - X_binom = np.random.binomial(1, 0.5, size=(50, 1)) - X = np.hstack([X_norm, X_binom]) - Y = np.random.normal(0, 1, size=50) - - with pm.Model() as model: - mu = pm.BART("mu", X, Y, m=10) - sigma = pm.HalfNormal("sigma", 1) - y = pm.Normal("y", mu, sigma, observed=Y) - idata = pm.sample(random_seed=3415) - - def test_predict(self): - rng = RandomState(12345) - pred_all = pm.bart.utils.predict(self.idata, rng, size=2) - rng = RandomState(12345) - pred_first = pm.bart.utils.predict(self.idata, rng, X_new=self.X[:10]) - - assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) - assert pred_all.shape == (2, 50) - assert pred_first.shape == (10,) - - @pytest.mark.parametrize( - "kwargs", - [ - {}, - { - "kind": "pdp", - "samples": 2, - "xs_interval": "quantiles", - "xs_values": [0.25, 0.5, 0.75], - "var_discrete": [3], - }, - {"kind": "ice", "instances": 2}, - {"var_idx": [0], "rug": False, "smooth": False, "color": "k"}, - {"grid": (1, 2), "sharey": "none", "alpha": 1}, - ], - ) - def test_pdp(self, kwargs): - pm.bart.utils.plot_dependence(self.idata, X=self.X, Y=self.Y, **kwargs) - - def test_pdp_pandas_labels(self): - pd = pytest.importorskip("pandas") - - X_names = ["norm1", "norm2", "binom"] - X_pd = pd.DataFrame(self.X, columns=X_names) - Y_pd = pd.Series(self.Y, name="response") - axes = pm.bart.utils.plot_dependence(self.idata, X=X_pd, Y=Y_pd) - - figure = axes[0].figure - assert figure.texts[0].get_text() == "Predicted response" - assert_array_equal([ax.get_xlabel() for ax in axes], X_names) - - -@pytest.mark.parametrize( - "size, expected", - [ - (None, np.zeros(50)), - ], -) -def test_bart_moment(size, expected): - X = np.zeros((50, 2)) - Y = np.zeros(50) - with pm.Model() as model: - pm.BART("x", X=X, Y=Y, size=size) - assert_moment_is_expected(model, expected) diff --git a/pymc/tests/test_step.py b/pymc/tests/test_step.py index 7f8480004d..8d647b88fb 100644 --- a/pymc/tests/test_step.py +++ b/pymc/tests/test_step.py @@ -27,6 +27,8 @@ from aesara.graph.op import Op from numpy.testing import assert_array_almost_equal +import pymc as pm + from pymc.aesaraf import floatX from pymc.data import Data from pymc.distributions import ( @@ -741,6 +743,26 @@ def kill_grad(x): steps = assign_step_methods(model, []) assert isinstance(steps, Slice) + def test_modify_step_methods(self): + """Test step methods can be changed""" + # remove nuts from step_methods + step_methods = list(pm.STEP_METHODS) + step_methods.remove(NUTS) + pm.STEP_METHODS = step_methods + + with Model() as model: + Normal("x", 0, 1) + steps = assign_step_methods(model, []) + assert not isinstance(steps, NUTS) + + # add back nuts + pm.STEP_METHODS = step_methods + [NUTS] + + with Model() as model: + Normal("x", 0, 1) + steps = assign_step_methods(model, []) + assert isinstance(steps, NUTS) + class TestPopulationSamplers: