Skip to content

Commit

Permalink
Make is_monotonic/is_monotonic_decreasing distributed
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Mar 18, 2020
1 parent da3740d commit 534148b
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 74 deletions.
99 changes: 80 additions & 19 deletions databricks/koalas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,9 @@ def is_monotonic(self):
"""
Return boolean if values in the object are monotonically increasing.
.. note:: the current implementation of is_monotonic_increasing uses Spark's
Window without specifying partition specification. This leads to move all data into
single partition in single machine and could cause serious
performance degradation. Avoid this method against very large dataset.
.. note:: the current implementation of is_monotonic requires to shuffle
and aggregate multiple times to check the order locally and globally,
which is potentially expensive.
Returns
-------
Expand Down Expand Up @@ -385,12 +384,7 @@ def is_monotonic(self):
>>> midx.is_monotonic
False
"""
return self._is_monotonic().all()

def _is_monotonic(self):
col = self._scol
window = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(-1, -1)
return self._with_new_scol((col >= F.lag(col, 1).over(window)) & col.isNotNull())
return self._is_monotonic("increasing")

is_monotonic_increasing = is_monotonic

Expand All @@ -399,10 +393,9 @@ def is_monotonic_decreasing(self):
"""
Return boolean if values in the object are monotonically decreasing.
.. note:: the current implementation of is_monotonic_decreasing uses Spark's
Window without specifying partition specification. This leads to move all data into
single partition in single machine and could cause serious
performance degradation. Avoid this method against very large dataset.
.. note:: the current implementation of is_monotonic_decreasing requires to shuffle
and aggregate multiple times to check the order locally and globally,
which is potentially expensive.
Returns
-------
Expand Down Expand Up @@ -465,12 +458,80 @@ def is_monotonic_decreasing(self):
>>> midx.is_monotonic_decreasing
True
"""
return self._is_monotonic_decreasing().all()
return self._is_monotonic("decreasing")

def _is_monotonic_decreasing(self):
col = self._scol
window = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(-1, -1)
return self._with_new_scol((col <= F.lag(col, 1).over(window)) & col.isNotNull())
def _is_locally_monotonic_spark_column(self, order):
window = (
Window.partitionBy(F.col("__partition_id"))
.orderBy(NATURAL_ORDER_COLUMN_NAME)
.rowsBetween(-1, -1)
)

if order == "increasing":
return (F.col("__origin") >= F.lag(F.col("__origin"), 1).over(window)) & F.col(
"__origin"
).isNotNull()
else:
return (F.col("__origin") <= F.lag(F.col("__origin"), 1).over(window)) & F.col(
"__origin"
).isNotNull()

def _is_monotonic(self, order):
assert order in ("increasing", "decreasing")

sdf = self._internal.spark_frame

sdf = (
sdf.select(
F.spark_partition_id().alias(
"__partition_id"
), # Make sure we use the same partition id in the whole job.
F.col(NATURAL_ORDER_COLUMN_NAME),
self._scol.alias("__origin"),
)
.select(
F.col("__partition_id"),
F.col("__origin"),
self._is_locally_monotonic_spark_column(order).alias(
"__comparison_within_partition"
),
)
.groupby(F.col("__partition_id"))
.agg(
F.min(F.col("__origin")).alias("__partition_min"),
F.max(F.col("__origin")).alias("__partition_max"),
F.min(F.coalesce(F.col("__comparison_within_partition"), F.lit(True))).alias(
"__comparison_within_partition"
),
)
)

# Now we're windowing the aggregation results without partition specification.
# The number of rows here will be as the same of partitions, which is expected
# to be small.
window = Window.orderBy(F.col("__partition_id")).rowsBetween(-1, -1)
if order == "increasing":
comparison_col = F.col("__partition_min") >= F.lag(F.col("__partition_max"), 1).over(
window
)
else:
comparison_col = F.col("__partition_min") <= F.lag(F.col("__partition_max"), 1).over(
window
)

sdf = sdf.select(
comparison_col.alias("__comparison_between_partitions"),
F.col("__comparison_within_partition"),
)

ret = sdf.select(
F.min(F.coalesce(F.col("__comparison_between_partitions"), F.lit(True)))
& F.min(F.coalesce(F.col("__comparison_within_partition"), F.lit(True)))
).collect()[0][0]
if ret is None:
return True
else:
return ret

