From 5084f2f739da198da110758d618adf474b756c8a Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Sun, 14 Jul 2024 19:29:28 -0500 Subject: [PATCH] MAINT: update for wcosmo astropy-like (#97) * MAINT: update for wcosmo astropy-like * MAINT: update to wcosmo astropy-like * FORMAT: add empty line * MAINT: update inheritance test for new wcosmo * TST: make tests work for different backends --- gwpopulation/experimental/cosmo_models.py | 9 ++++++--- gwpopulation/models/redshift.py | 5 ++++- test/redshift_test.py | 9 ++++++--- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/gwpopulation/experimental/cosmo_models.py b/gwpopulation/experimental/cosmo_models.py index 9af776d0..6067a21e 100644 --- a/gwpopulation/experimental/cosmo_models.py +++ b/gwpopulation/experimental/cosmo_models.py @@ -7,7 +7,9 @@ """ import numpy as xp -from wcosmo import FlatwCDM, available, z_at_value +from wcosmo import z_at_value +from wcosmo.astropy import WCosmoMixin, available +from wcosmo.utils import disable_units as wcosmo_disable_units from .jax import NonCachingModel @@ -24,6 +26,7 @@ class CosmoMixin: """ def __init__(self, cosmo_model="Planck15"): + wcosmo_disable_units() self.cosmo_model = cosmo_model if self.cosmo_model == "FlatwCDM": self.cosmology_names = ["H0", "Om0", "w0"] @@ -60,10 +63,10 @@ def cosmology(self, parameters): Returns ======= - wcosmo.FlatwCDM + wcosmo.astropy.WCosmoMixin The cosmology model. """ - if isinstance(self._cosmo, FlatwCDM): + if isinstance(self._cosmo, WCosmoMixin): return self._cosmo else: return self._cosmo(**self.cosmology_variables(parameters)) diff --git a/gwpopulation/models/redshift.py b/gwpopulation/models/redshift.py index fef8f0d5..6ccf4778 100644 --- a/gwpopulation/models/redshift.py +++ b/gwpopulation/models/redshift.py @@ -274,7 +274,10 @@ def total_four_volume(lamb, analysis_time, max_redshift=2.3): ----- This assumes a :code:`Planck15` cosmology. """ - from wcosmo.wcosmo import Planck15 + from wcosmo.astropy import Planck15 + from wcosmo.utils import disable_units + + disable_units() redshifts = xp.linspace(0, max_redshift, 2500) psi_of_z = (1 + redshifts) ** lamb diff --git a/test/redshift_test.py b/test/redshift_test.py index b683c001..5e886ad1 100644 --- a/test/redshift_test.py +++ b/test/redshift_test.py @@ -1,7 +1,8 @@ import numpy as np import pytest -from astropy.cosmology import Planck15 from bilby.core.prior import PriorDict, Uniform +from wcosmo.astropy import Planck15 +from wcosmo.utils import disable_units import gwpopulation from gwpopulation.models import redshift @@ -50,13 +51,14 @@ def test_powerlaw_volume(backend): trivial case """ gwpopulation.set_backend(backend) + disable_units() xp = gwpopulation.utils.xp zs = xp.linspace(1e-3, 2.3, 1000) zs_numpy = gwpopulation.utils.to_numpy(zs) model = redshift.PowerLawRedshift() parameters = dict(lamb=1) total_volume = np.trapz( - Planck15.differential_comoving_volume(zs_numpy).value * 4 * np.pi, + Planck15.differential_comoving_volume(zs_numpy) * 4 * np.pi, zs_numpy, ) approximation = float(model.normalisation(parameters)) @@ -69,8 +71,9 @@ def test_zero_outside_domain(): def test_four_volume(): + disable_units() assert ( - Planck15.comoving_volume(2.3).value / 1e9 + Planck15.comoving_volume(2.3) / 1e9 - redshift.total_four_volume(lamb=1, analysis_time=1, max_redshift=2.3) < 1e-3 )