Skip to content

Commit

Permalink
clean code, refactor and small speed-up
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Aug 23, 2021
1 parent 1ac00fc commit 4d8696f
Showing 1 changed file with 47 additions and 41 deletions.
88 changes: 47 additions & 41 deletions pymc3/step_methods/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class PGBART(ArrayStepShared):
generates_stats = True
stats_dtypes = [{"variable_inclusion": np.ndarray}]

def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", model=None):
def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", model=None):
_log.warning("BART is experimental. Use with caution.")
model = modelcontext(model)
initial_values = model.initial_point
Expand All @@ -78,7 +78,8 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m

self.init_mean = self.Y.mean()
# if data is binary
if np.all(np.unique(self.Y) == [0, 1]):
Y_unique = np.unique(self.Y)
if Y_unique.size == 2 and np.all(Y_unique == [0, 1]):
self.mu_std = 6 / (self.k * self.m ** 0.5)
# maybe we need to check for count data
else:
Expand All @@ -97,6 +98,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
self.mean = fast_mean()
self.normal = NormalSampler()
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
self.ssv = SampleSplittingVariable(self.split_prior)

self.tune = True
self.idx = 0
Expand All @@ -120,7 +122,6 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
self.a_tree.tree_id = i
p = ParticleTree(
self.a_tree,
self.prior_prob_leaf_node,
self.init_log_weight,
self.init_likelihood,
)
Expand All @@ -132,29 +133,31 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
sum_trees_output = q.data

variable_inclusion = np.zeros(self.num_variates, dtype="int")
self.ssv = SampleSplittingVariable(self.split_prior)

if self.idx == self.m:
self.idx = 0

for idx in range(self.idx, self.idx + self.chunk):
if idx >= self.m:
break
self.idx += 1
tree = self.all_particles[idx].tree
sum_trees_output_noi = sum_trees_output - tree.predict_output()
self.idx += 1
# Generate an initial set of SMC particles
# at the end of the algorithm we return one of these particles as the new tree
particles = self.init_particles(tree.tree_id)

