Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DataFrame parameter in Series.dot #1931

Merged
merged 30 commits into from
Dec 4, 2020
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
65e9b50
Happy path
xinrong-meng Nov 25, 2020
9933a3b
Keep DataFrame columns
xinrong-meng Nov 25, 2020
8e34cb2
compute.ops_on_diff_frames in caller
xinrong-meng Nov 25, 2020
e4e3962
_ksers is unordered
xinrong-meng Nov 25, 2020
53e7f77
Docstring example
xinrong-meng Nov 25, 2020
f5e4a84
Index equal
xinrong-meng Nov 25, 2020
e178304
Fix test
xinrong-meng Nov 25, 2020
0087c7b
Fewer Spark jobs
xinrong-meng Nov 30, 2020
18a50d3
Mypy fix
xinrong-meng Nov 30, 2020
6ef606b
Compare Index if pass in another Series
xinrong-meng Nov 30, 2020
afd2460
No f-strings for py 3.5
xinrong-meng Nov 30, 2020
8ba595b
Sort index before cmp
xinrong-meng Nov 30, 2020
a1eacd9
Combine test_dot
xinrong-meng Nov 30, 2020
31d852a
ks.option_context by def
xinrong-meng Nov 30, 2020
0e9fbfe
same_anchor; diff index order
xinrong-meng Nov 30, 2020
461892e
Rename; scol_for
xinrong-meng Nov 30, 2020
ca33f7d
Format
xinrong-meng Nov 30, 2020
60b803d
scol_for
xinrong-meng Dec 1, 2020
7fa9652
More test cases
xinrong-meng Dec 1, 2020
1c842d8
Reserve self.name; if same anchor
xinrong-meng Dec 1, 2020
dff1243
Index w/o name
xinrong-meng Dec 1, 2020
41ec2ff
Not combine anchor but to_frame
xinrong-meng Dec 1, 2020
a69f393
Type annotations
xinrong-meng Dec 1, 2020
7dbad5f
Use DataFrame not ks.DataFrame
xinrong-meng Dec 1, 2020
6d92a84
Same anchor inputs w/o combine_frames
xinrong-meng Dec 2, 2020
2899bed
Restructure tests
xinrong-meng Dec 2, 2020
8dd507c
Restore ser.dot(ser) with same length index
xinrong-meng Dec 2, 2020
9586625
Refactor
xinrong-meng Dec 2, 2020
10127ac
ser.dot(df) not same_anchor, must enable ops_on_diff_frames
xinrong-meng Dec 3, 2020
812de5a
Remove exception
xinrong-meng Dec 4, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions databricks/koalas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4719,9 +4719,9 @@ def combine_first(self, other) -> "Series":
*index_scols, cond.alias(self._internal.data_spark_column_names[0])
).distinct()
internal = self._internal.with_new_sdf(sdf)
return first_series(ks.DataFrame(internal))
return first_series(DataFrame(internal))

def dot(self, other) -> Union[Scalar, "Series"]:
def dot(self, other: Union["Series", DataFrame]) -> Union[Scalar, "Series"]:
"""
Compute the dot product between the Series and the columns of other.

Expand All @@ -4730,7 +4730,7 @@ def dot(self, other) -> Union[Scalar, "Series"]:

It can also be called using `self @ other` in Python >= 3.5.

.. note:: This API is slightly different from pandas when indexes from both
.. note:: This API is slightly different from pandas when indexes from both Series
are not aligned. To match with pandas', it requires to read the whole data for,
for example, counting. pandas raises an exception; however, Koalas just proceeds
and performs by ignoring mismatches with NaN permissively.
Expand Down Expand Up @@ -4772,20 +4772,49 @@ def dot(self, other) -> Union[Scalar, "Series"]:

>>> s @ s
14

>>> kdf = ks.DataFrame({'x': [0, 1, 2, 3], 'y': [0, -1, -2, -3]})
>>> kdf
x y
0 0 0
1 1 -1
2 2 -2
3 3 -3

>>> with ks.option_context("compute.ops_on_diff_frames", True):
... s.dot(kdf)
...
x 14
y -14
dtype: int64
"""
if isinstance(other, DataFrame):
raise ValueError(
"Series.dot() is currently not supported with DataFrame since "
"it will cause expansive calculation as many as the number "
"of columns of DataFrame"
)
if self._kdf is not other._kdf:
if len(self.index) != len(other.index):
if not same_anchor(self, other) and not self.index.sort_values().equals(
xinrong-meng marked this conversation as resolved.
Show resolved Hide resolved
other.index.sort_values()
):
raise ValueError("matrices are not aligned")
if isinstance(other, Series):
result = (self * other).sum()

return result
other = other.copy()
column_labels = other._internal.column_labels

self_column_label = verify_temp_column_name(other, "__self_column__")
other[self_column_label] = self
xinrong-meng marked this conversation as resolved.
Show resolved Hide resolved
self_kser = other._kser_for(self_column_label)

