Skip to content

Commit

Permalink
Rename FLRW to HasDistanceMeasures
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Mar 9, 2023
1 parent 0afc73a commit 202035c
Show file tree
Hide file tree
Showing 12 changed files with 98 additions and 100 deletions.
4 changes: 2 additions & 2 deletions src/cosmology/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""The Cosmology API standard."""

from cosmology.api._background import FriedmannLemaitreRobertsonWalker
from cosmology.api._components import (
HasBaryonComponent,
HasDarkEnergyComponent,
Expand All @@ -12,6 +11,7 @@
)
from cosmology.api._constants import CosmologyConstantsAPINamespace
from cosmology.api._core import CosmologyAPI
from cosmology.api._distances import HasDistanceMeasures
from cosmology.api._extras import HasHubbleParameter, HasTcmb
from cosmology.api._namespace import CosmologyAPINamespace
from cosmology.api._standard import StandardCosmologyAPI
Expand All @@ -22,7 +22,7 @@

__all__ = [
"CosmologyAPI",
"FriedmannLemaitreRobertsonWalker",
"HasDistanceMeasures",
"StandardCosmologyAPI",
# components
"HasGlobalCurvatureComponent",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@runtime_checkable
class FriedmannLemaitreRobertsonWalker(CosmologyAPI[ArrayT], Protocol):
class HasDistanceMeasures(CosmologyAPI[ArrayT], Protocol):
"""Cosmology API protocol for isotropic cosmologies.
This is a protocol class that defines the standard API for isotropic
Expand Down
4 changes: 2 additions & 2 deletions src/cosmology/api/_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Protocol, runtime_checkable

from cosmology.api._array_api import ArrayT
from cosmology.api._background import FriedmannLemaitreRobertsonWalker
from cosmology.api._components import (
HasBaryonComponent,
HasDarkEnergyComponent,
Expand All @@ -15,6 +14,7 @@
HasNeutrinoComponent,
HasPhotonComponent,
)
from cosmology.api._distances import HasDistanceMeasures
from cosmology.api._extras import HasHubbleParameter, HasTcmb

