Skip to content

Commit

Permalink
MAINT: update for wcosmo astropy-like (#97)
Browse files Browse the repository at this point in the history
* MAINT: update for wcosmo astropy-like

* MAINT: update to wcosmo astropy-like

* FORMAT: add empty line

* MAINT: update inheritance test for new wcosmo

* TST: make tests work for different backends
  • Loading branch information
ColmTalbot authored Jul 15, 2024
1 parent 96fb1a0 commit 5084f2f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
9 changes: 6 additions & 3 deletions gwpopulation/experimental/cosmo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
"""

import numpy as xp
from wcosmo import FlatwCDM, available, z_at_value
from wcosmo import z_at_value
from wcosmo.astropy import WCosmoMixin, available
from wcosmo.utils import disable_units as wcosmo_disable_units

from .jax import NonCachingModel

Expand All @@ -24,6 +26,7 @@ class CosmoMixin:
"""

def __init__(self, cosmo_model="Planck15"):
wcosmo_disable_units()
self.cosmo_model = cosmo_model
if self.cosmo_model == "FlatwCDM":
self.cosmology_names = ["H0", "Om0", "w0"]
Expand Down Expand Up @@ -60,10 +63,10 @@ def cosmology(self, parameters):
Returns
=======
wcosmo.FlatwCDM
wcosmo.astropy.WCosmoMixin
The cosmology model.
"""
if isinstance(self._cosmo, FlatwCDM):
if isinstance(self._cosmo, WCosmoMixin):
return self._cosmo
else:
return self._cosmo(**self.cosmology_variables(parameters))
Expand Down
5 changes: 4 additions & 1 deletion gwpopulation/models/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,10 @@ def total_four_volume(lamb, analysis_time, max_redshift=2.3):
-----
This assumes a :code:`Planck15` cosmology.
"""
from wcosmo.wcosmo import Planck15
from wcosmo.astropy import Planck15
from wcosmo.utils import disable_units

disable_units()

redshifts = xp.linspace(0, max_redshift, 2500)
psi_of_z = (1 + redshifts) ** lamb
Expand Down
9 changes: 6 additions & 3 deletions test/redshift_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import pytest
from astropy.cosmology import Planck15
from bilby.core.prior import PriorDict, Uniform
from wcosmo.astropy import Planck15
from wcosmo.utils import disable_units

import gwpopulation
from gwpopulation.models import redshift
Expand Down Expand Up @@ -50,13 +51,14 @@ def test_powerlaw_volume(backend):
trivial case
"""
gwpopulation.set_backend(backend)
disable_units()
xp = gwpopulation.utils.xp
zs = xp.linspace(1e-3, 2.3, 1000)
zs_numpy = gwpopulation.utils.to_numpy(zs)
model = redshift.PowerLawRedshift()
parameters = dict(lamb=1)
total_volume = np.trapz(
Planck15.differential_comoving_volume(zs_numpy).value * 4 * np.pi,
Planck15.differential_comoving_volume(zs_numpy) * 4 * np.pi,
zs_numpy,
)
approximation = float(model.normalisation(parameters))
Expand All @@ -69,8 +71,9 @@ def test_zero_outside_domain():


def test_four_volume():
disable_units()
assert (
Planck15.comoving_volume(2.3).value / 1e9
Planck15.comoving_volume(2.3) / 1e9
- redshift.total_four_volume(lamb=1, analysis_time=1, max_redshift=2.3)
< 1e-3
)

0 comments on commit 5084f2f

Please sign in to comment.