@property
def ndim(self):
Expand Down
59 changes: 19 additions & 40 deletions databricks/koalas/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1930,31 +1930,6 @@ def _comparator_for_monotonic_increasing(data_type):
else:
return compare_null_last

def _is_monotonic(self):
scol = self._scol
window = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(-1, -1)
prev = F.lag(scol, 1).over(window)

cond = F.lit(True)
for field in self.spark_type[::-1]:
left = scol.getField(field.name)
right = prev.getField(field.name)
compare = MultiIndex._comparator_for_monotonic_increasing(field.dataType)
cond = F.when(left.eqNullSafe(right), cond).otherwise(
compare(left, right, spark.Column.__gt__)
)

cond = prev.isNull() | cond

internal = _InternalFrame(
spark_frame=self._internal.spark_frame.select(
self._internal.index_spark_columns + [cond]
),
index_map=self._internal.index_map,
)

return _col(DataFrame(internal))

@staticmethod
def _comparator_for_monotonic_decreasing(data_type):
if isinstance(data_type, StringType):
Expand All @@ -1966,30 +1941,34 @@ def _comparator_for_monotonic_decreasing(data_type):
else:
return compare_null_first

def _is_monotonic_decreasing(self):
scol = self._scol
window = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(-1, -1)
prev = F.lag(scol, 1).over(window)
def _is_locally_monotonic_spark_column(self, order):
window = (
Window.partitionBy(F.col("__partition_id"))
.orderBy(NATURAL_ORDER_COLUMN_NAME)
.rowsBetween(-1, -1)
)

scol = F.col("__origin")
prev = F.lag(scol, 1).over(window)
cond = F.lit(True)
for field in self.spark_type[::-1]:
left = scol.getField(field.name)
right = prev.getField(field.name)
compare = MultiIndex._comparator_for_monotonic_decreasing(field.dataType)
if order == "increasing":
compare = MultiIndex._comparator_for_monotonic_increasing(field.dataType)
else:
compare = MultiIndex._comparator_for_monotonic_increasing(field.dataType)

cond = F.when(left.eqNullSafe(right), cond).otherwise(
compare(left, right, spark.Column.__lt__)
compare(
left,
right,
spark.Column.__gt__ if order == "increasing" else spark.Column.__lt__,
)
)

cond = prev.isNull() | cond

internal = _InternalFrame(
spark_frame=self._internal.spark_frame.select(
self._internal.index_spark_columns + [cond]
),
index_map=self._internal.index_map,
)

return _col(DataFrame(internal))
return cond

def to_frame(self, index=True, name=None) -> DataFrame:
"""
Expand Down
19 changes: 4 additions & 15 deletions databricks/koalas/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,30 +568,19 @@ def _select_rows(self, rows_sel):
if (start is None and rows_sel.start is not None) or (
stop is None and rows_sel.stop is not None
):
inc, dec = (
sdf.select(
index_column._is_monotonic()._scol.alias("__increasing__"),
index_column._is_monotonic_decreasing()._scol.alias("__decreasing__"),
)
.select(
F.min(F.coalesce("__increasing__", F.lit(True))),
F.min(F.coalesce("__decreasing__", F.lit(True))),
)
.first()
)
if start is None and rows_sel.start is not None:
start = rows_sel.start
if inc is not False:
if index_column.is_monotonic_increasing is not False:
cond.append(index_column._scol >= F.lit(start).cast(index_data_type))
elif dec is not False:
elif index_column.is_monotonic_decreasing is not False:
cond.append(index_column._scol <= F.lit(start).cast(index_data_type))
else:
raise KeyError(rows_sel.start)
if stop is None and rows_sel.stop is not None:
stop = rows_sel.stop
if inc is not False:
if index_column.is_monotonic_increasing is not False:
cond.append(index_column._scol <= F.lit(stop).cast(index_data_type))
elif dec is not False:
elif index_column.is_monotonic_decreasings is not False:
cond.append(index_column._scol >= F.lit(stop).cast(index_data_type))
else:
raise KeyError(rows_sel.stop)
Expand Down

0 comments on commit 534148b

Please sign in to comment.