Skip to content

Commit

Permalink
Merge Main into Ensemble Refactor Branch (#304)
Browse files Browse the repository at this point in the history
* check divisions, enable lazy syncs

* check divisions, enable lazy syncs

* initial tests

* add tests; calc_nobs preserve divisions

* batch with divisions

* cleanup

* fix sf2 tests

* add sync_tables check

* cleanup

* fix calc_nobs reset_index issue

* per table warnings; index comments

* add map_partitions mode for calc_nobs when divisions are known

* build metadata

* build metadata

* add multi partition test

* add version file to init

* add small test

* Fix table syncing to use inner joins. (#303)

* Fix table syncing to use inner joins.

* fix lint error

* Update test

---------

Co-authored-by: Doug Branton <[email protected]>
  • Loading branch information
wilsonbb and dougbrn authored Dec 1, 2023
1 parent a714b10 commit 5c847e1
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 85 deletions.
1 change: 1 addition & 0 deletions src/tape/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .ensemble_frame import * # noqa
from .timeseries import * # noqa
from .ensemble_readers import * # noqa
from ._version import __version__ # noqa
156 changes: 108 additions & 48 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
from .analysis.feature_extractor import BaseLightCurveFeature, FeatureExtractor
from .analysis.structure_function import SF_METHODS
from .analysis.structurefunction2 import calc_sf2
from .ensemble_frame import EnsembleFrame, EnsembleSeries, ObjectFrame, SourceFrame, TapeFrame, TapeSeries
from .ensemble_frame import EnsembleFrame, EnsembleSeries, ObjectFrame, SourceFrame, TapeFrame, TapeObjectFrame, TapeSourceFrame, TapeSeries
from .timeseries import TimeSeries
from .utils import ColumnMapper

# TODO import from EnsembleFrame...?
SOURCE_FRAME_LABEL = "source"
OBJECT_FRAME_LABEL = "object"

Expand Down Expand Up @@ -48,7 +47,6 @@ def __init__(self, client=True, **kwargs):
# A unique ID to allocate new result frame labels.
self.default_frame_id = 1

# TODO([email protected]) Replace self._source and self._object with these
self.source = None # Source Table EnsembleFrame
self.object = None # Object Table EnsembleFrame

Expand Down Expand Up @@ -779,40 +777,68 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True):
"""

if by_band:
band_counts = (
self._source.groupby([self._id_col])[self._band_col] # group by each object
.value_counts() # count occurence of each band
.to_frame() # convert series to dataframe
.reset_index() # break up the multiindex
.categorize(columns=[self._band_col]) # retype the band labels as categories
.pivot_table(values=self._band_col, index=self._id_col, columns=self._band_col, aggfunc="sum")
) # the pivot_table call makes each band_count a column of the id_col row

# repartition the result to align with object
if self._object.known_divisions:
self._object.divisions = tuple([None for i in range(self._object.npartitions + 1)])
band_counts = band_counts.repartition(npartitions=self._object.npartitions)
# Grab these up front to help out the task graph
id_col = self._id_col
band_col = self._band_col

# Get the band metadata
unq_bands = np.unique(self._source[band_col])
meta = {band: float for band in unq_bands}

# Map the groupby to each partition
band_counts = self._source.map_partitions(
lambda x: x.groupby(id_col)[[band_col]]
.value_counts()
.to_frame()
.reset_index()
.pivot_table(values=band_col, index=id_col, columns=band_col, aggfunc="sum"),
meta=meta,
).repartition(divisions=self._object.divisions)
else:
band_counts = (
self._source.groupby([self._id_col])[self._band_col] # group by each object
.value_counts() # count occurence of each band
.to_frame() # convert series to dataframe
.rename(columns={self._band_col: "counts"}) # rename column
.reset_index() # break up the multiindex
.categorize(columns=[self._band_col]) # retype the band labels as categories
.pivot_table(
values=self._band_col, index=self._id_col, columns=self._band_col, aggfunc="sum"
)
) # the pivot_table call makes each band_count a column of the id_col row

band_counts = band_counts.repartition(npartitions=self._object.npartitions)

# short-hand for calculating nobs_total
band_counts["total"] = band_counts[list(band_counts.columns)].sum(axis=1)

bands = band_counts.columns.values
self._object = self._object.assign(**{label + "_" + band: band_counts[band] for band in bands})
self._object = self._object.assign(
**{label + "_" + str(band): band_counts[band] for band in bands}
)

if temporary:
self._object_temp.extend(label + "_" + band for band in bands)
self._object_temp.extend(label + "_" + str(band) for band in bands)

else:
counts = self._source.groupby([self._id_col])[[self._band_col]].aggregate("count")

# repartition the result to align with object
if self._object.known_divisions:
self._object.divisions = tuple([None for i in range(self._object.npartitions + 1)])
counts = counts.repartition(npartitions=self._object.npartitions)
if self._object.known_divisions and self._source.known_divisions:
# Grab these up front to help out the task graph
id_col = self._id_col
band_col = self._band_col

# Map the groupby to each partition
counts = self._source.map_partitions(
lambda x: x.groupby([id_col])[[band_col]].aggregate("count")
).repartition(divisions=self._object.divisions)
else:
counts = counts.repartition(npartitions=self._object.npartitions)
# Just do a groupby on all source
counts = (
self._source.groupby([self._id_col])[[self._band_col]]
.aggregate("count")
.repartition(npartitions=self._object.npartitions)
)

self._object = self._object.assign(**{label + "_total": counts[self._band_col]})

Expand Down Expand Up @@ -849,8 +875,7 @@ def prune(self, threshold=50, col_name=None):
col_name = "nobs_total"

# Mask on object table
mask = self._object[col_name] >= threshold
self.update_frame(self._object[mask])
self = self.query(f"{col_name} >= {threshold}", table="object")

self._object.set_dirty(True) # Object table is now dirty

Expand Down Expand Up @@ -1134,12 +1159,18 @@ def s2n_inter_quartile_range(flux, err):
meta=meta,
)

# Inherit divisions if known from source and the resulting index is the id
# Groupby on index should always return a subset that adheres to the same divisions criteria
if self._source.known_divisions and batch.index.name == self._id_col:
batch.divisions = self._source.divisions

if label is not None:
if label == "":
label = self._generate_frame_label()
print(f"Using generated label, {label}, for a batch result.")
# Track the result frame under the provided label
self.add_frame(batch, label)

if compute:
return batch.compute()
else:
Expand Down Expand Up @@ -1243,8 +1274,6 @@ def from_dask_dataframe(
The ensemble object with the Dask dataframe data loaded.
"""
self._load_column_mapper(column_mapper, **kwargs)

# TODO([email protected]): Determine most efficient way to convert to SourceFrame/ObjectFrame
source_frame = SourceFrame.from_dask_dataframe(source_frame, self)

# Set the index of the source frame and save the resulting table
Expand All @@ -1255,7 +1284,6 @@ def from_dask_dataframe(
self.update_frame(self._generate_object_table())

else:
# TODO([email protected]): Determine most efficient way to convert to SourceFrame/ObjectFrame
self.update_frame(ObjectFrame.from_dask_dataframe(object_frame, ensemble=self))
self.update_frame(self._object.set_index(self._id_col, sorted=sorted, sort=sort))

Expand All @@ -1270,6 +1298,12 @@ def from_dask_dataframe(
elif partition_size:
self._source = self._source.repartition(partition_size=partition_size)

# Check that Divisions are established, warn if not.
for name, table in [("object", self._object), ("source", self._source)]:
if not table.known_divisions:
warnings.warn(
f"Divisions for {name} are not set, certain downstream dask operations may fail as a result. We recommend setting the `sort` or `sorted` flags when loading data to establish division information."
)
return self

def from_hipscat(self, dir, source_subdir="source", object_subdir="object", column_mapper=None, **kwargs):
Expand Down Expand Up @@ -1464,7 +1498,10 @@ def from_parquet(
columns.append(self._provenance_col)

# Read in the source parquet file(s)
source = SourceFrame.from_parquet(source_file, index=self._id_col, columns=columns, ensemble=self)
# Index is set False so that we can set it with a future set_index call
# This has the advantage of letting Dask set partition boundaries based
# on the divisions between the sources of different objects.
source = SourceFrame.from_parquet(source_file, index=False, columns=columns, ensemble=self)

# Generate a provenance column if not provided
if self._provenance_col is None:
Expand All @@ -1474,7 +1511,9 @@ def from_parquet(
object = None
if object_file:
# Read in the object file(s)
object = ObjectFrame.from_parquet(object_file, index=self._id_col, ensemble=self)
# Index is False so that we can set it with a future set_index call
# More meaningful for source than object but parity seems good here
object = ObjectFrame.from_parquet(object_file, index=False, ensemble=self)
return self.from_dask_dataframe(
source_frame=source,
object_frame=object,
Expand Down Expand Up @@ -1660,13 +1699,7 @@ def convert_flux_to_mag(self, zero_point, zp_form="mag", out_col_name=None, flux

def _generate_object_table(self):
"""Generate an empty object table from the source table."""
sor_idx = self._source.index.unique()
obj_df = pd.DataFrame(index=sor_idx)

# Convert the resulting dataframe into an ObjectFrame
# TODO(wbeebe): Switch for a cleaner loading fucnction
res = ObjectFrame.from_dask_dataframe(
dd.from_pandas(obj_df, npartitions=int(np.ceil(self._source.npartitions / 100))), ensemble=self)
res = self._source.map_partitions(lambda x: TapeObjectFrame(index=x.index.unique()))

return res

Expand Down Expand Up @@ -1719,9 +1752,20 @@ def _sync_tables(self):

if self._object.is_dirty():
# Sync Object to Source; remove any missing objects from source
obj_idx = list(self._object.index.compute())
self.update_frame(self._source.map_partitions(lambda x: x[x.index.isin(obj_idx)]))
self.update_frame(self._source.persist()) # persist the source frame

if self._object.known_divisions and self._source.known_divisions:
# Lazily Create an empty object table (just index) for joining
empty_obj = self._object.map_partitions(lambda x: TapeObjectFrame(index=x.index))
if type(empty_obj) != type(self._object):
raise ValueError("Bad type for empty_obj: " + str(type(empty_obj)))

# Join source onto the empty object table to align
self.update_frame(self._source.join(empty_obj, how="inner"))
else:
warnings.warn("Divisions are not known, syncing using a non-lazy method.")
obj_idx = list(self._object.index.compute())
self.update_frame(self._source.map_partitions(lambda x: x[x.index.isin(obj_idx)]))
self.update_frame(self._source.persist()) # persist the source frame

# Drop Temporary Source Columns on Sync
if len(self._source_temp):
Expand All @@ -1731,10 +1775,20 @@ def _sync_tables(self):

if self._source.is_dirty(): # not elif
if not self.keep_empty_objects:
# Sync Source to Object; remove any objects that do not have sources
sor_idx = list(self._source.index.unique().compute())
self.update_frame(self._object.map_partitions(lambda x: x[x.index.isin(sor_idx)]))
self.update_frame(self._object.persist()) # persist the object frame
if self._object.known_divisions and self._source.known_divisions:
# Lazily Create an empty source table (just unique indexes) for joining
empty_src = self._source.map_partitions(lambda x: TapeSourceFrame(index=x.index.unique()))
if type(empty_src) != type(self._source):
raise ValueError("Bad type for empty_src: " + str(type(empty_src)))

# Join object onto the empty unique source table to align
self.update_frame(self._object.join(empty_src, how="inner"))
else:
warnings.warn("Divisions are not known, syncing using a non-lazy method.")
# Sync Source to Object; remove any objects that do not have sources
sor_idx = list(self._source.index.unique().compute())
self.update_frame(self._object.map_partitions(lambda x: x[x.index.isin(sor_idx)]))
self.update_frame(self._object.persist()) # persist the object frame

# Drop Temporary Object Columns on Sync
if len(self._object_temp):
Expand Down Expand Up @@ -1834,7 +1888,7 @@ def _build_index(self, obj_id, band):
index = pd.MultiIndex.from_tuples(tuples, names=["object_id", "band", "index"])
return index

def sf2(self, sf_method="basic", argument_container=None, use_map=True):
def sf2(self, sf_method="basic", argument_container=None, use_map=True, compute=True):
"""Wrapper interface for calling structurefunction2 on the ensemble
Parameters
Expand Down Expand Up @@ -1876,11 +1930,17 @@ def sf2(self, sf_method="basic", argument_container=None, use_map=True):
self._source.index,
argument_container=argument_container,
)
return result

else:
result = self.batch(calc_sf2, use_map=use_map, argument_container=argument_container)
result = self.batch(
calc_sf2, use_map=use_map, argument_container=argument_container, compute=compute
)

return result
# Inherit divisions information if known
if self._source.known_divisions and self._object.known_divisions:
result.divisions = self._source.divisions

return result

def _translate_meta(self, meta):
"""Translates Dask-style meta into a TapeFrame or TapeSeries object.
Expand Down
19 changes: 19 additions & 0 deletions tests/tape_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,25 @@ def parquet_ensemble(dask_client):
return ens


# pylint: disable=redefined-outer-name
@pytest.fixture
def parquet_ensemble_with_divisions(dask_client):
"""Create an Ensemble from parquet data."""
ens = Ensemble(client=dask_client)
ens.from_parquet(
"tests/tape_tests/data/source/test_source.parquet",
"tests/tape_tests/data/object/test_object.parquet",
id_col="ps1_objid",
time_col="midPointTai",
band_col="filterName",
flux_col="psFlux",
err_col="psFluxErr",
sort=True,
)

return ens


# pylint: disable=redefined-outer-name
@pytest.fixture
def parquet_ensemble_from_source(dask_client):
Expand Down
Loading

0 comments on commit 5c847e1

Please sign in to comment.