Skip to content

Commit

Permalink
[SPARK-30434][PYTHON][SQL] Move pandas related functionalities into '…
Browse files Browse the repository at this point in the history
…pandas' sub-package

### What changes were proposed in this pull request?

This PR proposes to move pandas related functionalities into pandas package. Namely:

```bash
pyspark/sql/pandas
├── __init__.py
├── conversion.py  # Conversion between pandas <> PySpark DataFrames
├── functions.py   # pandas_udf
├── group_ops.py   # Grouped UDF / Cogrouped UDF + groupby.apply, groupby.cogroup.apply
├── map_ops.py     # Map Iter UDF + mapInPandas
├── serializers.py # pandas <> PyArrow serializers
├── types.py       # Type utils between pandas <> PyArrow
└── utils.py       # Version requirement checks
```

In order to separately locate `groupby.apply`, `groupby.cogroup.apply`, `mapInPandas`, `toPandas`, and `createDataFrame(pdf)` under `pandas` sub-package, I had to use a mix-in approach which Scala side uses often by `trait`, and also pandas itself uses this approach (see `IndexOpsMixin` as an example) to group related functionalities. Currently, you can think it's like Scala's self typed trait. See the structure below:

```python
class PandasMapOpsMixin(object):
    def mapInPandas(self, ...):
        ...
        return ...

    # other Pandas <> PySpark APIs
```

```python
class DataFrame(PandasMapOpsMixin):

    # other DataFrame APIs equivalent to Scala side.

```

Yes, This is a big PR but they are mostly just moving around except one case `createDataFrame` which I had to split the methods.

### Why are the changes needed?

There are pandas functionalities here and there and I myself gets lost where it was. Also, when you have to make a change commonly for all of pandas related features, it's almost impossible now.

Also, after this change, `DataFrame` and `SparkSession` become more consistent with Scala side since pandas is specific to Python, and this change separates pandas-specific APIs away from `DataFrame` or `SparkSession`.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Existing tests should cover. Also, I manually built the PySpark API documentation and checked.

Closes #27109 from HyukjinKwon/pandas-refactoring.

Authored-by: HyukjinKwon <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
  • Loading branch information
HyukjinKwon committed Jan 9, 2020
1 parent 18daa37 commit ee8d661
Show file tree
Hide file tree
Showing 25 changed files with 1,840 additions and 1,514 deletions.
7 changes: 7 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,13 @@ def __hash__(self):
"pyspark.sql.udf",
"pyspark.sql.window",
"pyspark.sql.avro.functions",
"pyspark.sql.pandas.conversion",
"pyspark.sql.pandas.map_ops",
"pyspark.sql.pandas.functions",
"pyspark.sql.pandas.group_ops",
"pyspark.sql.pandas.types",
"pyspark.sql.pandas.serializers",
"pyspark.sql.pandas.utils",
# unittests
"pyspark.sql.tests.test_arrow",
"pyspark.sql.tests.test_catalog",
Expand Down
2 changes: 1 addition & 1 deletion examples/src/main/python/sql/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from __future__ import print_function

from pyspark.sql import SparkSession
from pyspark.sql.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version

require_minimum_pandas_version()
require_minimum_pyarrow_version()
Expand Down
1 change: 1 addition & 0 deletions python/docs/pyspark.sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Module Context
.. automodule:: pyspark.sql
:members:
:undoc-members:
:inherited-members:
:exclude-members: builder
.. We need `exclude-members` to prevent default description generations
as a workaround for old Sphinx (< 1.6.6).
Expand Down
242 changes: 0 additions & 242 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,248 +185,6 @@ def loads(self, obj):
raise NotImplementedError


class ArrowCollectSerializer(Serializer):
"""
Deserialize a stream of batches followed by batch order information. Used in
DataFrame._collectAsArrow() after invoking Dataset.collectAsArrowToPython() in the JVM.
"""

def __init__(self):
self.serializer = ArrowStreamSerializer()

def dump_stream(self, iterator, stream):
return self.serializer.dump_stream(iterator, stream)

def load_stream(self, stream):
"""
Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields
a list of indices that can be used to put the RecordBatches in the correct order.
"""
# load the batches
for batch in self.serializer.load_stream(stream):
yield batch

# load the batch order indices or propagate any error that occurred in the JVM
num = read_int(stream)
if num == -1:
error_msg = UTF8Deserializer().loads(stream)
raise RuntimeError("An error occurred while calling "
"ArrowCollectSerializer.load_stream: {}".format(error_msg))
batch_order = []
for i in xrange(num):
index = read_int(stream)
batch_order.append(index)
yield batch_order

def __repr__(self):
return "ArrowCollectSerializer(%s)" % self.serializer


class ArrowStreamSerializer(Serializer):
"""
Serializes Arrow record batches as a stream.
"""

