Skip to content

Commit

Permalink
Implement DataFrame.dot supporting Series parameter (#1945)
Browse files Browse the repository at this point in the history
Compute the matrix multiplication between the DataFrame and other **series only**.

```
        >>> kdf = ks.DataFrame([[0, 1, -2, -1], [1, 1, 1, 1]])
        >>> kser = ks.Series([1, 1, 2, 1])
        >>> kdf.dot(kser)
        0   -4
        1    5
        dtype: int64

        Note how shuffling of the objects does not change the result.

        >>> kser2 = kser.reindex([1, 0, 2, 3])
        >>> kdf.dot(kser2)
        0   -4
        1    5
        dtype: int64
```
  • Loading branch information
xinrong-meng authored Dec 8, 2020
1 parent 01ada38 commit 9e8d99b
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 2 deletions.
88 changes: 88 additions & 0 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4008,6 +4008,94 @@ def duplicated(self, subset=None, keep="first") -> "Series":
)
)

# TODO: support other as DataFrame or array-like
def dot(self, other: "Series") -> "Series":
"""
Compute the matrix multiplication between the DataFrame and other.
This method computes the matrix product between the DataFrame and the
values of an other Series
It can also be called using ``self @ other`` in Python >= 3.5.
.. note:: This method is based on an expensive operation due to the nature
of big data. Internally it needs to generate each row for each value, and
then group twice - it is a huge operation. To prevent misusage, this method
has the 'compute.max_rows' default limit of input length, and raises a ValueError.
>>> from databricks.koalas.config import option_context
>>> with option_context(
... 'compute.max_rows', 1000, "compute.ops_on_diff_frames", True
... ): # doctest: +NORMALIZE_WHITESPACE
... kdf = ks.DataFrame({'a': range(1001)})
... kser = ks.Series([2], index=['a'])
... kdf.dot(kser)
Traceback (most recent call last):
...
ValueError: Current DataFrame has more then the given limit 1000 rows.
Please set 'compute.max_rows' by using 'databricks.koalas.config.set_option'
to retrieve to retrieve more than 1000 rows. Note that, before changing the
'compute.max_rows', this operation is considerably expensive.
Parameters
----------
other : Series
The other object to compute the matrix product with.
Returns
-------
Series
Return the matrix product between self and other as a Series.
See Also
--------
Series.dot: Similar method for Series.
Notes
-----
The dimensions of DataFrame and other must be compatible in order to
compute the matrix multiplication. In addition, the column names of
DataFrame and the index of other must contain the same values, as they
will be aligned prior to the multiplication.
The dot method for Series computes the inner product, instead of the
matrix product here.
Examples
--------
>>> from databricks.koalas.config import set_option, reset_option
>>> set_option("compute.ops_on_diff_frames", True)
>>> kdf = ks.DataFrame([[0, 1, -2, -1], [1, 1, 1, 1]])
>>> kser = ks.Series([1, 1, 2, 1])
>>> kdf.dot(kser)
0 -4
1 5
dtype: int64
Note how shuffling of the objects does not change the result.
>>> kser2 = kser.reindex([1, 0, 2, 3])
>>> kdf.dot(kser2)
0 -4
1 5
dtype: int64
>>> kdf @ kser2
0 -4
1 5
dtype: int64
>>> reset_option("compute.ops_on_diff_frames")
"""
if not isinstance(other, ks.Series):
raise TypeError("Unsupported type {}".format(type(other).__name__))
else:
return cast(ks.Series, other.dot(self.transpose())).rename(None)

def __matmul__(self, other):
"""
Matrix multiplication using binary `@` operator in Python>=3.5.
"""
return self.dot(other)

