Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for running batch on a single lightcurve #420

Merged
merged 7 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion docs/tutorials/batch_showcase.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,26 @@
"res1.compute() # Compute to see the result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By default `Ensemble.batch` will apply your function across all light curves. However with the `single_lc` parameter you can test out your function on only a single object."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The same batch call as above but now only on the final lightcurve.\n",
"\n",
"lc_id = 109 # id of the final lightcurve in the data\n",
"lc_res = ens.batch(my_mean, \"flux\", single_lc=lc_id)\n",
"lc_res.compute()"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -585,7 +605,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.14"
},
"vscode": {
"interpreter": {
Expand Down
31 changes: 23 additions & 8 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,7 @@
use_map=True,
on=None,
label="",
single_lc=None,
**kwargs,
):
"""Run a function from tape.TimeSeries on the available ids
Expand Down Expand Up @@ -1107,6 +1108,9 @@
source or object tables. If not specified, then the id column is
used by default. For TAPE and `light-curve` functions this is
populated automatically.
single_lc: `int`, optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be really helpful to also allow single_lc=True, where select_random_lightcurve is used to pick one for the user.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense. To do that, I added an id_only parameter to Ensemble.select_random_timeseries so that it will only fetch the object id of a random lightcurve. But let me know if you would prefer that not added to the API and I can break that out differently.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! That looks like a good solution

If provided, only the lightcurve with the specified id will be
used in batch. Default is None.
label: 'str', optional
If provided the ensemble will use this label to track the result
dataframe. If not provided, a label of the from "result_{x}" where x
Expand All @@ -1133,6 +1137,12 @@
from light_curve import EtaE
ens.batch(EtaE(), band_to_calc='g')

To run a TAPE function on a single lightcurve:
from tape.analysis.stetsonj import calc_stetson_J
ens = Ensemble().from_dataset('rrlyr82')
lc_id = 4378437892 # The lightcurve id
ensemble.batch(calc_stetson_J, band_to_calc='i', single_lc=lc_id)

Run a custom function on the ensemble::

def s2n_inter_quartile_range(flux, err):
Expand Down Expand Up @@ -1160,6 +1170,13 @@
if meta is None:
meta = (self._id_col, float) # return a series of ids, default assume a float is returned

src_to_batch = self.source
obj_to_batch = self.object

if single_lc is not None:
src_to_batch = src_to_batch.loc[single_lc]
obj_to_batch = obj_to_batch.loc[single_lc]

# Translate the meta into an appropriate TapeFrame or TapeSeries. This ensures that the
# batch result will be an EnsembleFrame or EnsembleSeries.
meta = self._translate_meta(meta)
Expand All @@ -1178,15 +1195,13 @@
on[-1] = self._band_col

# Handle object columns to group on
source_cols = list(self.source.columns)
object_cols = list(self.object.columns)
source_cols = list(src_to_batch.columns)
object_cols = list(obj_to_batch.columns)
object_group_cols = [col for col in on if (col in object_cols) and (col not in source_cols)]

if len(object_group_cols) > 0:
object_col_dd = self.object[object_group_cols]
source_to_batch = self.source.merge(object_col_dd, how="left")
else:
source_to_batch = self.source # Can directly use the source table
obj_to_batch = obj_to_batch[object_group_cols]
src_to_batch = src_to_batch.merge(obj_to_batch, how="left")

Check warning on line 1204 in src/tape/ensemble.py

View check run for this annotation

Codecov / codecov/patch

src/tape/ensemble.py#L1203-L1204

Added lines #L1203 - L1204 were not covered by tests

id_col = self._id_col # pre-compute needed for dask in lambda function

Expand All @@ -1211,11 +1226,11 @@

id_col = self._id_col # need to grab this before mapping

batch = source_to_batch.map_partitions(_batch_apply, func, on, *args, **kwargs, meta=meta)
batch = src_to_batch.map_partitions(_batch_apply, func, on, *args, **kwargs, meta=meta)

else: # use groupby
# don't use _batch_apply as meta must be specified in the apply call
batch = source_to_batch.groupby(on, group_keys=True, sort=False).apply(
batch = src_to_batch.groupby(on, group_keys=True, sort=False).apply(
_apply_func_to_lc,
func,
*args,
Expand Down
19 changes: 19 additions & 0 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2127,6 +2127,25 @@ def my_bounds(flux):
assert all([col in res.columns for col in res.compute().columns])


@pytest.mark.parametrize("data_fixture", ["parquet_ensemble", "parquet_ensemble_with_divisions"])
def test_batch_single_lc(data_fixture, request):
"""
Test that ensemble.batch() can run a function on a single light curve.
"""
parquet_ensemble = request.getfixturevalue(data_fixture)

lc = 88472935274829959

lc_res = parquet_ensemble.prune(10).batch(
calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=lc
)
assert len(lc_res) == 1

# Now ensure that we got the same result when we ran the function on the entire ensemble.
full_res = parquet_ensemble.prune(10).batch(calc_stetson_J, use_map=True, on=None, band_to_calc=None)
assert full_res.compute().loc[lc].stetsonJ == lc_res.compute().iloc[0].stetsonJ


def test_batch_labels(parquet_ensemble):
"""
Test that ensemble.batch() generates unique labels for result frames when none are provided.
Expand Down