diff --git a/pyproject.toml b/pyproject.toml index 51cbc490..3e7021bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dynamic=["version"] dependencies = [ 'pandas', 'numpy<=1.23.5', - 'dask>=2023.5.0', + 'dask>=2023.6.1', 'dask[distributed]', 'pyarrow', 'pyvo', diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index bda92361..8d0b65dc 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -7,6 +7,7 @@ import pandas as pd from dask.distributed import Client +from collections import Counter from .analysis.base import AnalysisFunction from .analysis.feature_extractor import BaseLightCurveFeature, FeatureExtractor @@ -151,7 +152,7 @@ def insert_sources( # Create the new row and set the paritioning to match the original dataframe. df2 = dd.DataFrame.from_dict(rows, npartitions=1) - df2 = df2.set_index(self._id_col, drop=True) + df2 = df2.set_index(self._id_col, drop=True, sort=True) # Save the divisions and number of partitions. prev_div = self._source.divisions @@ -169,6 +170,8 @@ def insert_sources( elif self._source.npartitions != prev_num: self._source = self._source.repartition(npartitions=prev_num) + return self + def client_info(self): """Calls the Dask Client, which returns cluster information @@ -206,6 +209,57 @@ def info(self, verbose=True, memory_usage=True, **kwargs): print("Source Table") self._source.info(verbose=verbose, memory_usage=memory_usage, **kwargs) + def check_sorted(self, table="object"): + """Checks to see if an Ensemble Dataframe is sorted (increasing) on + the index. + + Parameters + ---------- + table: `str`, optional + The table to check. + + Returns + ------- + A boolean value indicating whether the index is sorted (True) + or not (False) + """ + if table == "object": + idx = self._object.index + elif table == "source": + idx = self._source.index + else: + raise ValueError(f"{table} is not one of 'object' or 'source'") + + # Use the existing index function to check if it's sorted (increasing) + return idx.is_monotonic_increasing.compute() + + def check_lightcurve_cohesion(self): + """Checks to see if lightcurves are split across multiple partitions. + + With partitioned data, and source information represented by rows, it + is possible that when loading data or manipulating it in some way (most + likely a repartition) that the sources for a given object will be split + among multiple partitions. This function will check to see if all + lightcurves are "cohesive", meaning the sources for that object only + live in a single partition of the dataset. + + Returns + ------- + A boolean value indicating whether the sources tied to a given object + are only found in a single partition (True), or if they are split + across multiple partitions (False) + + """ + idx = self._source.index + counts = idx.map_partitions(lambda a: Counter(a.unique())).compute() + + unq_counter = counts[0] + for i in range(1, len(counts)): + unq_counter += counts[i] + if any(c >= 2 for c in unq_counter.values()): + return False + return True + def compute(self, table=None, **kwargs): """Wrapper for dask.dataframe.DataFrame.compute() @@ -802,7 +856,9 @@ def batch(self, func, *args, meta=None, use_map=True, compute=True, on=None, **k Determines whether `dask.dataframe.DataFrame.map_partitions` is used (True). Using map_partitions is generally more efficient, but requires the data from each lightcurve is housed in a single - partition. If False, a groupby will be performed instead. + partition. This can be checked using + `Ensemble.check_lightcurve_cohesion`. If False, a groupby will be + performed instead. compute: `boolean` Determines whether to compute the result immediately or hold for a later compute call. @@ -961,6 +1017,8 @@ def from_dask_dataframe( sync_tables=True, npartitions=None, partition_size=None, + sorted=False, + sort=False, **kwargs, ): """Read in Dask dataframe(s) into an ensemble object @@ -985,6 +1043,12 @@ def from_dask_dataframe( partition_size: `int`, optional If specified, attempts to repartition the ensemble to partitions of size `partition_size`. + sorted: bool, optional + If the index column is already sorted in increasing order. + Defaults to False + sort: `bool`, optional + If True, sorts the DataFrame by the id column. Otherwise set the + index on the individual existing partitions. Defaults to False. Returns ---------- @@ -994,14 +1058,14 @@ def from_dask_dataframe( self._load_column_mapper(column_mapper, **kwargs) # Set the index of the source frame and save the resulting table - self._source = source_frame.set_index(self._id_col, drop=True) + self._source = source_frame.set_index(self._id_col, drop=True, sorted=sorted, sort=sort) if object_frame is None: # generate an indexed object table from source self._object = self._generate_object_table() else: self._object = object_frame - self._object = self._object.set_index(self._id_col) + self._object = self._object.set_index(self._id_col, sorted=sorted, sort=sort) # Optionally sync the tables, recalculates nobs columns if sync_tables: @@ -1148,6 +1212,8 @@ def from_parquet( additional_cols=True, npartitions=None, partition_size=None, + sorted=False, + sort=False, **kwargs, ): """Read in parquet file(s) into an ensemble object @@ -1181,6 +1247,12 @@ def from_parquet( partition_size: `int`, optional If specified, attempts to repartition the ensemble to partitions of size `partition_size`. + sorted: bool, optional + If the index column is already sorted in increasing order. + Defaults to False + sort: `bool`, optional + If True, sorts the DataFrame by the id column. Otherwise set the + index on the individual existing partitions. Defaults to False. Returns ---------- @@ -1218,6 +1290,8 @@ def from_parquet( sync_tables=sync_tables, npartitions=npartitions, partition_size=partition_size, + sorted=sorted, + sort=sort, **kwargs, ) @@ -1275,7 +1349,9 @@ def available_datasets(self): return {key: datasets_file[key]["description"] for key in datasets_file.keys()} - def from_source_dict(self, source_dict, column_mapper=None, npartitions=1, **kwargs): + def from_source_dict( + self, source_dict, column_mapper=None, npartitions=1, sorted=False, sort=False, **kwargs + ): """Load the sources into an ensemble from a dictionary. Parameters @@ -1288,6 +1364,12 @@ def from_source_dict(self, source_dict, column_mapper=None, npartitions=1, **kwa npartitions: `int`, optional If specified, attempts to repartition the ensemble to the specified number of partitions + sorted: bool, optional + If the index column is already sorted in increasing order. + Defaults to False + sort: `bool`, optional + If True, sorts the DataFrame by the id column. Otherwise set the + index on the individual existing partitions. Defaults to False. Returns ---------- @@ -1304,6 +1386,8 @@ def from_source_dict(self, source_dict, column_mapper=None, npartitions=1, **kwa column_mapper=column_mapper, sync_tables=True, npartitions=npartitions, + sorted=sorted, + sort=sort, **kwargs, ) diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 3c30aeb5..1db32e2b 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -372,8 +372,15 @@ def test_insert_paritioned(dask_client): "flux": [0.5 * float(i) for i in range(num_points)], "band": [all_bands[i % 4] for i in range(num_points)], } - cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band") - ens.from_source_dict(rows, column_mapper=cmap, npartitions=4) + cmap = ColumnMapper( + id_col="id", + time_col="time", + flux_col="flux", + err_col="err", + band_col="band", + provenance_col="provenance", + ) + ens.from_source_dict(rows, column_mapper=cmap, npartitions=4, sort=True) # Save the old data for comparison. old_data = ens.compute("source") @@ -435,6 +442,61 @@ def test_core_wrappers(parquet_ensemble): parquet_ensemble.compute() +@pytest.mark.parametrize("data_sorted", [True, False]) +@pytest.mark.parametrize("npartitions", [1, 2]) +def test_check_sorted(dask_client, data_sorted, npartitions): + # Create some fake data. + + if data_sorted: + rows = { + "id": [8001, 8001, 8001, 8001, 8002, 8002, 8002, 8002, 8002], + "time": [10.1, 10.2, 10.2, 11.1, 11.2, 11.3, 11.4, 15.0, 15.1], + "band": ["g", "g", "b", "g", "b", "g", "g", "g", "g"], + "err": [1.0, 2.0, 1.0, 3.0, 2.0, 3.0, 4.0, 5.0, 6.0], + "flux": [1.0, 2.0, 5.0, 3.0, 1.0, 2.0, 3.0, 4.0, 5.0], + } + else: + rows = { + "id": [8002, 8002, 8002, 8002, 8002, 8001, 8001, 8002, 8002], + "time": [10.1, 10.2, 10.2, 11.1, 11.2, 11.3, 11.4, 15.0, 15.1], + "band": ["g", "g", "b", "g", "b", "g", "g", "g", "g"], + "err": [1.0, 2.0, 1.0, 3.0, 2.0, 3.0, 4.0, 5.0, 6.0], + "flux": [1.0, 2.0, 5.0, 3.0, 1.0, 2.0, 3.0, 4.0, 5.0], + } + cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band") + ens = Ensemble(client=dask_client) + ens.from_source_dict(rows, column_mapper=cmap, sort=False, npartitions=npartitions) + + assert ens.check_sorted("source") == data_sorted + + +@pytest.mark.parametrize("data_cohesion", [True, False]) +def test_check_lightcurve_cohesion(dask_client, data_cohesion): + # Create some fake data. + + if data_cohesion: + rows = { + "id": [8001, 8001, 8001, 8001, 8001, 8002, 8002, 8002, 8002], + "time": [10.1, 10.2, 10.2, 11.1, 11.2, 11.3, 11.4, 15.0, 15.1], + "band": ["g", "g", "b", "g", "b", "g", "g", "g", "g"], + "err": [1.0, 2.0, 1.0, 3.0, 2.0, 3.0, 4.0, 5.0, 6.0], + "flux": [1.0, 2.0, 5.0, 3.0, 1.0, 2.0, 3.0, 4.0, 5.0], + } + else: + rows = { + "id": [8001, 8001, 8001, 8001, 8002, 8002, 8002, 8002, 8001], + "time": [10.1, 10.2, 10.2, 11.1, 11.2, 11.3, 11.4, 15.0, 15.1], + "band": ["g", "g", "b", "g", "b", "g", "g", "g", "g"], + "err": [1.0, 2.0, 1.0, 3.0, 2.0, 3.0, 4.0, 5.0, 6.0], + "flux": [1.0, 2.0, 5.0, 3.0, 1.0, 2.0, 3.0, 4.0, 5.0], + } + cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band") + ens = Ensemble(client=dask_client) + ens.from_source_dict(rows, column_mapper=cmap, sort=False, npartitions=2) + + assert ens.check_lightcurve_cohesion() == data_cohesion + + def test_persist(dask_client): # Create some fake data. rows = {