product_ksers = [other._kser_for(label) * self_kser for label in column_labels]

dot_product_kser = DataFrame(
other._internal.with_new_columns(product_ksers, column_labels)
).sum()

return cast(Series, dot_product_kser).rename(self.name)

else:
assert isinstance(other, Series)
if not same_anchor(self, other):
if len(self.index) != len(other.index):
raise ValueError("matrices are not aligned")
return (self * other).sum()

def __matmul__(self, other):
"""
Expand Down Expand Up @@ -4940,7 +4969,7 @@ def asof(self, where) -> Union[Scalar, "Series"]:
should_return_series = True
if isinstance(self.index, ks.MultiIndex):
raise ValueError("asof is not supported for a MultiIndex")
if isinstance(where, (ks.Index, ks.Series, ks.DataFrame)):
if isinstance(where, (ks.Index, ks.Series, DataFrame)):
raise ValueError("where cannot be an Index, Series or a DataFrame")
if not self.index.is_monotonic_increasing:
raise ValueError("asof requires a sorted index")
Expand Down
44 changes: 36 additions & 8 deletions databricks/koalas/tests/test_ops_on_diff_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,13 +870,6 @@ def test_dot(self):
with self.assertRaisesRegex(ValueError, "matrices are not aligned"):
kser.dot(kser_other)

# with DataFram is not supported for now since performance issue,
# now we raise ValueError with proper message instead.
kdf = ks.DataFrame([[0, 1], [-2, 3], [4, -5]], index=[2, 4, 1])

with self.assertRaisesRegex(ValueError, r"Series\.dot\(\) is currently not supported*"):
kser.dot(kdf)

# for MultiIndex
midx = pd.MultiIndex(
[["lama", "cow", "falcon"], ["speed", "weight", "length"]],
Expand All @@ -886,9 +879,44 @@ def test_dot(self):
kser = ks.from_pandas(pser)
pser_other = pd.Series([-450, 20, 12, -30, -250, 15, -320, 100, 3], index=midx)
kser_other = ks.from_pandas(pser_other)

self.assert_eq(kser.dot(kser_other), pser.dot(pser_other))

pser = pd.Series([0, 1, 2, 3])
kser = ks.from_pandas(pser)

# DataFrame "other" without Index/MultiIndex as columns
pdf = pd.DataFrame([[0, 1], [-2, 3], [4, -5], [6, 7]])
kdf = ks.from_pandas(pdf)
self.assert_eq(kser.dot(kdf), pser.dot(pdf))

# DataFrame "other" with Index as columns
pdf.columns = pd.Index(["x", "y"])
xinrong-meng marked this conversation as resolved.
Show resolved Hide resolved
kdf = ks.from_pandas(pdf)
self.assert_eq(kser.dot(kdf), pser.dot(pdf))
pdf.columns = pd.Index(["x", "y"], name="cols_name")
kdf = ks.from_pandas(pdf)
self.assert_eq(kser.dot(kdf), pser.dot(pdf))

pdf = pdf.reindex([1, 0, 2, 3])
kdf = ks.from_pandas(pdf)
self.assert_eq(kser.dot(kdf), pser.dot(pdf))

# DataFrame "other" with MultiIndex as columns
pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y")])
xinrong-meng marked this conversation as resolved.
Show resolved Hide resolved
kdf = ks.from_pandas(pdf)
self.assert_eq(kser.dot(kdf), pser.dot(pdf))
pdf.columns = pd.MultiIndex.from_tuples(
[("a", "x"), ("b", "y")], names=["cols_name1", "cols_name2"]
)
kdf = ks.from_pandas(pdf)
self.assert_eq(kser.dot(kdf), pser.dot(pdf))

kser = ks.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).b
pser = kser.to_pandas()
kdf = ks.DataFrame({"c": [7, 8, 9]})
pdf = kdf.to_pandas()
self.assert_eq(kser.dot(kdf), pser.dot(pdf))

def test_to_series_comparison(self):
kidx1 = ks.Index([1, 2, 3, 4, 5])
kidx2 = ks.Index([1, 2, 3, 4, 5])
Expand Down
8 changes: 8 additions & 0 deletions databricks/koalas/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2123,6 +2123,14 @@ def test_droplevel(self):
pser.droplevel([("a", "1"), ("c", "3")]), kser.droplevel([("a", "1"), ("c", "3")])
)

def test_dot(self):
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
kdf = ks.from_pandas(pdf)

self.assert_eq((kdf["b"] * 10).dot(kdf["a"]), (pdf["b"] * 10).dot(pdf["a"]))
self.assert_eq((kdf["b"] * 10).dot(kdf), (pdf["b"] * 10).dot(pdf))
self.assert_eq((kdf["b"] * 10).dot(kdf + 1), (pdf["b"] * 10).dot(pdf + 1))

@unittest.skipIf(
LooseVersion(pyspark.__version__) < LooseVersion("3.0"),
"tail won't work properly with PySpark<3.0",
Expand Down