Skip to content

Commit

Permalink
Refine concat to handle the same anchor DataFrames properly. (#1627)
Browse files Browse the repository at this point in the history
Refine `concat` to handle the same anchor DataFrames properly and reduce unnecessary internal joins.
  • Loading branch information
ueshin authored Jul 8, 2020
1 parent ca3d277 commit 6a47ff6
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 51 deletions.
15 changes: 7 additions & 8 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9907,8 +9907,9 @@ def __setitem__(self, key, value):

if isinstance(value, (DataFrame, Series)) and not same_anchor(value, self):
# Different Series or DataFrames
key = self._index_normalized_label(key)
value = self._index_normalized_frame(value)
level = self._internal.column_labels_level
key = DataFrame._index_normalized_label(level, key)
value = DataFrame._index_normalized_frame(level, value)

def assign_columns(kdf, this_column_labels, that_column_labels):
assert len(key) == len(that_column_labels)
Expand All @@ -9933,14 +9934,13 @@ def assign_columns(kdf, this_column_labels, that_column_labels):

self._update_internal_frame(kdf._internal)

def _index_normalized_label(self, labels):
@staticmethod
def _index_normalized_label(level, labels):
"""
Returns a label that is normalized against the current column index level.
For example, the key "abc" can be ("abc", "", "") if the current Frame has
a multi-index for its column
"""
level = self._internal.column_labels_level

if isinstance(labels, str):
labels = [(labels,)]
elif isinstance(labels, tuple):
Expand All @@ -9956,16 +9956,15 @@ def _index_normalized_label(self, labels):
)
return [tuple(list(label) + ([""] * (level - len(label)))) for label in labels]

def _index_normalized_frame(self, kser_or_kdf):
@staticmethod
def _index_normalized_frame(level, kser_or_kdf):
"""
Returns a frame that is normalized against the current column index level.
For example, the name in `pd.Series([...], name="abc")` can be can be
("abc", "", "") if the current DataFrame has a multi-index for its column
"""

from databricks.koalas.series import Series

level = self._internal.column_labels_level
if isinstance(kser_or_kdf, Series):
kdf = kser_or_kdf.to_frame()
else:
Expand Down
88 changes: 46 additions & 42 deletions databricks/koalas/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@
from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
from databricks.koalas.base import IndexOpsMixin
from databricks.koalas.utils import (
align_diff_frames,
default_session,
name_like_string,
same_anchor,
scol_for,
validate_axis,
align_diff_frames,
)
from databricks.koalas.frame import DataFrame, _reduce_spark_multi
from databricks.koalas.internal import InternalFrame, SPARK_DEFAULT_SERIES_NAME
Expand Down Expand Up @@ -1765,56 +1766,61 @@ def concat(objs, axis=0, join="outer", ignore_index=False):
"and ks.DataFrame are valid".format(name=type(objs).__name__)
)

if join not in ["inner", "outer"]:
raise ValueError("Only can inner (intersect) or outer (union) join the other axis.")

axis = validate_axis(axis)
if axis == 1:
if isinstance(objs[0], ks.Series):
concat_kdf = objs[0].to_frame()
else:
concat_kdf = objs[0]
kdfs = [obj.to_frame() if isinstance(obj, Series) else obj for obj in objs]

with ks.option_context("compute.ops_on_diff_frames", True):
level = min(kdf._internal.column_labels_level for kdf in kdfs)
kdfs = [
DataFrame._index_normalized_frame(level, kdf)
if kdf._internal.column_labels_level > level
else kdf
for kdf in kdfs
]

def resolve_func(kdf, this_column_labels, that_column_labels):
duplicated_names = set(
this_column_label[1:] for this_column_label in this_column_labels
).intersection(
set(that_column_label[1:] for that_column_label in that_column_labels)
)
assert (
len(duplicated_names) > 0
), "inner or full join type does not include non-common columns"
pretty_names = [name_like_string(column_label) for column_label in duplicated_names]
concat_kdf = kdfs[0]
column_labels = concat_kdf._internal.column_labels.copy()

kdfs_not_same_anchor = []
for kdf in kdfs[1:]:
duplicated = [label for label in kdf._internal.column_labels if label in column_labels]
if len(duplicated) > 0:
pretty_names = [name_like_string(label) for label in duplicated]
raise ValueError(
"Labels have to be unique; however, got duplicated labels %s." % pretty_names
)
column_labels.extend(kdf._internal.column_labels)

for kser_or_kdf in objs[1:]:
# TODO: there is a corner case to optimize - when the series are from
# the same DataFrame.
# FIXME: force to create a new Spark DataFrame to make sure the anchors are
# different.
that_kdf = DataFrame(kser_or_kdf._internal.resolved_copy)
if same_anchor(concat_kdf, kdf):
concat_kdf = DataFrame(
concat_kdf._internal.with_new_columns(
concat_kdf._internal.data_spark_columns + kdf._internal.data_spark_columns,
concat_kdf._internal.column_labels + kdf._internal.column_labels,
)
)
else:
kdfs_not_same_anchor.append(kdf)

this_index_level = concat_kdf._internal.column_labels_level
that_index_level = that_kdf._internal.column_labels_level
if len(kdfs_not_same_anchor) > 0:
with ks.option_context("compute.ops_on_diff_frames", True):

if this_index_level > that_index_level:
concat_kdf = that_kdf._index_normalized_frame(concat_kdf)
if this_index_level < that_index_level:
that_kdf = concat_kdf._index_normalized_frame(that_kdf)
def resolve_func(kdf, this_column_labels, that_column_labels):
raise AssertionError("This should not happen.")

if join == "inner":
concat_kdf = align_diff_frames(
resolve_func, concat_kdf, that_kdf, fillna=False, how="inner",
)
elif join == "outer":
concat_kdf = align_diff_frames(
resolve_func, concat_kdf, that_kdf, fillna=False, how="full",
)
else:
raise ValueError(
"Only can inner (intersect) or outer (union) join the other axis."
)
for kdf in kdfs_not_same_anchor:
if join == "inner":
concat_kdf = align_diff_frames(
resolve_func, concat_kdf, kdf, fillna=False, how="inner",
)
elif join == "outer":
concat_kdf = align_diff_frames(
resolve_func, concat_kdf, kdf, fillna=False, how="full",
)

concat_kdf = concat_kdf[column_labels]

if ignore_index:
concat_kdf.columns = list(map(str, _range(len(concat_kdf.columns))))
Expand Down Expand Up @@ -1904,8 +1910,6 @@ def resolve_func(kdf, this_column_labels, that_column_labels):
)

kdfs.append(kdf[merged_columns])
else:
raise ValueError("Only can inner (intersect) or outer (union) join the other axis.")

if ignore_index:
sdfs = [kdf._internal.spark_frame.select(kdf._internal.data_spark_columns) for kdf in kdfs]
Expand Down
2 changes: 1 addition & 1 deletion databricks/koalas/tests/test_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_concat_column_axis(self):
for ignore_index, join in itertools.product(ignore_indexes, joins):
for obj in objs:
kdfs, pdfs = obj
with self.subTest(ignore_index=ignore_index, join=join, objs=obj):
with self.subTest(ignore_index=ignore_index, join=join, objs=pdfs):
actual = ks.concat(kdfs, axis=1, ignore_index=ignore_index, join=join)
expected = pd.concat(pdfs, axis=1, ignore_index=ignore_index, join=join)
self.assert_eq(
Expand Down

0 comments on commit 6a47ff6

Please sign in to comment.