Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize comparison binary operations #9542

Merged
merged 5 commits into from
Nov 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 79 additions & 18 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,6 +1878,7 @@ def _binaryop(
fn: str,
fill_value: Any = None,
reflect: bool = False,
can_reindex: bool = False,
*args,
**kwargs,
):
Expand All @@ -1898,14 +1899,17 @@ def _binaryop(
for right, (name, left) in zip(rhs, lhs._data.items())
}
elif isinstance(rhs, DataFrame):
if fn in cudf.utils.utils._EQUALITY_OPS:
if not lhs.columns.equals(rhs.columns) or not lhs.index.equals(
rhs.index
):
raise ValueError(
"Can only compare identically-labeled "
"DataFrame objects"
)
if (
not can_reindex
and fn in cudf.utils.utils._EQUALITY_OPS
and (
not lhs.columns.equals(rhs.columns)
or not lhs.index.equals(rhs.index)
)
):
raise ValueError(
"Can only compare identically-labeled " "DataFrame objects"
)

lhs, rhs = _align_indices(lhs, rhs)

Expand Down Expand Up @@ -6481,14 +6485,6 @@ def unstack(self, level=-1, fill_value=None):
self, level=level, fill_value=fill_value
)

def equals(self, other):
if not isinstance(other, DataFrame):
return False
for self_name, other_name in zip(self._data.names, other._data.names):
if self_name != other_name:
return False
return super().equals(other)

def explode(self, column, ignore_index=False):
"""
Transform each element of a list-like to a row, replicating index
Expand Down Expand Up @@ -6536,14 +6532,27 @@ def explode(self, column, ignore_index=False):
return super()._explode(column, ignore_index)


def make_binop_func(op):
def make_binop_func(op, postprocess=None):
# This function is used to wrap binary operations in Frame with an
# appropriate API for DataFrame as required for pandas compatibility. The
# main effect is reordering and error-checking parameters in
# DataFrame-specific ways. The postprocess argument is a callable that may
# optionally be provided to modify the result of the binop if additional
# processing is needed for pandas compatibility. The callable must have the
# signature
# def postprocess(left, right, output)
# where left and right are the inputs to the binop and output is the result
# of calling the wrapped Frame binop.
wrapped_func = getattr(Frame, op)

@functools.wraps(wrapped_func)
def wrapper(self, other, axis="columns", level=None, fill_value=None):
if axis not in (1, "columns"):
raise NotImplementedError("Only axis=1 supported at this time.")
return wrapped_func(self, other, axis, level, fill_value)
output = wrapped_func(self, other, axis, level, fill_value)
if postprocess is None:
return output
return postprocess(self, other, output)

# functools.wraps copies module level attributes to `wrapper` and sets
# __wrapped__ attributes to `wrapped_func`. Cpython looks up the signature
Expand All @@ -6560,6 +6569,7 @@ def wrapper(self, other, axis="columns", level=None, fill_value=None):
return wrapper


# Wrap arithmetic Frame binop functions with the expected API for Series.
for binop in [
"add",
"radd",
Expand All @@ -6583,6 +6593,57 @@ def wrapper(self, other, axis="columns", level=None, fill_value=None):
setattr(DataFrame, binop, make_binop_func(binop))


def _make_replacement_func(value):
vyasr marked this conversation as resolved.
Show resolved Hide resolved
# This function generates a postprocessing function suitable for use with
# make_binop_func that fills null columns with the desired fill value.

def func(left, right, output):
# This function may be passed as the postprocess argument to
# make_binop_func. Columns that are only present in one of the inputs
# will be null in the output. This function postprocesses the output to
# replace those nulls with some desired output.
if isinstance(right, Series):
uncommon_columns = set(left._column_names) ^ set(right.index)
elif isinstance(right, DataFrame):
uncommon_columns = set(left._column_names) ^ set(
right._column_names
)
elif _is_scalar_or_zero_d_array(right):
for name, col in output._data.items():
output._data[name] = col.fillna(value)
return output
else:
return output

for name in uncommon_columns:
output._data[name] = column.full(
size=len(output), fill_value=value, dtype="bool"
)
return output

return func


# The ne comparator needs special postprocessing because elements that missing
# in one operand should be treated as null and result in True in the output
# rather than simply propagating nulls.
DataFrame.ne = make_binop_func("ne", _make_replacement_func(True))


# All other comparison operators needs return False when one of the operands is
# missing in the input.
for binop in [
"eq",
"lt",
"le",
"gt",
"ge",
]:
setattr(
DataFrame, binop, make_binop_func(binop, _make_replacement_func(False))
)


def from_pandas(obj, nan_as_null=None):
"""
Convert certain Pandas objects into the cudf equivalent.
Expand Down
Loading