def to_koalas(self, index_col: Optional[Union[str, List[str]]] = None) -> "DataFrame":
"""
Converts the existing DataFrame into a Koalas DataFrame.
Expand Down
1 change: 0 additions & 1 deletion databricks/koalas/missing/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class _MissingPandasLikeDataFrame(object):
convert_dtypes = _unsupported_function("convert_dtypes")
corrwith = _unsupported_function("corrwith")
cov = _unsupported_function("cov")
dot = _unsupported_function("dot")
ewm = _unsupported_function("ewm")
first = _unsupported_function("first")
infer_objects = _unsupported_function("infer_objects")
Expand Down
53 changes: 52 additions & 1 deletion databricks/koalas/tests/test_ops_on_diff_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ def test_multi_index_column_assignment_frame(self):
with self.assertRaisesRegex(KeyError, "Key length \\(3\\) exceeds index depth \\(2\\)"):
kdf[("1", "2", "3")] = ks.Series([100, 200, 300, 200])

def test_dot(self):
def test_series_dot(self):
pser = pd.Series([90, 91, 85], index=[2, 4, 1])
kser = ks.from_pandas(pser)
pser_other = pd.Series([90, 91, 85], index=[2, 4, 1])
Expand Down Expand Up @@ -1200,6 +1200,57 @@ def test_dot(self):
pdf = kdf.to_pandas()
self.assert_eq(kser.dot(kdf), pser.dot(pdf))

def test_frame_dot(self):
pdf = pd.DataFrame([[0, 1, -2, -1], [1, 1, 1, 1]])
kdf = ks.from_pandas(pdf)

pser = pd.Series([1, 1, 2, 1])
kser = ks.from_pandas(pser)
self.assert_eq(kdf.dot(kser), pdf.dot(pser))

# Index reorder
pser = pser.reindex([1, 0, 2, 3])
kser = ks.from_pandas(pser)
self.assert_eq(kdf.dot(kser), pdf.dot(pser))

# ser with name
pser.name = "ser"
kser = ks.from_pandas(pser)
self.assert_eq(kdf.dot(kser), pdf.dot(pser))

# df with MultiIndex as column (ser with MultiIndex)
arrays = [[1, 1, 2, 2], ["red", "blue", "red", "blue"]]
pidx = pd.MultiIndex.from_arrays(arrays, names=("number", "color"))
pser = pd.Series([1, 1, 2, 1], index=pidx)
pdf = pd.DataFrame([[0, 1, -2, -1], [1, 1, 1, 1]], columns=pidx)
kdf = ks.from_pandas(pdf)
kser = ks.from_pandas(pser)
self.assert_eq(kdf.dot(kser), pdf.dot(pser))

# df with Index as column (ser with Index)
pidx = pd.Index([1, 2, 3, 4], name="number")
pser = pd.Series([1, 1, 2, 1], index=pidx)
pdf = pd.DataFrame([[0, 1, -2, -1], [1, 1, 1, 1]], columns=pidx)
kdf = ks.from_pandas(pdf)
kser = ks.from_pandas(pser)
self.assert_eq(kdf.dot(kser), pdf.dot(pser))

# df with Index
pdf.index = pd.Index(["x", "y"], name="char")
kdf = ks.from_pandas(pdf)
self.assert_eq(kdf.dot(kser), pdf.dot(pser))

# df with MultiIndex
pdf.index = pd.MultiIndex.from_arrays([[1, 1], ["red", "blue"]], names=("number", "color"))
kdf = ks.from_pandas(pdf)
self.assert_eq(kdf.dot(kser), pdf.dot(pser))

pdf = pd.DataFrame([[1, 2], [3, 4]])
kdf = ks.from_pandas(pdf)
self.assert_eq(kdf.dot(kdf[0]), pdf.dot(pdf[0]))
self.assert_eq(kdf.dot(kdf[0] * 10), pdf.dot(pdf[0] * 10))
self.assert_eq((kdf + 1).dot(kdf[0] * 10), (pdf + 1).dot(pdf[0] * 10))

def test_to_series_comparison(self):
kidx1 = ks.Index([1, 2, 3, 4, 5])
kidx2 = ks.Index([1, 2, 3, 4, 5])
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/frame.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Binary operator functions
DataFrame.ge
DataFrame.ne
DataFrame.eq
DataFrame.dot

Function application, GroupBy & Window
--------------------------------------
Expand Down

0 comments on commit 9e8d99b

Please sign in to comment.