From 314abe4a6ea5975745bf4079045e86acd380c0dd Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 16 Apr 2024 06:00:46 -0400 Subject: [PATCH] fix(bigquery): ensure session creation before creating temp tables (#8976) Fixes #8975. --- ibis/backends/bigquery/__init__.py | 23 ++++++++----------- .../bigquery/tests/system/test_client.py | 13 +++++++++-- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/ibis/backends/bigquery/__init__.py b/ibis/backends/bigquery/__init__.py index d832eeb20acd..668c8653b6ca 100644 --- a/ibis/backends/bigquery/__init__.py +++ b/ibis/backends/bigquery/__init__.py @@ -133,15 +133,19 @@ class Backend(SQLBackend, CanCreateDatabase, CanCreateSchema): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self._session_dataset: bq.DatasetReference | None = None + self.__session_dataset: bq.DatasetReference | None = None self._query_cache.lookup = lambda name: self.table( name, database=(self._session_dataset.project, self._session_dataset.dataset_id), ).op() - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: - self._make_session() + @property + def _session_dataset(self): + if self.__session_dataset is None: + self.__session_dataset = self._make_session() + return self.__session_dataset + def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: raw_name = op.name project = self._session_dataset.project @@ -574,24 +578,20 @@ def table( return rename_partitioned_column(table_expr, bq_table, self.partition_column) def _make_session(self) -> tuple[str, str]: - if ( - self._session_dataset is None - and (client := getattr(self, "client", None)) is not None - ): + if (client := getattr(self, "client", None)) is not None: job_config = bq.QueryJobConfig(use_query_cache=False) query = client.query( "SELECT 1", job_config=job_config, project=self.billing_project ) query.result() - self._session_dataset = bq.DatasetReference( + return bq.DatasetReference( project=query.destination.project, dataset_id=query.destination.dataset_id, ) + return None def _get_schema_using_query(self, query: str) -> sch.Schema: - self._make_session() - job = self.client.query( query, job_config=bq.QueryJobConfig(dry_run=True, use_query_cache=False), @@ -600,8 +600,6 @@ def _get_schema_using_query(self, query: str) -> sch.Schema: return BigQuerySchema.to_ibis(job.schema) def _execute(self, stmt, query_parameters=None): - self._make_session() - job_config = bq.job.QueryJobConfig(query_parameters=query_parameters or []) query = self.client.query( stmt, job_config=job_config, project=self.billing_project @@ -637,7 +635,6 @@ def _to_sqlglot( backend. """ - self._make_session() self._define_udf_translation_rules(expr) sql = super()._to_sqlglot(expr, limit=limit, params=params, **kwargs) diff --git a/ibis/backends/bigquery/tests/system/test_client.py b/ibis/backends/bigquery/tests/system/test_client.py index 68f51c9148f3..2a66a8f42246 100644 --- a/ibis/backends/bigquery/tests/system/test_client.py +++ b/ibis/backends/bigquery/tests/system/test_client.py @@ -372,8 +372,9 @@ def test_fully_qualified_table_creation(con, project_id, dataset_id, temp_table) def test_fully_qualified_memtable_compile(project_id, dataset_id): new_bq_con = ibis.bigquery.connect(project_id=project_id, dataset_id=dataset_id) - # New connection shouldn't have _session_dataset populated after connection - assert new_bq_con._session_dataset is None + # New connection shouldn't have __session_dataset populated after + # connection + assert new_bq_con._Backend__session_dataset is None t = ibis.memtable( {"a": [1, 2, 3], "b": [4, 5, 6]}, @@ -422,3 +423,11 @@ def test_list_tables_schema_warning_refactor(con): assert con.list_tables(database="ibis-gbq.pypi") == pypi_tables assert con.list_tables(database=("ibis-gbq", "pypi")) == pypi_tables + + +def test_create_temp_table_from_scratch(project_id, dataset_id): + con = ibis.bigquery.connect(project_id=project_id, dataset_id=dataset_id) + name = gen_name("bigquery_temp_table") + df = con.tables.functional_alltypes.limit(1) + t = con.create_table(name, obj=df, temp=True) + assert len(t.execute()) == 1