Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BigQuery Storage: Add support for arrow format in BQ Read API #8644

Merged
merged 10 commits into from
Jul 11, 2019
170 changes: 156 additions & 14 deletions bigquery_storage/google/cloud/bigquery_storage_v1beta1/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,27 @@
import pandas
except ImportError: # pragma: NO COVER
pandas = None
try:
import pyarrow
except ImportError: # pragma: NO COVER
pyarrow = None
import six

try:
import pyarrow
except ImportError: # pragma: NO COVER
pyarrow = None

from google.cloud.bigquery_storage_v1beta1 import types


_STREAM_RESUMPTION_EXCEPTIONS = (google.api_core.exceptions.ServiceUnavailable,)
_FASTAVRO_REQUIRED = (
"fastavro is required to parse ReadRowResponse messages with Avro bytes."
)

_AVRO_BYTES_OPERATION = "parse ReadRowResponse messages with Avro bytes"
_ARROW_BYTES_OPERATION = "parse ReadRowResponse messages with Arrow bytes"
_FASTAVRO_REQUIRED = "fastavro is required to {operation}."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume these parameterized errors are for when you do things like to_arrow with avro bytes?

Copy link
Contributor Author

@tswast tswast Jul 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is vestigial from when I was planning to implement to_arrow for Avro streams. Removed for now.

_PANDAS_REQUIRED = "pandas is required to create a DataFrame"
_PYARROW_REQUIRED = "pyarrow is required to {operation}."


class ReadRowsStream(object):
Expand Down Expand Up @@ -113,7 +124,7 @@ def __iter__(self):
while True:
try:
for message in self._wrapped:
rowcount = message.avro_rows.row_count
rowcount = message.row_count
self._position.offset += rowcount
yield message

Expand Down Expand Up @@ -152,11 +163,28 @@ def rows(self, read_session):
Iterable[Mapping]:
A sequence of rows, represented as dictionaries.
"""
if fastavro is None:
raise ImportError(_FASTAVRO_REQUIRED)

return ReadRowsIterable(self, read_session)

def to_arrow(self, read_session):
"""Create a :class:`pyarrow.Table` of all rows in the stream.

This method requires the pyarrow library and a stream using the Arrow
format.

Args:
read_session ( \
~google.cloud.bigquery_storage_v1beta1.types.ReadSession \
):
The read session associated with this read rows stream. This
contains the schema, which is required to parse the data
messages.

Returns:
pyarrow.Table:
A table of all rows in the stream.
"""
return self.rows(read_session).to_arrow()

def to_dataframe(self, read_session, dtypes=None):
"""Create a :class:`pandas.DataFrame` of all rows in the stream.

Expand Down Expand Up @@ -186,8 +214,6 @@ def to_dataframe(self, read_session, dtypes=None):
pandas.DataFrame:
A data frame of all rows in the stream.
"""
if fastavro is None:
raise ImportError(_FASTAVRO_REQUIRED)
if pandas is None:
raise ImportError(_PANDAS_REQUIRED)

Expand All @@ -212,6 +238,7 @@ def __init__(self, reader, read_session):
self._status = None
self._reader = reader
self._read_session = read_session
self._stream_parser = _StreamParser.from_read_session(self._read_session)

@property
def total_rows(self):
Expand All @@ -231,17 +258,31 @@ def pages(self):
"""
# Each page is an iterator of rows. But also has num_items, remaining,
# and to_dataframe.
stream_parser = _StreamParser(self._read_session)
for message in self._reader:
self._status = message.status
yield ReadRowsPage(stream_parser, message)
yield ReadRowsPage(self._stream_parser, message)

def __iter__(self):
"""Iterator for each row in all pages."""
for page in self.pages:
for row in page:
yield row

def to_arrow(self):
"""Create a :class:`pyarrow.Table` of all rows in the stream.

This method requires the pyarrow library and a stream using the Arrow
format.

Returns:
pyarrow.Table:
A table of all rows in the stream.
"""
record_batches = []
for page in self.pages:
record_batches.append(page.to_arrow())
return pyarrow.Table.from_batches(record_batches)

