Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manage data_spark_columns to avoid creating very many Spark DataFrames. #1554

Merged
merged 3 commits into from
Jun 1, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions databricks/koalas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def empty(self):
>>> ks.DataFrame({}, index=list('abc')).index.empty
False
"""
return self._internal._sdf.rdd.isEmpty()
return self._internal.applied.spark_frame.rdd.isEmpty()

@property
def hasnans(self):
Expand Down Expand Up @@ -860,7 +860,7 @@ def all(self, axis: Union[int, str] = 0) -> bool:
if axis != 0:
raise NotImplementedError('axis should be either 0 or "index" currently.')

sdf = self._internal._sdf.select(self.spark.column)
sdf = self._internal.spark_frame.select(self.spark.column)
col = scol_for(sdf, sdf.columns[0])

# Note that we're ignoring `None`s here for now.
Expand Down Expand Up @@ -923,7 +923,7 @@ def any(self, axis: Union[int, str] = 0) -> bool:
if axis != 0:
raise NotImplementedError('axis should be either 0 or "index" currently.')

sdf = self._internal._sdf.select(self.spark.column)
sdf = self._internal.spark_frame.select(self.spark.column)
col = scol_for(sdf, sdf.columns[0])

# Note that we're ignoring `None`s here for now.
Expand Down Expand Up @@ -1156,9 +1156,9 @@ def value_counts(self, normalize=False, sort=True, ascending=False, bins=None, d
raise NotImplementedError("value_counts currently does not support bins")

if dropna:
sdf_dropna = self._internal._sdf.select(self.spark.column).dropna()
sdf_dropna = self._internal.spark_frame.select(self.spark.column).dropna()
else:
sdf_dropna = self._internal._sdf.select(self.spark.column)
sdf_dropna = self._internal.spark_frame.select(self.spark.column)
index_name = SPARK_DEFAULT_INDEX_NAME
column_name = self._internal.data_spark_column_names[0]
sdf = sdf_dropna.groupby(scol_for(sdf_dropna, column_name).alias(index_name)).count()
Expand Down Expand Up @@ -1241,7 +1241,7 @@ def nunique(self, dropna: bool = True, approx: bool = False, rsd: float = 0.05)
>>> idx.nunique(dropna=False)
3
"""
res = self._internal._sdf.select([self._nunique(dropna, approx, rsd)])
res = self._internal.spark_frame.select([self._nunique(dropna, approx, rsd)])
return res.collect()[0][0]

def _nunique(self, dropna=True, approx=False, rsd=0.05):
Expand Down
190 changes: 113 additions & 77 deletions databricks/koalas/frame.py

Large diffs are not rendered by default.

135 changes: 72 additions & 63 deletions databricks/koalas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,9 @@ def aggregate(self, func_or_funcs=None, *args, **kwargs):

@staticmethod
def _spark_groupby(kdf, func, groupkeys=()):
sdf = kdf._sdf
groupkey_scols = [
s.spark.column.alias(SPARK_INDEX_NAME_FORMAT(i)) for i, s in enumerate(groupkeys)
]
groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(groupkeys))]
groupkey_scols = [s.spark.column.alias(name) for s, name in zip(groupkeys, groupkey_names)]

multi_aggs = any(isinstance(v, list) for v in func.values())
reordered = []
data_columns = []
Expand Down Expand Up @@ -273,11 +272,12 @@ def _spark_groupby(kdf, func, groupkeys=()):

else:
reordered.append(F.expr("{1}(`{0}`) as `{2}`".format(name, aggfunc, data_col)))
sdf = sdf.groupby(*groupkey_scols).agg(*reordered)

