Skip to content

Commit

Permalink
Add Explicit Metadata Propagation for EnsembleFrame joins (#301)
Browse files Browse the repository at this point in the history
* Support propagating frame metadata in joins

* Update doc strings and test
  • Loading branch information
wilsonbb authored Nov 29, 2023
1 parent 7e6abaf commit a714b10
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 2 deletions.
64 changes: 63 additions & 1 deletion src/tape/ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,68 @@ def merge(self, right, **kwargs):
result = super().merge(right, **kwargs)
return self._propagate_metadata(result)

def join(self, other, **kwargs):
"""Join columns of another DataFrame. Note that if `other` is a different type,
we expect the result to have the type of this object regardless of the value
of the`how` parameter.
This docstring was copied from pandas.core.frame.DataFrame.join.
Some inconsistencies with this version may exist.
Join columns with `other` DataFrame either on index or on a key
column. Efficiently join multiple DataFrame objects by index at once by
passing a list.
Parameters
----------
other : DataFrame, Series, or a list containing any combination of them
Index should be similar to one of the columns in this one. If a
Series is passed, its name attribute must be set, and that will be
used as the column name in the resulting joined DataFrame.
on : str, list of str, or array-like, optional
Column or index level name(s) in the caller to join on the index
in `other`, otherwise joins index-on-index. If multiple
values given, the `other` DataFrame must have a MultiIndex. Can
pass an array as the join key if it is not already contained in
the calling DataFrame. Like an Excel VLOOKUP operation.
how : {'left', 'right', 'outer', 'inner', 'cross'}, default 'left'
How to handle the operation of the two objects.
* left: use calling frame's index (or column if on is specified)
* right: use `other`'s index.
* outer: form union of calling frame's index (or column if on is
specified) with `other`'s index, and sort it lexicographically.
* inner: form intersection of calling frame's index (or column if
on is specified) with `other`'s index, preserving the order
of the calling's one.
* cross: creates the cartesian product from both frames, preserves the order
of the left keys.
lsuffix : str, default ''
Suffix to use from left frame's overlapping columns.
rsuffix : str, default ''
Suffix to use from right frame's overlapping columns.
sort : bool, default False
Order result DataFrame lexicographically by the join key. If False,
the order of the join key depends on the join type (how keyword).
validate : str, optional
If specified, checks if join is of specified type.
* "one_to_one" or "1:1": check if join keys are unique in both left
and right datasets.
* "one_to_many" or "1:m": check if join keys are unique in left dataset.
* "many_to_one" or "m:1": check if join keys are unique in right dataset.
* "many_to_many" or "m:m": allowed, but does not result in checks.
Returns
-------
result: `tape._Frame`
A TAPE dataframe containing columns from both the caller and `other`.
"""
result = super().join(other, **kwargs)
return self._propagate_metadata(result)

def drop(self, labels=None, axis=0, columns=None, errors="raise"):
"""Drop specified labels from rows or columns.
Expand Down Expand Up @@ -316,7 +378,7 @@ def drop(self, labels=None, axis=0, columns=None, errors="raise"):
Returns
-------
result: `tape._Frame`
Returns the frame or Nonewith the specified
Returns the frame or None with the specified
index or column labels removed or None if inplace=True.
"""
result = self._propagate_metadata(super().drop(labels=labels, axis=axis, columns=columns, errors=errors))
Expand Down
39 changes: 38 additions & 1 deletion tests/tape_tests/test_ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,4 +297,41 @@ def test_object_and_source_frame_propagation(data_fixture, request):
assert isinstance(merged_frame, SourceFrame)
assert merged_frame.label == SOURCE_LABEL
assert merged_frame.ensemble == ens
assert merged_frame.is_dirty()
assert merged_frame.is_dirty()


def test_object_and_source_joins(parquet_ensemble):
"""
Test that SourceFrame and ObjectFrame metadata and class type are correctly propagated across
joins.
"""
# Get Source and object frames to test joins on.
source_frame, object_frame = parquet_ensemble.source.copy(), parquet_ensemble.object.copy()

# Verify their metadata was preserved in the copy()
assert source_frame.label == SOURCE_LABEL
assert source_frame.ensemble is parquet_ensemble
assert object_frame.label == OBJECT_LABEL
assert object_frame.ensemble is parquet_ensemble

# Join a SourceFrame (left) with an ObjectFrame (right)
# Validate that metadata is preserved and the outputted object is a SourceFrame
joined_source = source_frame.join(object_frame, how='left')
assert joined_source.label is SOURCE_LABEL
assert type(joined_source) is SourceFrame
assert joined_source.ensemble is parquet_ensemble

# Now the same form of join (in terms of left/right) but produce an ObjectFrame. This is
# because frame1.join(frame2) will yield frame1's type regardless of left vs right.
assert type(object_frame.join(source_frame, how='right')) is ObjectFrame

# Join an ObjectFrame (left) with a SourceFrame (right)
# Validate that metadata is preserved and the outputted object is an ObjectFrame
joined_object = object_frame.join(source_frame, how='left')
assert joined_object.label is OBJECT_LABEL
assert type(joined_object) is ObjectFrame
assert joined_object.ensemble is parquet_ensemble

# Now the same form of join (in terms of left/right) but produce a SourceFrame. This is
# because frame1.join(frame2) will yield frame1's type regardless of left vs right.
assert type(source_frame.join(object_frame, how='right')) is SourceFrame

0 comments on commit a714b10

Please sign in to comment.