Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add to_arrow to get a pyarrow.Table from query results. #8609

Merged
merged 9 commits into from
Jul 10, 2019
Merged
15 changes: 9 additions & 6 deletions bigquery/google/cloud/bigquery/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ def _field_to_index_mapping(schema):
return {f.name: i for i, f in enumerate(schema)}


def _field_from_json(resource, field):
converter = _CELLDATA_FROM_JSON.get(field.field_type, lambda value, _: value)
if field.mode == "REPEATED":
return [converter(item["v"], field) for item in resource]
else:
return converter(resource, field)


def _row_tuple_from_json(row, schema):
"""Convert JSON row data to row with appropriate types.

Expand All @@ -214,12 +222,7 @@ def _row_tuple_from_json(row, schema):
"""
row_data = []
for field, cell in zip(schema, row["f"]):
converter = _CELLDATA_FROM_JSON[field.field_type]
if field.mode == "REPEATED":
row_data.append([converter(item["v"], field) for item in cell["v"]])
else:
row_data.append(converter(cell["v"], field))

row_data.append(_field_from_json(cell["v"], field))
return tuple(row_data)


Expand Down
62 changes: 53 additions & 9 deletions bigquery/google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

"""Shared helper functions for connecting BigQuery and pandas."""

import collections
import concurrent.futures
import warnings

Expand Down Expand Up @@ -115,7 +114,7 @@ def bq_to_arrow_data_type(field):
"""
if field.mode is not None and field.mode.upper() == "REPEATED":
inner_type = bq_to_arrow_data_type(
schema.SchemaField(field.name, field.field_type)
schema.SchemaField(field.name, field.field_type, fields=field.fields)
)
if inner_type:
return pyarrow.list_(inner_type)
Expand Down Expand Up @@ -144,6 +143,23 @@ def bq_to_arrow_field(bq_field):
return None


def bq_to_arrow_schema(bq_schema):
"""Return the Arrow schema, corresponding to a given BigQuery schema.

Raises:
ValueError:
If the Arrow type of any column cannot be determined.
"""
arrow_fields = []
for bq_field in bq_schema:
arrow_field = bq_to_arrow_field(bq_field)
if arrow_field is None:
# Auto-detect the schema if there is an unknown field type.
return None
arrow_fields.append(arrow_field)
return pyarrow.schema(arrow_fields)


def bq_to_arrow_array(series, bq_field):
arrow_type = bq_to_arrow_data_type(bq_field)
if bq_field.mode.upper() == "REPEATED":
Expand Down Expand Up @@ -210,13 +226,41 @@ def dataframe_to_parquet(dataframe, bq_schema, filepath):
pyarrow.parquet.write_table(arrow_table, filepath)


def _tabledata_list_page_to_arrow(page, column_names, arrow_types):
# Iterate over the page to force the API request to get the page data.
try:
next(iter(page))
except StopIteration:
pass

arrays = []
for column_index, arrow_type in enumerate(arrow_types):
arrays.append(pyarrow.array(page._columns[column_index], type=arrow_type))

return pyarrow.RecordBatch.from_arrays(arrays, column_names)


def download_arrow_tabledata_list(pages, schema):
"""Use tabledata.list to construct an iterable of RecordBatches."""
column_names = bq_to_arrow_schema(schema) or [field.name for field in schema]
arrow_types = [bq_to_arrow_data_type(field) for field in schema]

for page in pages:
yield _tabledata_list_page_to_arrow(page, column_names, arrow_types)


def _tabledata_list_page_to_dataframe(page, column_names, dtypes):
columns = collections.defaultdict(list)
for row in page:
for column in column_names:
columns[column].append(row[column])
for column in dtypes:
columns[column] = pandas.Series(columns[column], dtype=dtypes[column])
# Iterate over the page to force the API request to get the page data.
try:
next(iter(page))
except StopIteration:
pass

columns = {}
for column_index, column_name in enumerate(column_names):
dtype = dtypes.get(column_name)
columns[column_name] = pandas.Series(page._columns[column_index], dtype=dtype)

return pandas.DataFrame(columns, columns=column_names)


Expand Down Expand Up @@ -350,7 +394,7 @@ def download_dataframe_bqstorage(
continue

# Return any remaining values after the workers finished.
while not worker_queue.empty():
while not worker_queue.empty(): # pragma: NO COVER
try:
# Include a timeout because even though the queue is
# non-empty, it doesn't guarantee that a subsequent call to
Expand Down
38 changes: 38 additions & 0 deletions bigquery/google/cloud/bigquery/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2896,6 +2896,44 @@ def result(self, timeout=None, page_size=None, retry=DEFAULT_RETRY):
rows._preserve_order = _contains_order_by(self.query)
return rows

def to_arrow(self, progress_bar_type=None):
"""[Beta] Create a class:`pyarrow.Table` by loading all pages of a
table or query.