sdf = kdf._internal.spark_frame.select(groupkey_scols + kdf._internal.data_spark_columns)
sdf = sdf.groupby(*groupkey_names).agg(*reordered)
if len(groupkeys) > 0:
index_map = OrderedDict(
(SPARK_INDEX_NAME_FORMAT(i), s._internal.column_labels[0])
for i, s in enumerate(groupkeys)
(name, s._internal.column_labels[0]) for s, name in zip(groupkeys, groupkey_names)
)
else:
index_map = None
Expand Down Expand Up @@ -565,16 +565,16 @@ def size(self):
Name: B, dtype: int64
"""
groupkeys = self._groupkeys
groupkey_cols = [
s.alias(SPARK_INDEX_NAME_FORMAT(i)) for i, s in enumerate(self._groupkeys_scols)
]
sdf = self._kdf._sdf
sdf = sdf.groupby(*groupkey_cols).count()
groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(groupkeys))]
groupkey_scols = [s.spark.column.alias(name) for s, name in zip(groupkeys, groupkey_names)]
sdf = self._kdf._internal.spark_frame.select(
groupkey_scols + self._kdf._internal.data_spark_columns
)
sdf = sdf.groupby(*groupkey_names).count()
internal = InternalFrame(
spark_frame=sdf,
index_map=OrderedDict(
(SPARK_INDEX_NAME_FORMAT(i), s._internal.column_labels[0])
for i, s in enumerate(groupkeys)
(name, s._internal.column_labels[0]) for s, name in zip(groupkeys, groupkey_names)
),
data_spark_columns=[scol_for(sdf, "count")],
)
Expand Down Expand Up @@ -1012,7 +1012,7 @@ def pandas_apply(pdf, *a, **k):
else:
kdf_from_pandas = kser_or_kdf

return_schema = kdf_from_pandas._sdf.drop(*HIDDEN_COLUMNS).schema
return_schema = kdf_from_pandas._internal.spark_frame.drop(*HIDDEN_COLUMNS).schema
else:
if not is_series_groupby and getattr(return_sig, "__origin__", None) == ks.Series:
raise TypeError(
Expand Down Expand Up @@ -1139,7 +1139,7 @@ def filter(self, func):
if label not in self._column_labels_to_exlcude
]

data_schema = self._kdf[agg_columns]._internal.spark_frame.drop(*HIDDEN_COLUMNS).schema
data_schema = kdf[agg_columns]._internal.applied.spark_frame.drop(*HIDDEN_COLUMNS).schema

kdf, groupkey_labels, groupkey_names = GroupBy._prepare_group_map_apply(
kdf, self._groupkeys, agg_columns
Expand Down Expand Up @@ -1181,13 +1181,13 @@ def _prepare_group_map_apply(kdf, groupkeys, agg_columns):
]
kdf = kdf[[s.rename(label) for s, label in zip(groupkeys, groupkey_labels)] + agg_columns]
groupkey_names = [label if len(label) > 1 else label[0] for label in groupkey_labels]
return kdf, groupkey_labels, groupkey_names
return DataFrame(kdf._internal.applied), groupkey_labels, groupkey_names

@staticmethod
def _spark_group_map_apply(kdf, func, groupkeys_scols, return_schema, retain_index):
output_func = GroupBy._make_pandas_df_builder_func(kdf, func, return_schema, retain_index)
grouped_map_func = pandas_udf(return_schema, PandasUDFType.GROUPED_MAP)(output_func)
sdf = kdf._sdf.drop(*HIDDEN_COLUMNS)
sdf = kdf._internal.spark_frame.drop(*HIDDEN_COLUMNS)
return sdf.groupby(*groupkeys_scols).apply(grouped_map_func)

@staticmethod
Expand Down Expand Up @@ -1363,11 +1363,12 @@ def idxmax(self, skipna=True):
"""
if len(self._kdf._internal.index_names) != 1:
raise ValueError("idxmax only support one-level index now")
groupkeys = self._groupkeys
groupkey_cols = [
s.alias(SPARK_INDEX_NAME_FORMAT(i)) for i, s in enumerate(self._groupkeys_scols)
]
sdf = self._kdf._sdf

groupkey_names = ["__groupkey_{}__".format(i) for i in range(len(self._groupkeys))]

sdf = self._kdf._internal.spark_frame
for s, name in zip(self._groupkeys, groupkey_names):
sdf = sdf.withColumn(name, s.spark.column)
index = self._kdf._internal.index_spark_column_names[0]

stat_exprs = []
Expand All @@ -1378,19 +1379,21 @@ def idxmax(self, skipna=True):
order_column = Column(c._jc.desc_nulls_last())
else:
order_column = Column(c._jc.desc_nulls_first())
window = Window.partitionBy(groupkey_cols).orderBy(
window = Window.partitionBy(groupkey_names).orderBy(
order_column, NATURAL_ORDER_COLUMN_NAME
)
sdf = sdf.withColumn(
name, F.when(F.row_number().over(window) == 1, scol_for(sdf, index)).otherwise(None)
)
stat_exprs.append(F.max(scol_for(sdf, name)).alias(name))
sdf = sdf.groupby(*groupkey_cols).agg(*stat_exprs)

sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)

