Skip to content

Commit

Permalink
FEATURE: add jax backend
Browse files Browse the repository at this point in the history
- add JAX to the backends
- various changes to support JIT compilation
- add JAX backend to the test suite
  • Loading branch information
ColmTalbot committed Dec 15, 2023
1 parent 3db35c5 commit 9a49404
Show file tree
Hide file tree
Showing 14 changed files with 325 additions and 105 deletions.
21 changes: 18 additions & 3 deletions gwpopulation/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
]
__all_with_scs = ["models.mass", "utils"]
__backend__ = ""
SUPPORTED_BACKENDS = ["numpy", "cupy"]
_scipy_module = dict(numpy="scipy", cupy="cupyx.scipy")
SUPPORTED_BACKENDS = ["numpy", "cupy", "jax"]
_np_module = dict(numpy="numpy", cupy="cupy", jax="jax.numpy")
_scipy_module = dict(numpy="scipy", cupy="cupyx.scipy", jax="jax.scipy")


def disable_cupy():
Expand Down Expand Up @@ -42,13 +43,27 @@ def set_backend(backend="numpy"):
elif backend == __backend__:
return

if backend == "jax":
from jax import config

config.update("jax_enable_x64", True)

from importlib import import_module

try:
xp = import_module(backend)
xp = import_module(_np_module[backend])
scs = import_module(_scipy_module[backend]).special
except ModuleNotFoundError:
raise ModuleNotFoundError(f"{backend} not installed")
except ImportError:
raise ImportError(f"{backend} installed but not importable")
if backend == "jax":
try:
from jax.scipy.integrate import trapezoid

xp.trapz = trapezoid
except ModuleNotFoundError:
pass
for module in __all_with_xp:
__backend__ = backend
import_module(f".{module}", package="gwpopulation").xp = xp
Expand Down
11 changes: 5 additions & 6 deletions gwpopulation/hyperpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from bilby.core.utils import logger
from bilby.hyper.model import Model

from .utils import get_name, to_numpy
from .utils import get_name, to_number, to_numpy

xp = np

Expand Down Expand Up @@ -137,14 +137,13 @@ def ln_likelihood_and_variance(self):
variance += selection_variance
ln_l += selection
self._pop_added(added_keys)
return ln_l, float(variance)
return ln_l, to_number(variance, float)

def log_likelihood_ratio(self):
ln_l, variance = self.ln_likelihood_and_variance()
if variance > self._max_variance or xp.isnan(ln_l):
return -self._inf
else:
return float(xp.nan_to_num(ln_l))
ln_l = xp.nan_to_num(ln_l, nan=-xp.inf)
ln_l -= xp.nan_to_num(xp.inf * (self.maximum_uncertainty < variance), nan=0)
return to_number(xp.nan_to_num(ln_l), float)

def noise_log_likelihood(self):
return self.total_noise_evidence
Expand Down
58 changes: 33 additions & 25 deletions gwpopulation/models/interped.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
from functools import partial

import numpy as np

from ..utils import to_numpy

xp = np


def _setup_interpolant(nodes, values, kind="cubic", backend=xp):
"""
Cache the information necessary for linear interpolation of the mass
ratio normalisation
"""
from cached_interpolate import RegularCachingInterpolant as CachingInterpolant

nodes = to_numpy(nodes)
interpolant = CachingInterpolant(nodes, nodes, kind=kind, backend=backend)
interpolant.conversion = xp.asarray(interpolant.conversion)
interpolant = partial(interpolant, xp.asarray(values))
return interpolant