def to_dataframe(self, dtypes=None):
"""Create a :class:`pandas.DataFrame` of all rows in the stream.

Expand Down Expand Up @@ -291,8 +332,8 @@ def __init__(self, stream_parser, message):
self._stream_parser = stream_parser
self._message = message
self._iter_rows = None
self._num_items = self._message.avro_rows.row_count
self._remaining = self._message.avro_rows.row_count
self._num_items = self._message.row_count
self._remaining = self._message.row_count

def _parse_rows(self):
"""Parse rows from the message only once."""
Expand Down Expand Up @@ -326,6 +367,15 @@ def next(self):
# Alias needed for Python 2/3 support.
__next__ = next

def to_arrow(self):
"""Create an :class:`pyarrow.RecordBatch` of rows in the page.

Returns:
pyarrow.RecordBatch:
Rows from the message, as an Arrow record batch.
"""
return self._stream_parser.to_arrow(self._message)

def to_dataframe(self, dtypes=None):
"""Create a :class:`pandas.DataFrame` of rows in the page.

Expand Down Expand Up @@ -355,21 +405,61 @@ def to_dataframe(self, dtypes=None):


class _StreamParser(object):
def to_arrow(self, message):
raise NotImplementedError("Not implemented.")

def to_dataframe(self, message, dtypes=None):
raise NotImplementedError("Not implemented.")

def to_rows(self, message):
raise NotImplementedError("Not implemented.")

@staticmethod
def from_read_session(read_session):
schema_type = read_session.WhichOneof("schema")
if schema_type == "avro_schema":
return _AvroStreamParser(read_session)
elif schema_type == "arrow_schema":
return _ArrowStreamParser(read_session)
else:
raise TypeError(
"Unsupported schema type in read_session: {0}".format(schema_type)
)


class _AvroStreamParser(_StreamParser):
"""Helper to parse Avro messages into useful representations."""

def __init__(self, read_session):
"""Construct a _StreamParser.
"""Construct an _AvroStreamParser.

Args:
read_session (google.cloud.bigquery_storage_v1beta1.types.ReadSession):
A read session. This is required because it contains the schema
used in the stream messages.
"""
if fastavro is None:
raise ImportError(_FASTAVRO_REQUIRED)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be parameterized as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the {operation} from error message since I didn't actually need it.


self._read_session = read_session
self._avro_schema_json = None
self._fastavro_schema = None
self._column_names = None

def to_arrow(self, message):
"""Create an :class:`pyarrow.RecordBatch` of rows in the page.

Args:
message (google.cloud.bigquery_storage_v1beta1.types.ReadRowsResponse):
Protocol buffer from the read rows stream, to convert into an
Arrow record batch.

Returns:
pyarrow.RecordBatch:
Rows from the message, as an Arrow record batch.
"""
raise NotImplementedError("to_arrow not implemented for Avro streams.")

def to_dataframe(self, message, dtypes=None):
"""Create a :class:`pandas.DataFrame` of rows in the page.

Expand Down Expand Up @@ -447,6 +537,58 @@ def to_rows(self, message):
break # Finished with message


class _ArrowStreamParser(_StreamParser):
def __init__(self, read_session):
if pyarrow is None:
raise ImportError(
_PYARROW_REQUIRED.format(operation=_ARROW_BYTES_OPERATION)
)

self._read_session = read_session
self._schema = None

def to_arrow(self, message):
return self._parse_arrow_message(message)

def to_rows(self, message):
record_batch = self._parse_arrow_message(message)

# Iterate through each column simultaneously, and make a dict from the
# row values
for row in zip(*record_batch.columns):
yield dict(zip(self._column_names, row))

def to_dataframe(self, message, dtypes=None):
record_batch = self._parse_arrow_message(message)

if dtypes is None:
dtypes = {}

df = record_batch.to_pandas()

for column in dtypes:
df[column] = pandas.Series(df[column], dtype=dtypes[column])

return df

def _parse_arrow_message(self, message):
self._parse_arrow_schema()

return pyarrow.read_record_batch(
pyarrow.py_buffer(message.arrow_record_batch.serialized_record_batch),
self._schema,
)

def _parse_arrow_schema(self):
if self._schema:
return

self._schema = pyarrow.read_schema(
pyarrow.py_buffer(self._read_session.arrow_schema.serialized_schema)
)
self._column_names = [field.name for field in self._schema]


