diff --git a/bigframes/core/__init__.py b/bigframes/core/__init__.py index 71b1214d01..3b1bf48558 100644 --- a/bigframes/core/__init__.py +++ b/bigframes/core/__init__.py @@ -26,7 +26,6 @@ import pyarrow as pa import pyarrow.feather as pa_feather -import bigframes.core.compile import bigframes.core.expression as ex import bigframes.core.guid import bigframes.core.identifiers as ids @@ -35,7 +34,6 @@ import bigframes.core.nodes as nodes from bigframes.core.ordering import OrderingExpression import bigframes.core.ordering as orderings -import bigframes.core.rewrite import bigframes.core.schema as schemata import bigframes.core.tree_properties import bigframes.core.utils @@ -43,7 +41,6 @@ import bigframes.dtypes import bigframes.operations as ops import bigframes.operations.aggregations as agg_ops -import bigframes.session._io.bigquery if typing.TYPE_CHECKING: from bigframes.session import Session @@ -199,6 +196,8 @@ def as_cached( def _try_evaluate_local(self): """Use only for unit testing paths - not fully featured. Will throw exception if fails.""" + import bigframes.core.compile + return bigframes.core.compile.test_only_try_evaluate(self.node) def get_column_type(self, key: str) -> bigframes.dtypes.Dtype: @@ -422,22 +421,7 @@ def relational_join( l_mapping = { # Identity mapping, only rename right side lcol.name: lcol.name for lcol in self.node.ids } - r_mapping = { # Rename conflicting names - rcol.name: rcol.name - if (rcol.name not in l_mapping) - else bigframes.core.guid.generate_guid() - for rcol in other.node.ids - } - other_node = other.node - if set(other_node.ids) & set(self.node.ids): - other_node = nodes.SelectionNode( - other_node, - tuple( - (ex.deref(old_id), ids.ColumnId(new_id)) - for old_id, new_id in r_mapping.items() - ), - ) - + other_node, r_mapping = self.prepare_join_names(other) join_node = nodes.JoinNode( left_child=self.node, right_child=other_node, @@ -449,14 +433,63 @@ def relational_join( ) return ArrayValue(join_node), (l_mapping, r_mapping) - def try_align_as_projection( + def try_row_join( + self, + other: ArrayValue, + conditions: typing.Tuple[typing.Tuple[str, str], ...] = (), + ) -> Optional[ + typing.Tuple[ArrayValue, typing.Tuple[dict[str, str], dict[str, str]]] + ]: + l_mapping = { # Identity mapping, only rename right side + lcol.name: lcol.name for lcol in self.node.ids + } + other_node, r_mapping = self.prepare_join_names(other) + import bigframes.core.rewrite + + result_node = bigframes.core.rewrite.try_join_as_projection( + self.node, other_node, conditions + ) + if result_node is None: + return None + + return ( + ArrayValue(result_node), + (l_mapping, r_mapping), + ) + + def prepare_join_names( + self, other: ArrayValue + ) -> Tuple[bigframes.core.nodes.BigFrameNode, dict[str, str]]: + if set(other.node.ids) & set(self.node.ids): + r_mapping = { # Rename conflicting names + rcol.name: rcol.name + if (rcol.name not in self.column_ids) + else bigframes.core.guid.generate_guid() + for rcol in other.node.ids + } + return ( + nodes.SelectionNode( + other.node, + tuple( + (ex.deref(old_id), ids.ColumnId(new_id)) + for old_id, new_id in r_mapping.items() + ), + ), + r_mapping, + ) + else: + return other.node, {id: id for id in other.column_ids} + + def try_legacy_row_join( self, other: ArrayValue, join_type: join_def.JoinType, join_keys: typing.Tuple[join_def.CoalescedColumnMapping, ...], mappings: typing.Tuple[join_def.JoinColumnMapping, ...], ) -> typing.Optional[ArrayValue]: - result = bigframes.core.rewrite.join_as_projection( + import bigframes.core.rewrite + + result = bigframes.core.rewrite.legacy_join_as_projection( self.node, other.node, join_keys, mappings, join_type ) if result is not None: @@ -488,11 +521,4 @@ def _gen_namespaced_uid(self) -> str: return self._gen_namespaced_uids(1)[0] def _gen_namespaced_uids(self, n: int) -> List[str]: - i = len(self.node.defined_variables) - genned_ids: List[str] = [] - while len(genned_ids) < n: - attempted_id = f"col_{i}" - if attempted_id not in self.node.defined_variables: - genned_ids.append(attempted_id) - i = i + 1 - return genned_ids + return [ids.ColumnId.unique().name for _ in range(n)] diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index 4fc663817c..574bed00eb 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -2341,7 +2341,9 @@ def join( # Handle null index, which only supports row join # This is the canonical way of aligning on null index, so always allow (ignore block_identity_join) if self.index.nlevels == other.index.nlevels == 0: - result = try_row_join(self, other, how=how) + result = try_legacy_row_join(self, other, how=how) or try_new_row_join( + self, other + ) if result is not None: return result raise bigframes.exceptions.NullIndexError( @@ -2354,7 +2356,9 @@ def join( and (self.index.nlevels == other.index.nlevels) and (self.index.dtypes == other.index.dtypes) ): - result = try_row_join(self, other, how=how) + result = try_legacy_row_join(self, other, how=how) or try_new_row_join( + self, other + ) if result is not None: return result @@ -2693,7 +2697,35 @@ def is_uniquely_named(self: BlockIndexProperties): return len(set(self.names)) == len(self.names) -def try_row_join( +def try_new_row_join( + left: Block, right: Block +) -> Optional[Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]]: + join_keys = tuple( + (left_id, right_id) + for left_id, right_id in zip(left.index_columns, right.index_columns) + ) + join_result = left.expr.try_row_join(right.expr, join_keys) + if join_result is None: # did not succeed + return None + combined_expr, (get_column_left, get_column_right) = join_result + # Keep the left index column, and drop the matching right column + index_cols_post_join = [get_column_left[id] for id in left.index_columns] + combined_expr = combined_expr.drop_columns( + [get_column_right[id] for id in right.index_columns] + ) + block = Block( + combined_expr, + index_columns=index_cols_post_join, + column_labels=left.column_labels.append(right.column_labels), + index_labels=left.index.names, + ) + return ( + block, + (get_column_left, get_column_right), + ) + + +def try_legacy_row_join( left: Block, right: Block, *, @@ -2727,7 +2759,7 @@ def try_row_join( ) for id in right.value_columns ] - combined_expr = left_expr.try_align_as_projection( + combined_expr = left_expr.try_legacy_row_join( right_expr, join_type=how, join_keys=join_keys, diff --git a/bigframes/core/identifiers.py b/bigframes/core/identifiers.py index 8c2f7e910f..b7ae0e2434 100644 --- a/bigframes/core/identifiers.py +++ b/bigframes/core/identifiers.py @@ -18,6 +18,8 @@ import itertools from typing import Generator +import bigframes.core.guid + def standard_id_strings(prefix: str = "col_") -> Generator[str, None, None]: i = 0 @@ -47,6 +49,10 @@ def local_normalized(self) -> ColumnId: def __lt__(self, other: ColumnId) -> bool: return self.sql < other.sql + @classmethod + def unique(cls) -> ColumnId: + return ColumnId(name=bigframes.core.guid.generate_guid()) + @dataclasses.dataclass(frozen=True) class SerialColumnId(ColumnId): diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 30a130bbac..420348cca9 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -15,7 +15,7 @@ from __future__ import annotations import abc -from dataclasses import dataclass, field, fields, replace +import dataclasses import datetime import functools import itertools @@ -46,13 +46,13 @@ COLUMN_SET = frozenset[bfet_ids.ColumnId] -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class Field: id: bfet_ids.ColumnId dtype: bigframes.dtypes.Dtype -@dataclass(eq=False, frozen=True) +@dataclasses.dataclass(eq=False, frozen=True) class BigFrameNode(abc.ABC): """ Immutable node for representing 2D typed array as a tree of operators. @@ -83,6 +83,10 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]: """Direct children of this node""" return tuple([]) + @property + def projection_base(self) -> BigFrameNode: + return self + @property @abc.abstractmethod def row_count(self) -> typing.Optional[int]: @@ -123,11 +127,14 @@ def validate_tree(self) -> bool: for child in self.child_nodes: child.validate_tree() self._validate() + field_list = list(self.fields) + if len(set(field_list)) != len(field_list): + raise ValueError(f"Non unique field ids {list(self.fields)}") return True def _as_tuple(self) -> Tuple: """Get all fields as tuple.""" - return tuple(getattr(self, field.name) for field in fields(self)) + return tuple(getattr(self, field.name) for field in dataclasses.fields(self)) def __hash__(self) -> int: # Custom hash that uses cache to avoid costly recomputation @@ -282,7 +289,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: return self.transform_children(lambda x: x.prune(used_cols)) -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class UnaryNode(BigFrameNode): child: BigFrameNode @@ -301,18 +308,22 @@ def explicitly_ordered(self) -> bool: def transform_children( self, t: Callable[[BigFrameNode], BigFrameNode] ) -> BigFrameNode: - transformed = replace(self, child=t(self.child)) + transformed = dataclasses.replace(self, child=t(self.child)) if self == transformed: # reusing existing object speeds up eq, and saves a small amount of memory return self return transformed + def replace_child(self, new_child: BigFrameNode) -> UnaryNode: + new_self = dataclasses.replace(self, child=new_child) # type: ignore + return new_self + @property def order_ambiguous(self) -> bool: return self.child.order_ambiguous -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class SliceNode(UnaryNode): """Logical slice node conditionally becomes limit or filter over row numbers.""" @@ -375,7 +386,7 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): return self -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class JoinNode(BigFrameNode): left_child: BigFrameNode right_child: BigFrameNode @@ -437,7 +448,7 @@ def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: def transform_children( self, t: Callable[[BigFrameNode], BigFrameNode] ) -> BigFrameNode: - transformed = replace( + transformed = dataclasses.replace( self, left_child=t(self.left_child), right_child=t(self.right_child) ) if self == transformed: @@ -467,10 +478,10 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): ) for l_cond, r_cond in self.conditions ) - return replace(self, conditions=new_conds) # type: ignore + return dataclasses.replace(self, conditions=new_conds) # type: ignore -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class ConcatNode(BigFrameNode): # TODO: Explcitly map column ids from each child children: Tuple[BigFrameNode, ...] @@ -526,7 +537,9 @@ def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: def transform_children( self, t: Callable[[BigFrameNode], BigFrameNode] ) -> BigFrameNode: - transformed = replace(self, children=tuple(t(child) for child in self.children)) + transformed = dataclasses.replace( + self, children=tuple(t(child) for child in self.children) + ) if self == transformed: # reusing existing object speeds up eq, and saves a small amount of memory return self @@ -540,13 +553,13 @@ def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] ) -> BigFrameNode: new_ids = tuple(mappings.get(id, id) for id in self.output_ids) - return replace(self, output_ids=new_ids) + return dataclasses.replace(self, output_ids=new_ids) def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): return self -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class FromRangeNode(BigFrameNode): # TODO: Enforce single-row, single column constraint start: BigFrameNode @@ -594,7 +607,7 @@ def defines_namespace(self) -> bool: def transform_children( self, t: Callable[[BigFrameNode], BigFrameNode] ) -> BigFrameNode: - transformed = replace(self, start=t(self.start), end=t(self.end)) + transformed = dataclasses.replace(self, start=t(self.start), end=t(self.end)) if self == transformed: # reusing existing object speeds up eq, and saves a small amount of memory return self @@ -607,7 +620,9 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] ) -> BigFrameNode: - return replace(self, output_id=mappings.get(self.output_id, self.output_id)) + return dataclasses.replace( + self, output_id=mappings.get(self.output_id, self.output_id) + ) def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): return self @@ -616,7 +631,7 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): # Input Nodex # TODO: Most leaf nodes produce fixed column names based on the datasource # They should support renaming -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class LeafNode(BigFrameNode): @property def roots(self) -> typing.Set[BigFrameNode]: @@ -642,12 +657,12 @@ class ScanItem(typing.NamedTuple): source_id: str # Flexible enough for both local data and bq data -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class ScanList: items: typing.Tuple[ScanItem, ...] -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class ReadLocalNode(LeafNode): feather_bytes: bytes data_schema: schemata.ArraySchema @@ -713,20 +728,20 @@ def remap_vars( for item in self.scan_list.items ) ) - return replace(self, scan_list=new_scan_list) + return dataclasses.replace(self, scan_list=new_scan_list) def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): return self -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class GbqTable: - project_id: str = field() - dataset_id: str = field() - table_id: str = field() - physical_schema: Tuple[bq.SchemaField, ...] = field() - n_rows: int = field() - is_physically_stored: bool = field() + project_id: str = dataclasses.field() + dataset_id: str = dataclasses.field() + table_id: str = dataclasses.field() + physical_schema: Tuple[bq.SchemaField, ...] = dataclasses.field() + n_rows: int = dataclasses.field() + is_physically_stored: bool = dataclasses.field() cluster_cols: typing.Optional[Tuple[str, ...]] @staticmethod @@ -749,7 +764,7 @@ def from_table(table: bq.Table, columns: Sequence[str] = ()) -> GbqTable: ) -@dataclass(frozen=True) +@dataclasses.dataclass(frozen=True) class BigqueryDataSource: """ Google BigQuery Data source. @@ -765,14 +780,14 @@ class BigqueryDataSource: ## Put ordering in here or just add order_by node above? -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class ReadTableNode(LeafNode): source: BigqueryDataSource # Subset of physical schema column # Mapping of table schema ids to bfet id. scan_list: ScanList - table_session: bigframes.session.Session = field() + table_session: bigframes.session.Session = dataclasses.field() def _validate(self): # enforce invariants @@ -848,7 +863,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: tuple(item for item in self.scan_list.items if item.id in used_cols) or (self.scan_list.items[0],) ) - return replace(self, scan_list=new_scan_list) + return dataclasses.replace(self, scan_list=new_scan_list) def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] @@ -859,21 +874,21 @@ def remap_vars( for item in self.scan_list.items ) ) - return replace(self, scan_list=new_scan_list) + return dataclasses.replace(self, scan_list=new_scan_list) def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): return self -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class CachedTableNode(ReadTableNode): # The original BFET subtree that was cached # note: this isn't a "child" node. - original_node: BigFrameNode = field() + original_node: BigFrameNode = dataclasses.field() # Unary nodes -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class PromoteOffsetsNode(UnaryNode): col_id: bigframes.core.identifiers.ColumnId @@ -903,6 +918,14 @@ def row_count(self) -> Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return (self.col_id,) + @property + def projection_base(self) -> BigFrameNode: + return self.child.projection_base + + @property + def added_fields(self) -> Tuple[Field, ...]: + return (Field(self.col_id, bigframes.dtypes.INT_DTYPE),) + def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: if self.col_id not in used_cols: return self.child.prune(used_cols) @@ -913,13 +936,13 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] ) -> BigFrameNode: - return replace(self, col_id=mappings.get(self.col_id, self.col_id)) + return dataclasses.replace(self, col_id=mappings.get(self.col_id, self.col_id)) def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): return self -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class FilterNode(UnaryNode): predicate: ex.Expression @@ -950,7 +973,7 @@ def remap_vars( return self def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): - return replace( + return dataclasses.replace( self, predicate=self.predicate.remap_column_refs( mappings, allow_partial_bindings=True @@ -958,7 +981,7 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): ) -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class OrderByNode(UnaryNode): by: Tuple[OrderingExpression, ...] @@ -1008,10 +1031,10 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): for by_expr in self.by ), ) - return replace(self, by=new_by) + return dataclasses.replace(self, by=new_by) -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class ReversedNode(UnaryNode): # useless field to make sure has distinct hash reversed: bool = True @@ -1042,7 +1065,7 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): return self -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class SelectionNode(UnaryNode): input_output_pairs: typing.Tuple[ typing.Tuple[ex.DerefOp, bigframes.core.identifiers.ColumnId], ... @@ -1056,8 +1079,8 @@ def _validate(self): @functools.cached_property def fields(self) -> Iterable[Field]: return tuple( - Field(output, self.child.get_type(input.id)) - for input, output in self.input_output_pairs + Field(output, self.child.get_type(ref.id)) + for ref, output in self.input_output_pairs ) @property @@ -1072,6 +1095,10 @@ def variables_introduced(self) -> int: def defines_namespace(self) -> bool: return True + @property + def projection_base(self) -> BigFrameNode: + return self.child.projection_base + @property def row_count(self) -> Optional[int]: return self.child.row_count @@ -1098,17 +1125,17 @@ def remap_vars( new_pairs = tuple( (ref, mappings.get(id, id)) for ref, id in self.input_output_pairs ) - return replace(self, input_output_pairs=new_pairs) + return dataclasses.replace(self, input_output_pairs=new_pairs) def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): new_fields = tuple( (ex.remap_column_refs(mappings, allow_partial_bindings=True), id) for ex, id in self.input_output_pairs ) - return replace(self, input_output_pairs=new_fields) # type: ignore + return dataclasses.replace(self, input_output_pairs=new_fields) # type: ignore -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class ProjectionNode(UnaryNode): """Assigns new variables (without modifying existing ones)""" @@ -1146,6 +1173,10 @@ def variables_introduced(self) -> int: def row_count(self) -> Optional[int]: return self.child.row_count + @property + def projection_base(self) -> BigFrameNode: + return self.child.projection_base + @property def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return tuple(id for _, id in self.assignments) @@ -1164,19 +1195,19 @@ def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] ) -> BigFrameNode: new_fields = tuple((ex, mappings.get(id, id)) for ex, id in self.assignments) - return replace(self, assignments=new_fields) + return dataclasses.replace(self, assignments=new_fields) def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): new_fields = tuple( (ex.remap_column_refs(mappings, allow_partial_bindings=True), id) for ex, id in self.assignments ) - return replace(self, assignments=new_fields) + return dataclasses.replace(self, assignments=new_fields) # TODO: Merge RowCount into Aggregate Node? # Row count can be compute from table metadata sometimes, so it is a bit special. -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class RowCountNode(UnaryNode): col_id: bfet_ids.ColumnId = bfet_ids.ColumnId("count") @@ -1211,7 +1242,7 @@ def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] ) -> BigFrameNode: - return replace(self, col_id=mappings.get(self.col_id, self.col_id)) + return dataclasses.replace(self, col_id=mappings.get(self.col_id, self.col_id)) def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): return self @@ -1221,7 +1252,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: return self -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class AggregateNode(UnaryNode): aggregations: typing.Tuple[ typing.Tuple[ex.Aggregation, bigframes.core.identifiers.ColumnId], ... @@ -1292,7 +1323,7 @@ def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] ) -> BigFrameNode: new_aggs = tuple((agg, mappings.get(id, id)) for agg, id in self.aggregations) - return replace(self, aggregations=new_aggs) + return dataclasses.replace(self, aggregations=new_aggs) def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): new_aggs = tuple( @@ -1300,10 +1331,12 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): for agg, id in self.aggregations ) new_by_ids = tuple(id.remap_column_refs(mappings) for id in self.by_column_ids) - return replace(self, by_column_ids=new_by_ids, aggregations=new_aggs) + return dataclasses.replace( + self, by_column_ids=new_by_ids, aggregations=new_aggs + ) -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class WindowOpNode(UnaryNode): column_name: ex.DerefOp op: agg_ops.UnaryWindowOp @@ -1312,6 +1345,10 @@ class WindowOpNode(UnaryNode): never_skip_nulls: bool = False skip_reproject_unsafe: bool = False + def _validate(self): + """Validate the local data in the node.""" + assert self.column_name.id in self.child.ids + @property def non_local(self) -> bool: return True @@ -1324,6 +1361,14 @@ def fields(self) -> Iterable[Field]: def variables_introduced(self) -> int: return 1 + @property + def projection_base(self) -> BigFrameNode: + return self.child.projection_base + + @property + def added_fields(self) -> Tuple[Field, ...]: + return (self.added_field,) + @property def relation_ops_created(self) -> int: # Assume that if not reprojecting, that there is a sequence of window operations sharing the same window @@ -1356,12 +1401,12 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] ) -> BigFrameNode: - return replace( + return dataclasses.replace( self, output_name=mappings.get(self.output_name, self.output_name) ) def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): - return replace( + return dataclasses.replace( self, column_name=self.column_name.remap_column_refs( mappings, allow_partial_bindings=True @@ -1372,7 +1417,7 @@ def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): ) -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class RandomSampleNode(UnaryNode): fraction: float @@ -1408,7 +1453,7 @@ def remap_refs( # TODO: Explode should create a new column instead of overriding the existing one -@dataclass(frozen=True, eq=False) +@dataclasses.dataclass(frozen=True, eq=False) class ExplodeNode(UnaryNode): column_ids: typing.Tuple[ex.DerefOp, ...] @@ -1460,4 +1505,4 @@ def remap_refs( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] ) -> BigFrameNode: new_ids = tuple(id.remap_column_refs(mappings) for id in self.column_ids) - return replace(self, column_ids=new_ids) # type: ignore + return dataclasses.replace(self, column_ids=new_ids) # type: ignore diff --git a/bigframes/core/rewrite/__init__.py b/bigframes/core/rewrite/__init__.py new file mode 100644 index 0000000000..f7ee3c87c2 --- /dev/null +++ b/bigframes/core/rewrite/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from bigframes.core.rewrite.identifiers import remap_variables +from bigframes.core.rewrite.implicit_align import try_join_as_projection +from bigframes.core.rewrite.legacy_align import legacy_join_as_projection +from bigframes.core.rewrite.slices import pullup_limit_from_slice, replace_slice_ops + +__all__ = [ + "legacy_join_as_projection", + "try_join_as_projection", + "replace_slice_ops", + "pullup_limit_from_slice", + "remap_variables", +] diff --git a/bigframes/core/rewrite/identifiers.py b/bigframes/core/rewrite/identifiers.py new file mode 100644 index 0000000000..d49e5c1b42 --- /dev/null +++ b/bigframes/core/rewrite/identifiers.py @@ -0,0 +1,59 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Generator, Tuple + +import bigframes.core.identifiers +import bigframes.core.nodes + + +# TODO: May as well just outright remove selection nodes in this process. +def remap_variables( + root: bigframes.core.nodes.BigFrameNode, + id_generator: Generator[bigframes.core.identifiers.ColumnId, None, None], +) -> Tuple[ + bigframes.core.nodes.BigFrameNode, + dict[bigframes.core.identifiers.ColumnId, bigframes.core.identifiers.ColumnId], +]: + """ + Remap all variables in the BFET using the id_generator. + + Note: this will convert a DAG to a tree. + """ + child_replacement_map = dict() + ref_mapping = dict() + # Sequential ids are assigned bottom-up left-to-right + for child in root.child_nodes: + new_child, child_var_mapping = remap_variables(child, id_generator=id_generator) + child_replacement_map[child] = new_child + ref_mapping.update(child_var_mapping) + + # This is actually invalid until we've replaced all of children, refs and var defs + with_new_children = root.transform_children( + lambda node: child_replacement_map[node] + ) + + with_new_refs = with_new_children.remap_refs(ref_mapping) + + node_var_mapping = {old_id: next(id_generator) for old_id in root.node_defined_ids} + with_new_vars = with_new_refs.remap_vars(node_var_mapping) + with_new_vars._validate() + + return ( + with_new_vars, + node_var_mapping + if root.defines_namespace + else (ref_mapping | node_var_mapping), + ) diff --git a/bigframes/core/rewrite/implicit_align.py b/bigframes/core/rewrite/implicit_align.py new file mode 100644 index 0000000000..1d7fed09d2 --- /dev/null +++ b/bigframes/core/rewrite/implicit_align.py @@ -0,0 +1,190 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +from typing import Optional, Tuple + +import bigframes.core.expression +import bigframes.core.guid +import bigframes.core.identifiers +import bigframes.core.join_def +import bigframes.core.nodes +import bigframes.core.window_spec +import bigframes.operations.aggregations + +# Additive nodes leave existing columns completely intact, and only add new columns to the end +ADDITIVE_NODES = ( + bigframes.core.nodes.ProjectionNode, + bigframes.core.nodes.WindowOpNode, + bigframes.core.nodes.PromoteOffsetsNode, +) + + +@dataclasses.dataclass(frozen=True) +class ExpressionSpec: + expression: bigframes.core.expression.Expression + node: bigframes.core.nodes.BigFrameNode + + +def get_expression_spec( + node: bigframes.core.nodes.BigFrameNode, id: bigframes.core.identifiers.ColumnId +) -> ExpressionSpec: + """Normalizes column value by chaining expressions across multiple selection and projection nodes if possible. + This normalization helps identify whether columns are equivalent. + """ + # TODO: While we chain expression fragments from different nodes + # we could further normalize with constant folding and other scalar expression rewrites + expression: bigframes.core.expression.Expression = ( + bigframes.core.expression.DerefOp(id) + ) + curr_node = node + while True: + if isinstance(curr_node, bigframes.core.nodes.SelectionNode): + select_mappings = { + col_id: ref for ref, col_id in curr_node.input_output_pairs + } + expression = expression.bind_refs( + select_mappings, allow_partial_bindings=True + ) + elif isinstance(curr_node, bigframes.core.nodes.ProjectionNode): + proj_mappings = {col_id: expr for expr, col_id in curr_node.assignments} + expression = expression.bind_refs( + proj_mappings, allow_partial_bindings=True + ) + elif isinstance( + curr_node, + ( + bigframes.core.nodes.WindowOpNode, + bigframes.core.nodes.PromoteOffsetsNode, + ), + ): + # we don't yet have a way of normalizing window ops into a ExpressionSpec, which only + # handles normalizing scalar expressions at the moment. + pass + else: + return ExpressionSpec(expression, curr_node) + curr_node = curr_node.child + + +def _linearize_trees( + base_tree: bigframes.core.nodes.BigFrameNode, + append_tree: bigframes.core.nodes.BigFrameNode, +) -> bigframes.core.nodes.BigFrameNode: + """Linearize two divergent tree who only diverge through different additive nodes.""" + assert append_tree.projection_base == base_tree.projection_base + # base case: append tree does not have any additive nodes to linearize + if append_tree == append_tree.projection_base: + return base_tree + else: + assert isinstance(append_tree, ADDITIVE_NODES) + return append_tree.replace_child(_linearize_trees(base_tree, append_tree.child)) + + +def combine_nodes( + l_node: bigframes.core.nodes.BigFrameNode, + r_node: bigframes.core.nodes.BigFrameNode, +) -> bigframes.core.nodes.BigFrameNode: + assert l_node.projection_base == r_node.projection_base + l_node, l_selection = pull_up_selection(l_node) + r_node, r_selection = pull_up_selection( + r_node, rename_vars=True + ) # Rename only right vars to avoid collisions with left vars + combined_selection = (*l_selection, *r_selection) + merged_node = _linearize_trees(l_node, r_node) + return bigframes.core.nodes.SelectionNode(merged_node, combined_selection) + + +def try_join_as_projection( + l_node: bigframes.core.nodes.BigFrameNode, + r_node: bigframes.core.nodes.BigFrameNode, + join_keys: Tuple[Tuple[str, str], ...], +) -> Optional[bigframes.core.nodes.BigFrameNode]: + """Joins the two nodes""" + if l_node.projection_base != r_node.projection_base: + return None + # check join keys are equivalent by normalizing the expressions as much as posisble + # instead of just comparing ids + for l_key, r_key in join_keys: + # Caller is block, so they still work with raw strings rather than ids + left_id = bigframes.core.identifiers.ColumnId(l_key) + right_id = bigframes.core.identifiers.ColumnId(r_key) + if get_expression_spec(l_node, left_id) != get_expression_spec( + r_node, right_id + ): + return None + return combine_nodes(l_node, r_node) + + +def pull_up_selection( + node: bigframes.core.nodes.BigFrameNode, rename_vars: bool = False +) -> Tuple[ + bigframes.core.nodes.BigFrameNode, + Tuple[ + Tuple[bigframes.core.expression.DerefOp, bigframes.core.identifiers.ColumnId], + ..., + ], +]: + """Remove all selection nodes above the base node. Returns stripped tree. + + Args: + node (BigFrameNode): + The node from which to pull up SelectionNode ops + rename_vars (bool): + If true, will rename projected columns to new unique ids. + + Returns: + BigFrameNode, Selections + """ + if node == node.projection_base: # base case + return node, tuple( + (bigframes.core.expression.DerefOp(field.id), field.id) + for field in node.fields + ) + assert isinstance(node, (bigframes.core.nodes.SelectionNode, *ADDITIVE_NODES)) + child_node, child_selections = pull_up_selection( + node.child, rename_vars=rename_vars + ) + mapping = {out: ref.id for ref, out in child_selections} + if isinstance(node, ADDITIVE_NODES): + new_node: bigframes.core.nodes.BigFrameNode = node.replace_child(child_node) + new_node = new_node.remap_refs(mapping) + if rename_vars: + var_renames = { + field.id: bigframes.core.identifiers.ColumnId.unique() + for field in node.added_fields + } + new_node = new_node.remap_vars(var_renames) + else: + var_renames = {} + assert isinstance(new_node, ADDITIVE_NODES) + added_selections = ( + ( + bigframes.core.expression.DerefOp(var_renames.get(field.id, field.id)), + field.id, + ) + for field in node.added_fields + ) + new_selection = (*child_selections, *added_selections) + return new_node, new_selection + elif isinstance(node, bigframes.core.nodes.SelectionNode): + new_selection = tuple( + ( + bigframes.core.expression.DerefOp(mapping[ref.id]), + out, + ) + for ref, out in node.input_output_pairs + ) + return child_node, new_selection + raise ValueError(f"Couldn't pull up select from node: {node}") diff --git a/bigframes/core/rewrite.py b/bigframes/core/rewrite/legacy_align.py similarity index 61% rename from bigframes/core/rewrite.py rename to bigframes/core/rewrite/legacy_align.py index 8187b16d87..77ae9b3bb4 100644 --- a/bigframes/core/rewrite.py +++ b/bigframes/core/rewrite/legacy_align.py @@ -16,15 +16,13 @@ import dataclasses import functools import itertools -from typing import cast, Generator, Mapping, Optional, Sequence, Tuple +from typing import Mapping, Optional, Sequence, Tuple import bigframes.core.expression as scalar_exprs -import bigframes.core.guid as guids import bigframes.core.identifiers as ids import bigframes.core.join_def as join_defs import bigframes.core.nodes as nodes import bigframes.core.ordering as order -import bigframes.core.slices as slices import bigframes.operations as ops Selection = Tuple[Tuple[scalar_exprs.Expression, ids.ColumnId], ...] @@ -238,7 +236,7 @@ def expand(self) -> nodes.BigFrameNode: ) -def join_as_projection( +def legacy_join_as_projection( l_node: nodes.BigFrameNode, r_node: nodes.BigFrameNode, join_keys: Tuple[join_defs.CoalescedColumnMapping, ...], @@ -383,234 +381,3 @@ def common_selection_root( if r_node in l_nodes: return r_node return None - - -def pullup_limit_from_slice( - root: nodes.BigFrameNode, -) -> Tuple[nodes.BigFrameNode, Optional[int]]: - """ - This is a BQ-sql specific optimization that can be helpful as ORDER BY LIMIT is more efficient than WHERE + ROW_NUMBER(). - - Only use this if writing to an unclustered table. Clustering is not compatible with ORDER BY. - """ - if isinstance(root, nodes.SliceNode): - # head case - # More cases could be handled, but this is by far the most important, as it is used by df.head(), df[:N] - if root.is_limit: - assert not root.start - assert root.step == 1 - assert root.stop is not None - limit = root.stop - new_root, prior_limit = pullup_limit_from_slice(root.child) - if (prior_limit is not None) and (prior_limit < limit): - limit = prior_limit - return new_root, limit - elif ( - isinstance(root, (nodes.SelectionNode, nodes.ProjectionNode)) - and root.row_preserving - ): - new_child, prior_limit = pullup_limit_from_slice(root.child) - if prior_limit is not None: - return root.transform_children(lambda _: new_child), prior_limit - # Most ops don't support pulling up slice, like filter, agg, join, etc. - return root, None - - -def replace_slice_ops(root: nodes.BigFrameNode) -> nodes.BigFrameNode: - # TODO: we want to pull up some slices into limit op if near root. - if isinstance(root, nodes.SliceNode): - root = root.transform_children(replace_slice_ops) - return rewrite_slice(cast(nodes.SliceNode, root)) - else: - return root.transform_children(replace_slice_ops) - - -def rewrite_slice(node: nodes.SliceNode): - slice_def = (node.start, node.stop, node.step) - - # no-op (eg. df[::1]) - if slices.is_noop(slice_def, node.child.row_count): - return node.child - - # No filtering, just reverse (eg. df[::-1]) - if slices.is_reverse(slice_def, node.child.row_count): - return nodes.ReversedNode(node.child) - - if node.child.row_count: - slice_def = slices.to_forward_offsets(slice_def, node.child.row_count) - return slice_as_filter(node.child, *slice_def) - - -def slice_as_filter( - node: nodes.BigFrameNode, start: Optional[int], stop: Optional[int], step: int -) -> nodes.BigFrameNode: - if ( - ((start is None) or (start >= 0)) - and ((stop is None) or (stop >= 0)) - and (step > 0) - ): - node_w_offset = add_offsets(node) - predicate = convert_simple_slice( - scalar_exprs.DerefOp(node_w_offset.col_id), start or 0, stop, step - ) - filtered = nodes.FilterNode(node_w_offset, predicate) - return drop_cols(filtered, (node_w_offset.col_id,)) - - # fallback cases, generate both forward and backward offsets - if step < 0: - forward_offsets = add_offsets(node) - reversed_offsets = add_offsets(nodes.ReversedNode(forward_offsets)) - dual_indexed = reversed_offsets - else: - reversed_offsets = add_offsets(nodes.ReversedNode(node)) - forward_offsets = add_offsets(nodes.ReversedNode(reversed_offsets)) - dual_indexed = forward_offsets - default_start = 0 if step >= 0 else -1 - predicate = convert_complex_slice( - scalar_exprs.DerefOp(forward_offsets.col_id), - scalar_exprs.DerefOp(reversed_offsets.col_id), - start if (start is not None) else default_start, - stop, - step, - ) - filtered = nodes.FilterNode(dual_indexed, predicate) - return drop_cols(filtered, (forward_offsets.col_id, reversed_offsets.col_id)) - - -def add_offsets(node: nodes.BigFrameNode) -> nodes.PromoteOffsetsNode: - # Allow providing custom id generator? - offsets_id = ids.ColumnId(guids.generate_guid()) - return nodes.PromoteOffsetsNode(node, offsets_id) - - -def drop_cols( - node: nodes.BigFrameNode, drop_cols: Tuple[ids.ColumnId, ...] -) -> nodes.SelectionNode: - # adding a whole node that redefines the schema is a lot of overhead, should do something more efficient - selections = tuple( - (scalar_exprs.DerefOp(id), id) for id in node.ids if id not in drop_cols - ) - return nodes.SelectionNode(node, selections) - - -def convert_simple_slice( - offsets: scalar_exprs.Expression, - start: int = 0, - stop: Optional[int] = None, - step: int = 1, -) -> scalar_exprs.Expression: - """Performs slice but only for positive step size.""" - assert start >= 0 - assert (stop is None) or (stop >= 0) - - conditions = [] - if start > 0: - conditions.append(ops.ge_op.as_expr(offsets, scalar_exprs.const(start))) - if (stop is not None) and (stop >= 0): - conditions.append(ops.lt_op.as_expr(offsets, scalar_exprs.const(stop))) - if step > 1: - start_diff = ops.sub_op.as_expr(offsets, scalar_exprs.const(start)) - step_cond = ops.eq_op.as_expr( - ops.mod_op.as_expr(start_diff, scalar_exprs.const(step)), - scalar_exprs.const(0), - ) - conditions.append(step_cond) - - return merge_predicates(conditions) or scalar_exprs.const(True) - - -def convert_complex_slice( - forward_offsets: scalar_exprs.Expression, - reverse_offsets: scalar_exprs.Expression, - start: int, - stop: Optional[int], - step: int = 1, -) -> scalar_exprs.Expression: - conditions = [] - assert step != 0 - if start or ((start is not None) and step < 0): - if start > 0 and step > 0: - start_cond = ops.ge_op.as_expr(forward_offsets, scalar_exprs.const(start)) - elif start >= 0 and step < 0: - start_cond = ops.le_op.as_expr(forward_offsets, scalar_exprs.const(start)) - elif start < 0 and step > 0: - start_cond = ops.le_op.as_expr( - reverse_offsets, scalar_exprs.const(-start - 1) - ) - else: - assert start < 0 and step < 0 - start_cond = ops.ge_op.as_expr( - reverse_offsets, scalar_exprs.const(-start - 1) - ) - conditions.append(start_cond) - if stop is not None: - if stop >= 0 and step > 0: - stop_cond = ops.lt_op.as_expr(forward_offsets, scalar_exprs.const(stop)) - elif stop >= 0 and step < 0: - stop_cond = ops.gt_op.as_expr(forward_offsets, scalar_exprs.const(stop)) - elif stop < 0 and step > 0: - stop_cond = ops.gt_op.as_expr( - reverse_offsets, scalar_exprs.const(-stop - 1) - ) - else: - assert (stop < 0) and (step < 0) - stop_cond = ops.lt_op.as_expr( - reverse_offsets, scalar_exprs.const(-stop - 1) - ) - conditions.append(stop_cond) - if step != 1: - if step > 1 and start >= 0: - start_diff = ops.sub_op.as_expr(forward_offsets, scalar_exprs.const(start)) - elif step > 1 and start < 0: - start_diff = ops.sub_op.as_expr( - reverse_offsets, scalar_exprs.const(-start + 1) - ) - elif step < 0 and start >= 0: - start_diff = ops.add_op.as_expr(forward_offsets, scalar_exprs.const(start)) - else: - assert step < 0 and start < 0 - start_diff = ops.add_op.as_expr( - reverse_offsets, scalar_exprs.const(-start + 1) - ) - step_cond = ops.eq_op.as_expr( - ops.mod_op.as_expr(start_diff, scalar_exprs.const(step)), - scalar_exprs.const(0), - ) - conditions.append(step_cond) - return merge_predicates(conditions) or scalar_exprs.const(True) - - -# TODO: May as well just outright remove selection nodes in this process. -def remap_variables( - root: nodes.BigFrameNode, id_generator: Generator[ids.ColumnId, None, None] -) -> Tuple[nodes.BigFrameNode, dict[ids.ColumnId, ids.ColumnId]]: - """ - Remap all variables in the BFET using the id_generator. - - Note: this will convert a DAG to a tree. - """ - child_replacement_map = dict() - ref_mapping = dict() - # Sequential ids are assigned bottom-up left-to-right - for child in root.child_nodes: - new_child, child_var_mapping = remap_variables(child, id_generator=id_generator) - child_replacement_map[child] = new_child - ref_mapping.update(child_var_mapping) - - # This is actually invalid until we've replaced all of children, refs and var defs - with_new_children = root.transform_children( - lambda node: child_replacement_map[node] - ) - - with_new_refs = with_new_children.remap_refs(ref_mapping) - - node_var_mapping = {old_id: next(id_generator) for old_id in root.node_defined_ids} - with_new_vars = with_new_refs.remap_vars(node_var_mapping) - with_new_vars._validate() - - return ( - with_new_vars, - node_var_mapping - if root.defines_namespace - else (ref_mapping | node_var_mapping), - ) diff --git a/bigframes/core/rewrite/slices.py b/bigframes/core/rewrite/slices.py new file mode 100644 index 0000000000..906d635e93 --- /dev/null +++ b/bigframes/core/rewrite/slices.py @@ -0,0 +1,228 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import functools +from typing import cast, Optional, Sequence, Tuple + +import bigframes.core.expression as scalar_exprs +import bigframes.core.guid as guids +import bigframes.core.identifiers as ids +import bigframes.core.nodes as nodes +import bigframes.core.slices as slices +import bigframes.operations as ops + + +def pullup_limit_from_slice( + root: nodes.BigFrameNode, +) -> Tuple[nodes.BigFrameNode, Optional[int]]: + """ + This is a BQ-sql specific optimization that can be helpful as ORDER BY LIMIT is more efficient than WHERE + ROW_NUMBER(). + + Only use this if writing to an unclustered table. Clustering is not compatible with ORDER BY. + """ + if isinstance(root, nodes.SliceNode): + # head case + # More cases could be handled, but this is by far the most important, as it is used by df.head(), df[:N] + if root.is_limit: + assert not root.start + assert root.step == 1 + assert root.stop is not None + limit = root.stop + new_root, prior_limit = pullup_limit_from_slice(root.child) + if (prior_limit is not None) and (prior_limit < limit): + limit = prior_limit + return new_root, limit + elif ( + isinstance(root, (nodes.SelectionNode, nodes.ProjectionNode)) + and root.row_preserving + ): + new_child, prior_limit = pullup_limit_from_slice(root.child) + if prior_limit is not None: + return root.transform_children(lambda _: new_child), prior_limit + # Most ops don't support pulling up slice, like filter, agg, join, etc. + return root, None + + +def replace_slice_ops(root: nodes.BigFrameNode) -> nodes.BigFrameNode: + # TODO: we want to pull up some slices into limit op if near root. + if isinstance(root, nodes.SliceNode): + root = root.transform_children(replace_slice_ops) + return rewrite_slice(cast(nodes.SliceNode, root)) + else: + return root.transform_children(replace_slice_ops) + + +def rewrite_slice(node: nodes.SliceNode): + slice_def = (node.start, node.stop, node.step) + + # no-op (eg. df[::1]) + if slices.is_noop(slice_def, node.child.row_count): + return node.child + + # No filtering, just reverse (eg. df[::-1]) + if slices.is_reverse(slice_def, node.child.row_count): + return nodes.ReversedNode(node.child) + + if node.child.row_count: + slice_def = slices.to_forward_offsets(slice_def, node.child.row_count) + return slice_as_filter(node.child, *slice_def) + + +def slice_as_filter( + node: nodes.BigFrameNode, start: Optional[int], stop: Optional[int], step: int +) -> nodes.BigFrameNode: + if ( + ((start is None) or (start >= 0)) + and ((stop is None) or (stop >= 0)) + and (step > 0) + ): + node_w_offset = add_offsets(node) + predicate = convert_simple_slice( + scalar_exprs.DerefOp(node_w_offset.col_id), start or 0, stop, step + ) + filtered = nodes.FilterNode(node_w_offset, predicate) + return drop_cols(filtered, (node_w_offset.col_id,)) + + # fallback cases, generate both forward and backward offsets + if step < 0: + forward_offsets = add_offsets(node) + reversed_offsets = add_offsets(nodes.ReversedNode(forward_offsets)) + dual_indexed = reversed_offsets + else: + reversed_offsets = add_offsets(nodes.ReversedNode(node)) + forward_offsets = add_offsets(nodes.ReversedNode(reversed_offsets)) + dual_indexed = forward_offsets + default_start = 0 if step >= 0 else -1 + predicate = convert_complex_slice( + scalar_exprs.DerefOp(forward_offsets.col_id), + scalar_exprs.DerefOp(reversed_offsets.col_id), + start if (start is not None) else default_start, + stop, + step, + ) + filtered = nodes.FilterNode(dual_indexed, predicate) + return drop_cols(filtered, (forward_offsets.col_id, reversed_offsets.col_id)) + + +def add_offsets(node: nodes.BigFrameNode) -> nodes.PromoteOffsetsNode: + # Allow providing custom id generator? + offsets_id = ids.ColumnId(guids.generate_guid()) + return nodes.PromoteOffsetsNode(node, offsets_id) + + +def drop_cols( + node: nodes.BigFrameNode, drop_cols: Tuple[ids.ColumnId, ...] +) -> nodes.SelectionNode: + # adding a whole node that redefines the schema is a lot of overhead, should do something more efficient + selections = tuple( + (scalar_exprs.DerefOp(id), id) for id in node.ids if id not in drop_cols + ) + return nodes.SelectionNode(node, selections) + + +def convert_simple_slice( + offsets: scalar_exprs.Expression, + start: int = 0, + stop: Optional[int] = None, + step: int = 1, +) -> scalar_exprs.Expression: + """Performs slice but only for positive step size.""" + assert start >= 0 + assert (stop is None) or (stop >= 0) + + conditions = [] + if start > 0: + conditions.append(ops.ge_op.as_expr(offsets, scalar_exprs.const(start))) + if (stop is not None) and (stop >= 0): + conditions.append(ops.lt_op.as_expr(offsets, scalar_exprs.const(stop))) + if step > 1: + start_diff = ops.sub_op.as_expr(offsets, scalar_exprs.const(start)) + step_cond = ops.eq_op.as_expr( + ops.mod_op.as_expr(start_diff, scalar_exprs.const(step)), + scalar_exprs.const(0), + ) + conditions.append(step_cond) + + return merge_predicates(conditions) or scalar_exprs.const(True) + + +def convert_complex_slice( + forward_offsets: scalar_exprs.Expression, + reverse_offsets: scalar_exprs.Expression, + start: int, + stop: Optional[int], + step: int = 1, +) -> scalar_exprs.Expression: + conditions = [] + assert step != 0 + if start or ((start is not None) and step < 0): + if start > 0 and step > 0: + start_cond = ops.ge_op.as_expr(forward_offsets, scalar_exprs.const(start)) + elif start >= 0 and step < 0: + start_cond = ops.le_op.as_expr(forward_offsets, scalar_exprs.const(start)) + elif start < 0 and step > 0: + start_cond = ops.le_op.as_expr( + reverse_offsets, scalar_exprs.const(-start - 1) + ) + else: + assert start < 0 and step < 0 + start_cond = ops.ge_op.as_expr( + reverse_offsets, scalar_exprs.const(-start - 1) + ) + conditions.append(start_cond) + if stop is not None: + if stop >= 0 and step > 0: + stop_cond = ops.lt_op.as_expr(forward_offsets, scalar_exprs.const(stop)) + elif stop >= 0 and step < 0: + stop_cond = ops.gt_op.as_expr(forward_offsets, scalar_exprs.const(stop)) + elif stop < 0 and step > 0: + stop_cond = ops.gt_op.as_expr( + reverse_offsets, scalar_exprs.const(-stop - 1) + ) + else: + assert (stop < 0) and (step < 0) + stop_cond = ops.lt_op.as_expr( + reverse_offsets, scalar_exprs.const(-stop - 1) + ) + conditions.append(stop_cond) + if step != 1: + if step > 1 and start >= 0: + start_diff = ops.sub_op.as_expr(forward_offsets, scalar_exprs.const(start)) + elif step > 1 and start < 0: + start_diff = ops.sub_op.as_expr( + reverse_offsets, scalar_exprs.const(-start + 1) + ) + elif step < 0 and start >= 0: + start_diff = ops.add_op.as_expr(forward_offsets, scalar_exprs.const(start)) + else: + assert step < 0 and start < 0 + start_diff = ops.add_op.as_expr( + reverse_offsets, scalar_exprs.const(-start + 1) + ) + step_cond = ops.eq_op.as_expr( + ops.mod_op.as_expr(start_diff, scalar_exprs.const(step)), + scalar_exprs.const(0), + ) + conditions.append(step_cond) + return merge_predicates(conditions) or scalar_exprs.const(True) + + +def merge_predicates( + predicates: Sequence[scalar_exprs.Expression], +) -> Optional[scalar_exprs.Expression]: + if len(predicates) == 0: + return None + + return functools.reduce(ops.and_op.as_expr, predicates) diff --git a/tests/system/small/test_null_index.py b/tests/system/small/test_null_index.py index da9baa0069..f70e16447a 100644 --- a/tests/system/small/test_null_index.py +++ b/tests/system/small/test_null_index.py @@ -364,7 +364,9 @@ def test_null_index_align_error(scalars_df_null_index): with pytest.raises(bigframes.exceptions.NullIndexError): _ = ( scalars_df_null_index["int64_col"] - + scalars_df_null_index["int64_col"].cumsum() + + scalars_df_null_index["int64_col"].cumsum()[ + scalars_df_null_index["int64_col"] > 3 + ] ) diff --git a/tests/system/small/test_series.py b/tests/system/small/test_series.py index 0cc8cd4cbe..218708a19d 100644 --- a/tests/system/small/test_series.py +++ b/tests/system/small/test_series.py @@ -2323,6 +2323,34 @@ def test_cumsum_nested(scalars_df_index, scalars_pandas_df_index): ) +@skip_legacy_pandas +def test_nested_analytic_ops_align(scalars_df_index, scalars_pandas_df_index): + col_name = "float64_col" + # set non-unique index to check implicit alignment + bf_series = scalars_df_index.set_index("bool_col")[col_name].fillna(0.0) + pd_series = scalars_pandas_df_index.set_index("bool_col")[col_name].fillna(0.0) + + bf_result = ( + (bf_series + 5) + + (bf_series.cumsum().cumsum().cumsum() + bf_series.rolling(window=3).mean()) + + bf_series.expanding().max() + ).to_pandas() + # cumsum does not behave well on nullable ints in pandas, produces object type and never ignores NA + pd_result = ( + (pd_series + 5) + + ( + pd_series.cumsum().cumsum().cumsum().astype(pd.Float64Dtype()) + + pd_series.rolling(window=3).mean() + ) + + pd_series.expanding().max() + ) + + pd.testing.assert_series_equal( + bf_result, + pd_result, + ) + + def test_cumsum_int_filtered(scalars_df_index, scalars_pandas_df_index): col_name = "int64_col" diff --git a/tests/unit/core/test_rewrite.py b/tests/unit/core/test_rewrite.py index 0965238fcd..1f1a2c3db9 100644 --- a/tests/unit/core/test_rewrite.py +++ b/tests/unit/core/test_rewrite.py @@ -17,7 +17,7 @@ import bigframes.core as core import bigframes.core.nodes as nodes -import bigframes.core.rewrite as rewrites +import bigframes.core.rewrite.slices import bigframes.core.schema TABLE_REF = google.cloud.bigquery.TableReference.from_string("project.dataset.table") @@ -40,18 +40,18 @@ def test_rewrite_noop_slice(): slice = nodes.SliceNode(LEAF, None, None) - result = rewrites.rewrite_slice(slice) + result = bigframes.core.rewrite.slices.rewrite_slice(slice) assert result == LEAF def test_rewrite_reverse_slice(): slice = nodes.SliceNode(LEAF, None, None, -1) - result = rewrites.rewrite_slice(slice) + result = bigframes.core.rewrite.slices.rewrite_slice(slice) assert result == nodes.ReversedNode(LEAF) def test_rewrite_filter_slice(): slice = nodes.SliceNode(LEAF, None, 2) - result = rewrites.rewrite_slice(slice) + result = bigframes.core.rewrite.slices.rewrite_slice(slice) assert list(result.fields) == list(LEAF.fields) assert isinstance(result.child, nodes.FilterNode) diff --git a/tests/unit/test_planner.py b/tests/unit/test_planner.py index 491f3de6fa..c64b50395b 100644 --- a/tests/unit/test_planner.py +++ b/tests/unit/test_planner.py @@ -103,13 +103,10 @@ def test_session_aware_caching_fork_after_window_op(): Windowing is expensive, so caching should always compute the window function, in order to avoid later recomputation. """ - other = LEAF.promote_offsets()[0].create_constant(5, pd.Int64Dtype())[0] - target = ( - LEAF.promote_offsets()[0] - .create_constant(4, pd.Int64Dtype())[0] - .filter( - ops.eq_op.as_expr("col_a", ops.add_op.as_expr(ex.const(4), ex.const(3))) - ) + leaf_with_offsets = LEAF.promote_offsets()[0] + other = leaf_with_offsets.create_constant(5, pd.Int64Dtype())[0] + target = leaf_with_offsets.create_constant(4, pd.Int64Dtype())[0].filter( + ops.eq_op.as_expr("col_a", ops.add_op.as_expr(ex.const(4), ex.const(3))) ) result, cluster_cols = planner.session_aware_cache_plan( target.node, @@ -117,5 +114,5 @@ def test_session_aware_caching_fork_after_window_op(): other.node, ], ) - assert result == LEAF.promote_offsets()[0].node + assert result == leaf_with_offsets.node assert cluster_cols == [ids.ColumnId("col_a")]