Skip to content

Commit

Permalink
Add to_arrow with support for Arrow data format. (#8644)
Browse files Browse the repository at this point in the history
* BQ Storage: Add basic arrow stream parser

* BQ Storage: Add tests for to_dataframe with arrow data

* Add to_arrow with BQ Storage API.
  • Loading branch information
tswast authored Jul 11, 2019
1 parent aba3216 commit c5a7cd2
Show file tree
Hide file tree
Showing 5 changed files with 446 additions and 30 deletions.
164 changes: 153 additions & 11 deletions bigquery_storage/google/cloud/bigquery_storage_v1beta1/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,29 @@
import pandas
except ImportError: # pragma: NO COVER
pandas = None
try:
import pyarrow
except ImportError: # pragma: NO COVER
pyarrow = None
import six

try:
import pyarrow
except ImportError: # pragma: NO COVER
pyarrow = None

from google.cloud.bigquery_storage_v1beta1 import types


_STREAM_RESUMPTION_EXCEPTIONS = (google.api_core.exceptions.ServiceUnavailable,)

_FASTAVRO_REQUIRED = (
"fastavro is required to parse ReadRowResponse messages with Avro bytes."
)
_PANDAS_REQUIRED = "pandas is required to create a DataFrame"
_PYARROW_REQUIRED = (
"pyarrow is required to parse ReadRowResponse messages with Arrow bytes."
)


class ReadRowsStream(object):
Expand Down Expand Up @@ -113,7 +126,7 @@ def __iter__(self):
while True:
try:
for message in self._wrapped:
rowcount = message.avro_rows.row_count
rowcount = message.row_count
self._position.offset += rowcount
yield message

Expand Down Expand Up @@ -152,11 +165,28 @@ def rows(self, read_session):
Iterable[Mapping]:
A sequence of rows, represented as dictionaries.
"""
if fastavro is None:
raise ImportError(_FASTAVRO_REQUIRED)

return ReadRowsIterable(self, read_session)

def to_arrow(self, read_session):
"""Create a :class:`pyarrow.Table` of all rows in the stream.
This method requires the pyarrow library and a stream using the Arrow
format.
Args:
read_session ( \
~google.cloud.bigquery_storage_v1beta1.types.ReadSession \
):
The read session associated with this read rows stream. This
contains the schema, which is required to parse the data
messages.
Returns:
pyarrow.Table:
A table of all rows in the stream.
"""
return self.rows(read_session).to_arrow()

def to_dataframe(self, read_session, dtypes=None):
"""Create a :class:`pandas.DataFrame` of all rows in the stream.
Expand Down Expand Up @@ -186,8 +216,6 @@ def to_dataframe(self, read_session, dtypes=None):
pandas.DataFrame:
A data frame of all rows in the stream.
"""
if fastavro is None:
raise ImportError(_FASTAVRO_REQUIRED)
if pandas is None:
raise ImportError(_PANDAS_REQUIRED)

Expand All @@ -212,6 +240,7 @@ def __init__(self, reader, read_session):
self._status = None
self._reader = reader
self._read_session = read_session
self._stream_parser = _StreamParser.from_read_session(self._read_session)

@property
def total_rows(self):
Expand All @@ -231,17 +260,31 @@ def pages(self):
"""
# Each page is an iterator of rows. But also has num_items, remaining,
# and to_dataframe.
stream_parser = _StreamParser(self._read_session)
for message in self._reader:
self._status = message.status
yield ReadRowsPage(stream_parser, message)
yield ReadRowsPage(self._stream_parser, message)

def __iter__(self):
"""Iterator for each row in all pages."""
for page in self.pages:
for row in page:
yield row

def to_arrow(self):
"""Create a :class:`pyarrow.Table` of all rows in the stream.
This method requires the pyarrow library and a stream using the Arrow
format.
Returns:
pyarrow.Table:
A table of all rows in the stream.
"""
record_batches = []
for page in self.pages:
record_batches.append(page.to_arrow())
return pyarrow.Table.from_batches(record_batches)

def to_dataframe(self, dtypes=None):
"""Create a :class:`pandas.DataFrame` of all rows in the stream.
Expand Down Expand Up @@ -291,8 +334,8 @@ def __init__(self, stream_parser, message):
self._stream_parser = stream_parser
self._message = message
self._iter_rows = None
self._num_items = self._message.avro_rows.row_count
self._remaining = self._message.avro_rows.row_count
self._num_items = self._message.row_count
self._remaining = self._message.row_count

def _parse_rows(self):
"""Parse rows from the message only once."""
Expand Down Expand Up @@ -326,6 +369,15 @@ def next(self):
# Alias needed for Python 2/3 support.
__next__ = next

def to_arrow(self):
"""Create an :class:`pyarrow.RecordBatch` of rows in the page.
Returns:
pyarrow.RecordBatch:
Rows from the message, as an Arrow record batch.
"""
return self._stream_parser.to_arrow(self._message)

def to_dataframe(self, dtypes=None):
"""Create a :class:`pandas.DataFrame` of rows in the page.
Expand Down Expand Up @@ -355,21 +407,61 @@ def to_dataframe(self, dtypes=None):


class _StreamParser(object):
def to_arrow(self, message):
raise NotImplementedError("Not implemented.")

def to_dataframe(self, message, dtypes=None):
raise NotImplementedError("Not implemented.")

def to_rows(self, message):
raise NotImplementedError("Not implemented.")

@staticmethod
def from_read_session(read_session):
schema_type = read_session.WhichOneof("schema")
if schema_type == "avro_schema":
return _AvroStreamParser(read_session)
elif schema_type == "arrow_schema":
return _ArrowStreamParser(read_session)
else:
raise TypeError(
"Unsupported schema type in read_session: {0}".format(schema_type)
)


class _AvroStreamParser(_StreamParser):
"""Helper to parse Avro messages into useful representations."""

def __init__(self, read_session):
"""Construct a _StreamParser.
"""Construct an _AvroStreamParser.
Args:
read_session (google.cloud.bigquery_storage_v1beta1.types.ReadSession):
A read session. This is required because it contains the schema
used in the stream messages.
"""
if fastavro is None:
raise ImportError(_FASTAVRO_REQUIRED)

self._read_session = read_session
self._avro_schema_json = None
self._fastavro_schema = None
self._column_names = None

def to_arrow(self, message):
"""Create an :class:`pyarrow.RecordBatch` of rows in the page.
Args:
message (google.cloud.bigquery_storage_v1beta1.types.ReadRowsResponse):
Protocol buffer from the read rows stream, to convert into an
Arrow record batch.
Returns:
pyarrow.RecordBatch:
Rows from the message, as an Arrow record batch.
"""
raise NotImplementedError("to_arrow not implemented for Avro streams.")

def to_dataframe(self, message, dtypes=None):
"""Create a :class:`pandas.DataFrame` of rows in the page.
Expand Down Expand Up @@ -447,6 +539,56 @@ def to_rows(self, message):
break # Finished with message


class _ArrowStreamParser(_StreamParser):
def __init__(self, read_session):
if pyarrow is None:
raise ImportError(_PYARROW_REQUIRED)

self._read_session = read_session
self._schema = None

def to_arrow(self, message):
return self._parse_arrow_message(message)

def to_rows(self, message):
record_batch = self._parse_arrow_message(message)

# Iterate through each column simultaneously, and make a dict from the
# row values
for row in zip(*record_batch.columns):
yield dict(zip(self._column_names, row))

def to_dataframe(self, message, dtypes=None):
record_batch = self._parse_arrow_message(message)

if dtypes is None:
dtypes = {}

df = record_batch.to_pandas()

for column in dtypes:
df[column] = pandas.Series(df[column], dtype=dtypes[column])

return df

def _parse_arrow_message(self, message):
self._parse_arrow_schema()

return pyarrow.read_record_batch(
pyarrow.py_buffer(message.arrow_record_batch.serialized_record_batch),
self._schema,
)

def _parse_arrow_schema(self):
if self._schema:
return

self._schema = pyarrow.read_schema(
pyarrow.py_buffer(self._read_session.arrow_schema.serialized_schema)
)
self._column_names = [field.name for field in self._schema]


def _copy_stream_position(position):
"""Copy a StreamPosition.
Expand Down
4 changes: 2 additions & 2 deletions bigquery_storage/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def default(session):
session.install('mock', 'pytest', 'pytest-cov')
for local_dep in LOCAL_DEPS:
session.install('-e', local_dep)
session.install('-e', '.[pandas,fastavro]')
session.install('-e', '.[pandas,fastavro,pyarrow]')

# Run py.test against the unit tests.
session.run(
Expand Down Expand Up @@ -121,7 +121,7 @@ def system(session):
session.install('-e', os.path.join('..', 'test_utils'))
for local_dep in LOCAL_DEPS:
session.install('-e', local_dep)
session.install('-e', '.[pandas,fastavro]')
session.install('-e', '.[fastavro,pandas,pyarrow]')

# Run py.test against the system tests.
session.run('py.test', '--quiet', 'tests/system/')
Expand Down
1 change: 1 addition & 0 deletions bigquery_storage/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
extras = {
'pandas': 'pandas>=0.17.1',
'fastavro': 'fastavro>=0.21.2',
'pyarrow': 'pyarrow>=0.13.0',
}

package_root = os.path.abspath(os.path.dirname(__file__))
Expand Down
71 changes: 70 additions & 1 deletion bigquery_storage/tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os

import numpy
import pyarrow.types
import pytest

from google.cloud import bigquery_storage_v1beta1
Expand Down Expand Up @@ -67,14 +68,82 @@ def test_read_rows_full_table(client, project_id, small_table_reference):
assert len(block.avro_rows.serialized_binary_rows) > 0


def test_read_rows_to_dataframe(client, project_id):
def test_read_rows_to_arrow(client, project_id):
table_ref = bigquery_storage_v1beta1.types.TableReference()
table_ref.project_id = "bigquery-public-data"
table_ref.dataset_id = "new_york_citibike"
table_ref.table_id = "citibike_stations"

read_options = bigquery_storage_v1beta1.types.TableReadOptions()
read_options.selected_fields.append("station_id")
read_options.selected_fields.append("latitude")
read_options.selected_fields.append("longitude")
read_options.selected_fields.append("name")
session = client.create_read_session(
table_ref,
"projects/{}".format(project_id),
format_=bigquery_storage_v1beta1.enums.DataFormat.ARROW,
read_options=read_options,
requested_streams=1,
)
stream_pos = bigquery_storage_v1beta1.types.StreamPosition(
stream=session.streams[0]
)

tbl = client.read_rows(stream_pos).to_arrow(session)

assert tbl.num_columns == 4
schema = tbl.schema
# Use field_by_name because the order doesn't currently match that of
# selected_fields.
assert pyarrow.types.is_int64(schema.field_by_name("station_id").type)
assert pyarrow.types.is_float64(schema.field_by_name("latitude").type)
assert pyarrow.types.is_float64(schema.field_by_name("longitude").type)
assert pyarrow.types.is_string(schema.field_by_name("name").type)


def test_read_rows_to_dataframe_w_avro(client, project_id):
table_ref = bigquery_storage_v1beta1.types.TableReference()
table_ref.project_id = "bigquery-public-data"
table_ref.dataset_id = "new_york_citibike"
table_ref.table_id = "citibike_stations"
session = client.create_read_session(
table_ref, "projects/{}".format(project_id), requested_streams=1
)
schema_type = session.WhichOneof("schema")
assert schema_type == "avro_schema"

stream_pos = bigquery_storage_v1beta1.types.StreamPosition(
stream=session.streams[0]
)

frame = client.read_rows(stream_pos).to_dataframe(
session, dtypes={"latitude": numpy.float16}
)

# Station ID is a required field (no nulls), so the datatype should always
# be integer.
assert frame.station_id.dtype.name == "int64"
assert frame.latitude.dtype.name == "float16"
assert frame.longitude.dtype.name == "float64"
assert frame["name"].str.startswith("Central Park").any()


def test_read_rows_to_dataframe_w_arrow(client, project_id):
table_ref = bigquery_storage_v1beta1.types.TableReference()
table_ref.project_id = "bigquery-public-data"
table_ref.dataset_id = "new_york_citibike"
table_ref.table_id = "citibike_stations"

session = client.create_read_session(
table_ref,
"projects/{}".format(project_id),
format_=bigquery_storage_v1beta1.enums.DataFormat.ARROW,
requested_streams=1,
)
schema_type = session.WhichOneof("schema")
assert schema_type == "arrow_schema"

stream_pos = bigquery_storage_v1beta1.types.StreamPosition(
stream=session.streams[0]
)
Expand Down
Loading

0 comments on commit c5a7cd2

Please sign in to comment.