Skip to content

Commit

Permalink
DEV: remove float casts for jax jit (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
ColmTalbot authored Jan 31, 2024
1 parent e50708f commit 8a9db99
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions gwpopulation/hyperpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,17 +211,17 @@ def generate_extra_statistics(self, sample):
ln_ls, variances = self._compute_per_event_ln_bayes_factors(
return_uncertainty=True
)
total_variance = float(sum(variances))
total_variance = sum(variances)
for ii in range(self.n_posteriors):
sample[f"ln_bf_{ii}"] = float(ln_ls[ii])
sample[f"var_{ii}"] = float(variances[ii])
sample[f"ln_bf_{ii}"] = to_number(ln_ls[ii], float)
sample[f"var_{ii}"] = to_number(variances[ii], float)
selection, variance = self._selection_function_with_uncertainty()
variance /= selection**2
selection_variance = variance * self.n_posteriors**2
sample["selection"] = selection
sample["selection_variance"] = variance
total_variance += selection_variance
sample["variance"] = float(total_variance)
sample["variance"] = to_number(total_variance, float)
if added_keys is not None:
for key in added_keys:
self.parameters.pop(key)
Expand Down

0 comments on commit 8a9db99

Please sign in to comment.