Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract base class for analysis functions #206

Merged
merged 4 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/tape/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base import AnalysisFunction # noqa
from .light_curve import LightCurve # noqa
from .stetsonj import * # noqa
from .structurefunction2 import * # noqa
103 changes: 103 additions & 0 deletions src/tape/analysis/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
Contains the base class for analysis functions.
"""

from abc import ABC, abstractmethod
from typing import Callable, List

import pandas as pd


class AnalysisFunction(ABC, Callable):
"""Base class for analysis functions.

Analysis functions are functions that take few arrays representing
an object and return a single pandas.Series representing the result.

Methods
-------
cols(ens) -> List[str]
Return the columns that the analysis function takes as input.
meta(ens) -> pd.DataFrame
Return the metadata pandas.DataFrame required by Dask to pre-build
a computation graph. It is basically the schema for calculate() method
output.
on(ens) -> List[str]
Return the columns to group source table by.
Typically, `[ens._id_col]`.
__call__(*cols, **kwargs)
Calculate the analysis function.
"""

@abstractmethod
def cols(self, ens: "Ensemble") -> List[str]:
"""
Return the column names that the analysis function takes as input.

Parameters
----------
ens : Ensemble
The ensemble object, it could be required to get column names of
the "special" columns like `ens._time_col` or `ens._err_col`.

Returns
-------
List[str]
The column names to select and pass to .calculate() method.
For example `[ens._time_col, ens._flux_col]`.
"""
raise NotImplementedError

@abstractmethod
def meta(self, ens: "Ensemble"):
"""
Return the schema of the analysis function output.

Parameters
----------
ens : Ensemble
The ensemble object.

Returns
-------
pd.DataFrame or (str, dtype) tuple or {str: dtype} dictionary
Dask meta, for example
`pd.DataFrame(columns=['x', 'y'], dtype=float)`.
"""
raise NotImplementedError

@abstractmethod
def on(self, ens: "Ensemble") -> List[str]:
"""
Return the columns to group source table by.

Parameters
----------
ens : Ensemble
The ensemble object.

Returns:
--------
List[str]
The column names to group by. Typically, `[ens._id_col]`.
"""
return [ens._id_col]

@abstractmethod
def __call__(self, *cols, **kwargs):
"""Calculate the analysis function.

Parameters
----------
*cols : array_like
The columns to calculate the analysis function on. It must be
consistent with .cols(ens) output.
**kwargs
Additional keyword arguments.

Returns
-------
pd.Series or pd.DataFrame or array or value
The result, it must be consistent with .meta() output.
"""
raise NotImplementedError
152 changes: 92 additions & 60 deletions src/tape/analysis/stetsonj.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,96 @@
import numpy as np


def calc_stetson_J(flux, err, band, band_to_calc=None, check_nans=True):
"""Compute the StetsonJ statistic on data from one or several bands

Parameters
----------
flux : `numpy.ndarray` (N,)
Array of flux/magnitude measurements
err : `numpy.ndarray` (N,)
Array of associated flux/magnitude errors
band : `numpy.ndarray` (N,)
Array of associated band labels
band_to_calc : `str` or `list` of `str`
Bands to calculate StetsonJ on. Single band descriptor, or list
of such descriptors.
check_nans : `bool`
Boolean to run a check for NaN values and filter them out.

Returns
-------
stetsonJ : `dict`
StetsonJ statistic for each of input bands.

Notes
----------
In case that no value for `band_to_calc` is passed, the function is
executed on all available bands in `band`.
"""

# NaN filtering
if check_nans:
f_mask = np.isnan(flux)
e_mask = np.isnan(err) # always mask out nan errors?
nan_mask = np.logical_or(f_mask, e_mask)

flux = flux[~nan_mask]
err = err[~nan_mask]
band = band[~nan_mask]
from typing import Iterable, List, Optional, Union

unq_band = np.unique(band)

if band_to_calc is None:
band_to_calc = unq_band
if isinstance(band_to_calc, str):
band_to_calc = [band_to_calc]

assert hasattr(band_to_calc, "__iter__") is True

stetsonJ = {}
for b in band_to_calc:
if b in unq_band:
mask = band == b
fluxes = flux[mask]
errors = err[mask]
stetsonJ[b] = _stetson_J_single(fluxes, errors)
else:
stetsonJ[b] = np.nan

return stetsonJ
import numpy as np
import pandas as pd

from tape.analysis.base import AnalysisFunction


__all__ = ["calc_stetson_J", "StetsonJ"]


class StetsonJ(AnalysisFunction):
"""Compute the StetsonJ statistic on data from one or several bands"""

def cols(self, ens: "Ensemble") -> List[str]:
return [ens._flux_col, ens._err_col, ens._band_col]

def meta(self, ens: "Ensemble"):
return "stetsonJ", float

def on(self, ens: "Ensemble") -> List[str]:
return [ens._id_col]

def __call__(
self,
flux: np.ndarray,
err: np.ndarray,
band: np.ndarray,
*,
band_to_calc: Union[str, Iterable[str], None] = None,
check_nans: bool = False,
):
"""Compute the StetsonJ statistic on data from one or several bands

Parameters
----------
flux : `numpy.ndarray` (N,)
Array of flux/magnitude measurements
err : `numpy.ndarray` (N,)
Array of associated flux/magnitude errors
band : `numpy.ndarray` (N,)
Array of associated band labels
band_to_calc : `str` or `list` of `str`
Bands to calculate StetsonJ on. Single band descriptor, or list
of such descriptors.
check_nans : `bool`
Boolean to run a check for NaN values and filter them out.

Returns
-------
stetsonJ : `dict`
StetsonJ statistic for each of input bands.

Notes
----------
In case that no value for `band_to_calc` is passed, the function is
executed on all available bands in `band`.
"""

# NaN filtering
if check_nans:
f_mask = np.isnan(flux)
e_mask = np.isnan(err) # always mask out nan errors?
nan_mask = np.logical_or(f_mask, e_mask)

flux = flux[~nan_mask]
err = err[~nan_mask]
band = band[~nan_mask]

unq_band = np.unique(band)

if band_to_calc is None:
band_to_calc = unq_band
if isinstance(band_to_calc, str):
band_to_calc = [band_to_calc]

assert isinstance(band_to_calc, Iterable) is True

stetsonJ = {}
for b in band_to_calc:
if b in unq_band:
mask = band == b
fluxes = flux[mask]
errors = err[mask]
stetsonJ[b] = _stetson_J_single(fluxes, errors)
else:
stetsonJ[b] = np.nan

return stetsonJ


calc_stetson_J = StetsonJ()
calc_stetson_J.__doc__ = StetsonJ.__call__.__doc__


def _stetson_J_single(fluxes, errors):
Expand Down
Loading