Args:
progress_bar_type (Optional[str]):
If set, use the `tqdm <https://tqdm.github.io/>`_ library to
display a progress bar while the data downloads. Install the
``tqdm`` package to use this feature.

Possible values of ``progress_bar_type`` include:

``None``
No progress bar.
``'tqdm'``
Use the :func:`tqdm.tqdm` function to print a progress bar
to :data:`sys.stderr`.
``'tqdm_notebook'``
Use the :func:`tqdm.tqdm_notebook` function to display a
progress bar as a Jupyter notebook widget.
``'tqdm_gui'``
Use the :func:`tqdm.tqdm_gui` function to display a
progress bar as a graphical dialog box.

Returns:
pyarrow.Table
A :class:`pyarrow.Table` populated with row data and column
headers from the query results. The column headers are derived
from the destination table's schema.

Raises:
ValueError:
If the :mod:`pyarrow` library cannot be imported.

..versionadded:: 1.17.0
"""
return self.result().to_arrow(progress_bar_type=progress_bar_type)

def to_dataframe(self, bqstorage_client=None, dtypes=None, progress_bar_type=None):
"""Return a pandas DataFrame from a QueryJob

Expand Down
98 changes: 98 additions & 0 deletions bigquery/google/cloud/bigquery/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
except ImportError: # pragma: NO COVER
pandas = None

try:
import pyarrow
except ImportError: # pragma: NO COVER
pyarrow = None

