Skip to content

Commit

Permalink
refactor: Simplify join and read table nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
TrevorBergeron committed Sep 30, 2024
1 parent ef76f13 commit a051511
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 251 deletions.
88 changes: 64 additions & 24 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,20 @@ def from_pyarrow(cls, arrow_table: pa.Table, session: Session):

iobytes = io.BytesIO()
pa_feather.write_feather(adapted_table, iobytes)
# Scan all columns by default, we define this list as it can be pruned while preserving source_def
scan_list = nodes.ScanList(
tuple(
nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column)
for item in schema.items
)
)

node = nodes.ReadLocalNode(
iobytes.getvalue(),
data_schema=schema,
session=session,
n_rows=arrow_table.num_rows,
scan_list=scan_list,
)
return cls(node)

Expand All @@ -94,14 +103,30 @@ def from_table(
"Interpreting JSON column(s) as StringDtype. This behavior may change in future versions.",
bigframes.exceptions.PreviewWarning,
)
# define data source only for needed columns, this makes row-hashing cheaper
table = nodes.GbqTable.from_table(table, columns=schema.names)

# create ordering from info
ordering = None
if offsets_col:
ordering = orderings.TotalOrdering.from_offset_col(offsets_col)
elif primary_key:
ordering = orderings.TotalOrdering.from_primary_key(primary_key)

# Scan all columns by default, we define this list as it can be pruned while preserving source_def
scan_list = nodes.ScanList(
tuple(
nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column)
for item in schema.items
)
)
source_def = nodes.BigqueryDataSource(
table=table, at_time=at_time, sql_predicate=predicate, ordering=ordering
)
node = nodes.ReadTableNode(
table=nodes.GbqTable.from_table(table),
total_order_cols=(offsets_col,) if offsets_col else tuple(primary_key),
order_col_is_sequential=(offsets_col is not None),
columns=schema,
at_time=at_time,
source=source_def,
scan_list=scan_list,
table_session=session,
sql_predicate=predicate,
)
return cls(node)

Expand Down Expand Up @@ -147,12 +172,22 @@ def as_cached(
ordering: Optional[orderings.RowOrdering],
) -> ArrayValue:
"""
Replace the node with an equivalent one that references a tabel where the value has been materialized to.
Replace the node with an equivalent one that references a table where the value has been materialized to.
"""
table = nodes.GbqTable.from_table(cache_table)
source = nodes.BigqueryDataSource(table, ordering=ordering)
# Assumption: GBQ cached table uses field name as bq column name
scan_list = nodes.ScanList(
tuple(
nodes.ScanItem(field.id, field.dtype, field.id.name)
for field in self.node.fields
)
)
node = nodes.CachedTableNode(
original_node=self.node,
table=nodes.GbqTable.from_table(cache_table),
ordering=ordering,
source=source,
table_session=self.session,
scan_list=scan_list,
)
return ArrayValue(node)

Expand Down Expand Up @@ -369,28 +404,33 @@ def relational_join(
conditions: typing.Tuple[typing.Tuple[str, str], ...] = (),
type: typing.Literal["inner", "outer", "left", "right", "cross"] = "inner",
) -> typing.Tuple[ArrayValue, typing.Tuple[dict[str, str], dict[str, str]]]:
l_mapping = { # Identity mapping, only rename right side
lcol.name: lcol.name for lcol in self.node.ids
}
r_mapping = { # Rename conflicting names
rcol.name: rcol.name
if (rcol not in l_mapping)
else bigframes.core.guid.generate_guid()
for rcol in other.node.ids
}
other_node = other.node
if set(other_node.ids) & set(self.node.ids):
other_node = nodes.SelectionNode(
other_node,
tuple(
(ex.deref(old_id), ids.ColumnId(new_id))
for old_id, new_id in r_mapping.items()
),
)

join_node = nodes.JoinNode(
left_child=self.node,
right_child=other.node,
right_child=other_node,
conditions=tuple(
(ex.deref(l_col), ex.deref(r_col)) for l_col, r_col in conditions
),
type=type,
)
# Maps input ids to output ids for caller convenience
l_size = len(self.node.schema)
l_mapping = {
lcol: ocol
for lcol, ocol in zip(
self.node.schema.names, join_node.schema.names[:l_size]
)
}
r_mapping = {
rcol: ocol
for rcol, ocol in zip(
other.node.schema.names, join_node.schema.names[l_size:]
)
}
return ArrayValue(join_node), (l_mapping, r_mapping)