for t in range(1, self.max_stages):
for t in range(self.max_stages):
# Get old particle at stage t
particles[0] = self.get_old_tree_particle(tree.tree_id, t)
if t > 0:
particles[0] = self.get_old_tree_particle(tree.tree_id, t)
# sample each particle (try to grow each tree)
for c in range(1, self.num_particles):
particles[c].sample_tree_sequential(
compute_logp = [True]
for p in particles[1:]:
clp = p.sample_tree_sequential(
self.ssv,
self.available_predictors,
self.prior_prob_leaf_node,
self.X,
self.missing_data,
sum_trees_output,
Expand All @@ -163,34 +166,34 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
self.normal,
self.mu_std,
)
compute_logp.append(clp)
# Update weights. Since the prior is used as the proposal,the weights
# are updated additively as the ratio of the new and old log_likelihoods
for p in particles:
new_likelihood = self.likelihood_logp(
sum_trees_output_noi + p.tree.predict_output()
)
p.log_weight += new_likelihood - p.old_likelihood_logp
p.old_likelihood_logp = new_likelihood

for clp, p in zip(compute_logp, particles):
if clp: # Compute the likelihood when p has changed from the previous iteration
new_likelihood = self.likelihood_logp(
sum_trees_output_noi + p.tree.predict_output()
)
p.log_weight += new_likelihood - p.old_likelihood_logp
p.old_likelihood_logp = new_likelihood
# Normalize weights
W_t, normalized_weights = self.normalize(particles)

# Set the new weights
for p in particles:
p.log_weight = W_t

# Resample all but first particle
re_n_w = normalized_weights[1:] / normalized_weights[1:].sum()

new_indices = np.random.choice(self.indices, size=len(self.indices), p=re_n_w)
particles[1:] = particles[new_indices]

# Set the new weights
for p in particles:
p.log_weight = W_t

# Check if particles can keep growing, otherwise stop iterating
non_available_nodes_for_expansion = np.ones(self.num_particles - 1)
for c in range(1, self.num_particles):
if len(particles[c].expansion_nodes) != 0:
non_available_nodes_for_expansion[c - 1] = 0
if np.all(non_available_nodes_for_expansion):
non_available_nodes_for_expansion = []
for p in particles[1:]:
if p.expansion_nodes:
non_available_nodes_for_expansion.append(0)
if all(non_available_nodes_for_expansion):
break

# Get the new tree and update
Expand All @@ -203,6 +206,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
if self.tune:
for index in new_particle.used_variates:
self.split_prior[index] += 1
self.ssv = SampleSplittingVariable(self.split_prior)
else:
self.iter += 1
self.sum_trees.append(new_tree)
Expand Down Expand Up @@ -253,14 +257,16 @@ def init_particles(self, tree_id):
"""
Initialize particles
"""
particles = [self.get_old_tree_particle(tree_id, 0)]
p = self.get_old_tree_particle(tree_id, 0)
p.log_weight = self.init_log_weight
p.old_likelihood_logp = self.init_likelihood
particles = [p]

for _ in range(1, self.num_particles):
for _ in self.indices:
self.a_tree.tree_id = tree_id
particles.append(
ParticleTree(
self.a_tree,
self.prior_prob_leaf_node,
self.init_log_weight,
self.init_likelihood,
)
Expand All @@ -274,20 +280,20 @@ class ParticleTree:
Particle tree
"""

def __init__(self, tree, prior_prob_leaf_node, log_weight=0, likelihood=0):
def __init__(self, tree, log_weight, likelihood):
self.tree = tree.copy() # keeps the tree that we care at the moment
self.expansion_nodes = [0]
self.tree_history = [self.tree]
self.expansion_nodes_history = [self.expansion_nodes]
self.log_weight = log_weight
self.prior_prob_leaf_node = prior_prob_leaf_node
self.old_likelihood_logp = likelihood
self.used_variates = []

def sample_tree_sequential(
self,
ssv,
available_predictors,
prior_prob_leaf_node,
X,
missing_data,
sum_trees_output,
Expand All @@ -296,13 +302,14 @@ def sample_tree_sequential(
normal,
mu_std,
):
clp = False
if self.expansion_nodes:
index_leaf_node = self.expansion_nodes.pop(0)
# Probability that this node will remain a leaf node
prob_leaf = self.prior_prob_leaf_node[self.tree[index_leaf_node].depth]
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]

if prob_leaf < np.random.random():
index_selected_predictor = grow_tree(
clp, index_selected_predictor = grow_tree(
self.tree,
index_leaf_node,
ssv,
Expand All @@ -315,21 +322,20 @@ def sample_tree_sequential(
normal,
mu_std,
)
if index_selected_predictor is not None:
if clp:
new_indexes = self.tree.idx_leaf_nodes[-2:]
self.expansion_nodes.extend(new_indexes)
self.used_variates.append(index_selected_predictor)

self.tree_history.append(self.tree)
self.expansion_nodes_history.append(self.expansion_nodes)
return clp

def set_particle_to_step(self, t):
if len(self.tree_history) <= t:
self.tree = self.tree_history[-1]
self.expansion_nodes = self.expansion_nodes_history[-1]
else:
self.tree = self.tree_history[t]
self.expansion_nodes = self.expansion_nodes_history[t]
t = -1
self.tree = self.tree_history[t]
self.expansion_nodes = self.expansion_nodes_history[t]


def preprocess_XY(X, Y):
Expand Down Expand Up @@ -410,7 +416,7 @@ def grow_tree(
]

if available_splitting_values.size == 0:
return None
return False, None

idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values))
selected_splitting_rule = available_splitting_values[idx_selected_splitting_values]
Expand Down Expand Up @@ -443,7 +449,7 @@ def grow_tree(
)
tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node)

return index_selected_predictor
return True, index_selected_predictor


def get_new_idx_data_points(current_split_node, idx_data_points, X):
Expand Down

0 comments on commit 4d8696f

Please sign in to comment.