class InterpolatedNoBaseModelIdentical(object):
"""
Base class for the Interpolated classes with no base model
Expand Down Expand Up @@ -35,50 +53,40 @@ def variable_names(self):
return keys

def setup_interpolant(self, nodes, values):
from cached_interpolate import CachingInterpolant

kwargs = dict(x=nodes, y=values, kind=self.kind, backend=xp)
self._norm_spline = CachingInterpolant(**kwargs)
kwargs = dict(kind=self.kind, backend=xp)
self._norm_spline = _setup_interpolant(nodes, self._xs, **kwargs)
self._data_spline = {
param: CachingInterpolant(**kwargs) for param in self.parameters
param: _setup_interpolant(nodes, values[param], **kwargs)
for param in self.parameters
}

def p_x_unnormed(self, dataset, parameter, x_splines, f_splines, **kwargs):

if self.spline_selector is None:
if self._norm_spline is None:
self.setup_interpolant(x_splines, f_splines)
if self._norm_spline is None:
self.setup_interpolant(x_splines, dataset)

self.spline_selector = (dataset[f"{parameter}"] >= x_splines[0]) & (
dataset[f"{parameter}"] <= x_splines[-1]
)
perturbation = self._data_spline[parameter](y=f_splines)

perturbation = self._data_spline[parameter](
x=dataset[f"{parameter}"][self.spline_selector], y=f_splines
p_x = xp.exp(perturbation)
p_x *= (dataset[f"{parameter}"] >= x_splines[0]) & (
dataset[f"{parameter}"] <= x_splines[-1]
)

p_x = xp.zeros(xp.shape(dataset[self.parameters[0]]))
p_x[self.spline_selector] = xp.exp(perturbation)
return p_x

def norm_p_x(self, f_splines=None, x_splines=None, **kwargs):
if self.norm_selector is None:
self.norm_selector = (self._xs >= x_splines[0]) & (
self._xs <= x_splines[-1]
)

perturbation = self._norm_spline(x=self._xs[self.norm_selector], y=f_splines)
p_x = xp.zeros(len(self._xs))
p_x[self.norm_selector] = xp.exp(perturbation)
perturbation = self._norm_spline(y=f_splines)
p_x = xp.exp(perturbation)
p_x *= (self._xs >= x_splines[0]) & (self._xs <= x_splines[-1])
norm = xp.trapz(p_x, self._xs)
return norm

def p_x_identical(self, dataset, **kwargs):

self.infer_n_nodes(**kwargs)

f_splines = np.array([kwargs[f"{key}"] for key in self.fkeys])
x_splines = np.array([kwargs[f"{key}"] for key in self.xkeys])
f_splines = xp.array([kwargs[key] for key in self.fkeys])
x_splines = xp.array([kwargs[key] for key in self.xkeys])

p_x = xp.ones(xp.shape(dataset[self.parameters[0]]))

Expand Down
57 changes: 28 additions & 29 deletions gwpopulation/models/mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,13 @@ def double_power_law_primary_mass(mass, alpha_1, alpha_2, mmin, mmax, break_frac
Maximum mass in the powerlaw distributed component (:math:`m_\max`).
"""

prob = xp.zeros_like(mass)
m_break = mmin + break_fraction * (mmax - mmin)
correction = powerlaw(m_break, alpha=-alpha_2, low=m_break, high=mmax) / powerlaw(
m_break, alpha=-alpha_1, low=mmin, high=m_break
)
low_part = powerlaw(mass[mass < m_break], alpha=-alpha_1, low=mmin, high=m_break)
prob[mass < m_break] = low_part * correction
high_part = powerlaw(mass[mass >= m_break], alpha=-alpha_2, low=m_break, high=mmax)
prob[mass >= m_break] = high_part
low_part = powerlaw(mass, alpha=-alpha_1, low=mmin, high=m_break)
high_part = powerlaw(mass, alpha=-alpha_2, low=m_break, high=mmax)
prob = low_part * (mass < m_break) * correction + high_part * (mass >= m_break)
return prob / (1 + correction)


Expand Down Expand Up @@ -527,14 +525,15 @@ def __call__(self, dataset, *args, **kwargs):
beta = kwargs.pop("beta")
mmin = kwargs.get("mmin", self.mmin)
mmax = kwargs.get("mmax", self.mmax)
if mmin < self.mmin:
raise ValueError(
"{self.__class__}: mmin ({mmin}) < self.mmin ({self.mmin})"
)
if mmax > self.mmax:
raise ValueError(
"{self.__class__}: mmax ({mmax}) > self.mmax ({self.mmax})"
)
if "jax" not in xp.__name__:
if mmin < self.mmin:
raise ValueError(
"{self.__class__}: mmin ({mmin}) < self.mmin ({self.mmin})"
)
if mmax > self.mmax:
raise ValueError(
"{self.__class__}: mmax ({mmax}) > self.mmax ({self.mmax})"
)
delta_m = kwargs.get("delta_m", 0)
p_m1 = self.p_m1(dataset, **kwargs, **self.kwargs)
p_q = self.p_q(dataset, beta=beta, mmin=mmin, delta_m=delta_m)
Expand All @@ -549,18 +548,24 @@ def p_m1(self, dataset, **kwargs):
dataset["mass_1"], mmin=mmin, mmax=self.mmax, delta_m=delta_m
)
norm = self.norm_p_m1(delta_m=delta_m, **kwargs)
print(
norm,
self.smoothing(
dataset["mass_1"], mmin=mmin, mmax=self.mmax, delta_m=delta_m
),
)
return p_m / norm