internal = InternalFrame(
spark_frame=sdf,
index_map=OrderedDict(
(SPARK_INDEX_NAME_FORMAT(i), s._internal.column_labels[0])
for i, s in enumerate(groupkeys)
(name, s._internal.column_labels[0])
for s, name in zip(self._groupkeys, groupkey_names)
),
column_labels=[kser._internal.column_labels[0] for kser in self._agg_columns],
data_spark_columns=[
Expand Down Expand Up @@ -1440,11 +1443,12 @@ def idxmin(self, skipna=True):
"""
if len(self._kdf._internal.index_names) != 1:
raise ValueError("idxmin only support one-level index now")
groupkeys = self._groupkeys
groupkey_cols = [
s.alias(SPARK_INDEX_NAME_FORMAT(i)) for i, s in enumerate(self._groupkeys_scols)
]
sdf = self._kdf._sdf

groupkey_names = ["__groupkey_{}__".format(i) for i in range(len(self._groupkeys))]

sdf = self._kdf._internal.spark_frame
for s, name in zip(self._groupkeys, groupkey_names):
sdf = sdf.withColumn(name, s.spark.column)
index = self._kdf._internal.index_spark_column_names[0]

stat_exprs = []
Expand All @@ -1455,19 +1459,21 @@ def idxmin(self, skipna=True):
order_column = Column(c._jc.asc_nulls_last())
else:
order_column = Column(c._jc.asc_nulls_first())
window = Window.partitionBy(groupkey_cols).orderBy(
window = Window.partitionBy(groupkey_names).orderBy(
order_column, NATURAL_ORDER_COLUMN_NAME
)
sdf = sdf.withColumn(
name, F.when(F.row_number().over(window) == 1, scol_for(sdf, index)).otherwise(None)
)
stat_exprs.append(F.max(scol_for(sdf, name)).alias(name))
sdf = sdf.groupby(*groupkey_cols).agg(*stat_exprs)

sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)

internal = InternalFrame(
spark_frame=sdf,
index_map=OrderedDict(
(SPARK_INDEX_NAME_FORMAT(i), s._internal.column_labels[0])
for i, s in enumerate(groupkeys)
(name, s._internal.column_labels[0])
for s, name in zip(self._groupkeys, groupkey_names)
),
column_labels=[kser._internal.column_labels[0] for kser in self._agg_columns],
data_spark_columns=[
Expand Down Expand Up @@ -1704,9 +1710,10 @@ def head(self, n=5):
]

kdf, groupkey_labels, _ = self._prepare_group_map_apply(kdf, self._groupkeys, agg_columns)

groupkey_scols = [kdf._internal.spark_column_for(label) for label in groupkey_labels]

sdf = kdf._sdf
sdf = kdf._internal.spark_frame
tmp_col = verify_temp_column_name(sdf, "__row_number__")
window = Window.partitionBy(groupkey_scols).orderBy(NATURAL_ORDER_COLUMN_NAME)
sdf = (
Expand Down Expand Up @@ -1912,7 +1919,7 @@ def pandas_transform(pdf):
pdf = kdf.head(limit + 1)._to_internal_pandas()
pdf = pdf.groupby(groupkey_names).transform(func, *args, **kwargs)
kdf_from_pandas = DataFrame(pdf)
return_schema = kdf_from_pandas._sdf.drop(*HIDDEN_COLUMNS).schema
return_schema = kdf_from_pandas._internal.spark_frame.drop(*HIDDEN_COLUMNS).schema
if len(pdf) <= limit:
return kdf_from_pandas

Expand Down Expand Up @@ -2057,41 +2064,41 @@ def _reduce_for_stat_function(self, sfun, only_numeric, should_include_groupkeys
agg_columns = self._agg_columns
agg_columns_scols = self._agg_columns_scols

groupkey_cols = [
s.alias(SPARK_INDEX_NAME_FORMAT(i)) for i, s in enumerate(self._groupkeys_scols)
]
groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
groupkey_scols = [s.alias(name) for s, name in zip(self._groupkeys_scols, groupkey_names)]

sdf = self._kdf._sdf
sdf = self._kdf._internal.spark_frame.select(groupkey_scols + agg_columns_scols)

data_columns = []
column_labels = []
if len(agg_columns) > 0:
stat_exprs = []
for kser, c in zip(agg_columns, agg_columns_scols):
for kser in agg_columns:
spark_type = kser.spark.data_type
name = kser._internal.data_spark_column_names[0]
label = kser._internal.column_labels[0]
scol = scol_for(sdf, name)
# TODO: we should have a function that takes dataframes and converts the numeric
# types. Converting the NaNs is used in a few places, it should be in utils.
# Special handle floating point types because Spark's count treats nan as a valid
# value, whereas Pandas count doesn't include nan.
if isinstance(spark_type, DoubleType) or isinstance(spark_type, FloatType):
stat_exprs.append(sfun(F.nanvl(c, F.lit(None))).alias(name))
stat_exprs.append(sfun(F.nanvl(scol, F.lit(None))).alias(name))
data_columns.append(name)
column_labels.append(label)
elif isinstance(spark_type, NumericType) or not only_numeric:
stat_exprs.append(sfun(c).alias(name))
stat_exprs.append(sfun(scol).alias(name))
data_columns.append(name)
column_labels.append(label)
sdf = sdf.groupby(*groupkey_cols).agg(*stat_exprs)
sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
else:
sdf = sdf.select(*groupkey_cols).distinct()
sdf = sdf.select(*groupkey_names).distinct()

internal = InternalFrame(
spark_frame=sdf,
index_map=OrderedDict(
(SPARK_INDEX_NAME_FORMAT(i), s._internal.column_labels[0])
for i, s in enumerate(self._groupkeys)
(name, s._internal.column_labels[0])
for s, name in zip(self._groupkeys, groupkey_names)
),
column_labels=column_labels,
data_spark_columns=[scol_for(sdf, col) for col in data_columns],
Expand Down Expand Up @@ -2337,7 +2344,7 @@ def describe(self):
)

kdf = self.agg(["count", "mean", "std", "min", "quartiles", "max"]).reset_index()
sdf = kdf._sdf
sdf = kdf._internal.spark_frame
agg_cols = [col.name for col in self._agg_columns]
formatted_percentiles = ["25%", "50%", "75%"]

Expand Down Expand Up @@ -2482,12 +2489,14 @@ def nsmallest(self, n=5):
"""
if len(self._kdf._internal.index_names) > 1:
raise ValueError("nsmallest do not support multi-index now")
sdf = self._kdf._sdf

sdf = self._kdf._internal.spark_frame
name = self._agg_columns[0]._internal.data_spark_column_names[0]
window = Window.partitionBy(self._groupkeys_scols).orderBy(
scol_for(sdf, name), NATURAL_ORDER_COLUMN_NAME
self._agg_columns[0].spark.column, NATURAL_ORDER_COLUMN_NAME
)
sdf = sdf.withColumn("rank", F.row_number().over(window)).filter(F.col("rank") <= n)

internal = InternalFrame(
spark_frame=sdf.drop(NATURAL_ORDER_COLUMN_NAME),
index_map=OrderedDict(
Expand Down Expand Up @@ -2533,12 +2542,14 @@ def nlargest(self, n=5):
"""
if len(self._kdf._internal.index_names) > 1:
raise ValueError("nlargest do not support multi-index now")
sdf = self._kdf._sdf

sdf = self._kdf._internal.spark_frame
name = self._agg_columns[0]._internal.data_spark_column_names[0]
window = Window.partitionBy(self._groupkeys_scols).orderBy(
F.col(name).desc(), NATURAL_ORDER_COLUMN_NAME
self._agg_columns[0].spark.column.desc(), NATURAL_ORDER_COLUMN_NAME
)
sdf = sdf.withColumn("rank", F.row_number().over(window)).filter(F.col("rank") <= n)

internal = InternalFrame(
spark_frame=sdf.drop(NATURAL_ORDER_COLUMN_NAME),
index_map=OrderedDict(
Expand Down Expand Up @@ -2594,25 +2605,23 @@ def value_counts(self, sort=None, ascending=None, dropna=True):
Name: B, dtype: int64
"""
groupkeys = self._groupkeys + self._agg_columns
groupkey_cols = [
s.alias(SPARK_INDEX_NAME_FORMAT(i))
for i, s in enumerate(self._groupkeys_scols + self._agg_columns_scols)
]
sdf = self._kdf._sdf
groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(groupkeys))]
groupkey_cols = [s.spark.column.alias(name) for s, name in zip(groupkeys, groupkey_names)]

sdf = self._kdf._internal.spark_frame
agg_column = self._agg_columns[0]._internal.data_spark_column_names[0]
sdf = sdf.groupby(*groupkey_cols).count().withColumnRenamed("count", agg_column)

if sort:
if ascending:
sdf = sdf.orderBy(F.col(agg_column).asc())
sdf = sdf.orderBy(scol_for(sdf, agg_column).asc())
else:
sdf = sdf.orderBy(F.col(agg_column).desc())
sdf = sdf.orderBy(scol_for(sdf, agg_column).desc())

internal = InternalFrame(
spark_frame=sdf,
index_map=OrderedDict(
(SPARK_INDEX_NAME_FORMAT(i), s._internal.column_labels[0])
for i, s in enumerate(groupkeys)
(name, s._internal.column_labels[0]) for s, name in zip(groupkeys, groupkey_names)
),
data_spark_columns=[scol_for(sdf, agg_column)],
)
Expand Down
Loading