Skip to content

Commit

Permalink
Refactored how the simulation strategy is selected
Browse files Browse the repository at this point in the history
  • Loading branch information
MetinSa committed Sep 4, 2021
1 parent 16fead3 commit ac84be6
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 39 deletions.
2 changes: 1 addition & 1 deletion zodipy/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_coordinates(
X_earth: np.ndarray,
X_unit: np.ndarray,
R: np.ndarray,
) -> Tuple[np.ndarray]:
) -> Tuple[Tuple[np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
"""Returns coordinates for which to evaluate the density.
The density of a component is computed in the prime coordinate
Expand Down
3 changes: 1 addition & 2 deletions zodipy/_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np


EmissionCallable = Callable[
[float, np.ndarray, np.ndarray, np.ndarray, float], np.ndarray
]
Expand All @@ -18,7 +17,7 @@ def trapezoidal(
npix: int,
pixels: np.ndarray,
) -> np.ndarray:
"""Integrates the emission for a component using trapezoidal."""
"""Integrates the emission for a component using the trapezoidal method."""

comp_emission = np.zeros(npix)[pixels]
emission_prev = emission_func(freq, x_obs, x_earth, x_unit, R[0])
Expand Down
31 changes: 16 additions & 15 deletions zodipy/_los_config.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,45 @@
from typing import Tuple, Dict
from typing import Tuple, Dict, Union

import numpy as np

LOSConfigType = Dict[str, Union[np.ndarray, Tuple[float, int]]]


class LOSFactory:
"""Factory responsible for registring and book-keeping a line-of-sight (LOS)."""
"""Factory responsible for registering and book-keeping of LOS configs."""

def __init__(self) -> None:
self._configs = {}

def register_config(
self, name: str, components: Dict[str, Tuple[float, int]]
) -> None:
def register_config(self, name: str, components: LOSConfigType) -> None:
"""Initializes and stores a LOS."""

error_msg = (
"Line-of-sight config must either be an array, or a tuple with "
"the format (start, stop, n, geom) where geom is either "
"'linear' or 'log'"
)
config = {}
for key, value in components.items():
if isinstance(value, np.ndarray):
config[key] = value
elif isinstance(value, (tuple, list)):
elif isinstance(value, tuple):
try:
start, stop, n, geom = value
except ValueError:
raise ValueError(
"Line-of-sight config must either be an array, or "
"a tuple with the format (start, stop, n, geom)"
"where geom is either 'linear' or 'log'"
)
raise ValueError(error_msg)
if geom.lower() == "linear":
geom = np.linspace
elif geom.lower() == "log":
geom = np.geomspace
else:
raise ValueError("geom must be either 'linear' or 'log'")
raise ValueError(error_msg)
config[key] = geom(start, stop, n)

else:
raise ValueError(error_msg)
self._configs[name] = config

def get_config(self, name: str) -> np.ndarray:
def get_config(self, name: str) -> Dict[str, np.ndarray]:
"""Returns a registered config."""

config = self._configs.get(name)
Expand All @@ -46,5 +48,4 @@ def get_config(self, name: str) -> np.ndarray:
f"Config {name} is not registered. Available configs are "
f"{list(self._configs.keys())}"
)

return config
32 changes: 32 additions & 0 deletions zodipy/_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,35 @@ def simulate(self, nside: int, freq: float) -> np.ndarray:
warnings.filterwarnings("ignore", category=RuntimeWarning)

return emission / hit_counts.sum(axis=0) * 1e20


def get_simulation_strategy(
model: InterplanetaryDustModel,
line_of_sight_config: Dict[str, np.ndarray],
observer_locations: np.ndarray,
earth_locations: np.ndarray,
hit_counts: np.ndarray,
) -> SimulationStrategy:
"""Initializes and returns a simulation strategy given initial conditions."""

number_of_observations = len(observer_locations)
if hit_counts is not None:
hit_counts = np.asarray(hit_counts)
number_of_hit_counts = 1 if np.ndim(hit_counts) == 1 else len(hit_counts)
if number_of_hit_counts != number_of_observations:
raise ValueError(
f"The number of 'hit_counts' ({number_of_hit_counts}) are "
"not matching the number of observations "
f"({number_of_observations})"
)

if number_of_observations == 1:
simulation_strategy = InstantaneousStrategy
observer_locations = observer_locations.squeeze()
earth_locations = earth_locations.squeeze()
else:
simulation_strategy = TimeOrderedStrategy

return simulation_strategy(
model, line_of_sight_config, observer_locations, earth_locations, hit_counts
)
2 changes: 1 addition & 1 deletion zodipy/los_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import numpy as np


EPS = np.finfo(float).eps
RADIAL_CUTOFF = 6


LOS_configs = LOSFactory()

LOS_configs.register_config(
Expand Down
22 changes: 2 additions & 20 deletions zodipy/zodi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from zodipy._coordinates import get_target_coordinates, to_frame
from zodipy._simulation import InstantaneousStrategy, TimeOrderedStrategy
from zodipy._simulation import get_simulation_strategy
from zodipy.los_configs import LOS_configs
from zodipy.models import models

Expand Down Expand Up @@ -59,28 +59,10 @@ def __init__(

model = models.get_model(model)
line_of_sight_config = LOS_configs.get_config(line_of_sight_config)

observer_locations = get_target_coordinates(observer, epochs)
earth_locations = get_target_coordinates("earth", epochs)

number_of_observations = len(observer_locations)
if hit_counts is not None:
hit_counts = np.asarray(hit_counts)
number_of_hit_counts = 1 if np.ndim(hit_counts) == 1 else len(hit_counts)
if number_of_hit_counts != number_of_observations:
raise ValueError(
f"The number of 'hit_counts' ({number_of_hit_counts}) are not "
"matching the number of observations "
f"({number_of_observations})"
)

if number_of_observations == 1:
simulation_strategy = InstantaneousStrategy
observer_locations = observer_locations.squeeze()
earth_locations = earth_locations.squeeze()
else:
simulation_strategy = TimeOrderedStrategy
self._simulation_strategy = simulation_strategy(
self._simulation_strategy = get_simulation_strategy(
model, line_of_sight_config, observer_locations, earth_locations, hit_counts
)

Expand Down

0 comments on commit ac84be6

Please sign in to comment.