Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add df snapshots lookup for read_gbq #229

Merged
merged 18 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions bigframes/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def read_gbq(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
) -> bigframes.dataframe.DataFrame:
_set_default_session_location_if_possible(query_or_table)
return global_session.with_default_session(
Expand All @@ -494,6 +495,7 @@ def read_gbq(
index_col=index_col,
col_order=col_order,
max_results=max_results,
use_cache=use_cache,
)


Expand All @@ -516,6 +518,7 @@ def read_gbq_query(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
) -> bigframes.dataframe.DataFrame:
_set_default_session_location_if_possible(query)
return global_session.with_default_session(
Expand All @@ -524,6 +527,7 @@ def read_gbq_query(
index_col=index_col,
col_order=col_order,
max_results=max_results,
use_cache=use_cache,
)


Expand All @@ -536,6 +540,7 @@ def read_gbq_table(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
) -> bigframes.dataframe.DataFrame:
_set_default_session_location_if_possible(query)
return global_session.with_default_session(
Expand All @@ -544,6 +549,7 @@ def read_gbq_table(
index_col=index_col,
col_order=col_order,
max_results=max_results,
use_cache=use_cache,
)


Expand Down
56 changes: 32 additions & 24 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
# Now that we're starting the session, don't allow the options to be
# changed.
context._session_started = True
self._df_snapshot: Dict[bigquery.TableReference, datetime.datetime] = {}
ashleyxuu marked this conversation as resolved.
Show resolved Hide resolved

@property
def bqclient(self):
Expand Down Expand Up @@ -232,6 +233,7 @@ def read_gbq(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unresolving. You didn't populate the QueryJobConfig.use_query_cache property as I requested.

# Add a verify index argument that fails if the index is not unique.
) -> dataframe.DataFrame:
# TODO(b/281571214): Generate prompt to show the progress of read_gbq.
Expand All @@ -242,6 +244,7 @@ def read_gbq(
col_order=col_order,
max_results=max_results,
api_name="read_gbq",
use_cache=use_cache,
)
else:
# TODO(swast): Query the snapshot table but mark it as a
Expand All @@ -253,13 +256,15 @@ def read_gbq(
col_order=col_order,
max_results=max_results,
api_name="read_gbq",
use_cache=use_cache,
)

def _query_to_destination(
self,
query: str,
index_cols: List[str],
api_name: str,
use_cache: bool = True,
) -> Tuple[Optional[bigquery.TableReference], Optional[bigquery.QueryJob]]:
# If a dry_run indicates this is not a query type job, then don't
# bother trying to do a CREATE TEMP TABLE ... AS SELECT ... statement.
Expand All @@ -284,6 +289,7 @@ def _query_to_destination(
job_config = bigquery.QueryJobConfig()
job_config.labels["bigframes-api"] = api_name
job_config.destination = temp_table
job_config.use_query_cache = use_cache

try:
# Write to temp table to workaround BigQuery 10 GB query results
Expand All @@ -305,6 +311,7 @@ def read_gbq_query(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
) -> dataframe.DataFrame:
"""Turn a SQL query into a DataFrame.

Expand Down Expand Up @@ -362,6 +369,7 @@ def read_gbq_query(
col_order=col_order,
max_results=max_results,
api_name="read_gbq_query",
use_cache=use_cache,
)

def _read_gbq_query(
Expand All @@ -372,14 +380,18 @@ def _read_gbq_query(
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
api_name: str = "read_gbq_query",
use_cache: bool = True,
) -> dataframe.DataFrame:
if isinstance(index_col, str):
index_cols = [index_col]
else:
index_cols = list(index_col)

destination, query_job = self._query_to_destination(
query, index_cols, api_name=api_name
query,
index_cols,
api_name=api_name,
use_cache=use_cache,
)

# If there was no destination table, that means the query must have
Expand All @@ -403,6 +415,7 @@ def _read_gbq_query(
index_col=index_cols,
col_order=col_order,
max_results=max_results,
use_cache=use_cache,
)

def read_gbq_table(
Expand All @@ -412,6 +425,7 @@ def read_gbq_table(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
) -> dataframe.DataFrame:
"""Turn a BigQuery table into a DataFrame.

Expand All @@ -434,33 +448,22 @@ def read_gbq_table(
col_order=col_order,
max_results=max_results,
api_name="read_gbq_table",
use_cache=use_cache,
)

def _get_snapshot_sql_and_primary_key(
self,
table_ref: bigquery.table.TableReference,
*,
api_name: str,
use_cache: bool = True,
) -> Tuple[ibis_types.Table, Optional[Sequence[str]]]:
"""Create a read-only Ibis table expression representing a table.

If we can get a total ordering from the table, such as via primary key
column(s), then return those too so that ordering generation can be
avoided.
"""
if table_ref.dataset_id.upper() == "_SESSION":
# _SESSION tables aren't supported by the tables.get REST API.
return (
self.ibis_client.sql(
f"SELECT * FROM `_SESSION`.`{table_ref.table_id}`"
),
None,
)
table_expression = self.ibis_client.table(
table_ref.table_id,
database=f"{table_ref.project}.{table_ref.dataset_id}",
)

# If there are primary keys defined, the query engine assumes these
# columns are unique, even if the constraint is not enforced. We make
# the same assumption and use these columns as the total ordering keys.
Expand All @@ -481,14 +484,18 @@ def _get_snapshot_sql_and_primary_key(

job_config = bigquery.QueryJobConfig()
job_config.labels["bigframes-api"] = api_name
current_timestamp = list(
self.bqclient.query(
"SELECT CURRENT_TIMESTAMP() AS `current_timestamp`",
job_config=job_config,
).result()
)[0][0]
if use_cache and table_ref in self._df_snapshot.keys():
snapshot_timestamp = self._df_snapshot[table_ref]
else:
snapshot_timestamp = list(
self.bqclient.query(
"SELECT CURRENT_TIMESTAMP() AS `current_timestamp`",
job_config=job_config,
).result()
)[0][0]
ashleyxuu marked this conversation as resolved.
Show resolved Hide resolved
self._df_snapshot[table_ref] = snapshot_timestamp
table_expression = self.ibis_client.sql(
bigframes_io.create_snapshot_sql(table_ref, current_timestamp)
bigframes_io.create_snapshot_sql(table_ref, snapshot_timestamp)
)
return table_expression, primary_keys

Expand All @@ -500,20 +507,21 @@ def _read_gbq_table(
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
api_name: str,
use_cache: bool = True,
) -> dataframe.DataFrame:
if max_results and max_results <= 0:
raise ValueError("`max_results` should be a positive number.")

# TODO(swast): Can we re-use the temp table from other reads in the
# session, if the original table wasn't modified?
table_ref = bigquery.table.TableReference.from_string(
query, default_project=self.bqclient.project
)

(
table_expression,
total_ordering_cols,
) = self._get_snapshot_sql_and_primary_key(table_ref, api_name=api_name)
) = self._get_snapshot_sql_and_primary_key(
table_ref, api_name=api_name, use_cache=use_cache
)

for key in col_order:
if key not in table_expression.columns:
Expand Down
5 changes: 0 additions & 5 deletions bigframes/session/_io/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,6 @@ def create_snapshot_sql(
table_ref: bigquery.TableReference, current_timestamp: datetime.datetime
) -> str:
"""Query a table via 'time travel' for consistent reads."""

# If we have a _SESSION table, assume that it's already a copy. Nothing to do here.
if table_ref.dataset_id.upper() == "_SESSION":
return f"SELECT * FROM `_SESSION`.`{table_ref.table_id}`"

# If we have an anonymous query results table, it can't be modified and
# there isn't any BigQuery time travel.
if table_ref.dataset_id.startswith("_"):
Expand Down
18 changes: 18 additions & 0 deletions tests/system/small/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import random
import tempfile
import textwrap
import time
import typing
from typing import List

Expand Down Expand Up @@ -308,6 +309,23 @@ def test_read_gbq_w_script_no_select(session, dataset_id: str):
assert df["statement_type"][0] == "SCRIPT"


def test_read_gbq_twice_with_same_timestamp(session, penguins_table_id):
df1 = session.read_gbq(penguins_table_id)
time.sleep(1)
df2 = session.read_gbq(penguins_table_id)
df1.columns = [
"species1",
"island1",
"culmen_length_mm1",
"culmen_depth_mm1",
"flipper_length_mm1",
"body_mass_g1",
"sex1",
]
df3 = df1.join(df2)
assert df3 is not None


def test_read_gbq_model(session, penguins_linear_model_name):
model = session.read_gbq_model(penguins_linear_model_name)
assert isinstance(model, bigframes.ml.linear_model.LinearRegression)
Expand Down
14 changes: 0 additions & 14 deletions tests/unit/session/test_io_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,6 @@ def test_create_snapshot_sql_doesnt_timetravel_anonymous_datasets():
assert "`my-test-project`.`_e8166e0cdb`.`anonbb92cd`" in sql


def test_create_snapshot_sql_doesnt_timetravel_session_tables():
table_ref = bigquery.TableReference.from_string("my-test-project._session.abcdefg")

sql = bigframes.session._io.bigquery.create_snapshot_sql(
table_ref, datetime.datetime.now(datetime.timezone.utc)
)

# We aren't modifying _SESSION tables, so don't use time travel.
assert "SYSTEM_TIME" not in sql

# Don't need the project ID for _SESSION tables.
assert "my-test-project" not in sql


def test_create_temp_table_default_expiration():
"""Make sure the created table has an expiration."""
bqclient = mock.create_autospec(bigquery.Client)
Expand Down
3 changes: 3 additions & 0 deletions third_party/bigframes_vendored/pandas/io/gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def read_gbq(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
use_cache: bool = True,
):
"""Loads a DataFrame from BigQuery.

Expand Down Expand Up @@ -83,6 +84,8 @@ def read_gbq(
max_results (Optional[int], default None):
If set, limit the maximum number of rows to fetch from the
query results.
use_cache (bool, default True):
Whether to cache the query inputs. Default to True.

Returns:
bigframes.dataframe.DataFrame: A DataFrame representing results of the query or table.
Expand Down