diff --git a/bigframes/session.py b/bigframes/session.py index 4c71657798..28a38f9307 100644 --- a/bigframes/session.py +++ b/bigframes/session.py @@ -88,6 +88,8 @@ _BIGQUERYCONNECTION_REGIONAL_ENDPOINT = "{location}-bigqueryconnection.googleapis.com" _BIGQUERYSTORAGE_REGIONAL_ENDPOINT = "{location}-bigquerystorage.googleapis.com" +_MAX_CLUSTER_COLUMNS = 4 + # TODO(swast): Need to connect to regional endpoints when performing remote # functions operations (BQ Connection IAM, Cloud Run / Cloud Functions). @@ -356,15 +358,10 @@ def read_gbq_query( else: index_cols = list(index_col) - # Make sure we cluster by the index column so that subsequent - # operations are as speedy as they can be. - if index_cols: - destination: bigquery.Table | bigquery.TableReference = ( - self._query_to_session_table(query, index_cols) - ) - else: - _, query_job = self._start_query(query) - destination = query_job.destination + # Can't cluster since don't know if index_cols are clusterable data types + # TODO(tbergeron): Maybe use dryrun to determine types of index_cols to see if can cluster + _, query_job = self._start_query(query) + destination = query_job.destination # If there was no destination table, that means the query must have # been DDL or DML. Return some job metadata, instead. @@ -607,7 +604,7 @@ def _read_bigquery_load_job( index_cols = list(index_col) if not job_config.clustering_fields and index_cols: - job_config.clustering_fields = index_cols + job_config.clustering_fields = index_cols[:_MAX_CLUSTER_COLUMNS] if isinstance(filepath_or_buffer, str): if filepath_or_buffer.startswith("gs://"): @@ -724,7 +721,9 @@ def read_pandas(self, pandas_dataframe: pandas.DataFrame) -> dataframe.DataFrame filter(lambda name: name is not None, pandas_dataframe_copy.index.names) ) index_labels = typing.cast(List[Optional[str]], index_cols) - cluster_cols = index_cols + [ordering_col] + + # Clustering probably not needed anyways as pandas tables are small + cluster_cols = [ordering_col] if len(index_cols) == 0: # Block constructor will implicitly build default index @@ -937,9 +936,15 @@ def _create_sequential_ordering( ibis.row_number().cast(ibis_dtypes.int64).name(default_ordering_name) ) table = table.mutate(**{default_ordering_name: default_ordering_col}) + clusterable_index_cols = [ + col for col in index_cols if _can_cluster(table[col].type()) + ] + cluster_cols = (clusterable_index_cols + [default_ordering_name])[ + :_MAX_CLUSTER_COLUMNS + ] table_ref = self._query_to_session_table( self.ibis_client.compile(table), - cluster_cols=list(index_cols) + [default_ordering_name], + cluster_cols=cluster_cols, ) table = self.ibis_client.sql(f"SELECT * FROM `{table_ref.table_id}`") ordering_reference = core.OrderingColumnReference(default_ordering_name) @@ -953,6 +958,10 @@ def _create_sequential_ordering( def _query_to_session_table( self, query_text: str, cluster_cols: Iterable[str] ) -> bigquery.TableReference: + if len(list(cluster_cols)) > _MAX_CLUSTER_COLUMNS: + raise ValueError( + f"Too many cluster columns: {list(cluster_cols)}, max {_MAX_CLUSTER_COLUMNS} allowed." + ) # Can't set a table in _SESSION as destination via query job API, so we # run DDL, instead. table = self._create_session_table() @@ -1136,3 +1145,16 @@ def _start_generic_job(self, job: formatting_helpers.GenericJob): def connect(context: Optional[bigquery_options.BigQueryOptions] = None) -> Session: return Session(context) + + +def _can_cluster(ibis_type: ibis_dtypes.DataType): + # https://cloud.google.com/bigquery/docs/clustered-tables + # Notably, float is excluded + return ( + ibis_type.is_integer() + or ibis_type.is_string() + or ibis_type.is_decimal() + or ibis_type.is_date() + or ibis_type.is_timestamp() + or ibis_type.is_boolean() + ) diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index 8dfd840ede..599b8aabbc 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -113,6 +113,33 @@ def test_read_gbq_w_col_order( ["bool_col"], id="non_unique_index", ), + pytest.param( + "{scalars_table_id}", + ["float64_col"], + id="non_unique_float_index", + ), + pytest.param( + "{scalars_table_id}", + [ + "timestamp_col", + "float64_col", + "datetime_col", + "int64_too", + ], + id="multi_part_index_direct", + ), + pytest.param( + "SELECT * FROM {scalars_table_id}", + [ + "timestamp_col", + "float64_col", + "string_col", + "bool_col", + "int64_col", + "int64_too", + ], + id="multi_part_index_w_query", + ), ], ) def test_read_gbq_w_index_col(