Skip to content

Commit

Permalink
Fix EnsembleFrame.repartition (#349)
Browse files Browse the repository at this point in the history
* Add EnsembleFrame.repartition

* Repartition source frame with update_ensemble

* lint fix
  • Loading branch information
wilsonbb authored Jan 17, 2024
1 parent a3de1be commit f4109c6
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def insert_sources(
if all(prev_div):
self.update_frame(self.source.repartition(divisions=prev_div))
elif self.source.npartitions != prev_num:
self.source = self.source.repartition(npartitions=prev_num)
self.update_frame(self.source.repartition(npartitions=prev_num))

return self

Expand Down
75 changes: 75 additions & 0 deletions src/tape/ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,81 @@ def compute(self, **kwargs):
self.ensemble._lazy_sync_tables_from_frame(self)
return super().compute(**kwargs)

def repartition(
self,
divisions=None,
npartitions=None,
partition_size=None,
freq=None,
force=False,
):
"""Repartition dataframe along new divisions
Doc string below derived from dask.dataframe.DataFrame
Parameters
----------
divisions : list, optional
The "dividing lines" used to split the dataframe into partitions.
For ``divisions=[0, 10, 50, 100]``, there would be three output partitions,
where the new index contained [0, 10), [10, 50), and [50, 100), respectively.
See https://docs.dask.org/en/latest/dataframe-design.html#partitions.
Only used if npartitions and partition_size isn't specified.
For convenience if given an integer this will defer to npartitions
and if given a string it will defer to partition_size (see below)
npartitions : int, optional
Approximate number of partitions of output. Only used if partition_size
isn't specified. The number of partitions used may be slightly
lower than npartitions depending on data distribution, but will never be
higher.
partition_size: int or string, optional
Max number of bytes of memory for each partition. Use numbers or
strings like 5MB. If specified npartitions and divisions will be
ignored. Note that the size reflects the number of bytes used as
computed by ``pandas.DataFrame.memory_usage``, which will not
necessarily match the size when storing to disk.
.. warning::
This keyword argument triggers computation to determine
the memory size of each partition, which may be expensive.
freq : str, pd.Timedelta
A period on which to partition timeseries data like ``'7D'`` or
``'12h'`` or ``pd.Timedelta(hours=12)``. Assumes a datetime index.
force : bool, default False
Allows the expansion of the existing divisions.
If False then the new divisions' lower and upper bounds must be
the same as the old divisions'.
Notes
-----
Exactly one of `divisions`, `npartitions`, `partition_size`, or `freq`
should be specified. A ``ValueError`` will be raised when that is
not the case.
Also note that ``len(divisons)`` is equal to ``npartitions + 1``. This is because ``divisions``
represents the upper and lower bounds of each partition. The first item is the
lower bound of the first partition, the second item is the lower bound of the
second partition and the upper bound of the first partition, and so on.
The second-to-last item is the lower bound of the last partition, and the last
(extra) item is the upper bound of the last partition.
Examples
--------
>>> df = df.repartition(npartitions=10) # doctest: +SKIP
>>> df = df.repartition(divisions=[0, 5, 10, 20]) # doctest: +SKIP
>>> df = df.repartition(freq='7d') # doctest: +SKIP
"""
result = super().repartition(
divisions=divisions,
npartitions=npartitions,
partition_size=partition_size,
freq=freq,
force=force,
)
return self._propagate_metadata(result)


class TapeSeries(pd.Series):
"""A barebones extension of a Pandas series to be used for underlying Ensemble data.
Expand Down
7 changes: 7 additions & 0 deletions tests/tape_tests/test_ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ def test_ensemble_frame_propagation(data_fixture, request):
assert merged_frame.ensemble == ens
assert merged_frame.is_dirty()

# Test that frame metadata is preserved after repartitioning
repartitioned_frame = ens_frame.copy().repartition(npartitions=10)
assert isinstance(repartitioned_frame, EnsembleFrame)
assert repartitioned_frame.label == TEST_LABEL
assert repartitioned_frame.ensemble == ens
assert repartitioned_frame.is_dirty()

# Test that head returns a subset of the underlying TapeFrame.
h = ens_frame.head(5)
assert isinstance(h, TapeFrame)
Expand Down

0 comments on commit f4109c6

Please sign in to comment.