__all__: list[str] = []
Expand All @@ -31,7 +31,7 @@ class StandardCosmologyAPI(
HasGlobalCurvatureComponent[ArrayT],
HasTcmb[ArrayT],
HasHubbleParameter[ArrayT],
FriedmannLemaitreRobertsonWalker[ArrayT],
HasDistanceMeasures[ArrayT],
Protocol,
):
"""API Protocol for the standard cosmology and expected set of components.
Expand Down
44 changes: 21 additions & 23 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
CosmologyAPI,
CosmologyAPINamespace,
CosmologyConstantsAPINamespace,
FriedmannLemaitreRobertsonWalker,
HasBaryonComponent,
HasDarkEnergyComponent,
HasDarkMatterComponent,
HasDistanceMeasures,
HasGlobalCurvatureComponent,
HasMatterComponent,
HasNeutrinoComponent,
Expand Down Expand Up @@ -293,47 +293,45 @@ def hastcmb_cls(
# Background API


BACKGROUND_FLRW_ATTRS, BACKGROUND_FLRW_METHS = _get_attrs_meths(
FriedmannLemaitreRobertsonWalker, CosmologyAPI
)
DISTANCES_ATTRS, DISTANCES_METHS = _get_attrs_meths(HasDistanceMeasures, CosmologyAPI)


@pytest.fixture(scope="session")
def bkg_flrw_attrs() -> frozenset[str]:
"""The FriedmannLemaitreRobertsonWalker atributes."""
return BACKGROUND_FLRW_ATTRS
def dists_attrs() -> frozenset[str]:
"""The HasDistanceMeasures atributes."""
return DISTANCES_ATTRS


@pytest.fixture(scope="session")
def bkg_flrw_meths() -> frozenset[str]:
"""The FriedmannLemaitreRobertsonWalker methods."""
return BACKGROUND_FLRW_METHS
def dists_meths() -> frozenset[str]:
"""The HasDistanceMeasures methods."""
return DISTANCES_METHS


@pytest.fixture(scope="session")
def bkg_flrw_cls(
def dists_cls(
cosmology_cls: type[CosmologyAPI],
bkg_flrw_attrs: set[str],
bkg_flrw_meths: set[str],
) -> type[FriedmannLemaitreRobertsonWalker]:
dists_attrs: set[str],
dists_meths: set[str],
) -> type[HasDistanceMeasures]:
"""An example Background class."""
flds = set() # there are no fields
return make_dataclass(
"ExampleFriedmannLemaitreRobertsonWalker",
"ExampleHasDistanceMeasures",
[(n, Array, field(default_factory=_default_one)) for n in flds],
bases=(cosmology_cls,),
namespace={n: property(_return_one) for n in bkg_flrw_attrs - flds}
| {n: _return_1arg for n in bkg_flrw_meths},
namespace={n: property(_return_one) for n in dists_attrs - flds}
| {n: _return_1arg for n in dists_meths},
frozen=True,
)


@pytest.fixture(scope="session")
def bkg_flrw(
bkg_flrw_cls: type[FriedmannLemaitreRobertsonWalker],
) -> FriedmannLemaitreRobertsonWalker:
def dists(
dists_cls: type[HasDistanceMeasures],
) -> HasDistanceMeasures:
"""An example FLRW API instance."""
return bkg_flrw_cls()
return dists_cls()


# ==============================================================================
Expand All @@ -359,7 +357,7 @@ def standard_meths() -> frozenset[str]:

@pytest.fixture(scope="session")
def standard_cls( # noqa: PLR0913
bkg_flrw_cls: type[FriedmannLemaitreRobertsonWalker],
dists_cls: type[HasDistanceMeasures],
globalcurvature_cls: type[HasGlobalCurvatureComponent],
matter_cls: type[HasMatterComponent],
baryon_cls: type[HasBaryonComponent],
Expand All @@ -379,7 +377,7 @@ def standard_cls( # noqa: PLR0913
darkmatter_cls,
matter_cls,
globalcurvature_cls,
bkg_flrw_cls,
dists_cls,
)
flds = functools.reduce(
operator.or_, ({f.name for f in fields(c)} for c in bases)
Expand Down
60 changes: 0 additions & 60 deletions tests/api/test_background.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/api/test_components/test_darkenergy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ class ExampleHasDarkEnergyComponent:
# TODO: more examples?


def test_compliant_darkenergycomponent(bkg_flrw_cls):
def test_compliant_darkenergycomponent(dists_cls):
"""
Test that a compliant instance is a
`cosmology.api.HasDarkEnergyComponent`.
"""
ExampleHasDarkEnergyComponent = make_dataclass(
"ExampleHasDarkEnergyComponent",
[(n, Array, field(default_factory=_default_one)) for n in {"Omega_de0"}],
bases=(bkg_flrw_cls,),
bases=(dists_cls,),
namespace={"Omega_de": _return_1arg},
frozen=True,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/api/test_components/test_globalcurvature.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ class ExampleHasGlobalCurvatureComponent:
# TODO: more examples?


def test_compliant_globalcurvaturecomponent(bkg_flrw_cls):
def test_compliant_globalcurvaturecomponent(dists_cls):
"""
Test that a compliant instance is a
`cosmology.api.HasGlobalCurvatureComponent`.
"""
ExampleHasGlobalCurvatureComponent = make_dataclass(
"ExampleHasGlobalCurvatureComponent",
[(n, Array, field(default_factory=_default_one)) for n in {}],
bases=(bkg_flrw_cls,),
bases=(dists_cls,),
namespace={"Omega_k0": _return_one, "Omega_k": _return_1arg},
frozen=True,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/api/test_components/test_matter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ class ExampleHasMatterComponent:
# TODO: more examples?


def test_compliant_mattercomponent(bkg_flrw_cls):
def test_compliant_mattercomponent(dists_cls):
"""
Test that a compliant instance is a
`cosmology.api.HasMatterComponent`.
"""
ExampleHasMatterComponent = make_dataclass(
"ExampleHasMatterComponent",
[(n, Array, field(default_factory=_default_one)) for n in {"Omega_m0"}],
bases=(bkg_flrw_cls,),
bases=(dists_cls,),
namespace={"Omega_m": _return_1arg},
frozen=True,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/api/test_components/test_neutrino.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ class ExampleHasNeutrinoComponent:
# TODO: more examples?


def test_compliant_neutrinocomponent(bkg_flrw_cls):
def test_compliant_neutrinocomponent(dists_cls):
"""
Test that a compliant instance is a
`cosmology.api.HasNeutrinoComponent`.
"""
ExampleHasNeutrinoComponent = make_dataclass(
"ExampleHasNeutrinoComponent",
[(n, Array, field(default_factory=_default_one)) for n in {"Neff", "m_nu"}],
bases=(bkg_flrw_cls,),
bases=(dists_cls,),
namespace={"Omega_nu0": _return_one, "Omega_nu": _return_1arg},
frozen=True,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/api/test_components/test_photon.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ class ExampleHasPhotonComponent:
# TODO: more examples?


def test_compliant_photoncomponent(bkg_flrw_cls):
def test_compliant_photoncomponent(dists_cls):
"""
Test that a compliant instance is a
`cosmology.api.HasPhotonComponent`.
"""
ExampleHasPhotonComponent = make_dataclass(
"ExampleHasPhotonComponent",
[(n, Array, field(default_factory=_default_one)) for n in {}],
bases=(bkg_flrw_cls,),
bases=(dists_cls,),
namespace={"Omega_gamma0": _return_one, "Omega_gamma": _return_1arg},
frozen=True,
)
Expand Down
60 changes: 60 additions & 0 deletions tests/api/test_distances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Test ``cosmology.api.core``."""