def _copy_stream_position(position):
"""Copy a StreamPosition.

Expand Down
4 changes: 2 additions & 2 deletions bigquery_storage/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def default(session):
session.install('mock', 'pytest', 'pytest-cov')
for local_dep in LOCAL_DEPS:
session.install('-e', local_dep)
session.install('-e', '.[pandas,fastavro]')
session.install('-e', '.[pandas,fastavro,pyarrow]')

# Run py.test against the unit tests.
session.run(
Expand Down Expand Up @@ -121,7 +121,7 @@ def system(session):
session.install('-e', os.path.join('..', 'test_utils'))
for local_dep in LOCAL_DEPS:
session.install('-e', local_dep)
session.install('-e', '.[pandas,fastavro]')
session.install('-e', '.[fastavro,pandas,pyarrow]')

# Run py.test against the system tests.
session.run('py.test', '--quiet', 'tests/system/')
Expand Down
1 change: 1 addition & 0 deletions bigquery_storage/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
extras = {
'pandas': 'pandas>=0.17.1',
'fastavro': 'fastavro>=0.21.2',
'pyarrow': 'pyarrow>=0.13.0',
}

package_root = os.path.abspath(os.path.dirname(__file__))
Expand Down
71 changes: 70 additions & 1 deletion bigquery_storage/tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os

import numpy
import pyarrow.types
import pytest

from google.cloud import bigquery_storage_v1beta1
Expand Down Expand Up @@ -67,14 +68,82 @@ def test_read_rows_full_table(client, project_id, small_table_reference):
assert len(block.avro_rows.serialized_binary_rows) > 0


def test_read_rows_to_dataframe(client, project_id):
def test_read_rows_to_arrow(client, project_id):
table_ref = bigquery_storage_v1beta1.types.TableReference()
table_ref.project_id = "bigquery-public-data"
table_ref.dataset_id = "new_york_citibike"
table_ref.table_id = "citibike_stations"

read_options = bigquery_storage_v1beta1.types.TableReadOptions()
read_options.selected_fields.append("station_id")
read_options.selected_fields.append("latitude")
read_options.selected_fields.append("longitude")
read_options.selected_fields.append("name")
session = client.create_read_session(
table_ref,
"projects/{}".format(project_id),
format_=bigquery_storage_v1beta1.enums.DataFormat.ARROW,
read_options=read_options,
requested_streams=1,
)
stream_pos = bigquery_storage_v1beta1.types.StreamPosition(
stream=session.streams[0]
)

tbl = client.read_rows(stream_pos).to_arrow(session)

assert tbl.num_columns == 4
schema = tbl.schema
# Use field_by_name because the order doesn't currently match that of
# selected_fields.
assert pyarrow.types.is_int64(schema.field_by_name("station_id").type)
assert pyarrow.types.is_float64(schema.field_by_name("latitude").type)
assert pyarrow.types.is_float64(schema.field_by_name("longitude").type)
assert pyarrow.types.is_string(schema.field_by_name("name").type)


def test_read_rows_to_dataframe_w_avro(client, project_id):
table_ref = bigquery_storage_v1beta1.types.TableReference()
table_ref.project_id = "bigquery-public-data"
table_ref.dataset_id = "new_york_citibike"
table_ref.table_id = "citibike_stations"
session = client.create_read_session(
table_ref, "projects/{}".format(project_id), requested_streams=1
)
schema_type = session.WhichOneof("schema")
assert schema_type == "avro_schema"

stream_pos = bigquery_storage_v1beta1.types.StreamPosition(
stream=session.streams[0]
)

frame = client.read_rows(stream_pos).to_dataframe(
session, dtypes={"latitude": numpy.float16}
)

# Station ID is a required field (no nulls), so the datatype should always
# be integer.
assert frame.station_id.dtype.name == "int64"
assert frame.latitude.dtype.name == "float16"
assert frame.longitude.dtype.name == "float64"
assert frame["name"].str.startswith("Central Park").any()


def test_read_rows_to_dataframe_w_arrow(client, project_id):
table_ref = bigquery_storage_v1beta1.types.TableReference()
table_ref.project_id = "bigquery-public-data"
table_ref.dataset_id = "new_york_citibike"
table_ref.table_id = "citibike_stations"

session = client.create_read_session(
table_ref,
"projects/{}".format(project_id),
format_=bigquery_storage_v1beta1.enums.DataFormat.ARROW,
requested_streams=1,
)
schema_type = session.WhichOneof("schema")
assert schema_type == "arrow_schema"

stream_pos = bigquery_storage_v1beta1.types.StreamPosition(
stream=session.streams[0]
)
Expand Down
Loading