Skip to content

Commit

Permalink
Merge pull request #442 from mwcraig/add-passband-column-format
Browse files Browse the repository at this point in the history
Add method to transform catalog to band-named columns
  • Loading branch information
mwcraig authored Sep 10, 2024
2 parents 5fd8075 + 7a719a4 commit 754f24f
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 0 deletions.
59 changes: 59 additions & 0 deletions stellarphot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,65 @@ def from_vizier(

return cat

def passband_columns(self, passbands=None):
"""
Return an `astropy.table.Table` with passbands as column names instead
of the default format, which has a single column for passbands.
Parameters
----------
passbands : list, optional
List of passbands to include in the output. If not provided, all
passbands in the catalog will be included.
Returns
-------
`astropy.table.Table`
Table of catalog information with passbands as column names. See Notes below
for important details about column names.
Notes
-----
The column names in the output will be the passband names with ``mag_`` as a
prefix. An error column for each passband will be generated, with the prefix
``mag_error_``. If the catalog already has columns with these names, they will
be overwritten. The input catalog will not be changed.
"""
catalog_passbands = set(self["passband"])
if passbands is None:
passbands = catalog_passbands
input_passbands = set(passbands)
missing_passbands = input_passbands - catalog_passbands
if missing_passbands:
raise ValueError(
f"Passbands \"{', '.join(missing_passbands)}\" not found in catalog."
)
passband_mask = np.zeros(len(self), dtype=bool)
for passband in input_passbands:
passband_mask |= self["passband"] == passband

reduced_input = self[passband_mask]

# Switch to pandas for making the new table.
df = reduced_input.to_pandas()

# This makes a MultiIndex for the columns -- "mag" and "mag_error" are the
# top level, and the passbands are the second level.
df = df.pivot(
columns="passband", index=["id", "ra", "dec"], values=["mag", "mag_error"]
)

# The column names are a MultiIndex, so we flatten them to either "mag_band"
# or "mag_error_band", where "band" is the passband name.
df.columns = df.columns.to_series().str.join("_")

# We also reset the index which was set to the id, ra, and dec columns above.
df = df.reset_index()

# Convert back to an astropy table and return it.
return Table.from_pandas(df)


def apass_dr9(field_center, radius=1 * u.degree, clip_by_frame=False, padding=100):
"""
Expand Down
48 changes: 48 additions & 0 deletions stellarphot/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,54 @@ def test_tidy_vizier_catalog_several_mags():
assert set(result["passband"]) == {"V", "B", "i", "r-i", "B-V"}


def test_catalog_to_passband_columns():
# Test that the passband columns are correctly identified
# even if the passband names are not in the passband map.
apass_input = Table.read(get_pkg_data_filename("data/test_apass_subset.ecsv"))

input_data = CatalogData._tidy_vizier_catalog(
apass_input,
r"^([a-zA-Z]+|[a-zA-Z]+-[a-zA-Z]+)_?mag$",
r"^([a-zA-Z]+-[a-zA-Z]+)$",
)
print(input_data.colnames)
input_data["RAJ2000"].unit = u.deg
input_data["DEJ2000"].unit = u.deg
cat = CatalogData(
input_data=input_data,
colname_map=dict(recno="id", RAJ2000="ra", DEJ2000="dec"),
catalog_name="APASS",
catalog_source="Vizier",
passband_map=dict(g="SG", r="SR", i="SI"),
)

# Check that calling with a bad filter name raises an error
with pytest.raises(ValueError, match="not found in catalog"):
cat.passband_columns(["not a filter"])

# These are the only passbands in the input test data
passbands = ["V", "SI"]
new_cat = cat.passband_columns(passbands=passbands)

# Check the column names
for pb in passbands:
assert f"mag_{pb}" in new_cat.colnames
assert f"mag_error_{pb}" in new_cat.colnames

# Check the data
assert set(new_cat["mag_V"]) == set(apass_input["Vmag"])
assert set(new_cat["mag_error_V"]) == set(apass_input["e_Vmag"])
assert set(new_cat["mag_SI"]) == set(apass_input["i_mag"])
assert set(new_cat["mag_error_SI"]) == set(apass_input["e_i_mag"])

# Check that calling with no passbands returns all of the passbands
new_cat = cat.passband_columns()
cat_pbs = set(cat["passband"])
for pb in cat_pbs:
assert f"mag_{pb}" in new_cat.colnames
assert f"mag_error_{pb}" in new_cat.colnames


@pytest.mark.remote_data
def test_catalog_from_vizier_search_apass():
# Nothing special about this point...
Expand Down

0 comments on commit 754f24f

Please sign in to comment.