Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine concat to handle the same anchor DataFrames properly. #1627

Merged
merged 2 commits into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -10285,8 +10285,9 @@ def __setitem__(self, key, value):

if isinstance(value, (DataFrame, Series)) and not same_anchor(value, self):
# Different Series or DataFrames
key = self._index_normalized_label(key)
value = self._index_normalized_frame(value)
level = self._internal.column_labels_level
key = DataFrame._index_normalized_label(level, key)
value = DataFrame._index_normalized_frame(level, value)

def assign_columns(kdf, this_column_labels, that_column_labels):
assert len(key) == len(that_column_labels)
Expand All @@ -10311,14 +10312,13 @@ def assign_columns(kdf, this_column_labels, that_column_labels):

self._update_internal_frame(kdf._internal)

def _index_normalized_label(self, labels):
@staticmethod
def _index_normalized_label(level, labels):
"""
Returns a label that is normalized against the current column index level.
For example, the key "abc" can be ("abc", "", "") if the current Frame has
a multi-index for its column
"""
level = self._internal.column_labels_level

if isinstance(labels, str):
labels = [(labels,)]
elif isinstance(labels, tuple):
Expand All @@ -10334,16 +10334,15 @@ def _index_normalized_label(self, labels):
)
return [tuple(list(label) + ([""] * (level - len(label)))) for label in labels]

def _index_normalized_frame(self, kser_or_kdf):
@staticmethod
def _index_normalized_frame(level, kser_or_kdf):
"""
Returns a frame that is normalized against the current column index level.
For example, the name in `pd.Series([...], name="abc")` can be can be
("abc", "", "") if the current DataFrame has a multi-index for its column
"""

from databricks.koalas.series import Series

level = self._internal.column_labels_level
if isinstance(kser_or_kdf, Series):
kdf = kser_or_kdf.to_frame()
else:
Expand Down
88 changes: 46 additions & 42 deletions databricks/koalas/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@
from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
from databricks.koalas.base import IndexOpsMixin
from databricks.koalas.utils import (
align_diff_frames,
default_session,
name_like_string,
same_anchor,
scol_for,
validate_axis,
align_diff_frames,
)
from databricks.koalas.frame import DataFrame, _reduce_spark_multi
from databricks.koalas.internal import InternalFrame, SPARK_DEFAULT_SERIES_NAME
Expand Down Expand Up @@ -1765,56 +1766,61 @@ def concat(objs, axis=0, join="outer", ignore_index=False):
"and ks.DataFrame are valid".format(name=type(objs).__name__)
)

if join not in ["inner", "outer"]:
raise ValueError("Only can inner (intersect) or outer (union) join the other axis.")

axis = validate_axis(axis)
if axis == 1:
if isinstance(objs[0], ks.Series):
concat_kdf = objs[0].to_frame()
else:
concat_kdf = objs[0]
kdfs = [obj.to_frame() if isinstance(obj, Series) else obj for obj in objs]

with ks.option_context("compute.ops_on_diff_frames", True):
level = min(kdf._internal.column_labels_level for kdf in kdfs)
kdfs = [
DataFrame._index_normalized_frame(level, kdf)
if kdf._internal.column_labels_level > level
else kdf
for kdf in kdfs
]

def resolve_func(kdf, this_column_labels, that_column_labels):
duplicated_names = set(
this_column_label[1:] for this_column_label in this_column_labels
).intersection(
set(that_column_label[1:] for that_column_label in that_column_labels)
)
assert (
len(duplicated_names) > 0
), "inner or full join type does not include non-common columns"
pretty_names = [name_like_string(column_label) for column_label in duplicated_names]
concat_kdf = kdfs[0]
column_labels = concat_kdf._internal.column_labels.copy()

kdfs_not_same_anchor = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kdfs_different_anchor :-)?

for kdf in kdfs[1:]:
duplicated = [label for label in kdf._internal.column_labels if label in column_labels]
if len(duplicated) > 0:
pretty_names = [name_like_string(label) for label in duplicated]
raise ValueError(
"Labels have to be unique; however, got duplicated labels %s." % pretty_names
)
column_labels.extend(kdf._internal.column_labels)

for kser_or_kdf in objs[1:]:
# TODO: there is a corner case to optimize - when the series are from
# the same DataFrame.
# FIXME: force to create a new Spark DataFrame to make sure the anchors are
# different.
that_kdf = DataFrame(kser_or_kdf._internal.resolved_copy)
if same_anchor(concat_kdf, kdf):
concat_kdf = DataFrame(
concat_kdf._internal.with_new_columns(
concat_kdf._internal.data_spark_columns + kdf._internal.data_spark_columns,
concat_kdf._internal.column_labels + kdf._internal.column_labels,
)
)
else:
kdfs_not_same_anchor.append(kdf)

this_index_level = concat_kdf._internal.column_labels_level
that_index_level = that_kdf._internal.column_labels_level
if len(kdfs_not_same_anchor) > 0:
with ks.option_context("compute.ops_on_diff_frames", True):

if this_index_level > that_index_level:
concat_kdf = that_kdf._index_normalized_frame(concat_kdf)
if this_index_level < that_index_level:
that_kdf = concat_kdf._index_normalized_frame(that_kdf)
def resolve_func(kdf, this_column_labels, that_column_labels):
raise AssertionError("This should not happen.")

if join == "inner":
concat_kdf = align_diff_frames(
resolve_func, concat_kdf, that_kdf, fillna=False, how="inner",
)
elif join == "outer":
concat_kdf = align_diff_frames(
resolve_func, concat_kdf, that_kdf, fillna=False, how="full",
)
else:
raise ValueError(
"Only can inner (intersect) or outer (union) join the other axis."
)
for kdf in kdfs_not_same_anchor:
if join == "inner":
concat_kdf = align_diff_frames(
resolve_func, concat_kdf, kdf, fillna=False, how="inner",
)
elif join == "outer":
concat_kdf = align_diff_frames(
resolve_func, concat_kdf, kdf, fillna=False, how="full",
)

concat_kdf = concat_kdf[column_labels]

if ignore_index:
concat_kdf.columns = list(map(str, _range(len(concat_kdf.columns))))
Expand Down Expand Up @@ -1904,8 +1910,6 @@ def resolve_func(kdf, this_column_labels, that_column_labels):
)

kdfs.append(kdf[merged_columns])
else:
raise ValueError("Only can inner (intersect) or outer (union) join the other axis.")

if ignore_index:
sdfs = [kdf._internal.spark_frame.select(kdf._internal.data_spark_columns) for kdf in kdfs]
Expand Down
2 changes: 1 addition & 1 deletion databricks/koalas/tests/test_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_concat_column_axis(self):
for ignore_index, join in itertools.product(ignore_indexes, joins):
for obj in objs:
kdfs, pdfs = obj
with self.subTest(ignore_index=ignore_index, join=join, objs=obj):
with self.subTest(ignore_index=ignore_index, join=join, objs=pdfs):
actual = ks.concat(kdfs, axis=1, ignore_index=ignore_index, join=join)
expected = pd.concat(pdfs, axis=1, ignore_index=ignore_index, join=join)
self.assert_eq(
Expand Down