Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Make repr cache the block where appropriate #350

Merged
merged 4 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,10 +1695,19 @@ def to_sql_query(
idx_labels,
)

def cached(self) -> Block:
def cached(self, *, optimize_offsets=False, force: bool = False) -> Block:
"""Write the block to a session table and create a new block object that references it."""
# use a heuristic for whether something needs to be cached
if (not force) and self.session._is_trivially_executable(self.expr):
return self
if optimize_offsets:
expr = self.session._cache_with_offsets(self.expr)
else:
expr = self.session._cache_with_cluster_cols(
self.expr, cluster_cols=self.index_columns
)
return Block(
self.session._execute_and_cache(self.expr, cluster_cols=self.index_columns),
expr,
index_columns=self.index_columns,
column_labels=self.column_labels,
index_labels=self.index_labels,
Expand Down
69 changes: 68 additions & 1 deletion bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ def deterministic(self) -> bool:
"""Whether this node will evaluates deterministically."""
return True

@property
def row_preserving(self) -> bool:
"""Whether this node preserves input rows."""
return True

@property
def non_local(self) -> bool:
"""
Whether this node combines information across multiple rows instead of processing rows independently.
Used as an approximation for whether the expression may require shuffling to execute (and therefore be expensive).
"""
return False

@property
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
"""Direct children of this node"""
Expand Down Expand Up @@ -104,6 +117,14 @@ class JoinNode(BigFrameNode):
join: JoinDefinition
allow_row_identity_join: bool = True

@property
def row_preserving(self) -> bool:
return False

@property
def non_local(self) -> bool:
return True

@property
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
return (self.left_child, self.right_child)
Expand Down Expand Up @@ -184,11 +205,19 @@ def __hash__(self):
def peekable(self) -> bool:
return False

@property
def non_local(self) -> bool:
return False


@dataclass(frozen=True)
class FilterNode(UnaryNode):
predicate: ex.Expression

@property
def row_preserving(self) -> bool:
return False

def __hash__(self):
return self._node_hash

Expand Down Expand Up @@ -221,7 +250,13 @@ def __hash__(self):
# TODO: Merge RowCount and Corr into Aggregate Node
@dataclass(frozen=True)
class RowCountNode(UnaryNode):
pass
@property
def row_preserving(self) -> bool:
return False

@property
def non_local(self) -> bool:
return True


@dataclass(frozen=True)
Expand All @@ -230,13 +265,21 @@ class AggregateNode(UnaryNode):
by_column_ids: typing.Tuple[str, ...] = tuple([])
dropna: bool = True

@property
def row_preserving(self) -> bool:
return False

def __hash__(self):
return self._node_hash

@property
def peekable(self) -> bool:
return False

@property
def non_local(self) -> bool:
return True


# TODO: Unify into aggregate
@dataclass(frozen=True)
Expand All @@ -246,10 +289,18 @@ class CorrNode(UnaryNode):
def __hash__(self):
return self._node_hash

@property
def row_preserving(self) -> bool:
return False

@property
def peekable(self) -> bool:
return False

@property
def non_local(self) -> bool:
return True


@dataclass(frozen=True)
class WindowOpNode(UnaryNode):
Expand All @@ -267,6 +318,10 @@ def __hash__(self):
def peekable(self) -> bool:
return False

@property
def non_local(self) -> bool:
return True


@dataclass(frozen=True)
class ReprojectOpNode(UnaryNode):
Expand All @@ -290,6 +345,14 @@ class UnpivotNode(UnaryNode):
def __hash__(self):
return self._node_hash

@property
def row_preserving(self) -> bool:
return False

@property
def non_local(self) -> bool:
return True

@property
def peekable(self) -> bool:
return False
Expand All @@ -303,5 +366,9 @@ class RandomSampleNode(UnaryNode):
def deterministic(self) -> bool:
return False

@property
def row_preserving(self) -> bool:
return False

def __hash__(self):
return self._node_hash
8 changes: 8 additions & 0 deletions bigframes/core/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ class ExpressionOrdering:
# Therefore, any modifications(or drops) done to these columns must result in hidden copies being made.
total_ordering_columns: frozenset[str] = field(default_factory=frozenset)

@classmethod
def from_offset_col(cls, col: str) -> ExpressionOrdering:
return ExpressionOrdering(
(OrderingColumnReference(col),),
integer_encoding=IntegerEncoding(True, is_sequential=True),
total_ordering_columns=frozenset({col}),
)

def with_non_sequential(self):
"""Create a copy that is marked as non-sequential.

Expand Down
27 changes: 27 additions & 0 deletions bigframes/core/traversal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import bigframes.core.nodes as nodes


def is_trivially_executable(node: nodes.BigFrameNode) -> bool:
if local_only(node):
return True
children_trivial = all(is_trivially_executable(child) for child in node.child_nodes)
self_trivial = (not node.non_local) and (node.row_preserving)
return children_trivial and self_trivial


def local_only(node: nodes.BigFrameNode) -> bool:
return all(isinstance(node, nodes.ReadLocalNode) for node in node.roots)
12 changes: 10 additions & 2 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@ def __repr__(self) -> str:
max_results = opts.max_rows
if opts.repr_mode == "deferred":
return formatter.repr_query_job(self.query_job)

self._cached()
# TODO(swast): pass max_columns and get the true column count back. Maybe
# get 1 more column than we have requested so that pandas can add the
# ... for us?
Expand Down Expand Up @@ -629,6 +631,8 @@ def _repr_html_(self) -> str:
max_results = bigframes.options.display.max_rows
if opts.repr_mode == "deferred":
return formatter.repr_query_job_html(self.query_job)

self._cached()
# TODO(swast): pass max_columns and get the true column count back. Maybe
# get 1 more column than we have requested so that pandas can add the
# ... for us?
Expand Down Expand Up @@ -3100,8 +3104,12 @@ def _set_block(self, block: blocks.Block):
def _get_block(self) -> blocks.Block:
return self._block

def _cached(self) -> DataFrame:
self._set_block(self._block.cached())
def _cached(self, *, force: bool = False) -> DataFrame:
"""Materialize dataframe to a temporary table.
No-op if the dataframe represents a trivial transformation of an existing materialization.
Force=True is used for BQML integration where need to copy data rather than use snapshot.
"""
self._set_block(self._block.cached(force=force))
return self

_DataFrameOrSeries = typing.TypeVar("_DataFrameOrSeries")
Expand Down
10 changes: 7 additions & 3 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,11 @@ def create_model(
# Cache dataframes to make sure base table is not a snapshot
# cached dataframe creates a full copy, never uses snapshot
if y_train is None:
input_data = X_train._cached()
input_data = X_train._cached(force=True)
else:
input_data = X_train._cached().join(y_train._cached(), how="outer")
input_data = X_train._cached(force=True).join(
y_train._cached(force=True), how="outer"
)
options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()})

session = X_train._session
Expand Down Expand Up @@ -281,7 +283,9 @@ def create_time_series_model(
options = dict(options)
# Cache dataframes to make sure base table is not a snapshot
# cached dataframe creates a full copy, never uses snapshot
input_data = X_train._cached().join(y_train._cached(), how="outer")
input_data = X_train._cached(force=True).join(
y_train._cached(force=True), how="outer"
)
options.update({"TIME_SERIES_TIMESTAMP_COL": X_train.columns.tolist()[0]})
options.update({"TIME_SERIES_DATA_COL": y_train.columns.tolist()[0]})

Expand Down
5 changes: 3 additions & 2 deletions bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def __repr__(self) -> str:
if opts.repr_mode == "deferred":
return formatter.repr_query_job(self.query_job)

self._cached()
pandas_df, _, query_job = self._block.retrieve_repr_request_results(max_results)
self._set_internal_query_job(query_job)

Expand Down Expand Up @@ -1521,8 +1522,8 @@ def _slice(
),
)

def _cached(self) -> Series:
self._set_block(self._block.cached())
def _cached(self, *, force: bool = True) -> Series:
self._set_block(self._block.cached(force=force))
return self


Expand Down
38 changes: 37 additions & 1 deletion bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import bigframes.core.guid as guid
from bigframes.core.ordering import IntegerEncoding, OrderingColumnReference
import bigframes.core.ordering as orderings
import bigframes.core.traversal as traversals
import bigframes.core.utils as utils
import bigframes.dataframe as dataframe
import bigframes.formatting_helpers as formatting_helpers
Expand Down Expand Up @@ -1475,7 +1476,7 @@ def _start_query(
results_iterator = query_job.result(max_results=max_results)
return results_iterator, query_job

def _execute_and_cache(
def _cache_with_cluster_cols(
self, array_value: core.ArrayValue, cluster_cols: typing.Sequence[str]
) -> core.ArrayValue:
"""Executes the query and uses the resulting table to rewrite future executions."""
Expand Down Expand Up @@ -1506,6 +1507,41 @@ def _execute_and_cache(
ordering=compiled_value._ordering,
)

def _cache_with_offsets(self, array_value: core.ArrayValue) -> core.ArrayValue:
"""Executes the query and uses the resulting table to rewrite future executions."""
# TODO: Use this for all executions? Problem is that caching materializes extra
# ordering columns
compiled_value = self._compile_ordered(array_value)

ibis_expr = compiled_value._to_ibis_expr(
ordering_mode="offset_col", order_col_name="bigframes_offsets"
)
tmp_table = self._ibis_to_temp_table(
ibis_expr, cluster_cols=["bigframes_offsets"], api_name="cached"
)
table_expression = self.ibis_client.table(
f"{tmp_table.project}.{tmp_table.dataset_id}.{tmp_table.table_id}"
)
new_columns = [table_expression[column] for column in compiled_value.column_ids]
new_hidden_columns = [table_expression["bigframes_offsets"]]
# TODO: Instead, keep session-wide map of cached results and automatically reuse
return core.ArrayValue.from_ibis(
self,
table_expression,
columns=new_columns,
hidden_ordering_columns=new_hidden_columns,
ordering=orderings.ExpressionOrdering.from_offset_col("bigframes_offsets"),
)

def _is_trivially_executable(self, array_value: core.ArrayValue):
"""
Can the block be evaluated very cheaply?
If True, the array_value probably is not worth caching.
"""
# Once rewriting is available, will want to rewrite before
# evaluating execution cost.
return traversals.is_trivially_executable(array_value.node)

def _execute(
self,
array_value: core.ArrayValue,
Expand Down