def try_align_as_projection(
Expand Down
144 changes: 52 additions & 92 deletions bigframes/core/compile/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import dataclasses
import functools
import io
import itertools
import typing

import ibis
Expand All @@ -29,9 +28,12 @@
import bigframes.core.compile.concat as concat_impl
import bigframes.core.compile.default_ordering as default_ordering
import bigframes.core.compile.ibis_types
import bigframes.core.compile.scalar_op_compiler
import bigframes.core.compile.scalar_op_compiler as compile_scalar
import bigframes.core.compile.schema_translator
import bigframes.core.compile.single_column
import bigframes.core.expression as ex
import bigframes.core.guid as guids
import bigframes.core.identifiers as ids
import bigframes.core.nodes as nodes
import bigframes.core.ordering as bf_ordering

Expand All @@ -45,6 +47,7 @@ class Compiler:
# In strict mode, ordering will always be deterministic
# In unstrict mode, ordering from ReadTable or after joins may be ambiguous to improve query performance.
strict: bool = True
scalar_op_compiler = compile_scalar.ScalarOpCompiler()

def compile_ordered_ir(self, node: nodes.BigFrameNode) -> compiled.OrderedIR:
ir = typing.cast(compiled.OrderedIR, self.compile_node(node, True))
Expand Down Expand Up @@ -121,138 +124,95 @@ def compile_readlocal(self, node: nodes.ReadLocalNode, ordered: bool = True):
else:
return ordered_ir.to_unordered()

@_compile_node.register
def compile_cached_table(self, node: nodes.CachedTableNode, ordered: bool = True):
full_table_name = (
f"{node.table.project_id}.{node.table.dataset_id}.{node.table.table_id}"
)
physical_schema = ibis.backends.bigquery.BigQuerySchema.to_ibis(
node.table.physical_schema
)
ibis_table = ibis.table(physical_schema, full_table_name)
if ordered:
if node.ordering is None:
# If this happens, session malfunctioned while applying cached results.
raise ValueError(
"Cannot use unordered cached value. Result requires ordering information."
)
if self.strict and not isinstance(node.ordering, bf_ordering.TotalOrdering):
raise ValueError(
"Cannot use partially ordered cached value. Result requires total ordering information."
)
ir = compiled.OrderedIR(
ibis_table,
columns=tuple(
bigframes.core.compile.ibis_types.ibis_value_to_canonical_type(
ibis_table[col.sql]
)
for col in itertools.chain(
map(lambda x: x.id, node.fields), node._hidden_columns
)
),
ordering=node.ordering,
)
ir = ir._select(
tuple(ir._get_ibis_column(name) for name in node.schema.names)
)
return ir
else:
return compiled.UnorderedIR(
ibis_table,
columns=tuple(
bigframes.core.compile.ibis_types.ibis_value_to_canonical_type(
ibis_table[col]
)
for col in node.schema.names
),
)

@_compile_node.register
def compile_readtable(self, node: nodes.ReadTableNode, ordered: bool = True):
if ordered:
return self.compile_read_table_ordered(node)
return self.compile_read_table_ordered(node.source, node.scan_list)
else:
return self.compile_read_table_unordered(node)
return self.compile_read_table_unordered(node.source, node.scan_list)

def read_table_as_unordered_ibis(
self, node: nodes.ReadTableNode
self, source: nodes.BigqueryDataSource
) -> ibis.expr.types.Table:
full_table_name = (
f"{node.table.project_id}.{node.table.dataset_id}.{node.table.table_id}"
)
used_columns = (
*node.schema.names,
*[i for i in node.total_order_cols if i not in node.schema.names],
)
full_table_name = f"{source.table.project_id}.{source.table.dataset_id}.{source.table.table_id}"
used_columns = tuple(col.name for col in source.table.physical_schema)
# Physical schema might include unused columns, unsupported datatypes like JSON
physical_schema = ibis.backends.bigquery.BigQuerySchema.to_ibis(
list(i for i in node.table.physical_schema if i.name in used_columns)
list(i for i in source.table.physical_schema if i.name in used_columns)
)
if node.at_time is not None or node.sql_predicate is not None:
if source.at_time is not None or source.sql_predicate is not None:
import bigframes.session._io.bigquery

sql = bigframes.session._io.bigquery.to_query(
full_table_name,
columns=used_columns,
sql_predicate=node.sql_predicate,
time_travel_timestamp=node.at_time,
sql_predicate=source.sql_predicate,
time_travel_timestamp=source.at_time,
)
return ibis.backends.bigquery.Backend().sql(
schema=physical_schema, query=sql
)
else:
return ibis.table(physical_schema, full_table_name)

def compile_read_table_unordered(self, node: nodes.ReadTableNode):
ibis_table = self.read_table_as_unordered_ibis(node)
def compile_read_table_unordered(
self, source: nodes.BigqueryDataSource, scan: nodes.ScanList
):
ibis_table = self.read_table_as_unordered_ibis(source)
return compiled.UnorderedIR(
ibis_table,
tuple(
bigframes.core.compile.ibis_types.ibis_value_to_canonical_type(
ibis_table[col]
ibis_table[scan_item.source_id].name(scan_item.id.sql)
)
for col in node.schema.names
for scan_item in scan.items
),
)

def compile_read_table_ordered(self, node: nodes.ReadTableNode):
ibis_table = self.read_table_as_unordered_ibis(node)
if node.total_order_cols:
ordering_value_columns = tuple(
bf_ordering.ascending_over(col) for col in node.total_order_cols
def compile_read_table_ordered(
self, source: nodes.BigqueryDataSource, scan_list: nodes.ScanList
):
ibis_table = self.read_table_as_unordered_ibis(source)
if source.ordering is not None:
visible_column_mapping = {
ids.ColumnId(scan_item.source_id): scan_item.id
for scan_item in scan_list.items
}
full_mapping = {
ids.ColumnId(col.name): ids.ColumnId(guids.generate_guid())
for col in source.ordering.referenced_columns
}
full_mapping.update(visible_column_mapping)

ordering = source.ordering.remap_column_refs(full_mapping)
hidden_columns = tuple(
ibis_table[source_id.sql].name(out_id.sql)
for source_id, out_id in full_mapping.items()
if source_id not in visible_column_mapping
)
if node.order_col_is_sequential:
integer_encoding = bf_ordering.IntegerEncoding(
is_encoded=True, is_sequential=True
elif self.strict: # In strict mode, we fallback to ordering by row hash
order_values = [
col.name(guids.generate_guid())
for col in default_ordering.gen_default_ordering(
ibis_table, use_double_hash=True
)
else:
integer_encoding = bf_ordering.IntegerEncoding()
ordering: bf_ordering.RowOrdering = bf_ordering.TotalOrdering(
ordering_value_columns,
integer_encoding=integer_encoding,
total_ordering_columns=frozenset(map(ex.deref, node.total_order_cols)),
)
hidden_columns = ()
elif self.strict:
ibis_table, ordering = default_ordering.gen_default_ordering(
ibis_table, use_double_hash=True
)
hidden_columns = tuple(
ibis_table[col]
for col in ibis_table.columns
if col not in node.schema.names
]
ordering = bf_ordering.TotalOrdering.from_primary_key(
[value.get_name() for value in order_values]
)
hidden_columns = tuple(order_values)
else:
# In unstrict mode, don't generate total ordering from hashing as this is
# expensive (prevent removing any columns from table scan)
ordering, hidden_columns = bf_ordering.RowOrdering(), ()

return compiled.OrderedIR(
ibis_table,
columns=tuple(
bigframes.core.compile.ibis_types.ibis_value_to_canonical_type(
ibis_table[col]
ibis_table[scan_item.source_id].name(scan_item.id.sql)
)
for col in node.schema.names
for scan_item in scan_list.items
),
ordering=ordering,
hidden_ordering_columns=hidden_columns,
Expand Down
23 changes: 4 additions & 19 deletions bigframes/core/compile/default_ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from __future__ import annotations

import itertools
from typing import cast

import bigframes_vendored.ibis.expr.operations as vendored_ibis_ops
Expand All @@ -27,9 +26,7 @@
import ibis.expr.datatypes as ibis_dtypes
import ibis.expr.types as ibis_types

import bigframes.core.expression as ex
import bigframes.core.guid as guid
import bigframes.core.ordering as order


def _convert_to_nonnull_string(column: ibis_types.Column) -> ibis_types.StringValue:
Expand Down Expand Up @@ -59,7 +56,9 @@ def _convert_to_nonnull_string(column: ibis_types.Column) -> ibis_types.StringVa
return cast(ibis_types.StringColumn, ibis.literal("\\")).concat(escaped)


def gen_default_ordering(table: ibis.table, use_double_hash: bool = True):
def gen_default_ordering(
table: ibis.table, use_double_hash: bool = True
) -> list[ibis.Value]:
ordering_hash_part = guid.generate_guid("bigframes_ordering_")
ordering_hash_part2 = guid.generate_guid("bigframes_ordering_")
ordering_rand_part = guid.generate_guid("bigframes_ordering_")
Expand All @@ -82,18 +81,4 @@ def gen_default_ordering(table: ibis.table, use_double_hash: bool = True):
if use_double_hash
else [full_row_hash, random_value]
)

original_column_ids = table.columns
table_with_ordering = table.select(
itertools.chain(original_column_ids, order_values)
)

ordering = order.TotalOrdering(
ordering_value_columns=tuple(
order.ascending_over(col.get_name()) for col in order_values
),
total_ordering_columns=frozenset(
ex.deref(col.get_name()) for col in order_values
),
)
return table_with_ordering, ordering
return order_values
Loading

0 comments on commit a051511

Please sign in to comment.