Skip to content

Commit

Permalink
Merge pull request #336 from lincc-frameworks/select_random_lc
Browse files Browse the repository at this point in the history
Adds a random lightcurve selection function to the Ensemble
  • Loading branch information
dougbrn authored Dec 22, 2023
2 parents 97a2bfc + bc5ff2b commit e5404a6
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
48 changes: 48 additions & 0 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,54 @@ def _sync_tables(self):
self.object.set_dirty(False)
return self

def select_random_timeseries(self, seed=None):
"""Selects a random lightcurve from a random partition of the Ensemble.
Parameters
----------
seed: int, or None
Sets a seed to return the same object id on successive runs. `None`
by default, in which case a seed is not set for the operation.
Returns
-------
ts: `TimeSeries`
Timeseries for a single object
Note
----
This is not uniformly sampled. As a random partition is chosen first to
avoid a search in full index space, and partitions may vary in the
number of objects they contain. In other words, objects in smaller
partitions will have a higher probability of being chosen than objects
in larger partitions.
"""

rng = np.random.default_rng(seed)

# We will select one partition at random to select an object from
partitions = np.array(range(self.object.npartitions))
rng.shuffle(partitions) # shuffle for empty checking

object_selected = False
i = 0

# Scan through the shuffled partition list until a partition with data is found
while not object_selected:
partition_index = self.object.partitions[partitions[i]].index
# Check for empty partitions
if len(partition_index) > 0:
lcid = rng.choice(partition_index.values) # randomly select lightcurve
print(f"Selected Object {lcid} from Partition {partitions[i]}")
object_selected = True
else:
i += 1
if i >= len(partitions):
raise IndexError("Found no object IDs in the Object Table.")

return self.to_timeseries(lcid)

def to_timeseries(
self,
target,
Expand Down
58 changes: 58 additions & 0 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TapeSeries,
TapeObjectFrame,
TapeSourceFrame,
TimeSeries,
)
from tape.analysis.stetsonj import calc_stetson_J
from tape.analysis.structure_function.base_argument_container import StructureFunctionArgumentContainer
Expand Down Expand Up @@ -1829,6 +1830,63 @@ def test_batch_with_custom_frame_meta(parquet_ensemble, custom_meta):
assert isinstance(parquet_ensemble.select_frame("sf2_result"), EnsembleFrame)


@pytest.mark.parametrize("repartition", [False, True])
@pytest.mark.parametrize("seed", [None, 42])
def test_select_random_timeseries(parquet_ensemble, repartition, seed):
"""Test the behavior of ensemble.select_random_timeseries"""

ens = parquet_ensemble

if repartition:
ens.object = ens.object.repartition(3)

ts = ens.select_random_timeseries(seed=seed)

assert isinstance(ts, TimeSeries)

if seed == 42 and not repartition:
assert ts.meta["id"] == 88472935274829959
elif seed == 42 and repartition:
assert ts.meta["id"] == 88480001333818899


@pytest.mark.parametrize("all_empty", [False, True])
def test_select_random_timeseries_empty_partitions(dask_client, all_empty):
"Test the edge case where object has empty partitions"

data_dict = {
"id": [42],
"flux": [1],
"time": [1],
"err": [1],
"band": [1],
}

colmap = ColumnMapper().assign(
id_col="id",
time_col="time",
flux_col="flux",
err_col="err",
band_col="band",
)

ens = Ensemble(client=dask_client)
ens.from_source_dict(data_dict, column_mapper=colmap)

# The single id will be in the last partition
ens.object = ens.object.repartition(5)

# Remove the last partition, make sure we get the expected error when the
# Object table has no IDs in any partition
if all_empty:
ens.object = ens.object.partitions[0:-1]
with pytest.raises(IndexError):
ens.select_random_timeseries()
else:
ts = ens.select_random_timeseries()
assert ts.meta["id"] == 42 # Should always find the only object


def test_to_timeseries(parquet_ensemble):
"""
Test that ensemble.to_timeseries() runs and assigns the correct metadata
Expand Down

0 comments on commit e5404a6

Please sign in to comment.