def norm_p_m1(self, delta_m, **kwargs):
"""Calculate the normalisation factor for the primary mass"""
mmin = kwargs.get("mmin", self.mmin)
if delta_m == 0:
if "jax" not in xp.__name__ and delta_m == 0:
return 1
p_m = self.__class__.primary_model(self.m1s, **kwargs)
p_m *= self.smoothing(self.m1s, mmin=mmin, mmax=self.mmax, delta_m=delta_m)

norm = xp.trapz(p_m, self.m1s)
return norm
return norm ** (delta_m > 0)

def p_q(self, dataset, beta, mmin, delta_m):
p_q = powerlaw(dataset["mass_ratio"], beta, 1, mmin / dataset["mass_1"])
Expand All @@ -580,13 +585,11 @@ def p_q(self, dataset, beta, mmin, delta_m):

def norm_p_q(self, beta, mmin, delta_m):
"""Calculate the mass ratio normalisation by linear interpolation"""
if delta_m == 0.0:
return 1
p_q = powerlaw(self.qs_grid, beta, 1, mmin / self.m1s_grid)
p_q *= self.smoothing(
self.m1s_grid * self.qs_grid, mmin=mmin, mmax=self.m1s_grid, delta_m=delta_m
)
norms = xp.nan_to_num(xp.trapz(p_q, self.qs, axis=0))
norms = xp.nan_to_num(xp.trapz(p_q, self.qs, axis=0)) ** (delta_m > 0)

return self._q_interpolant(norms)

Expand All @@ -595,16 +598,11 @@ def _cache_q_norms(self, masses):
Cache the information necessary for linear interpolation of the mass
ratio normalisation
"""
from functools import partial

from cached_interpolate import RegularCachingInterpolant as CachingInterpolant
from .interped import _setup_interpolant

from ..utils import to_numpy

nodes = to_numpy(self.m1s)
interpolant = CachingInterpolant(nodes, nodes, kind="cubic", backend=xp)
interpolant.conversion = xp.asarray(interpolant.conversion)
self._q_interpolant = partial(interpolant, xp.asarray(masses))
self._q_interpolant = _setup_interpolant(
self.m1s, masses, kind="cubic", backend=xp
)

@staticmethod
def smoothing(masses, mmin, mmax, delta_m):
Expand All @@ -623,8 +621,9 @@ def smoothing(masses, mmin, mmax, delta_m):
See also, https://en.wikipedia.org/wiki/Window_function#Planck-taper_window
"""
if delta_m > 0.0:
shifted_mass = xp.clip((masses - mmin) / delta_m, 1e-6, 1 - 1e-6)
if "jax" in xp.__name__ or delta_m > 0.0:
shifted_mass = xp.nan_to_num((masses - mmin) / delta_m, nan=0)
shifted_mass = xp.clip(shifted_mass, 1e-6, 1 - 1e-6)
exponent = 1 / shifted_mass - 1 / (1 - shifted_mass)
window = scs.expit(-exponent)
window *= (masses >= mmin) * (masses <= mmax)
Expand Down
2 changes: 0 additions & 2 deletions gwpopulation/models/spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def independent_spin_magnitude_beta(
amax_1, amax_2: float
Maximum spin of the more/less massive black hole.
"""
if alpha_chi_1 < 0 or beta_chi_1 < 0 or alpha_chi_2 < 0 or beta_chi_2 < 0:
return 0
prior = beta_dist(
dataset["a_1"], alpha_chi_1, beta_chi_1, scale=amax_1
) * beta_dist(dataset["a_2"], alpha_chi_2, beta_chi_2, scale=amax_2)
Expand Down
Loading

0 comments on commit 9a49404

Please sign in to comment.