Skip to content

Commit

Permalink
refactor: Make expression nodes prunable (#1030)
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorBergeron authored Oct 2, 2024
1 parent 057f3f0 commit 4105dba
Show file tree
Hide file tree
Showing 9 changed files with 358 additions and 257 deletions.
91 changes: 66 additions & 25 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,20 @@ def from_pyarrow(cls, arrow_table: pa.Table, session: Session):

iobytes = io.BytesIO()
pa_feather.write_feather(adapted_table, iobytes)
# Scan all columns by default, we define this list as it can be pruned while preserving source_def
scan_list = nodes.ScanList(
tuple(
nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column)
for item in schema.items
)
)

node = nodes.ReadLocalNode(
iobytes.getvalue(),
data_schema=schema,
session=session,
n_rows=arrow_table.num_rows,
scan_list=scan_list,
)
return cls(node)

Expand Down Expand Up @@ -104,14 +113,30 @@ def from_table(
"Interpreting JSON column(s) as StringDtype. This behavior may change in future versions.",
bigframes.exceptions.PreviewWarning,
)
# define data source only for needed columns, this makes row-hashing cheaper
table_def = nodes.GbqTable.from_table(table, columns=schema.names)

# create ordering from info
ordering = None
if offsets_col:
ordering = orderings.TotalOrdering.from_offset_col(offsets_col)
elif primary_key:
ordering = orderings.TotalOrdering.from_primary_key(primary_key)

# Scan all columns by default, we define this list as it can be pruned while preserving source_def
scan_list = nodes.ScanList(
tuple(
nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column)
for item in schema.items
)
)
source_def = nodes.BigqueryDataSource(
table=table_def, at_time=at_time, sql_predicate=predicate, ordering=ordering
)
node = nodes.ReadTableNode(
table=nodes.GbqTable.from_table(table),
total_order_cols=(offsets_col,) if offsets_col else tuple(primary_key),
order_col_is_sequential=(offsets_col is not None),
columns=schema,
at_time=at_time,
source=source_def,
scan_list=scan_list,
table_session=session,
sql_predicate=predicate,
)
return cls(node)

Expand Down Expand Up @@ -157,12 +182,22 @@ def as_cached(
ordering: Optional[orderings.RowOrdering],
) -> ArrayValue:
"""
Replace the node with an equivalent one that references a tabel where the value has been materialized to.
Replace the node with an equivalent one that references a table where the value has been materialized to.
"""
table = nodes.GbqTable.from_table(cache_table)
source = nodes.BigqueryDataSource(table, ordering=ordering)
# Assumption: GBQ cached table uses field name as bq column name
scan_list = nodes.ScanList(
tuple(
nodes.ScanItem(field.id, field.dtype, field.id.name)
for field in self.node.fields
)
)
node = nodes.CachedTableNode(
original_node=self.node,
table=nodes.GbqTable.from_table(cache_table),
ordering=ordering,
source=source,
table_session=self.session,
scan_list=scan_list,
)
return ArrayValue(node)

Expand Down Expand Up @@ -379,28 +414,34 @@ def relational_join(
conditions: typing.Tuple[typing.Tuple[str, str], ...] = (),
type: typing.Literal["inner", "outer", "left", "right", "cross"] = "inner",
) -> typing.Tuple[ArrayValue, typing.Tuple[dict[str, str], dict[str, str]]]:
l_mapping = { # Identity mapping, only rename right side
lcol.name: lcol.name for lcol in self.node.ids
}
r_mapping = { # Rename conflicting names
rcol.name: rcol.name
if (rcol.name not in l_mapping)
else bigframes.core.guid.generate_guid()
for rcol in other.node.ids
}
other_node = other.node
if set(other_node.ids) & set(self.node.ids):
other_node = nodes.SelectionNode(
other_node,
tuple(
(ex.deref(old_id), ids.ColumnId(new_id))
for old_id, new_id in r_mapping.items()
),
)

join_node = nodes.JoinNode(
left_child=self.node,
right_child=other.node,
right_child=other_node,
conditions=tuple(
(ex.deref(l_col), ex.deref(r_col)) for l_col, r_col in conditions
(ex.deref(l_mapping[l_col]), ex.deref(r_mapping[r_col]))
for l_col, r_col in conditions
),
type=type,
)
# Maps input ids to output ids for caller convenience
l_size = len(self.node.schema)
l_mapping = {
lcol: ocol
for lcol, ocol in zip(
self.node.schema.names, join_node.schema.names[:l_size]
)
}
r_mapping = {
rcol: ocol
for rcol, ocol in zip(
other.node.schema.names, join_node.schema.names[l_size:]
)
}
return ArrayValue(join_node), (l_mapping, r_mapping)

def try_align_as_projection(
Expand Down
14 changes: 7 additions & 7 deletions bigframes/core/compile/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
RowOrdering,
TotalOrdering,
)
import bigframes.core.schema as schemata
import bigframes.core.sql
from bigframes.core.window_spec import RangeWindowBounds, RowsWindowBounds, WindowSpec
import bigframes.dtypes
Expand Down Expand Up @@ -585,9 +584,7 @@ def has_total_order(self) -> bool:

@classmethod
def from_pandas(
cls,
pd_df: pandas.DataFrame,
schema: schemata.ArraySchema,
cls, pd_df: pandas.DataFrame, scan_cols: bigframes.core.nodes.ScanList
) -> OrderedIR:
"""
Builds an in-memory only (SQL only) expr from a pandas dataframe.
Expand All @@ -603,18 +600,21 @@ def from_pandas(
# derive the ibis schema from the original pandas schema
ibis_schema = [
(
name,
local_label,
bigframes.core.compile.ibis_types.bigframes_dtype_to_ibis_dtype(dtype),
)
for name, dtype in zip(schema.names, schema.dtypes)
for id, dtype, local_label in scan_cols.items
]
ibis_schema.append((ORDER_ID_COLUMN, ibis_dtypes.int64))

keys_memtable = ibis.memtable(ibis_values, schema=ibis.schema(ibis_schema))

return cls(
keys_memtable,
columns=[keys_memtable[column].name(column) for column in pd_df.columns],
columns=[
keys_memtable[local_label].name(col_id.sql)
for col_id, _, local_label in scan_cols.items
],
ordering=TotalOrdering.from_offset_col(ORDER_ID_COLUMN),
hidden_ordering_columns=(keys_memtable[ORDER_ID_COLUMN],),
)
Expand Down
Loading

0 comments on commit 4105dba

Please sign in to comment.