Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Make expression nodes prunable #1030

Merged
merged 5 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 66 additions & 25 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,20 @@ def from_pyarrow(cls, arrow_table: pa.Table, session: Session):

iobytes = io.BytesIO()
pa_feather.write_feather(adapted_table, iobytes)
# Scan all columns by default, we define this list as it can be pruned while preserving source_def
scan_list = nodes.ScanList(
tuple(
nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column)
for item in schema.items
)
)

node = nodes.ReadLocalNode(
iobytes.getvalue(),
data_schema=schema,
session=session,
n_rows=arrow_table.num_rows,
scan_list=scan_list,
)
return cls(node)

Expand Down Expand Up @@ -104,14 +113,30 @@ def from_table(
"Interpreting JSON column(s) as StringDtype. This behavior may change in future versions.",
bigframes.exceptions.PreviewWarning,
)
# define data source only for needed columns, this makes row-hashing cheaper
table_def = nodes.GbqTable.from_table(table, columns=schema.names)

# create ordering from info
ordering = None
if offsets_col:
ordering = orderings.TotalOrdering.from_offset_col(offsets_col)
elif primary_key:
ordering = orderings.TotalOrdering.from_primary_key(primary_key)

# Scan all columns by default, we define this list as it can be pruned while preserving source_def
scan_list = nodes.ScanList(
tuple(
nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column)
for item in schema.items
)
)
source_def = nodes.BigqueryDataSource(
table=table_def, at_time=at_time, sql_predicate=predicate, ordering=ordering
)
node = nodes.ReadTableNode(
table=nodes.GbqTable.from_table(table),
total_order_cols=(offsets_col,) if offsets_col else tuple(primary_key),
order_col_is_sequential=(offsets_col is not None),
columns=schema,
at_time=at_time,
source=source_def,
scan_list=scan_list,
table_session=session,
sql_predicate=predicate,
)
return cls(node)

Expand Down Expand Up @@ -157,12 +182,22 @@ def as_cached(
ordering: Optional[orderings.RowOrdering],
) -> ArrayValue:
"""
Replace the node with an equivalent one that references a tabel where the value has been materialized to.
Replace the node with an equivalent one that references a table where the value has been materialized to.
"""
table = nodes.GbqTable.from_table(cache_table)
source = nodes.BigqueryDataSource(table, ordering=ordering)
# Assumption: GBQ cached table uses field name as bq column name
scan_list = nodes.ScanList(
tuple(
nodes.ScanItem(field.id, field.dtype, field.id.name)
for field in self.node.fields
)
)
node = nodes.CachedTableNode(
original_node=self.node,
table=nodes.GbqTable.from_table(cache_table),
ordering=ordering,
source=source,
table_session=self.session,
scan_list=scan_list,
)
return ArrayValue(node)

Expand Down Expand Up @@ -379,28 +414,34 @@ def relational_join(
conditions: typing.Tuple[typing.Tuple[str, str], ...] = (),
type: typing.Literal["inner", "outer", "left", "right", "cross"] = "inner",
) -> typing.Tuple[ArrayValue, typing.Tuple[dict[str, str], dict[str, str]]]:
l_mapping = { # Identity mapping, only rename right side
lcol.name: lcol.name for lcol in self.node.ids
}
r_mapping = { # Rename conflicting names
rcol.name: rcol.name
if (rcol.name not in l_mapping)
else bigframes.core.guid.generate_guid()
for rcol in other.node.ids
}
other_node = other.node
if set(other_node.ids) & set(self.node.ids):
other_node = nodes.SelectionNode(
other_node,
tuple(
(ex.deref(old_id), ids.ColumnId(new_id))
for old_id, new_id in r_mapping.items()
),
)

join_node = nodes.JoinNode(
left_child=self.node,
right_child=other.node,
right_child=other_node,
conditions=tuple(
(ex.deref(l_col), ex.deref(r_col)) for l_col, r_col in conditions
(ex.deref(l_mapping[l_col]), ex.deref(r_mapping[r_col]))
for l_col, r_col in conditions
),
type=type,
)
# Maps input ids to output ids for caller convenience
l_size = len(self.node.schema)
l_mapping = {
lcol: ocol
for lcol, ocol in zip(
self.node.schema.names, join_node.schema.names[:l_size]
)
}
r_mapping = {
rcol: ocol
for rcol, ocol in zip(
other.node.schema.names, join_node.schema.names[l_size:]
)
}
return ArrayValue(join_node), (l_mapping, r_mapping)

def try_align_as_projection(
Expand Down
14 changes: 7 additions & 7 deletions bigframes/core/compile/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
RowOrdering,
TotalOrdering,
)
import bigframes.core.schema as schemata
import bigframes.core.sql
from bigframes.core.window_spec import RangeWindowBounds, RowsWindowBounds, WindowSpec
import bigframes.dtypes
Expand Down Expand Up @@ -585,9 +584,7 @@ def has_total_order(self) -> bool:

@classmethod
def from_pandas(
cls,
pd_df: pandas.DataFrame,
schema: schemata.ArraySchema,
cls, pd_df: pandas.DataFrame, scan_cols: bigframes.core.nodes.ScanList
) -> OrderedIR:
"""
Builds an in-memory only (SQL only) expr from a pandas dataframe.
Expand All @@ -603,18 +600,21 @@ def from_pandas(
# derive the ibis schema from the original pandas schema
ibis_schema = [
(
name,
local_label,
bigframes.core.compile.ibis_types.bigframes_dtype_to_ibis_dtype(dtype),
)
for name, dtype in zip(schema.names, schema.dtypes)
for id, dtype, local_label in scan_cols.items
]
ibis_schema.append((ORDER_ID_COLUMN, ibis_dtypes.int64))

keys_memtable = ibis.memtable(ibis_values, schema=ibis.schema(ibis_schema))

return cls(
keys_memtable,
columns=[keys_memtable[column].name(column) for column in pd_df.columns],
columns=[
keys_memtable[local_label].name(col_id.sql)
for col_id, _, local_label in scan_cols.items
],
ordering=TotalOrdering.from_offset_col(ORDER_ID_COLUMN),
hidden_ordering_columns=(keys_memtable[ORDER_ID_COLUMN],),
)
Expand Down
Loading