Skip to content

Commit

Permalink
Merge pull request #234 from lincc-frameworks/f2m_from_float
Browse files Browse the repository at this point in the history
Adds ability to give a single float input to zero_point
  • Loading branch information
dougbrn authored Sep 15, 2023
2 parents 9688d12 + 16e509e commit 7561df3
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 33 deletions.
53 changes: 36 additions & 17 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,21 +1240,15 @@ def from_source_dict(self, source_dict, column_mapper=None, npartitions=1, **kwa
**kwargs,
)

def convert_flux_to_mag(self, flux_col, zero_point, err_col=None, zp_form="mag", out_col_name=None):
def convert_flux_to_mag(self, zero_point, zp_form="mag", out_col_name=None, flux_col=None, err_col=None):
"""Converts a flux column into a magnitude column.
Parameters
----------
flux_col: 'str'
The name of the ensemble flux column to convert into magnitudes.
zero_point: 'str'
zero_point: 'str' or 'float'
The name of the ensemble column containing the zero point
information for column transformation.
err_col: 'str', optional
The name of the ensemble column containing the errors to propagate.
Errors are propagated using the following approximation:
Err= (2.5/log(10))*(flux_error/flux), which holds mainly when the
error in flux is much smaller than the flux.
information for column transformation. Alternatively, a single
float number to apply for all fluxes.
zp_form: `str`, optional
The form of the zero point column, either "flux" or
"magnitude"/"mag". Determines how the zero point (zp) is applied in
Expand All @@ -1265,26 +1259,51 @@ def convert_flux_to_mag(self, flux_col, zero_point, err_col=None, zp_form="mag",
The name of the output magnitude column, if None then the output
is just the flux column name + "_mag". The error column is also
generated as the out_col_name + "_err".
flux_col: 'str', optional
The name of the ensemble flux column to convert into magnitudes.
Uses the Ensemble mapped flux column if not specified.
err_col: 'str', optional
The name of the ensemble column containing the errors to propagate.
Errors are propagated using the following approximation:
Err= (2.5/log(10))*(flux_error/flux), which holds mainly when the
error in flux is much smaller than the flux. Uses the Ensemble
mapped error column if not specified.
Returns
----------
ensemble: `tape.ensemble.Ensemble`
The ensemble object with a new magnitude (and error) column.
"""

# Assign Ensemble cols if not provided
if flux_col is None:
flux_col = self._flux_col
if err_col is None:
err_col = self._err_col

if out_col_name is None:
out_col_name = flux_col + "_mag"

if zp_form == "flux": # mag = -2.5*np.log10(flux/zp)
self._source = self._source.assign(
**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / x[zero_point])}
)
if isinstance(zero_point, str):
self._source = self._source.assign(
**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / x[zero_point])}
)
else:
self._source = self._source.assign(
**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / zero_point)}
)

elif zp_form == "magnitude" or zp_form == "mag": # mag = -2.5*np.log10(flux) + zp
self._source = self._source.assign(
**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + x[zero_point]}
)

if isinstance(zero_point, str):
self._source = self._source.assign(
**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + x[zero_point]}
)
else:
self._source = self._source.assign(
**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + zero_point}
)
else:
raise ValueError(f"{zp_form} is not a valid zero_point format.")

Expand Down
26 changes: 10 additions & 16 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_from_parquet(data_fixture, request):
# Check to make sure the critical quantity labels are bound to real columns
assert parquet_ensemble._source[col] is not None


@pytest.mark.parametrize(
"data_fixture",
[
Expand Down Expand Up @@ -108,7 +109,6 @@ def test_from_dataframe(data_fixture, request):
amplitude = ens.batch(calc_stetson_J)
assert len(amplitude) == 5


def test_available_datasets(dask_client):
"""
Test that the ensemble is able to successfully read in the list of available TAPE datasets
Expand Down Expand Up @@ -749,10 +749,10 @@ def test_coalesce(dask_client, drop_inputs):
assert col in ens._source.columns


@pytest.mark.parametrize("zero_point", [("zp_mag", "zp_flux"), (25.0, 10**10)])
@pytest.mark.parametrize("zp_form", ["flux", "mag", "magnitude", "lincc"])
@pytest.mark.parametrize("err_col", [None, "error"])
@pytest.mark.parametrize("out_col_name", [None, "mag"])
def test_convert_flux_to_mag(dask_client, zp_form, err_col, out_col_name):
def test_convert_flux_to_mag(dask_client, zero_point, zp_form, out_col_name):
ens = Ensemble(client=dask_client)

source_dict = {
Expand All @@ -775,32 +775,26 @@ def test_convert_flux_to_mag(dask_client, zp_form, err_col, out_col_name):
ens.from_source_dict(source_dict, column_mapper=col_map)

if zp_form == "flux":
ens.convert_flux_to_mag("flux", "zp_flux", err_col, zp_form, out_col_name)
ens.convert_flux_to_mag(zero_point[1], zp_form, out_col_name)

res_mag = ens._source.compute()[output_column].to_list()[0]
assert pytest.approx(res_mag, 0.001) == 21.28925

if err_col is not None:
res_err = ens._source.compute()[output_column + "_err"].to_list()[0]
assert pytest.approx(res_err, 0.001) == 0.355979
else:
assert output_column + "_err" not in ens._source.columns
res_err = ens._source.compute()[output_column + "_err"].to_list()[0]
assert pytest.approx(res_err, 0.001) == 0.355979

elif zp_form == "mag" or zp_form == "magnitude":
ens.convert_flux_to_mag("flux", "zp_mag", err_col, zp_form, out_col_name)
ens.convert_flux_to_mag(zero_point[0], zp_form, out_col_name)

res_mag = ens._source.compute()[output_column].to_list()[0]
assert pytest.approx(res_mag, 0.001) == 21.28925

if err_col is not None:
res_err = ens._source.compute()[output_column + "_err"].to_list()[0]
assert pytest.approx(res_err, 0.001) == 0.355979
else:
assert output_column + "_err" not in ens._source.columns
res_err = ens._source.compute()[output_column + "_err"].to_list()[0]
assert pytest.approx(res_err, 0.001) == 0.355979

else:
with pytest.raises(ValueError):
ens.convert_flux_to_mag("flux", "zp_mag", err_col, zp_form, "mag")
ens.convert_flux_to_mag(zero_point[0], zp_form, "mag")


def test_find_day_gap_offset(dask_client):
Expand Down

0 comments on commit 7561df3

Please sign in to comment.