Skip to content

Commit

Permalink
Refactored SimulationStrategy classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
MetinSa committed Aug 20, 2021
1 parent c350365 commit b84e250
Showing 1 changed file with 58 additions and 64 deletions.
122 changes: 58 additions & 64 deletions zodipy/simulation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from math import radians
from typing import Union, Iterable, List
import warnings
from typing import Iterable

import healpy as hp
import numpy as np
Expand All @@ -13,7 +11,7 @@

@dataclass
class SimulationStrategy(ABC):
"""Base class that represents a simulation strategy.
"""Base class that represents a simulation strategy.
Attributes
----------
Expand All @@ -24,15 +22,17 @@ class SimulationStrategy(ABC):
Configuration object that determines how a component is integrated
along a line-of-sight.
observer_locations
The locations of the observer.
earth_locations
The locations of the Earth corresponding to the observer locations.
The location(s) of the observer.
earth_location
The location(s) of the Earth.
"""

model: InterplanetaryDustModel
integration_config: IntegrationConfig
observer_locations: Iterable
earth_locations: Iterable
hit_maps: np.ndarray


@abstractmethod
def simulate(self, nside: int, freq: float, solar_cut: float) -> np.ndarray:
Expand All @@ -57,72 +57,74 @@ def simulate(self, nside: int, freq: float, solar_cut: float) -> np.ndarray:
Simulated Zodiacal emission.
"""

@staticmethod
def get_observed_pixels(
X_observer: np.ndarray,
X_unit: np.ndarray,
solar_cut: Union[float, None]
) -> List[np.ndarray]:
"""Returns a list of observed pixels per observation.
All pixels that have an angular distance of larger than some angle
solar_cut between the observer and the sun are masked.
"""

if solar_cut is None:
return Ellipsis
@dataclass
class InstantaneousStrategy(SimulationStrategy):
"""Simulation strategy for instantaneous emission."""

angular_distance = (
hp.rotator.angdist(obs , X_unit) for obs in X_observer
)
def simulate(self, nside: int, freq: float) -> np.ndarray:
"""See base class for a description."""

observed_pixels = [
ang_dist < radians(solar_cut) for ang_dist in angular_distance
]
components = self.model.components
emissivities = self.model.emissivities

return observed_pixels
X_observer = self.observer_locations
X_earth = self.earth_locations

if (hit_map := self.hit_maps) is not None:
pixels = np.flatnonzero(hit_map)
else:
pixels = Ellipsis

class InstantaneousStrategy(SimulationStrategy):
"""Simulation strategy for instantaneous emission."""
npix = hp.nside2npix(nside)
X_unit = np.asarray(hp.pix2vec(nside, np.arange(npix)))[pixels]

def __init__(
self, model, integration_config, observer_locations, earth_locations
) -> None:
"""Initializing the strategy."""
emission = np.zeros((len(components), npix))

super().__init__(
model, integration_config, observer_locations, earth_locations
)
for comp_idx, (comp_name, comp) in enumerate(components.items()):
integration_config = self.integration_config[comp_name]
R = integration_config.R

def simulate(self, nside: int, freq: float, solar_cut: float) -> np.ndarray:
comp_emission = comp.get_emission(
freq, X_observer, X_earth, X_unit, R
)
integrated_comp_emission = integration_config.integrator(
comp_emission, R, dx=integration_config.dR, axis=0
)

comp_emissivity = emissivities.get_emissivity(comp_name, freq)
integrated_comp_emission *= comp_emissivity

emission[comp_idx, pixels] = integrated_comp_emission

return emission * 1e20


@dataclass
class TimeOrderedStrategy(SimulationStrategy):
"""Simulation strategy for time-ordered emission."""

def simulate(self, nside: int, freq: float) -> np.ndarray:
"""See base class for a description."""

npix = hp.nside2npix(nside)
pixels = np.arange(npix)

hit_maps = self.hit_maps
if hp.get_nside(hit_maps) != nside:
hit_maps = hp.ud_grade(self.hit_maps, nside, power=-2)

X_observer = self.observer_locations
X_earth = self.earth_locations
X_unit = np.asarray(hp.pix2vec(nside, pixels))

n_observations = len(X_observer)

pixels = self.get_observed_pixels(X_observer, X_unit, solar_cut)
X_unit = np.asarray(hp.pix2vec(nside, np.arange(npix)))

components = self.model.components
emissivities = self.model.emissivities

# Unobserved pixels are represented as NANs
emission = np.zeros((n_observations, len(components), npix)) + np.NAN
emission = np.zeros((len(components), npix))

for observation_idx, (observer_pos, earth_pos) in enumerate(
zip(X_observer, X_earth)
):
if solar_cut is None:
observed_pixels = pixels
else:
observed_pixels = pixels[observation_idx]
unit_vectors = X_unit[:, observed_pixels]
for observer_pos, earth_pos, hit_map in zip(X_observer, X_earth, hit_maps):
pixels = np.flatnonzero(hit_map)
unit_vectors = X_unit[:, pixels]

for comp_idx, (comp_name, comp) in enumerate(components.items()):
integration_config = self.integration_config[comp_name]
Expand All @@ -137,16 +139,8 @@ def simulate(self, nside: int, freq: float, solar_cut: float) -> np.ndarray:

comp_emissivity = emissivities.get_emissivity(comp_name, freq)
integrated_comp_emission *= comp_emissivity

emission[observation_idx, comp_idx, observed_pixels] = (
integrated_comp_emission
emission[comp_idx, pixels] += (
integrated_comp_emission * hit_map[pixels]
)

with warnings.catch_warnings():
# np.nanmean throws a RuntimeWarning if all pixels along an
# axis is NANs. This may occur when parts of the sky is left
# unobserved over all observations. Here we manually disable
# the warning thay is thrown in the aforementioned scenario.
warnings.filterwarnings("ignore", category=RuntimeWarning)

return np.nanmean(emission, axis=0) * 1e20
return emission / hit_maps.sum(axis=0) * 1e20

0 comments on commit b84e250

Please sign in to comment.