Skip to content

Commit

Permalink
Fix cupy numpy switching (#73)
Browse files Browse the repository at this point in the history
* DEV: improve backend settings

* TEST: make tests run on cupy if available
  • Loading branch information
ColmTalbot authored Oct 3, 2023
1 parent fd07a67 commit 858c68c
Show file tree
Hide file tree
Showing 20 changed files with 1,038 additions and 947 deletions.
35 changes: 6 additions & 29 deletions gwpopulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,16 @@
The code is hosted at `<www.github.com/ColmTalbot/gwpopulation>`_.
"""
from . import conversions, cupy_utils, hyperpe, models, utils, vt
from . import conversions, hyperpe, models, utils, vt
from .backend import SUPPORTED_BACKENDS, disable_cupy, enable_cupy, set_backend
from .hyperpe import RateLikelihood

try:
from ._version import version as __version__
except ModuleNotFoundError: # development mode
__version__ = "unknown"


__all_with_xp = [
models.mass,
models.redshift,
models.spin,
cupy_utils,
hyperpe,
utils,
vt,
]


def disable_cupy():
import numpy as np

for module in __all_with_xp:
module.xp = np


def enable_cupy():
try:
import cupy as cp
except ImportError:
import numpy as cp

print("Cannot import cupy, falling back to numpy.")
for module in __all_with_xp:
module.xp = cp
try:
set_backend("cupy")
except ModuleNotFoundError:
set_backend("numpy")
56 changes: 56 additions & 0 deletions gwpopulation/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
__all_with_xp = [
"hyperpe",
"models.interped",
"models.mass",
"models.redshift",
"models.spin",
"utils",
"vt",
]
__all_with_scs = ["models.mass", "utils"]
__backend__ = ""
SUPPORTED_BACKENDS = ["numpy", "cupy"]
_scipy_module = dict(numpy="scipy", cupy="cupyx.scipy")


def disable_cupy():
from warnings import warn

warn(
f"Function enable_cupy is deprecated, use set_backed('cupy') instead",
DeprecationWarning,
)
set_backend(backend="numpy")


def enable_cupy():
from warnings import warn

warn(
f"Function enable_cupy is deprecated, use set_backed('cupy') instead",
DeprecationWarning,
)
set_backend(backend="cupy")


def set_backend(backend="numpy"):
global __backend__
if backend not in SUPPORTED_BACKENDS:
raise ValueError(
f"Backend {backend} not supported, should be in {', '.join(SUPPORTED_BACKENDS)}"
)
elif backend == __backend__:
return

from importlib import import_module

try:
xp = import_module(backend)
scs = import_module(_scipy_module[backend]).special
except ModuleNotFoundError:
raise ModuleNotFoundError(f"{backend} not installed")
for module in __all_with_xp:
__backend__ = backend
import_module(f".{module}", package="gwpopulation").xp = xp
for module in __all_with_scs:
import_module(f".{module}", package="gwpopulation").scs = scs
119 changes: 0 additions & 119 deletions gwpopulation/cupy_utils.py

This file was deleted.

14 changes: 10 additions & 4 deletions gwpopulation/hyperpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from bilby.core.utils import logger
from bilby.hyper.model import Model

from .cupy_utils import CUPY_LOADED, to_numpy, xp
from .utils import get_name
from .utils import get_name, to_numpy

xp = np


class HyperparameterLikelihood(Likelihood):
Expand Down Expand Up @@ -67,8 +68,13 @@ def __init__(
If the uncertainty is larger than this value a log likelihood of
-inf will be returned. Default = inf
"""
if cupy and not CUPY_LOADED:
logger.warning("Cannot import cupy, falling back to numpy.")
if cupy:
from .backend import set_backend

try:
set_backend("cupy")
except ImportError:
logger.warning(f"Cupy not available, using {xp.__name__}.")

self.samples_per_posterior = max_samples
self.data = self.resample_posteriors(posteriors, max_samples=max_samples)
Expand Down
4 changes: 2 additions & 2 deletions gwpopulation/models/interped.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from ..cupy_utils import trapz, xp
xp = np


class InterpolatedNoBaseModelIdentical(object):
Expand Down Expand Up @@ -70,7 +70,7 @@ def norm_p_x(self, f_splines=None, x_splines=None, **kwargs):
perturbation = self._norm_spline(x=self._xs[self.norm_selector], y=f_splines)
p_x = xp.zeros(len(self._xs))
p_x[self.norm_selector] = xp.exp(perturbation)
norm = trapz(p_x, self._xs)
norm = xp.trapz(p_x, self._xs)
return norm

def p_x_identical(self, dataset, **kwargs):
Expand Down
53 changes: 25 additions & 28 deletions gwpopulation/models/mass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
"""
import inspect

from ..cupy_utils import trapz, xp
import numpy as np
import scipy.special as scs

from ..utils import powerlaw, truncnorm

xp = np


def double_power_law_primary_mass(mass, alpha_1, alpha_2, mmin, mmax, break_fraction):
r"""
Expand Down Expand Up @@ -555,7 +559,7 @@ def norm_p_m1(self, delta_m, **kwargs):
p_m = self.__class__.primary_model(self.m1s, **kwargs)
p_m *= self.smoothing(self.m1s, mmin=mmin, mmax=self.mmax, delta_m=delta_m)

norm = trapz(p_m, self.m1s)
norm = xp.trapz(p_m, self.m1s)
return norm

def p_q(self, dataset, beta, mmin, delta_m):
Expand All @@ -582,29 +586,25 @@ def norm_p_q(self, beta, mmin, delta_m):
p_q *= self.smoothing(
self.m1s_grid * self.qs_grid, mmin=mmin, mmax=self.m1s_grid, delta_m=delta_m
)
norms = trapz(p_q, self.qs, axis=0)

all_norms = (
norms[self.n_below] * (1 - self.step) + norms[self.n_above] * self.step
)
norms = xp.nan_to_num(xp.trapz(p_q, self.qs, axis=0))

return all_norms
return self._q_interpolant(norms)

def _cache_q_norms(self, masses):
"""
Cache the information necessary for linear interpolation of the mass
ratio normalisation
"""
self.n_below = xp.zeros_like(masses, dtype=int) - 1
m_below = xp.zeros_like(masses)
for mm in self.m1s:
self.n_below += masses > mm
m_below[masses > mm] = mm
self.n_above = self.n_below + 1
max_idx = len(self.m1s)
self.n_below[self.n_below < 0] = 0
self.n_above[self.n_above == max_idx] = max_idx - 1
self.step = xp.minimum((masses - m_below) / self.dm, 1)
from functools import partial

from cached_interpolate import RegularCachingInterpolant as CachingInterpolant

from ..utils import to_numpy

nodes = to_numpy(self.m1s)
interpolant = CachingInterpolant(nodes, nodes, kind="cubic", backend=xp)
interpolant.conversion = xp.asarray(interpolant.conversion)
self._q_interpolant = partial(interpolant, xp.asarray(masses))

@staticmethod
def smoothing(masses, mmin, mmax, delta_m):
Expand All @@ -623,17 +623,14 @@ def smoothing(masses, mmin, mmax, delta_m):
See also, https://en.wikipedia.org/wiki/Window_function#Planck-taper_window
"""
window = xp.ones_like(masses)
if delta_m > 0.0:
smoothing_region = (masses >= mmin) & (masses < (mmin + delta_m))
shifted_mass = masses[smoothing_region] - mmin
if shifted_mass.size:
exponent = xp.nan_to_num(
delta_m / shifted_mass + delta_m / (shifted_mass - delta_m)
)
window[smoothing_region] = 1 / (xp.exp(exponent) + 1)
window[(masses < mmin) | (masses > mmax)] = 0
return window
shifted_mass = xp.clip((masses - mmin) / delta_m, 1e-6, 1 - 1e-6)
exponent = 1 / shifted_mass - 1 / (1 - shifted_mass)
window = scs.expit(-exponent)
window *= (masses >= mmin) * (masses <= mmax)
return window
else:
return xp.ones(masses.shape)


class SinglePeakSmoothedMassDistribution(BaseSmoothedMassDistribution):
Expand Down
6 changes: 4 additions & 2 deletions gwpopulation/models/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import numpy as np

from ..cupy_utils import to_numpy, trapz, xp
from ..utils import to_numpy

xp = np


class _Redshift(object):
Expand Down Expand Up @@ -49,7 +51,7 @@ def normalisation(self, parameters):
(float, array-like): Total spacetime volume
"""
psi_of_z = self.psi_of_z(redshift=self.zs, **parameters)
norm = trapz(psi_of_z * self.dvc_dz / (1 + self.zs), self.zs)
norm = xp.trapz(psi_of_z * self.dvc_dz / (1 + self.zs), self.zs)
return norm

def probability(self, dataset, **parameters):
Expand Down
2 changes: 1 addition & 1 deletion gwpopulation/models/spin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Implemented spin models
"""
import numpy as xp

from ..cupy_utils import xp
from ..utils import beta_dist, truncnorm, unnormalized_2d_gaussian
from .interped import InterpolatedNoBaseModelIdentical

Expand Down
Loading

0 comments on commit 858c68c

Please sign in to comment.