def dump_stream(self, iterator, stream):
import pyarrow as pa
writer = None
try:
for batch in iterator:
if writer is None:
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
finally:
if writer is not None:
writer.close()

def load_stream(self, stream):
import pyarrow as pa
reader = pa.ipc.open_stream(stream)
for batch in reader:
yield batch

def __repr__(self):
return "ArrowStreamSerializer"


class ArrowStreamPandasSerializer(ArrowStreamSerializer):
"""
Serializes Pandas.Series as Arrow data with Arrow streaming format.
:param timezone: A timezone to respect when handling timestamp values
:param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation
:param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name
"""

def __init__(self, timezone, safecheck, assign_cols_by_name):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name

def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import _check_series_localize_timestamps

# If the given column is a date type column, creates a series of datetime.date directly
# instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
# datetime64[ns] type handling.
s = arrow_column.to_pandas(date_as_object=True)

s = _check_series_localize_timestamps(s, self._timezone)
return s

def _create_batch(self, series):
"""
Create an Arrow record batch from the given pandas.Series or list of Series,
with optional type.
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
:return: Arrow RecordBatch
"""
import pandas as pd
import pyarrow as pa
from pyspark.sql.types import _check_series_convert_timestamps_internal
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
(len(series) == 2 and isinstance(series[1], pa.DataType)):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)

def create_array(s, t):
mask = s.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s, self._timezone)
try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
except pa.ArrowException as e:
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
"Array (%s). It can be caused by overflows or other unsafe " + \
"conversions warned by Arrow. Arrow safe type check can be " + \
"disabled by using SQL config " + \
"`spark.sql.execution.pandas.arrowSafeTypeConversion`."
raise RuntimeError(error_msg % (s.dtype, t), e)
return array

arrs = []
for s, t in series:
if t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
raise ValueError("A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s)))

# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
if len(s) == 0 and len(s.columns) == 0:
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
# Assign result columns by schema name if user labeled with strings
elif self._assign_cols_by_name and any(isinstance(name, basestring)
for name in s.columns):
arrs_names = [(create_array(s[field.name], field.type), field.name)
for field in t]
# Assign result columns by position
else:
arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
for i, field in enumerate(t)]

struct_arrs, struct_names = zip(*arrs_names)
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(create_array(s, t))

return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])

def dump_stream(self, iterator, stream):
"""
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
a list of series accompanied by an optional pyarrow type to coerce the data to.
"""
batches = (self._create_batch(series) for series in iterator)
super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)

def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
import pyarrow as pa
for batch in batches:
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]

def __repr__(self):
return "ArrowStreamPandasSerializer"


class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
"""
Serializer used by Python worker to evaluate Pandas UDFs
"""

def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
super(ArrowStreamPandasUDFSerializer, self) \
.__init__(timezone, safecheck, assign_cols_by_name)
self._df_for_struct = df_for_struct

def arrow_to_pandas(self, arrow_column):
import pyarrow.types as types

if self._df_for_struct and types.is_struct(arrow_column.type):
import pandas as pd
series = [super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(column)
.rename(field.name)
for column, field in zip(arrow_column.flatten(), arrow_column.type)]
s = pd.concat(series, axis=1)
else:
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
return s

def dump_stream(self, iterator, stream):
"""
Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
This should be sent after creating the first record batch so in case of an error, it can
be sent back to the JVM before the Arrow stream starts.
"""

def init_stream_yield_batches():
should_write_start_length = True
for series in iterator:
batch = self._create_batch(series)
if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False
yield batch

return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)

def __repr__(self):
return "ArrowStreamPandasUDFSerializer"


class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):

def load_stream(self, stream):
"""
Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two
lists of pandas.Series.
"""
import pyarrow as pa
dataframes_in_group = None

while dataframes_in_group is None or dataframes_in_group > 0:
dataframes_in_group = read_int(stream)

if dataframes_in_group == 2:
batch1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
batch2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)]
yield (
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch1).itercolumns()],
[self.arrow_to_pandas(c) for c in pa.Table.from_batches(batch2).itercolumns()]
)

elif dataframes_in_group != 0:
raise ValueError(
'Invalid number of pandas.DataFrames in group {0}'.format(dataframes_in_group))


class BatchedSerializer(Serializer):

"""
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@
from pyspark.sql.group import GroupedData
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
from pyspark.sql.window import Window, WindowSpec
from pyspark.sql.cogroup import CoGroupedData
from pyspark.sql.pandas.group_ops import PandasCogroupedOps


__all__ = [
'SparkSession', 'SQLContext', 'UDFRegistration',
'DataFrame', 'GroupedData', 'Column', 'Catalog', 'Row',
'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
'DataFrameReader', 'DataFrameWriter', 'CoGroupedData'
'DataFrameReader', 'DataFrameWriter', 'PandasCogroupedOps'
]
Loading

0 comments on commit ee8d661

Please sign in to comment.