try:
import tqdm
except ImportError: # pragma: NO COVER
Expand All @@ -58,6 +63,10 @@
"The pandas library is not installed, please install "
"pandas to use the to_dataframe() function."
)
_NO_PYARROW_ERROR = (
"The pyarrow library is not installed, please install "
"pandas to use the to_arrow() function."
)
_NO_TQDM_ERROR = (
"A progress bar was requested, but there was an error loading the tqdm "
"library. Please install tqdm to use the progress bar functionality."
Expand Down Expand Up @@ -1394,6 +1403,72 @@ def _get_progress_bar(self, progress_bar_type):
warnings.warn(_NO_TQDM_ERROR, UserWarning, stacklevel=3)
return None

def _to_arrow_iterable(self):
"""Create an iterable of arrow RecordBatches, to process the table as a stream."""
for record_batch in _pandas_helpers.download_arrow_tabledata_list(
iter(self.pages), self.schema
):
yield record_batch

def to_arrow(self, progress_bar_type=None):
"""[Beta] Create a class:`pyarrow.Table` by loading all pages of a
table or query.

Args:
progress_bar_type (Optional[str]):
If set, use the `tqdm <https://tqdm.github.io/>`_ library to
display a progress bar while the data downloads. Install the
``tqdm`` package to use this feature.

Possible values of ``progress_bar_type`` include:

``None``
No progress bar.
``'tqdm'``
Use the :func:`tqdm.tqdm` function to print a progress bar
to :data:`sys.stderr`.
``'tqdm_notebook'``
Use the :func:`tqdm.tqdm_notebook` function to display a
progress bar as a Jupyter notebook widget.
``'tqdm_gui'``
Use the :func:`tqdm.tqdm_gui` function to display a
progress bar as a graphical dialog box.

Returns:
pyarrow.Table
A :class:`pyarrow.Table` populated with row data and column
headers from the query results. The column headers are derived
from the destination table's schema.

Raises:
ValueError:
If the :mod:`pyarrow` library cannot be imported.

..versionadded:: 1.17.0
"""
if pyarrow is None:
raise ValueError(_NO_PYARROW_ERROR)

progress_bar = self._get_progress_bar(progress_bar_type)

record_batches = []
for record_batch in self._to_arrow_iterable():
record_batches.append(record_batch)

if progress_bar is not None:
# In some cases, the number of total rows is not populated
# until the first page of rows is fetched. Update the
# progress bar's total to keep an accurate count.
progress_bar.total = progress_bar.total or self.total_rows
progress_bar.update(record_batch.num_rows)

if progress_bar is not None:
# Indicate that the download has finished.
progress_bar.close()

arrow_schema = _pandas_helpers.bq_to_arrow_schema(self._schema)
return pyarrow.Table.from_batches(record_batches, schema=arrow_schema)

def _to_dataframe_iterable(self, bqstorage_client=None, dtypes=None):
"""Create an iterable of pandas DataFrames, to process the table as a stream.

Expand Down Expand Up @@ -1734,6 +1809,25 @@ def _item_to_row(iterator, resource):
)


def _tabledata_list_page_columns(schema, response):
"""Make a generator of all the columns in a page from tabledata.list.

This enables creating a :class:`pandas.DataFrame` and other
column-oriented data structures such as :class:`pyarrow.RecordBatch`
"""
columns = []
rows = response.get("rows", [])

def get_column_data(field_index, field):
for row in rows:
yield _helpers._field_from_json(row["f"][field_index]["v"], field)

for field_index, field in enumerate(schema):
columns.append(get_column_data(field_index, field))

return columns


# pylint: disable=unused-argument
def _rows_page_start(iterator, page, response):
"""Grab total rows when :class:`~google.cloud.iterator.Page` starts.
Expand All @@ -1747,6 +1841,10 @@ def _rows_page_start(iterator, page, response):
:type response: dict
:param response: The JSON API response for a page of rows in a table.
"""
# Make a (lazy) copy of the page in column-oriented format for use in data
# science packages.
page._columns = _tabledata_list_page_columns(iterator._schema, response)

total_rows = response.get("totalRows")
if total_rows is not None:
total_rows = int(total_rows)
Expand Down
58 changes: 58 additions & 0 deletions bigquery/samples/query_to_arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def main(client):
# [START bigquery_query_to_arrow]
# TODO(developer): Import the client library.
# from google.cloud import bigquery

# TODO(developer): Construct a BigQuery client object.
# client = bigquery.Client()

sql = """
WITH races AS (
SELECT "800M" AS race,
[STRUCT("Rudisha" as name, [23.4, 26.3, 26.4, 26.1] as splits),
STRUCT("Makhloufi" as name, [24.5, 25.4, 26.6, 26.1] as splits),
STRUCT("Murphy" as name, [23.9, 26.0, 27.0, 26.0] as splits),
STRUCT("Bosse" as name, [23.6, 26.2, 26.5, 27.1] as splits),
STRUCT("Rotich" as name, [24.7, 25.6, 26.9, 26.4] as splits),
STRUCT("Lewandowski" as name, [25.0, 25.7, 26.3, 27.2] as splits),
STRUCT("Kipketer" as name, [23.2, 26.1, 27.3, 29.4] as splits),
STRUCT("Berian" as name, [23.7, 26.1, 27.0, 29.3] as splits)]
AS participants)
SELECT
race,
participant
FROM races r
CROSS JOIN UNNEST(r.participants) as participant;
"""
query_job = client.query(sql)
arrow_table = query_job.to_arrow()

print(
"Downloaded {} rows, {} columns.".format(
arrow_table.num_rows, arrow_table.num_columns
)
)
print("\nSchema:\n{}".format(repr(arrow_table.schema)))
# [END bigquery_query_to_arrow]
return arrow_table


if __name__ == "__main__":
from google.cloud import bigquery

main(bigquery.Client())
29 changes: 29 additions & 0 deletions bigquery/samples/tests/test_query_to_arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pyarrow

from .. import query_to_arrow


def test_main(capsys, client):

arrow_table = query_to_arrow.main(client)
out, err = capsys.readouterr()
assert "Downloaded 8 rows, 2 columns." in out

arrow_schema = arrow_table.schema
assert arrow_schema.names == ["race", "participant"]
assert pyarrow.types.is_string(arrow_schema.types[0])
assert pyarrow.types.is_struct(arrow_schema.types[1])
Loading