from __future__ import annotations

from dataclasses import field, make_dataclass

from cosmology.api import HasDistanceMeasures
from cosmology.api._array_api import Array

from .conftest import _default_one, _return_1arg, _return_one

################################################################################
# TESTS
################################################################################


def test_noncompliant_dists():
"""
Test that a non-compliant instance is not a
`cosmology.api.HasDistanceMeasures`.
"""
# Simple example: missing everything

class HasDistanceMeasuresCosmology:
pass

cosmo = HasDistanceMeasuresCosmology()

assert not isinstance(cosmo, HasDistanceMeasures)

# TODO: more examples?


def test_compliant_dists(dists_cls, dists_attrs, dists_meths):
"""
Test that an instance is `cosmology.api.HasDistanceMeasures` even if it
doesn't inherit from `cosmology.api.HasDistanceMeasures`.
"""
flds = set()

HasDistanceMeasuresCosmology = make_dataclass(
"HasDistanceMeasuresCosmology",
[(n, Array, field(default_factory=_default_one)) for n in flds],
bases=(dists_cls,),
namespace={n: property(_return_one) for n in dists_attrs - set(flds)}
| {n: _return_1arg for n in dists_meths},
frozen=True,
)

cosmo = HasDistanceMeasuresCosmology(name=None)

assert isinstance(cosmo, HasDistanceMeasures)


def test_fixture(dists):
"""
Test that the ``dists`` fixture is a
`cosmology.api.HasDistanceMeasures`.
"""
assert isinstance(dists, HasDistanceMeasures)
4 changes: 2 additions & 2 deletions tests/api/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from cosmology.api import StandardCosmologyAPI
from cosmology.api._array_api import Array
from cosmology.api._background import FriedmannLemaitreRobertsonWalker
from cosmology.api._components import (
HasBaryonComponent,
HasDarkEnergyComponent,
Expand All @@ -17,6 +16,7 @@
HasPhotonComponent,
)
from cosmology.api._core import CosmologyAPI
from cosmology.api._distances import HasDistanceMeasures
from cosmology.api._extras import HasHubbleParameter, HasTcmb

from .conftest import _default_one, _return_1arg, _return_one
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_compliant_standard(cosmology_cls, standard_attrs, standard_meths):

# Check Base and Background
assert isinstance(cosmo, CosmologyAPI)
assert isinstance(cosmo, FriedmannLemaitreRobertsonWalker)
assert isinstance(cosmo, HasDistanceMeasures)

# Check Components
assert isinstance(cosmo, HasBaryonComponent)
Expand Down

0 comments on commit 202035c

Please sign in to comment.