Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented argmin & argmax for Series #1790

Merged
merged 14 commits into from
Sep 25, 2020
2 changes: 0 additions & 2 deletions databricks/koalas/missing/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ class MissingPandasLikeSeries(object):
compound = _unsupported_function("compound", deprecated=True)
put = _unsupported_function("put", deprecated=True)
ptp = _unsupported_function("ptp", deprecated=True)
argmax = _unsupported_function("argmax", deprecated=True)
argmin = _unsupported_function("argmin", deprecated=True)

# Functions we won't support.
real = _unsupported_property(
Expand Down
94 changes: 94 additions & 0 deletions databricks/koalas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5189,6 +5189,100 @@ def explode(self) -> "Series":
internal = internal.copy(spark_frame=internal.spark_frame.drop(NATURAL_ORDER_COLUMN_NAME))
return first_series(DataFrame(internal))

def argmax(self):
"""
Return int position of the largest value in the Series.

If the maximum is achieved in multiple locations,
the first row position is returned.

Returns
-------
int
Row position of the maximum value.

Examples
--------
Consider dataset containing cereal calories

>>> s = ks.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0,
... 'Cinnamon Toast Crunch': 120.0, 'Cocoa Puff': 110.0})
>>> s # doctest: +SKIP
Corn Flakes 100.0
Almond Delight 110.0
Cinnamon Toast Crunch 120.0
Cocoa Puff 110.0
dtype: float64

>>> s.argmax() # doctest: +SKIP
2
"""
if self.empty:
raise ValueError("attempt to get argmax of an empty sequence")
itholic marked this conversation as resolved.
Show resolved Hide resolved

sdf = self._internal.spark_frame.select(self._internal.data_spark_columns)

# We should remember the natural sequence started from 0
seq_col_name = verify_temp_column_name(sdf, "__distributed_sequence_column__")
sdf = InternalFrame.attach_distributed_sequence_column(sdf, seq_col_name)

col_name = (
SPARK_DEFAULT_SERIES_NAME if self._column_label is None else self._column_label[0]
)
itholic marked this conversation as resolved.
Show resolved Hide resolved
max_value = sdf.select(F.max(col_name)).head(1)[0][0]
itholic marked this conversation as resolved.
Show resolved Hide resolved

# If the maximum is achieved in multiple locations, the first row position is returned.
max_value_position = sdf.filter(F.col(col_name) == max_value).head(1)[0][0]

return max_value_position

def argmin(self):
"""
Return int position of the smallest value in the Series.

If the minimum is achieved in multiple locations,
the first row position is returned.

Returns
-------
int
Row position of the minimum value.

Examples
--------
Consider dataset containing cereal calories

>>> s = ks.Series({'Corn Flakes': 100.0, 'Almond Delight': 110.0,
... 'Cinnamon Toast Crunch': 120.0, 'Cocoa Puff': 110.0})
>>> s # doctest: +SKIP
Corn Flakes 100.0
Almond Delight 110.0
Cinnamon Toast Crunch 120.0
Cocoa Puff 110.0
dtype: float64

>>> s.argmin() # doctest: +SKIP
0
"""
if self.empty:
raise ValueError("attempt to get argmin of an empty sequence")
itholic marked this conversation as resolved.
Show resolved Hide resolved

sdf = self._internal.spark_frame.select(self._internal.data_spark_columns)

# We should remember the natural sequence started from 0
seq_col_name = verify_temp_column_name(sdf, "__distributed_sequence_column__")
sdf = InternalFrame.attach_distributed_sequence_column(sdf, seq_col_name)

col_name = (
SPARK_DEFAULT_SERIES_NAME if self._column_label is None else self._column_label[0]
)
min_value = sdf.select(F.min(col_name)).head(1)[0][0]
itholic marked this conversation as resolved.
Show resolved Hide resolved

# If the minimum is achieved in multiple locations, the first row position is returned.
min_value_position = sdf.filter(F.col(col_name) == min_value).head(1)[0][0]

return min_value_position

def _cum(self, func, skipna, part_cols=(), ascending=True):
# This is used to cummin, cummax, cumsum, etc.

Expand Down
42 changes: 42 additions & 0 deletions databricks/koalas/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2045,3 +2045,45 @@ def test_explode(self):
kser = ks.from_pandas(pser)
expected = pser
self.assert_eq(kser.explode(), expected)

def test_argmin_argmax(self):
pser = pd.Series(
{
"Corn Flakes": 100.0,
"Almond Delight": 110.0,
"Cinnamon Toast Crunch": 120.0,
"Cocoa Puff": 110.0,
"Expensive Flakes": 120.0,
"Cheap Flakes": 100.0,
},
name="Koalas",
)
kser = ks.from_pandas(pser)

if LooseVersion(pd.__version__) >= LooseVersion("1.0"):
self.assert_eq(pser.argmin(), kser.argmin())
self.assert_eq(pser.argmax(), kser.argmax())

# MultiIndex
pser.index = pd.MultiIndex.from_tuples(
[("a", "u"), ("b", "v"), ("c", "w"), ("d", "x"), ("e", "y"), ("f", "z")]
)
kser = ks.from_pandas(pser)
self.assert_eq(pser.argmin(), kser.argmin())
self.assert_eq(pser.argmax(), kser.argmax())
itholic marked this conversation as resolved.
Show resolved Hide resolved
else:
self.assert_eq(pser.values.argmin(), kser.argmin())
self.assert_eq(pser.values.argmax(), kser.argmax())

# MultiIndex
pser.index = pd.MultiIndex.from_tuples(
[("a", "u"), ("b", "v"), ("c", "w"), ("d", "x"), ("e", "y"), ("f", "z")]
)
kser = ks.from_pandas(pser)
self.assert_eq(pser.values.argmin(), kser.argmin())
self.assert_eq(pser.values.argmax(), kser.argmax())

with self.assertRaisesRegex(ValueError, "attempt to get argmin of an empty sequence"):
ks.Series([]).argmin()
with self.assertRaisesRegex(ValueError, "attempt to get argmax of an empty sequence"):
ks.Series([]).argmax()
2 changes: 2 additions & 0 deletions docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ Reshaping, sorting, transposing
.. autosummary::
:toctree: api/

Series.argmin
Series.argmax
Series.sort_index
Series.sort_values
Series.unstack
Expand Down