diff --git a/pymc3/step_methods/pgbart.py b/pymc3/step_methods/pgbart.py index fab674b20a6..351f1ae8a26 100644 --- a/pymc3/step_methods/pgbart.py +++ b/pymc3/step_methods/pgbart.py @@ -61,7 +61,7 @@ class PGBART(ArrayStepShared): generates_stats = True stats_dtypes = [{"variable_inclusion": np.ndarray}] - def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", model=None): + def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", model=None): _log.warning("BART is experimental. Use with caution.") model = modelcontext(model) initial_values = model.initial_point @@ -78,7 +78,8 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m self.init_mean = self.Y.mean() # if data is binary - if np.all(np.unique(self.Y) == [0, 1]): + Y_unique = np.unique(self.Y) + if Y_unique.size == 2 and np.all(Y_unique == [0, 1]): self.mu_std = 6 / (self.k * self.m ** 0.5) # maybe we need to check for count data else: @@ -97,6 +98,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m self.mean = fast_mean() self.normal = NormalSampler() self.prior_prob_leaf_node = compute_prior_probability(self.alpha) + self.ssv = SampleSplittingVariable(self.split_prior) self.tune = True self.idx = 0 @@ -120,7 +122,6 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m self.a_tree.tree_id = i p = ParticleTree( self.a_tree, - self.prior_prob_leaf_node, self.init_log_weight, self.init_likelihood, ) @@ -132,7 +133,6 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: sum_trees_output = q.data variable_inclusion = np.zeros(self.num_variates, dtype="int") - self.ssv = SampleSplittingVariable(self.split_prior) if self.idx == self.m: self.idx = 0 @@ -140,21 +140,24 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: for idx in range(self.idx, self.idx + self.chunk): if idx >= self.m: break - self.idx += 1 tree = self.all_particles[idx].tree sum_trees_output_noi = sum_trees_output - tree.predict_output() + self.idx += 1 # Generate an initial set of SMC particles # at the end of the algorithm we return one of these particles as the new tree particles = self.init_particles(tree.tree_id) - for t in range(1, self.max_stages): + for t in range(self.max_stages): # Get old particle at stage t - particles[0] = self.get_old_tree_particle(tree.tree_id, t) + if t > 0: + particles[0] = self.get_old_tree_particle(tree.tree_id, t) # sample each particle (try to grow each tree) - for c in range(1, self.num_particles): - particles[c].sample_tree_sequential( + compute_logp = [True] + for p in particles[1:]: + clp = p.sample_tree_sequential( self.ssv, self.available_predictors, + self.prior_prob_leaf_node, self.X, self.missing_data, sum_trees_output, @@ -163,34 +166,34 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: self.normal, self.mu_std, ) + compute_logp.append(clp) # Update weights. Since the prior is used as the proposal,the weights # are updated additively as the ratio of the new and old log_likelihoods - for p in particles: - new_likelihood = self.likelihood_logp( - sum_trees_output_noi + p.tree.predict_output() - ) - p.log_weight += new_likelihood - p.old_likelihood_logp - p.old_likelihood_logp = new_likelihood - + for clp, p in zip(compute_logp, particles): + if clp: # Compute the likelihood when p has changed from the previous iteration + new_likelihood = self.likelihood_logp( + sum_trees_output_noi + p.tree.predict_output() + ) + p.log_weight += new_likelihood - p.old_likelihood_logp + p.old_likelihood_logp = new_likelihood # Normalize weights W_t, normalized_weights = self.normalize(particles) - # Set the new weights - for p in particles: - p.log_weight = W_t - # Resample all but first particle re_n_w = normalized_weights[1:] / normalized_weights[1:].sum() - new_indices = np.random.choice(self.indices, size=len(self.indices), p=re_n_w) particles[1:] = particles[new_indices] + # Set the new weights + for p in particles: + p.log_weight = W_t + # Check if particles can keep growing, otherwise stop iterating - non_available_nodes_for_expansion = np.ones(self.num_particles - 1) - for c in range(1, self.num_particles): - if len(particles[c].expansion_nodes) != 0: - non_available_nodes_for_expansion[c - 1] = 0 - if np.all(non_available_nodes_for_expansion): + non_available_nodes_for_expansion = [] + for p in particles[1:]: + if p.expansion_nodes: + non_available_nodes_for_expansion.append(0) + if all(non_available_nodes_for_expansion): break # Get the new tree and update @@ -203,6 +206,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: if self.tune: for index in new_particle.used_variates: self.split_prior[index] += 1 + self.ssv = SampleSplittingVariable(self.split_prior) else: self.iter += 1 self.sum_trees.append(new_tree) @@ -253,14 +257,16 @@ def init_particles(self, tree_id): """ Initialize particles """ - particles = [self.get_old_tree_particle(tree_id, 0)] + p = self.get_old_tree_particle(tree_id, 0) + p.log_weight = self.init_log_weight + p.old_likelihood_logp = self.init_likelihood + particles = [p] - for _ in range(1, self.num_particles): + for _ in self.indices: self.a_tree.tree_id = tree_id particles.append( ParticleTree( self.a_tree, - self.prior_prob_leaf_node, self.init_log_weight, self.init_likelihood, ) @@ -274,13 +280,12 @@ class ParticleTree: Particle tree """ - def __init__(self, tree, prior_prob_leaf_node, log_weight=0, likelihood=0): + def __init__(self, tree, log_weight, likelihood): self.tree = tree.copy() # keeps the tree that we care at the moment self.expansion_nodes = [0] self.tree_history = [self.tree] self.expansion_nodes_history = [self.expansion_nodes] self.log_weight = log_weight - self.prior_prob_leaf_node = prior_prob_leaf_node self.old_likelihood_logp = likelihood self.used_variates = [] @@ -288,6 +293,7 @@ def sample_tree_sequential( self, ssv, available_predictors, + prior_prob_leaf_node, X, missing_data, sum_trees_output, @@ -296,13 +302,14 @@ def sample_tree_sequential( normal, mu_std, ): + clp = False if self.expansion_nodes: index_leaf_node = self.expansion_nodes.pop(0) # Probability that this node will remain a leaf node - prob_leaf = self.prior_prob_leaf_node[self.tree[index_leaf_node].depth] + prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth] if prob_leaf < np.random.random(): - index_selected_predictor = grow_tree( + clp, index_selected_predictor = grow_tree( self.tree, index_leaf_node, ssv, @@ -315,21 +322,20 @@ def sample_tree_sequential( normal, mu_std, ) - if index_selected_predictor is not None: + if clp: new_indexes = self.tree.idx_leaf_nodes[-2:] self.expansion_nodes.extend(new_indexes) self.used_variates.append(index_selected_predictor) self.tree_history.append(self.tree) self.expansion_nodes_history.append(self.expansion_nodes) + return clp def set_particle_to_step(self, t): if len(self.tree_history) <= t: - self.tree = self.tree_history[-1] - self.expansion_nodes = self.expansion_nodes_history[-1] - else: - self.tree = self.tree_history[t] - self.expansion_nodes = self.expansion_nodes_history[t] + t = -1 + self.tree = self.tree_history[t] + self.expansion_nodes = self.expansion_nodes_history[t] def preprocess_XY(X, Y): @@ -410,7 +416,7 @@ def grow_tree( ] if available_splitting_values.size == 0: - return None + return False, None idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values)) selected_splitting_rule = available_splitting_values[idx_selected_splitting_values] @@ -443,7 +449,7 @@ def grow_tree( ) tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node) - return index_selected_predictor + return True, index_selected_predictor def get_new_idx_data_points(current_split_node, idx_data_points, X):