diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index 386ede17090f..e210c99f1c28 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -1,379 +1,22 @@ from __future__ import annotations -from collections import defaultdict -from typing import TYPE_CHECKING - -import toolz - import ibis.common.graph as g import ibis.expr.operations as ops -import ibis.expr.operations.relations as rels -import ibis.expr.types as ir -from ibis import util from ibis.common.deferred import deferred, var -from ibis.common.exceptions import ExpressionError, IbisTypeError, IntegrityError -from ibis.common.patterns import Eq, In, pattern, replace +from ibis.common.patterns import pattern from ibis.util import Namespace -if TYPE_CHECKING: - from collections.abc import Iterable, Iterator - p = Namespace(pattern, module=ops) c = Namespace(deferred, module=ops) x = var("x") y = var("y") -# --------------------------------------------------------------------- -# Some expression metaprogramming / graph transformations to support -# compilation later - - -def sub_immediate_parents(node: ops.Node, table: ops.TableNode) -> ops.Node: - """Replace immediate parent tables in `op` with `table`.""" - parents = find_immediate_parent_tables(node) - return node.replace(In(parents) >> table) - - -def find_immediate_parent_tables(input_node, keep_input=True): - """Find every first occurrence of a `ir.Table` object in `input_node`. - - This function does not traverse into `Table` objects. For example, the - underlying `PhysicalTable` of a `Selection` will not be yielded. - - Parameters - ---------- - input_node - Input node - keep_input - Whether to keep the input when traversing - - Yields - ------ - ir.Expr - Parent table expression - - Examples - -------- - >>> import ibis, toolz - >>> t = ibis.table([("a", "int64")], name="t") - >>> expr = t.mutate(foo=t.a + 1) - >>> (result,) = find_immediate_parent_tables(expr.op()) - >>> result.equals(expr.op()) - True - >>> (result,) = find_immediate_parent_tables(expr.op(), keep_input=False) - >>> result.equals(t.op()) - True - """ - assert all(isinstance(arg, ops.Node) for arg in util.promote_list(input_node)) - - def finder(node): - if isinstance(node, ops.TableNode): - if keep_input or node != input_node: - return g.halt, node - else: - return g.proceed, None - - # HACK: special case ops.Contains to only consider the needle's base - # table, since that's the only expression that matters for determining - # cardinality - elif isinstance(node, ops.InColumn): - # we allow InColumn.options to be a column from a foreign table - return [node.value], None - else: - return g.proceed, None - - return list(toolz.unique(g.traverse(finder, input_node))) - - -def get_mutation_exprs(exprs: list[ir.Expr], table: ir.Table) -> list[ir.Expr | None]: - """Return the exprs to use to instantiate the mutation.""" - # The below logic computes the mutation node exprs by splitting the - # assignment exprs into two disjoint sets: - # 1) overwriting_cols_to_expr, which maps a column name to its expr - # if the expr contains a column that overwrites an existing table column. - # All keys in this dict are columns in the original table that are being - # overwritten by an assignment expr. - # 2) non_overwriting_exprs, which is a list of all exprs that do not do - # any overwriting. That is, if an expr is in this list, then its column - # name does not exist in the original table. - # Given these two data structures, we can compute the mutation node exprs - # based on whether any columns are being overwritten. - overwriting_cols_to_expr: dict[str, ir.Expr | None] = {} - non_overwriting_exprs: list[ir.Expr] = [] - table_schema = table.schema() - for expr in exprs: - expr_contains_overwrite = False - if isinstance(expr, ir.Value) and expr.get_name() in table_schema: - overwriting_cols_to_expr[expr.get_name()] = expr - expr_contains_overwrite = True - - if not expr_contains_overwrite: - non_overwriting_exprs.append(expr) - - columns = table.columns - if overwriting_cols_to_expr: - return [ - overwriting_cols_to_expr.get(column, table[column]) - for column in columns - if overwriting_cols_to_expr.get(column, table[column]) is not None - ] + non_overwriting_exprs - - table_expr: ir.Expr = table - return [table_expr] + exprs - - -def pushdown_selection_filters(parent, predicates): - if not predicates: - return parent - - default = ops.Selection(parent, selections=[], predicates=predicates) - if not isinstance(parent, (ops.Selection, ops.Aggregation)): - return default - - projected_column_names = set() - for value in parent._projection.selections: - if isinstance(value, (ops.Relation, ops.TableColumn)): - # we are only interested in projected value expressions, not tables - # nor column references which are not changing the projection - continue - elif value.find((ops.WindowFunction, ops.ExistsSubquery), filter=ops.Value): - # the parent has analytic projections like window functions so we - # can't push down filters to that level - return default - else: - # otherwise collect the names of newly projected value expressions - # which are not just plain column references - projected_column_names.add(value.name) - - conflicting_projection = p.TableColumn(parent, In(projected_column_names)) - pushdown_pattern = Eq(parent) >> parent.table - - simplified = [] - for pred in predicates: - if pred.find(conflicting_projection, filter=p.Value): - return default - try: - simplified.append(pred.replace(pushdown_pattern)) - except (IntegrityError, IbisTypeError): - # former happens when there is a duplicate column name in the parent - # which is a join, the latter happens for semi/anti joins - return default - - return parent.copy(predicates=parent.predicates + tuple(simplified)) - - -@replace(p.Analytic | p.Reduction) -def wrap_analytic(_, default_frame): - return ops.WindowFunction(_, default_frame) - - -@replace(p.WindowFunction) -def merge_windows(_, default_frame): - if _.frame.start and default_frame.start and _.frame.start != default_frame.start: - raise ExpressionError( - "Unable to merge windows with conflicting `start` boundary" - ) - if _.frame.end and default_frame.end and _.frame.end != default_frame.end: - raise ExpressionError("Unable to merge windows with conflicting `end` boundary") - - start = _.frame.start or default_frame.start - end = _.frame.end or default_frame.end - group_by = tuple(toolz.unique(_.frame.group_by + default_frame.group_by)) - - order_by = {} - # iterate in the order of the existing keys followed by the new keys - # - # this allows duplicates to be overridden with no effect on the original - # position - # - # see https://github.com/ibis-project/ibis/issues/7940 for how this - # originally manifested - for sort_key in default_frame.order_by + _.frame.order_by: - order_by[sort_key.expr] = sort_key.ascending - order_by = tuple(ops.SortKey(k, v) for k, v in order_by.items()) - - frame = _.frame.copy(start=start, end=end, group_by=group_by, order_by=order_by) - return ops.WindowFunction(_.func, frame) - - -def windowize_function(expr, default_frame): - ctx = {"default_frame": default_frame} - node = expr.op() - node = node.replace(merge_windows, filter=p.Value, context=ctx) - node = node.replace(wrap_analytic, filter=p.Value & ~p.WindowFunction, context=ctx) - return node.to_expr() - - -def contains_first_or_last_agg(exprs): - def fn(node: ops.Node) -> tuple[bool, bool | None]: - if not isinstance(node, ops.Value): - return g.halt, None - return g.proceed, isinstance(node, (ops.First, ops.Last)) - - return any(g.traverse(fn, exprs)) - - -def simplify_aggregation(agg): - def _pushdown(nodes): - subbed = [] - for node in nodes: - new_node = node.replace(Eq(agg.table) >> agg.table.table) - subbed.append(new_node) - - # TODO(kszucs): perhaps this validation could be omitted - if subbed: - valid = shares_all_roots(subbed, agg.table.table) - else: - valid = True - - return valid, subbed - - table = agg.table - if ( - isinstance(table, ops.Selection) - and not table.selections - # more aggressive than necessary, a better solution would be to check - # whether the selections have any order sensitive aggregates that - # *depend on* the sort_keys - and not (table.sort_keys or contains_first_or_last_agg(table.selections)) - ): - metrics_valid, lowered_metrics = _pushdown(agg.metrics) - by_valid, lowered_by = _pushdown(agg.by) - having_valid, lowered_having = _pushdown(agg.having) - - if metrics_valid and by_valid and having_valid: - valid_lowered_sort_keys = frozenset(lowered_metrics).union(lowered_by) - return ops.Aggregation( - table.table, - lowered_metrics, - by=lowered_by, - having=lowered_having, - predicates=agg.table.predicates, - # only the sort keys that exist as grouping keys or metrics can - # be included - sort_keys=[ - key - for key in agg.table.sort_keys - if key.expr in valid_lowered_sort_keys - ], - ) - - return agg - - -class Projector: - """Analysis and validation of projection operation. - - This pass tries to take advantage of projection fusion opportunities where - they exist, i.e. combining compatible projections together rather than - nesting them. - - Translation / evaluation later will not attempt to do any further fusion / - simplification. - """ - - def __init__(self, parent, proj_exprs): - # TODO(kszucs): rewrite projector to work with operations exclusively - proj_exprs = util.promote_list(proj_exprs) - self.parent = parent - self.input_exprs = proj_exprs - self.resolved_exprs = [parent._ensure_expr(e) for e in proj_exprs] - - default_frame = ops.RowsWindowFrame(table=parent) - self.clean_exprs = [ - windowize_function(expr, default_frame) for expr in self.resolved_exprs - ] - - def get_result(self): - roots = find_immediate_parent_tables(self.parent.op()) - first_root = roots[0] - parent_op = self.parent.op() - - # reprojection of the same selections - if len(self.clean_exprs) == 1: - first = self.clean_exprs[0].op() - if isinstance(first, ops.Selection): - if first.selections == parent_op.selections: - return parent_op - - if len(roots) == 1 and isinstance(first_root, ops.Selection): - fused_op = self.try_fusion(first_root) - if fused_op is not None: - return fused_op - - return ops.Selection(self.parent, self.clean_exprs) - - def try_fusion(self, root): - assert self.parent.op() == root - - root_table = root.table - root_table_expr = root_table.to_expr() - roots = find_immediate_parent_tables(root_table) - fused_exprs = [] - clean_exprs = self.clean_exprs - - if not isinstance(root_table, ops.Join): - try: - resolved = [ - root_table_expr._ensure_expr(expr) for expr in self.input_exprs - ] - except (AttributeError, IbisTypeError): - resolved = clean_exprs - else: - # if any expressions aren't exactly equivalent then don't try - # to fuse them - if any( - not res_root_root.equals(res_root) - for res_root_root, res_root in zip(resolved, clean_exprs) - ): - return None - else: - # joins cannot be used to resolve expressions, but we still may be - # able to fuse columns from a projection off of a join. In that - # case, use the projection's input expressions as the columns with - # which to attempt fusion - resolved = clean_exprs - - root_selections = root.selections - parent_op = self.parent.op() - for val in resolved: - # a * projection - if isinstance(val, ir.Table) and ( - parent_op.equals(val.op()) - # gross we share the same table root. Better way to - # detect? - or len(roots) == 1 - and find_immediate_parent_tables(val.op())[0] == roots[0] - ): - have_root = False - for root_sel in root_selections: - # Don't add the * projection twice - if root_sel.equals(root_table): - fused_exprs.append(root_table) - have_root = True - continue - fused_exprs.append(root_sel) - - # This was a filter, so implicitly a select * - if not have_root and not root_selections: - fused_exprs = [root_table, *fused_exprs] - elif shares_all_roots(val.op(), root_table): - fused_exprs.append(val) - else: - return None - - return ops.Selection( - root_table, - fused_exprs, - predicates=root.predicates, - sort_keys=root.sort_keys, - ) - +# TODO(kszucs): should be removed def find_first_base_table(node): def predicate(node): - if isinstance(node, ops.TableNode): + if isinstance(node, ops.Relation): return g.halt, node else: return g.proceed, None @@ -384,42 +27,7 @@ def predicate(node): return None -def _find_projections(node): - if isinstance(node, ops.Selection): - # remove predicates and sort_keys, so that child tables are considered - # equivalent even if their predicates and sort_keys are not - return g.proceed, node._projection - elif isinstance(node, ops.SelfReference): - return g.proceed, node - elif isinstance(node, ops.Aggregation): - return g.proceed, node._projection - elif isinstance(node, ops.Join): - return g.proceed, None - elif isinstance(node, ops.TableNode): - return g.halt, node - elif isinstance(node, ops.InColumn): - # we allow InColumn.options to be a column from a foreign table - return [node.value], None - else: - return g.proceed, None - - -def shares_all_roots(exprs, parents): - # unique table dependencies of exprs and parents - exprs_deps = set(g.traverse(_find_projections, exprs)) - parents_deps = set(g.traverse(_find_projections, parents)) - return exprs_deps <= parents_deps - - -def shares_some_roots(exprs, parents): - # unique table dependencies of exprs and parents - exprs_deps = set(g.traverse(_find_projections, exprs)) - parents_deps = set(g.traverse(_find_projections, parents)) - # Also return True if exprs has no roots (e.g. literal-only expressions) - return bool(exprs_deps & parents_deps) or not exprs_deps - - -def flatten_predicate(node): +def flatten_predicates(node): """Yield the expressions corresponding to the `And` nodes of a predicate. Examples @@ -449,45 +57,3 @@ def predicate(node): return g.halt, node return list(g.traverse(predicate, node)) - - -def find_predicates(node, flatten=True): - # TODO(kszucs): consider to remove flatten argument and compose with - # flatten_predicates instead - def predicate(node): - assert isinstance(node, ops.Node), type(node) - if isinstance(node, ops.Value) and node.dtype.is_boolean(): - if flatten and isinstance(node, ops.And): - return g.proceed, None - else: - return g.halt, node - return g.proceed, None - - return list(g.traverse(predicate, node)) - - -def find_subqueries(node: ops.Node, min_dependents=1) -> tuple[ops.Node, ...]: - subquery_dependents = defaultdict(set) - for n in filter(None, util.promote_list(node)): - dependents = g.Graph.from_dfs(n).invert() - for u, vs in dependents.toposort().items(): - # count the number of table-node dependents on the current node - # but only if the current node is a selection or aggregation - if isinstance(u, (rels.Projection, rels.Aggregation, rels.Limit)): - subquery_dependents[u].update(vs) - - return tuple( - node - for node, dependents in reversed(subquery_dependents.items()) - if len(dependents) >= min_dependents - ) - - -def find_toplevel_unnest_children(nodes: Iterable[ops.Node]) -> Iterator[ops.Table]: - def finder(node): - return ( - isinstance(node, ops.Value), - find_first_base_table(node) if isinstance(node, ops.Unnest) else None, - ) - - return g.traverse(finder, nodes) diff --git a/ibis/expr/api.py b/ibis/expr/api.py index 8631d3e630dd..687bdcd12fdd 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -5,6 +5,7 @@ import builtins import datetime import functools +import itertools import numbers import operator from collections import Counter @@ -303,6 +304,9 @@ def schema( return sch.Schema.from_tuples(zip(names, types)) +_table_names = (f"unbound_table_{i:d}" for i in itertools.count()) + + def table( schema: SupportsSchema | None = None, name: str | None = None, @@ -333,9 +337,12 @@ def table( a int64 b string """ - if isinstance(schema, type) and name is None: - name = schema.__name__ - return ops.UnboundTable(schema=schema, name=name).to_expr() + if name is None: + if isinstance(schema, type): + name = schema.__name__ + else: + name = next(_table_names) + return ops.UnboundTable(name=name, schema=schema).to_expr() @lazy_singledispatch diff --git a/ibis/expr/builders.py b/ibis/expr/builders.py index 9009518a19ea..333d3456bf43 100644 --- a/ibis/expr/builders.py +++ b/ibis/expr/builders.py @@ -9,13 +9,11 @@ import ibis.expr.rules as rlz import ibis.expr.types as ir from ibis import util -from ibis.common.annotations import annotated +from ibis.common.annotations import annotated, attribute from ibis.common.deferred import Deferred, Resolver, deferrable from ibis.common.exceptions import IbisInputError from ibis.common.grounds import Concrete from ibis.common.typing import VarTuple # noqa: TCH001 -from ibis.expr.operations.relations import Relation # noqa: TCH001 -from ibis.expr.types.relations import bind_expr if TYPE_CHECKING: from typing_extensions import Self @@ -146,6 +144,25 @@ class WindowBuilder(Builder): orderings: VarTuple[Union[str, Resolver, ops.Value]] = () max_lookback: Optional[ops.Value[dt.Interval]] = None + @attribute + def _table(self): + inputs = ( + self.start, + self.end, + *self.groupings, + *self.orderings, + self.max_lookback, + ) + valuerels = (v.relations for v in inputs if isinstance(v, ops.Value)) + relations = frozenset().union(*valuerels) + if len(relations) == 0: + return None + elif len(relations) == 1: + (table,) = relations + return table + else: + raise IbisInputError("Window frame can only depend on a single relation") + def _maybe_cast_boundary(self, boundary, dtype): if boundary.dtype == dtype: return boundary @@ -214,9 +231,23 @@ def lookback(self, value) -> Self: return self.copy(max_lookback=value) @annotated - def bind(self, table: Relation): - groupings = bind_expr(table.to_expr(), self.groupings) - orderings = bind_expr(table.to_expr(), self.orderings) + def bind(self, table: Optional[ops.Relation]): + table = table or self._table + if table is None: + raise IbisInputError("Unable to bind window frame to a table") + + table = table.to_expr() + + def bind_value(value): + if isinstance(value, str): + return table._get_column(value) + elif isinstance(value, Resolver): + return value.resolve({"_": table}) + else: + return value + + groupings = map(bind_value, self.groupings) + orderings = map(bind_value, self.orderings) if self.how == "rows": return ops.RowsWindowFrame( table=table, diff --git a/ibis/expr/decompile.py b/ibis/expr/decompile.py index a6cbf56de716..af279447bd76 100644 --- a/ibis/expr/decompile.py +++ b/ibis/expr/decompile.py @@ -10,6 +10,7 @@ import ibis.expr.operations as ops import ibis.expr.types as ir from ibis.common.graph import Graph +from ibis.expr.rewrites import simplify from ibis.util import experimental _method_overrides = { @@ -31,11 +32,8 @@ ops.ExtractYear: "year", ops.Intersection: "intersect", ops.IsNull: "isnull", - ops.LeftAntiJoin: "anti_join", - ops.LeftSemiJoin: "semi_join", ops.Lowercase: "lower", ops.RegexSearch: "re_search", - ops.SelfReference: "view", ops.StartsWith: "startswith", ops.StringContains: "contains", ops.StringSQLILike: "ilike", @@ -87,7 +85,6 @@ def translate(op, *args, **kwargs): @translate.register(ops.Value) -@translate.register(ops.TableNode) def value(op, *args, **kwargs): method = _get_method_name(op) kwargs = [(k, v) for k, v in kwargs.items() if v is not None] @@ -125,44 +122,80 @@ def _try_unwrap(stmt): if len(stmt) == 1: return stmt[0] else: - return f"[{', '.join(stmt)}]" + stmt = map(str, stmt) + values = ", ".join(stmt) + return f"[{values}]" + + +def _wrap_alias(values, rendered): + result = [] + for k, v in values.items(): + text = rendered[k] + if v.name != k: + text = f"{text}.name({k!r})" + result.append(text) + return result + + +def _inline(args): + return ", ".join(map(str, args)) + +@translate.register(ops.Project) +def project(op, parent, values): + out = f"{parent}" + if not values: + return out -@translate.register(ops.Selection) -def selection(op, table, selections, predicates, sort_keys): - out = f"{table}" - if selections: - out = f"{out}.select({_try_unwrap(selections)})" + values = _wrap_alias(op.values, values) + return f"{out}.select({_inline(values)})" + + +@translate.register(ops.Filter) +def filter_(op, parent, predicates): + out = f"{parent}" if predicates: - out = f"{out}.filter({_try_unwrap(predicates)})" - if sort_keys: - out = f"{out}.order_by({_try_unwrap(sort_keys)})" + out = f"{out}.filter({_inline(predicates)})" return out -@translate.register(ops.Aggregation) -def aggregation(op, table, by, metrics, predicates, having, sort_keys): - out = f"{table}" - if predicates: - out = f"{out}.filter({_try_unwrap(predicates)})" - if by: - out = f"{out}.group_by({_try_unwrap(by)})" - if having: - out = f"{out}.having({_try_unwrap(having)})" - if metrics: - out = f"{out}.aggregate({_try_unwrap(metrics)})" - if sort_keys: - out = f"{out}.order_by({_try_unwrap(sort_keys)})" +@translate.register(ops.Sort) +def sort(op, parent, keys): + out = f"{parent}" + if keys: + out = f"{out}.order_by({_inline(keys)})" return out -@translate.register(ops.Join) -def join(op, left, right, predicates): - method = _get_method_name(op) - return f"{left}.{method}({right}, {_try_unwrap(predicates)})" +@translate.register(ops.Aggregate) +def aggregation(op, parent, groups, metrics): + groups = _wrap_alias(op.groups, groups) + metrics = _wrap_alias(op.metrics, metrics) + if groups and metrics: + return f"{parent}.aggregate([{_inline(metrics)}], by=[{_inline(groups)}])" + elif metrics: + return f"{parent}.aggregate([{_inline(metrics)}])" + else: + raise ValueError("No metrics to aggregate") + + +@translate.register(ops.SelfReference) +def self_reference(op, parent, identifier): + return parent + + +@translate.register(ops.JoinLink) +def join_link(op, table, predicates, how): + return f".{how}_join({table}, {_try_unwrap(predicates)})" + + +@translate.register(ops.JoinChain) +def join(op, first, rest, values): + calls = "".join(rest) + return f"{first}{calls}" -@translate.register(ops.SetOp) +@translate.register(ops.Set) def union(op, left, right, distinct): method = _get_method_name(op) if distinct: @@ -172,16 +205,16 @@ def union(op, left, right, distinct): @translate.register(ops.Limit) -def limit(op, table, n, offset): +def limit(op, parent, n, offset): if offset: - return f"{table}.limit({n}, {offset})" + return f"{parent}.limit({n}, {offset})" else: - return f"{table}.limit({n})" + return f"{parent}.limit({n})" -@translate.register(ops.TableColumn) -def table_column(op, table, name): - return f"{table}.{name}" +@translate.register(ops.Field) +def table_column(op, rel, name): + return f"{rel}.{name}" @translate.register(ops.SortKey) @@ -292,14 +325,22 @@ def isin(op, value, options): class CodeContext: - always_assign = (ops.ScalarParameter, ops.UnboundTable, ops.Aggregation) - always_ignore = (ops.TableColumn, dt.Primitive, dt.Variadic, dt.Temporal) + always_assign = (ops.ScalarParameter, ops.UnboundTable, ops.Aggregate) + always_ignore = ( + ops.SelfReference, + ops.Field, + dt.Primitive, + dt.Variadic, + dt.Temporal, + ) shorthands = { - ops.Aggregation: "agg", + ops.Aggregate: "agg", ops.Literal: "lit", ops.ScalarParameter: "param", - ops.Selection: "proj", - ops.TableNode: "t", + ops.Project: "p", + ops.Relation: "r", + ops.Filter: "f", + ops.Sort: "s", } def __init__(self, assign_result_to="result"): @@ -308,7 +349,7 @@ def __init__(self, assign_result_to="result"): def variable_for(self, node): klass = type(node) - if isinstance(node, ops.TableNode) and isinstance(node, ops.Named): + if isinstance(node, ops.Relation) and hasattr(node, "name"): name = node.name elif klass in self.shorthands: name = self.shorthands[klass] @@ -345,7 +386,7 @@ def render(self, node, code, n_dependents): @experimental def decompile( - node: ops.Node | ir.Expr, + expr: ir.Expr, render_import: bool = True, assign_result_to: str = "result", format: bool = False, @@ -354,7 +395,7 @@ def decompile( Parameters ---------- - node + expr node or expression to decompile render_import Whether to add `import ibis` to the result. @@ -368,13 +409,11 @@ def decompile( str Equivalent Python source code for `node`. """ - if isinstance(node, ir.Expr): - node = node.op() - elif not isinstance(node, ops.Node): - raise TypeError( - f"Expected ibis expression or operation, got {type(node).__name__}" - ) + if not isinstance(expr, ir.Expr): + raise TypeError(f"Expected ibis expression, got {type(expr).__name__}") + node = expr.op() + node = simplify(node) out = io.StringIO() ctx = CodeContext(assign_result_to=assign_result_to) dependents = Graph(node).invert() diff --git a/ibis/expr/format.py b/ibis/expr/format.py index c0a244287e14..6ac9dfeb7b8a 100644 --- a/ibis/expr/format.py +++ b/ibis/expr/format.py @@ -192,11 +192,14 @@ def fmt(op, **kwargs): @fmt.register(ops.Relation) -@fmt.register(ops.DummyTable) @fmt.register(ops.WindowingTVF) -def _relation(op, **kwargs): - schema = render_schema(op.schema, indent_level=1) - return f"{op.__class__.__name__}\n{schema}" +def _relation(op, parent=None, **kwargs): + if parent is None: + top = f"{op.__class__.__name__}\n" + else: + top = f"{op.__class__.__name__}[{parent}]\n" + kwargs["schema"] = render_schema(op.schema) + return top + render_fields(kwargs, 1) @fmt.register(ops.PhysicalTable) @@ -218,6 +221,7 @@ def _in_memory_table(op, data, **kwargs): @fmt.register(ops.SQLStringView) def _sql_query_result(op, query, **kwargs): clsname = op.__class__.__name__ + if isinstance(op, ops.SQLStringView): child, name = kwargs["child"], kwargs["name"] top = f"{clsname}[{child}]: {name}\n" @@ -235,38 +239,54 @@ def _sql_query_result(op, query, **kwargs): @fmt.register(ops.FillNa) @fmt.register(ops.DropNa) -def _fill_na(op, table, **kwargs): - name = f"{op.__class__.__name__}[{table}]\n" +def _fill_na(op, parent, **kwargs): + name = f"{op.__class__.__name__}[{parent}]\n" return name + render_fields(kwargs, 1) -@fmt.register(ops.Aggregation) -def _aggregation(op, table, **kwargs): - name = f"{op.__class__.__name__}[{table}]\n" - kwargs["by"] = {node.name: r for node, r in zip(op.by, kwargs["by"])} - kwargs["metrics"] = {node.name: r for node, r in zip(op.metrics, kwargs["metrics"])} +@fmt.register(ops.Aggregate) +def _aggregate(op, parent, **kwargs): + name = f"{op.__class__.__name__}[{parent}]\n" return name + render_fields(kwargs, 1) -@fmt.register(ops.Selection) -def _selection(op, table, selections, **kwargs): - name = f"{op.__class__.__name__}[{table}]\n" +@fmt.register(ops.Project) +def _project(op, parent, values): + name = f"{op.__class__.__name__}[{parent}]\n" - # special handling required to support both relation and value selections - rels, values = [], {} - for node, rendered in zip(op.selections, selections): - if isinstance(node, ops.Relation): - rels.append(rendered) - else: - values[node.name] = f"{rendered}{type_info(node.dtype)}" + fields = {} + for k, v in values.items(): + node = op.values[k] + fields[f"{k}:"] = f"{v}{type_info(node.dtype)}" - segments = filter(None, [render(rels), render(values)]) - kwargs["selections"] = "\n".join(segments) + return name + render_schema(fields, 1) + + +@fmt.register(ops.DummyTable) +def _dummy_table(op, values): + name = op.__class__.__name__ + "\n" + + fields = {} + for k, v in values.items(): + node = op.values[k] + fields[f"{k}:"] = f"{v}{type_info(node.dtype)}" + + return name + render_schema(fields, 1) - return name + render_fields(kwargs, 1) +@fmt.register(ops.Filter) +def _project(op, parent, predicates): + name = f"{op.__class__.__name__}[{parent}]\n" + return name + render(predicates, 1) -@fmt.register(ops.SetOp) + +@fmt.register(ops.Sort) +def _sort(op, parent, keys): + name = f"{op.__class__.__name__}[{parent}]\n" + return name + render(keys, 1) + + +@fmt.register(ops.Set) def _set_op(op, left, right, distinct): args = [str(left), str(right)] if op.distinct is not None: @@ -274,7 +294,7 @@ def _set_op(op, left, right, distinct): return f"{op.__class__.__name__}[{', '.join(args)}]" -@fmt.register(ops.Join) +@fmt.register(ops.JoinChain) def _join(op, left, right, predicates, **kwargs): args = [str(left), str(right)] name = f"{op.__class__.__name__}[{', '.join(args)}]" @@ -291,30 +311,48 @@ def _join(op, left, right, predicates, **kwargs): return f"{top}\n{fields}" if fields else top +@fmt.register(ops.JoinLink) +def _join(op, how, table, predicates): + args = [str(how), str(table)] + name = f"{op.__class__.__name__}[{', '.join(args)}]" + return f"{name}\n{render(predicates, 1)}" + + +@fmt.register(ops.JoinChain) +def _join_project(op, first, rest, **kwargs): + name = f"{op.__class__.__name__}[{first}]\n" + return name + render(rest, 1) + "\n" + render_fields(kwargs, 1) + + @fmt.register(ops.Limit) @fmt.register(ops.Sample) -def _limit(op, table, **kwargs): +def _limit(op, parent, **kwargs): params = inline_args(kwargs) - return f"{op.__class__.__name__}[{table}, {params}]" + return f"{op.__class__.__name__}[{parent}, {params}]" @fmt.register(ops.SelfReference) @fmt.register(ops.Distinct) -def _self_reference(op, table, **kwargs): - return f"{op.__class__.__name__}[{table}]" +def _self_reference(op, parent, **kwargs): + return f"{op.__class__.__name__}[{parent}]" @fmt.register(ops.Literal) def _literal(op, value, **kwargs): if op.dtype.is_interval(): return f"{value!r} {op.dtype.unit.short}" + elif op.dtype.is_array(): + return f"{list(value)!r}" else: return f"{value!r}" -@fmt.register(ops.TableColumn) -def _table_column(op, table, name): - return f"{table}.{name}" +@fmt.register(ops.Field) +def _relation_field(op, rel, name): + if name.isidentifier(): + return f"{rel}.{name}" + else: + return f"{rel}[{name!r}]" @fmt.register(ops.Value) diff --git a/ibis/expr/operations/core.py b/ibis/expr/operations/core.py index c7b4e0a1dc75..5db1e2c2f17a 100644 --- a/ibis/expr/operations/core.py +++ b/ibis/expr/operations/core.py @@ -32,13 +32,14 @@ def op(self) -> Self: """Make `Node` backwards compatible with code that uses `Expr.op()`.""" return self - @abstractmethod - def to_expr(self): - ... - # Avoid custom repr for performance reasons __repr__ = object.__repr__ + # TODO(kszucs): hidrate the __children__ traversable attribute + # @attribute + # def __children__(self): + # return super().__children__ + # TODO(kszucs): remove this mixin @public @@ -126,6 +127,12 @@ def shape(self) -> S: ds.Shape """ + @attribute + def relations(self): + """Set of relations the value node depends on.""" + children = (n.relations for n in self.__children__ if isinstance(n, Value)) + return frozenset().union(*children) + @property @util.deprecated(as_of="7.0", instead="use .dtype property instead") def output_dtype(self): @@ -167,10 +174,14 @@ class Unary(Value): arg: Value - @property + @attribute def shape(self) -> ds.DataShape: return self.arg.shape + @attribute + def relations(self): + return self.arg.relations + @public class Binary(Value): @@ -179,10 +190,14 @@ class Binary(Value): left: Value right: Value - @property + @attribute def shape(self) -> ds.DataShape: return max(self.left.shape, self.right.shape) + @attribute + def relations(self): + return self.left.relations | self.right.relations + @public class Argument(Value): diff --git a/ibis/expr/operations/generic.py b/ibis/expr/operations/generic.py index be349cd45777..15cbd5a5d345 100644 --- a/ibis/expr/operations/generic.py +++ b/ibis/expr/operations/generic.py @@ -1,13 +1,12 @@ from __future__ import annotations import itertools -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any, Optional from typing import Literal as LiteralType from public import public from typing_extensions import TypeVar -import ibis.common.exceptions as com import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz @@ -20,33 +19,6 @@ from ibis.expr.operations.relations import Relation # noqa: TCH001 -@public -class TableColumn(Value, Named): - """Selects a column from a `Table`.""" - - table: Relation - name: Union[str, int] - - shape = ds.columnar - - def __init__(self, table, name): - if isinstance(name, int): - name = table.schema.name_at_position(name) - - if name not in table.schema: - columns_formatted = ", ".join(map(repr, table.schema.names)) - raise com.IbisTypeError( - f"Column {name!r} is not found in table. " - f"Existing columns: {columns_formatted}." - ) - - super().__init__(table=table, name=name) - - @property - def dtype(self): - return self.table.schema[self.name] - - @public class RowID(Value, Named): """The row number (an autonumeric) of the returned result.""" @@ -57,22 +29,9 @@ class RowID(Value, Named): shape = ds.columnar dtype = dt.int64 - -@public -class TableArrayView(Value, Named): - """Helper operation class for creating scalar subqueries.""" - - table: Relation - - shape = ds.columnar - - @property - def dtype(self): - return self.table.schema[self.name] - - @property - def name(self): - return self.table.schema.names[0] + @attribute + def relations(self): + return frozenset({self.table}) @public diff --git a/ibis/expr/operations/geospatial.py b/ibis/expr/operations/geospatial.py index 21067dc1d161..ba0d6d93a01c 100644 --- a/ibis/expr/operations/geospatial.py +++ b/ibis/expr/operations/geospatial.py @@ -4,7 +4,7 @@ import ibis.expr.datatypes as dt from ibis.expr.operations.core import Binary, Unary, Value -from ibis.expr.operations.reductions import Reduction +from ibis.expr.operations.reductions import Filterable, Reduction @public @@ -181,9 +181,11 @@ class GeoTouches(GeoSpatialBinOp): @public -class GeoUnaryUnion(Reduction, GeoSpatialUnOp): +class GeoUnaryUnion(Filterable, Reduction): """Returns the pointwise union of the geometries in the column.""" + arg: Value[dt.GeoSpatial] + dtype = dt.geometry diff --git a/ibis/expr/operations/logical.py b/ibis/expr/operations/logical.py index 3ca8675d49c2..78a33de77e1c 100644 --- a/ibis/expr/operations/logical.py +++ b/ibis/expr/operations/logical.py @@ -2,14 +2,12 @@ from public import public -import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import ValidationError, attribute from ibis.common.exceptions import IbisTypeError from ibis.common.typing import VarTuple # noqa: TCH001 -from ibis.expr.operations.core import Binary, Column, Unary, Value -from ibis.expr.operations.relations import Relation # noqa: TCH001 +from ibis.expr.operations.core import Binary, Unary, Value @public @@ -137,15 +135,6 @@ def shape(self): return rlz.highest_precedence_shape(args) -@public -class InColumn(Value): - value: Value - options: Column[dt.Any] - - dtype = dt.boolean - shape = rlz.shape_like("args") - - @public class IfElse(Value): """Ternary case expression, equivalent to. @@ -164,67 +153,3 @@ class IfElse(Value): @attribute def dtype(self): return rlz.highest_precedence_dtype([self.true_expr, self.false_null_expr]) - - -@public -class ExistsSubquery(Value): - foreign_table: Relation - predicates: VarTuple[Value[dt.Boolean]] - - dtype = dt.boolean - shape = ds.columnar - - -@public -class UnresolvedExistsSubquery(Value): - """An exists subquery whose outer leaf table is unknown. - - Notes - ----- - Consider the following ibis expressions - - ```python - import ibis - - t = ibis.table(dict(a="string")) - s = ibis.table(dict(a="string")) - - cond = (t.a == s.a).any() - ``` - - Without knowing the table to use as the outer query there are two ways to - turn this expression into a SQL `EXISTS` predicate, depending on which of - `t` or `s` is filtered on. - - Filtering from `t`: - - ```sql - SELECT * - FROM t - WHERE EXISTS (SELECT 1 FROM s WHERE t.a = s.a) - ``` - - Filtering from `s`: - - ```sql - SELECT * - FROM s - WHERE EXISTS (SELECT 1 FROM t WHERE t.a = s.a) - ``` - - Notably the correlated subquery cannot stand on its own. - - The purpose of `UnresolvedExistsSubquery` is to capture enough information - about an exists predicate such that it can be resolved when predicates are - resolved against the outer leaf table when `Selection`s are constructed. - """ - - tables: VarTuple[Relation] - predicates: VarTuple[Value[dt.Boolean]] - - dtype = dt.boolean - shape = ds.columnar - - def resolve(self, table) -> ExistsSubquery: - (foreign_table,) = (t for t in self.tables if t != table) - return ExistsSubquery(foreign_table, self.predicates) diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index e0b63041395e..2a85dbfcbab5 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -22,6 +22,7 @@ def __window_op__(self): return self +# TODO(kszucs): all reductions all filterable so we could remove Filterable class Filterable(Value): where: Optional[Value[dt.Boolean]] = None @@ -39,6 +40,10 @@ class CountStar(Filterable, Reduction): dtype = dt.int64 + @attribute + def relations(self): + return frozenset({self.arg}) + @public class CountDistinctStar(Filterable, Reduction): @@ -46,6 +51,10 @@ class CountDistinctStar(Filterable, Reduction): dtype = dt.int64 + @attribute + def relations(self): + return frozenset({self.arg}) + @public class Arbitrary(Filterable, Reduction): diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index 3ba7ff5b7e39..d42637013730 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -1,671 +1,457 @@ from __future__ import annotations import itertools +import typing from abc import abstractmethod -from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional -from typing import Union as UnionType +from typing import Annotated, Any, Literal, Optional, TypeVar from public import public -import ibis.common.exceptions as com +import ibis.expr.datashape as ds import ibis.expr.datatypes as dt -import ibis.expr.operations as ops -from ibis import util -from ibis.common.annotations import annotated, attribute -from ibis.common.collections import FrozenDict # noqa: TCH001 -from ibis.common.deferred import Deferred +import ibis.expr.rules as rlz +from ibis.common.annotations import attribute +from ibis.common.collections import FrozenDict +from ibis.common.exceptions import IbisTypeError, IntegrityError, RelationError from ibis.common.grounds import Concrete -from ibis.common.patterns import Between, Coercible, Eq -from ibis.common.typing import VarTuple # noqa: TCH001 -from ibis.expr.operations.core import Column, Named, Node, Scalar, Value +from ibis.common.patterns import Between, InstanceOf +from ibis.common.typing import Coercible, VarTuple +from ibis.expr.operations.core import Alias, Column, Node, Scalar, Value from ibis.expr.operations.sortkeys import SortKey # noqa: TCH001 from ibis.expr.schema import Schema from ibis.formats import TableProxy # noqa: TCH001 +from ibis.util import gen_name -if TYPE_CHECKING: - import ibis.expr.types as ir +T = TypeVar("T") - -_table_names = (f"unbound_table_{i:d}" for i in itertools.count()) - - -@public -def genname(): - return next(_table_names) +Unaliased = Annotated[T, ~InstanceOf(Alias)] @public class Relation(Node, Coercible): @classmethod def __coerce__(cls, value): - import pandas as pd + from ibis.expr.types import TableExpr - import ibis - import ibis.expr.types as ir - - if isinstance(value, pd.DataFrame): - return ibis.memtable(value).op() - elif isinstance(value, ir.Expr): + if isinstance(value, Relation): + return value + elif isinstance(value, TableExpr): return value.op() else: - return value + raise TypeError(f"Cannot coerce {value!r} to a Relation") + + @property + @abstractmethod + def values(self) -> FrozenDict[str, Value]: + """A mapping of column names to expressions which build up the relation. - def order_by(self, sort_exprs): - return Selection(self, [], sort_keys=sort_exprs) + This attribute is heavily used in rewrites as well as during field + dereferencing in the API layer. The returned expressions must only + originate from parent relations, depending on the relation type. + """ @property @abstractmethod def schema(self) -> Schema: + """The schema of the relation. + + All relations must have a well-defined schema. + """ ... - def to_expr(self): - import ibis.expr.types as ir + @property + def fields(self) -> FrozenDict[str, Column]: + """A mapping of column names to fields of the relation. - return ir.Table(self) + This calculated property shouldn't be overridden in subclasses since it + is mostly used for convenience. + """ + return FrozenDict({k: Field(self, k) for k in self.schema}) + def to_expr(self): + from ibis.expr.types import TableExpr -TableNode = Relation + return TableExpr(self) @public -class Namespace(Concrete): - database: Optional[str] = None - schema: Optional[str] = None +class Field(Value): + rel: Relation + name: str + shape = ds.columnar -@public -class PhysicalTable(Relation, Named): - pass + def __init__(self, rel, name): + if name not in rel.schema: + columns_formatted = ", ".join(map(repr, rel.schema.names)) + raise IbisTypeError( + f"Column {name!r} is not found in table. " + f"Existing columns: {columns_formatted}." + ) + super().__init__(rel=rel, name=name) + + @attribute + def dtype(self): + return self.rel.schema[self.name] + + @attribute + def relations(self): + return frozenset({self.rel}) -# TODO(kszucs): PhysicalTable should have a source attribute and UnbountTable -# should just extend TableNode @public -class UnboundTable(PhysicalTable): - schema: Schema - name: Optional[str] = None - namespace: Namespace = Namespace() +class Subquery(Value): + rel: Relation + shape = ds.columnar - def __init__(self, schema, name, namespace) -> None: - if name is None: - name = genname() - super().__init__(schema=schema, name=name, namespace=namespace) + def __init__(self, rel, **kwargs): + if len(rel.schema) != 1: + raise IntegrityError( + f"Subquery must have exactly one column, got {len(rel.schema)}" + ) + super().__init__(rel=rel, **kwargs) + @attribute + def name(self): + return self.rel.schema.names[0] -@public -class DatabaseTable(PhysicalTable): - name: str - schema: Schema - source: Any - namespace: Namespace = Namespace() + @attribute + def value(self): + return self.rel.values[self.name] + + @attribute + def relations(self): + return frozenset() + + @property + def dtype(self): + return self.value.dtype @public -class SQLQueryResult(TableNode): - """A table sourced from the result set of a select query.""" +class ScalarSubquery(Subquery): + def __init__(self, rel): + from ibis.expr.rewrites import ReductionValue - query: str - schema: Schema - source: Any + super().__init__(rel=rel) + if not self.value.find(ReductionValue, filter=Value): + raise IntegrityError( + f"Subquery {self.value!r} is not scalar, it must be turned into a scalar subquery first" + ) @public -class InMemoryTable(PhysicalTable): - name: str - schema: Schema - data: TableProxy +class ExistsSubquery(Subquery): + dtype = dt.boolean + +@public +class InSubquery(Subquery): + needle: Value + dtype = dt.boolean -# TODO(kszucs): desperately need to clean this up, the majority of this -# functionality should be handled by input rules for the Join class -def _clean_join_predicates(left, right, predicates): - import ibis.expr.analysis as an - import ibis.expr.types as ir - from ibis.expr.analysis import shares_all_roots - - result = [] - - for pred in predicates: - if isinstance(pred, tuple): - if len(pred) != 2: - raise com.ExpressionError("Join key tuple must be length 2") - lk, rk = pred - lk = left.to_expr()._ensure_expr(lk) - rk = right.to_expr()._ensure_expr(rk) - pred = lk == rk - elif isinstance(pred, str): - pred = left.to_expr()[pred] == right.to_expr()[pred] - elif pred is True or pred is False: - pred = ops.Literal(pred, dtype="bool").to_expr() - elif isinstance(pred, Value): - pred = pred.to_expr() - elif isinstance(pred, Deferred): - # resolve deferred expressions on the left table - pred = pred.resolve(left.to_expr()) - elif not isinstance(pred, ir.Expr): - raise NotImplementedError(pred) - - if not isinstance(pred, ir.BooleanValue): - raise com.ExpressionError("Join predicate must be a boolean expression") - - preds = an.flatten_predicate(pred.op()) - result.extend(preds) - - # Validate join predicates. Each predicate must be valid jointly when - # considering the roots of each input table - for predicate in result: - if not shares_all_roots(predicate, [left, right]): - raise com.RelationError( - f"The expression {predicate!r} does not fully " - "originate from dependencies of the table " - "expression." + def __init__(self, **kwargs): + super().__init__(**kwargs) + if not rlz.comparable(self.value, self.needle): + raise IntegrityError( + f"Subquery {self.needle!r} is not comparable to {self.value!r}" ) - assert all(isinstance(pred, ops.Node) for pred in result) + @attribute + def relations(self): + return self.needle.relations - return tuple(result) +def _check_integrity(values, allowed_parents): + for value in values: + for rel in value.relations: + if rel not in allowed_parents: + raise IntegrityError( + f"Cannot add {value!r} to projection, they belong to another relation" + ) -@public -class Join(Relation): - left: Relation - right: Relation - predicates: Any = () - def __init__(self, left, right, predicates, **kwargs): - # TODO(kszucs): predicates should be already a list of operations, need - # to update the validation rule for the Join classes which is a noop - # currently - import ibis.expr.operations as ops - import ibis.expr.types as ir +@public +class Project(Relation): + parent: Relation + values: FrozenDict[str, Unaliased[Value]] - # TODO(kszucs): need to factor this out to appropriate join predicate - # rules - predicates = [ - pred.op() if isinstance(pred, ir.Expr) else pred - for pred in util.promote_list(predicates) - ] - - if left.equals(right): - # GH #667: If left and right table have a common parent expression, - # e.g. they have different filters, we need to add a self-reference - # and make the appropriate substitution in the join predicates - right = ops.SelfReference(right) - elif isinstance(right, Join): - # for joins with joins on the right side we turn the right side - # into a view, otherwise the join tree is incorrectly flattened - # and tables on the right are incorrectly scoped - old = right - new = right = ops.SelfReference(right) - rule = Eq(old) >> new - predicates = [pred.replace(rule) for pred in predicates] - - predicates = _clean_join_predicates(left, right, predicates) - - super().__init__(left=left, right=right, predicates=predicates, **kwargs) + def __init__(self, parent, values): + _check_integrity(values.values(), {parent}) + super().__init__(parent=parent, values=values) - @property + @attribute def schema(self): - # TODO(kszucs): use `return self.left.schema | self.right.schema` instead which - # eliminates unnecessary projection over the join, but currently breaks the - # pandas backend - left, right = self.left.schema, self.right.schema - if duplicates := left.keys() & right.keys(): - raise com.IntegrityError(f"Duplicate column name(s): {duplicates}") - return Schema( - { - name: typ.copy(nullable=True) - for name, typ in itertools.chain(left.items(), right.items()) - } - ) + return Schema({k: v.dtype for k, v in self.values.items()}) -@public -class InnerJoin(Join): - pass +class Simple(Relation): + parent: Relation + @attribute + def values(self): + return self.parent.fields -@public -class LeftJoin(Join): - pass + @attribute + def schema(self): + return self.parent.schema @public -class RightJoin(Join): - pass +class SelfReference(Simple): + _uid_counter = itertools.count() + identifier: Optional[int] = None -@public -class OuterJoin(Join): - pass + def __init__(self, parent, identifier): + if identifier is None: + identifier = next(self._uid_counter) + super().__init__(parent=parent, identifier=identifier) + @attribute + def name(self) -> str: + if (name := getattr(self.parent, "name", None)) is not None: + return f"{name}_ref" + return gen_name("self_ref") -@public -class AnyInnerJoin(Join): - pass + +JoinKind = Literal[ + "inner", + "left", + "right", + "outer", + "asof", + "semi", + "anti", + "any_inner", + "any_left", + "cross", +] @public -class AnyLeftJoin(Join): - pass +class JoinLink(Node): + how: JoinKind + table: SelfReference + predicates: VarTuple[Value[dt.Boolean]] @public -class LeftSemiJoin(Join): +class JoinChain(Relation): + first: Relation + rest: VarTuple[JoinLink] + values: FrozenDict[str, Unaliased[Value]] + + def __init__(self, first, rest, values): + allowed_parents = {first} + for join in rest: + allowed_parents.add(join.table) + _check_integrity(join.predicates, allowed_parents) + _check_integrity(values.values(), allowed_parents) + super().__init__(first=first, rest=rest, values=values) + @attribute def schema(self): - return self.left.schema + return Schema({k: v.dtype.copy(nullable=True) for k, v in self.values.items()}) + + def to_expr(self): + import ibis.expr.types as ir + + return ir.JoinExpr(self) @public -class LeftAntiJoin(Join): - @attribute - def schema(self): - return self.left.schema +class Sort(Simple): + keys: VarTuple[SortKey] + + def __init__(self, parent, keys): + _check_integrity(keys, {parent}) + super().__init__(parent=parent, keys=keys) @public -class CrossJoin(Join): - pass +class Filter(Simple): + predicates: VarTuple[Value[dt.Boolean]] + + def __init__(self, parent, predicates): + from ibis.expr.rewrites import ReductionValue + + for pred in predicates: + if pred.find(ReductionValue, filter=Value): + raise IntegrityError( + f"Cannot add {pred!r} to filter, it is a reduction" + ) + if pred.relations and parent not in pred.relations: + raise IntegrityError( + f"Cannot add {pred!r} to filter, they belong to another relation" + ) + super().__init__(parent=parent, predicates=predicates) @public -class AsOfJoin(Join): - # TODO(kszucs): convert to proper predicate rules - by: Any = () - tolerance: Optional[Value[dt.Interval]] = None +class Limit(Simple): + n: typing.Union[int, Scalar[dt.Integer], None] = None + offset: typing.Union[int, Scalar[dt.Integer]] = 0 + + +@public +class Aggregate(Relation): + parent: Relation + groups: FrozenDict[str, Unaliased[Column]] + metrics: FrozenDict[str, Unaliased[Scalar]] + + def __init__(self, parent, groups, metrics): + _check_integrity(groups.values(), {parent}) + _check_integrity(metrics.values(), {parent}) + if duplicates := groups.keys() & metrics.keys(): + raise RelationError( + f"Cannot add {duplicates} to aggregate, they are already in the groupby" + ) + super().__init__(parent=parent, groups=groups, metrics=metrics) + + @attribute + def values(self): + return FrozenDict({**self.groups, **self.metrics}) - def __init__(self, left, right, by, predicates, **kwargs): - by = _clean_join_predicates(left, right, util.promote_list(by)) - super().__init__(left=left, right=right, by=by, predicates=predicates, **kwargs) + @attribute + def schema(self): + return Schema({k: v.dtype for k, v in self.values.items()}) @public -class SetOp(Relation): +class Set(Relation): left: Relation right: Relation distinct: bool = False def __init__(self, left, right, **kwargs): - # convert to dictionary first, to get key-unordered comparison - # semantics + # convert to dictionary first, to get key-unordered comparison semantics if dict(left.schema) != dict(right.schema): - raise com.RelationError("Table schemas must be equal for set operations") + raise RelationError("Table schemas must be equal for set operations") elif left.schema.names != right.schema.names: # rewrite so that both sides have the columns in the same order making it # easier for the backends to implement set operations - cols = [ops.TableColumn(right, name) for name in left.schema.names] - right = Selection(right, cols) + cols = {name: Field(right, name) for name in left.schema.names} + right = Project(right, cols) super().__init__(left=left, right=right, **kwargs) + @attribute + def values(self): + return FrozenDict() + @attribute def schema(self): return self.left.schema @public -class Union(SetOp): +class Union(Set): pass @public -class Intersection(SetOp): +class Intersection(Set): pass @public -class Difference(SetOp): +class Difference(Set): pass @public -class Limit(Relation): - table: Relation - n: UnionType[int, Scalar[dt.Integer], None] = None - offset: UnionType[int, Scalar[dt.Integer]] = 0 +class PhysicalTable(Relation): + name: str @attribute - def schema(self): - return self.table.schema + def values(self): + return FrozenDict() @public -class SelfReference(Relation): - table: Relation - - @attribute - def name(self) -> str: - if (name := getattr(self.table, "name", None)) is not None: - return f"{name}_ref" - return util.gen_name("self_ref") - - @attribute - def schema(self): - return self.table.schema - - -class Projection(Relation): - table: Relation - selections: VarTuple[Relation | Value] - - @attribute - def schema(self): - # Resolve schema and initialize - if not self.selections: - return self.table.schema - - types, names = [], [] - for projection in self.selections: - if isinstance(projection, Value): - names.append(projection.name) - types.append(projection.dtype) - elif isinstance(projection, TableNode): - schema = projection.schema - names.extend(schema.names) - types.extend(schema.types) - - return Schema.from_tuples(zip(names, types)) - - -def _add_alias(op: ops.Value | ops.TableNode): - """Add a name to a projected column if necessary.""" - if isinstance(op, ops.Value) and not isinstance(op, (ops.Alias, ops.TableColumn)): - return ops.Alias(op, op.name) - else: - return op +class UnboundTable(PhysicalTable): + schema: Schema @public -class Selection(Projection): - predicates: VarTuple[Value[dt.Boolean]] = () - sort_keys: VarTuple[SortKey] = () - - def __init__(self, table, selections, predicates, sort_keys, **kwargs): - from ibis.expr.analysis import shares_all_roots, shares_some_roots - - if not shares_all_roots(selections + sort_keys, table): - raise com.RelationError( - "Selection expressions don't fully originate from " - "dependencies of the table expression." - ) - - for predicate in predicates: - if isinstance(predicate, ops.Literal): - if not (dtype := predicate.dtype).is_boolean(): - raise com.IbisTypeError(f"Invalid predicate dtype: {dtype}") - elif not shares_some_roots(predicate, table): - raise com.RelationError("Predicate doesn't share any roots with table") - - super().__init__( - table=table, - selections=tuple(map(_add_alias, selections)), - predicates=predicates, - sort_keys=sort_keys, - **kwargs, - ) - - @annotated - def order_by(self, keys: VarTuple[SortKey]): - from ibis.expr.analysis import shares_all_roots, sub_immediate_parents - - if not self.selections: - if shares_all_roots(keys, table := self.table): - sort_keys = tuple(self.sort_keys) + tuple( - sub_immediate_parents(key, table) for key in keys - ) - - return Selection( - table, - self.selections, - predicates=self.predicates, - sort_keys=sort_keys, - ) - - return Selection(self, [], sort_keys=keys) - - @attribute - def _projection(self): - return Projection(self.table, self.selections) +class Namespace(Concrete): + database: Optional[str] = None + schema: Optional[str] = None @public -class DummyTable(Relation): - # TODO(kszucs): verify that it has at least one element: Length(at_least=1) - values: VarTuple[Value[dt.Any]] - - @attribute - def schema(self): - return Schema({op.name: op.dtype for op in self.values}) +class DatabaseTable(PhysicalTable): + schema: Schema + source: Any + namespace: Namespace = Namespace() @public -class Aggregation(Relation): - table: Relation - metrics: VarTuple[Scalar] = () - by: VarTuple[Column] = () - having: VarTuple[Scalar[dt.Boolean]] = () - predicates: VarTuple[Value[dt.Boolean]] = () - sort_keys: VarTuple[SortKey] = () - - def __init__(self, table, metrics, by, having, predicates, sort_keys): - from ibis.expr.analysis import shares_all_roots, shares_some_roots - - # All non-scalar refs originate from the input table - if not shares_all_roots(metrics + by + having + sort_keys, table): - raise com.RelationError( - "Selection expressions don't fully originate from " - "dependencies of the table expression." - ) - - # invariant due to Aggregation and AggregateSelection requiring a valid - # Selection - assert all(shares_some_roots(predicate, table) for predicate in predicates) - - if not by: - sort_keys = tuple() - - super().__init__( - table=table, - metrics=tuple(map(_add_alias, metrics)), - by=tuple(map(_add_alias, by)), - having=having, - predicates=predicates, - sort_keys=sort_keys, - ) - - @attribute - def _projection(self): - return Projection(self.table, self.metrics + self.by) - - @attribute - def schema(self): - names, types = [], [] - for value in self.by + self.metrics: - names.append(value.name) - types.append(value.dtype) - return Schema.from_tuples(zip(names, types)) - - @annotated - def order_by(self, keys: VarTuple[SortKey]): - from ibis.expr.analysis import shares_all_roots, sub_immediate_parents - - if shares_all_roots(keys, table := self.table): - sort_keys = tuple(self.sort_keys) + tuple( - sub_immediate_parents(key, table) for key in keys - ) - return Aggregation( - table, - metrics=self.metrics, - by=self.by, - having=self.having, - predicates=self.predicates, - sort_keys=sort_keys, - ) - - return Selection(self, [], sort_keys=keys) +class InMemoryTable(PhysicalTable): + schema: Schema + data: TableProxy @public -class Distinct(Relation): - """Distinct is a table-level unique-ing operation. +class SQLQueryResult(Relation): + """A table sourced from the result set of a select query.""" - In SQL, you might have: + query: str + schema: Schema + source: Any + values = FrozenDict() - SELECT DISTINCT foo - FROM table - SELECT DISTINCT foo, bar - FROM table - """ +@public +class SQLStringView(PhysicalTable): + """A view created from a SQL string.""" - table: Relation + child: Relation + query: str @attribute def schema(self): - return self.table.schema + # TODO(kszucs): avoid converting to expression + backend = self.child.to_expr()._find_backend() + return backend._get_schema_using_query(self.query) @public -class Sample(Relation): - """Sample performs random sampling of records in a table.""" - - table: Relation - fraction: Annotated[float, Between(0, 1)] - method: Literal["row", "block"] - seed: UnionType[int, None] = None +class DummyTable(Relation): + values: FrozenDict[str, Value] @attribute def schema(self): - return self.table.schema + return Schema({k: v.dtype for k, v in self.values.items()}) -# TODO(kszucs): split it into two operations, one working with a single replacement -# value and the other with a mapping -# TODO(kszucs): the single value case was limited to numeric and string types @public -class FillNa(Relation): +class FillNa(Simple): """Fill null values in the table.""" - table: Relation - replacements: UnionType[Value[dt.Numeric | dt.String], FrozenDict[str, Any]] - - @attribute - def schema(self): - return self.table.schema + replacements: typing.Union[Value[dt.Numeric | dt.String], FrozenDict[str, Any]] @public -class DropNa(Relation): +class DropNa(Simple): """Drop null values in the table.""" - table: Relation - how: Literal["any", "all"] - subset: Optional[VarTuple[Column[dt.Any]]] = None - - @attribute - def schema(self): - return self.table.schema + how: typing.Literal["any", "all"] + subset: Optional[VarTuple[Column]] = None @public -class View(PhysicalTable): - """A view created from an expression.""" - - child: Relation - name: str +class Sample(Simple): + """Sample performs random sampling of records in a table.""" - @attribute - def schema(self): - return self.child.schema + fraction: Annotated[float, Between(0, 1)] + method: typing.Literal["row", "block"] + seed: typing.Union[int, None] = None @public -class SQLStringView(PhysicalTable): - """A view created from a SQL string.""" - - child: Relation - name: str - query: str - - @attribute - def schema(self): - # TODO(kszucs): avoid converting to expression - backend = self.child.to_expr()._find_backend() - return backend._get_schema_using_query(self.query) +class Distinct(Simple): + """Distinct is a table-level unique-ing operation.""" -def _dedup_join_columns(expr: ir.Table, lname: str, rname: str): - from ibis.expr.operations.generic import TableColumn - from ibis.expr.operations.logical import Equals - - op = expr.op() - left = op.left.to_expr() - right = op.right.to_expr() - - right_columns = frozenset(right.columns) - overlap = frozenset(column for column in left.columns if column in right_columns) - equal = set() - - if isinstance(op, InnerJoin) and util.all_of(op.predicates, Equals): - # For inner joins composed exclusively of equality predicates, we can - # avoid renaming columns with colliding names if their values are - # guaranteed to be equal due to the predicate. Here we collect a set of - # colliding column names that are known to have equal values between - # the left and right tables in the join. - tables = {op.left, op.right} - for pred in op.predicates: - if ( - isinstance(pred.left, TableColumn) - and isinstance(pred.right, TableColumn) - and {pred.left.table, pred.right.table} == tables - and pred.left.name == pred.right.name - ): - equal.add(pred.left.name) - - if not overlap: - return expr - - # Rename columns in the left table that overlap, unless they're known to be - # equal to a column in the right - left_projections = [ - left[column] - .cast(left[column].type().copy(nullable=True)) - .name(lname.format(name=column) if lname else column) - if column in overlap and column not in equal - else left[column].cast(left[column].type().copy(nullable=True)).name(column) - for column in left.columns - ] - - # Rename columns in the right table that overlap, dropping any columns that - # are known to be equal to those in the left table - right_projections = [ - right[column] - .cast(right[column].type().copy(nullable=True)) - .name(rname.format(name=column) if rname else column) - if column in overlap - else right[column].cast(right[column].type().copy(nullable=True)).name(column) - for column in right.columns - if column not in equal - ] - projections = left_projections + right_projections - - # Certain configurations can result in the renamed columns still colliding, - # here we check for duplicates again, and raise a nicer error message if - # any exist. - seen = set() - collisions = set() - for column in projections: - name = column.get_name() - if name in seen: - collisions.add(name) - seen.add(name) - if collisions: - raise com.IntegrityError( - f"Joining with `lname={lname!r}, rname={rname!r}` resulted in multiple " - f"columns mapping to the following names `{sorted(collisions)}`. Please " - f"adjust `lname` and/or `rname` accordingly" - ) - return expr.select(projections) - - -public(TableNode=Relation) +# TODO(kszucs): support t.select(*t) syntax by implementing TableExpr.__iter__() diff --git a/ibis/expr/operations/sortkeys.py b/ibis/expr/operations/sortkeys.py index 643427bf93d3..4b65b5e7adfb 100644 --- a/ibis/expr/operations/sortkeys.py +++ b/ibis/expr/operations/sortkeys.py @@ -28,6 +28,7 @@ class SortKey(Value): """A sort operation.""" + # TODO(kszucs): rename expr to arg or something else except expr expr: Value ascending: bool = True diff --git a/ibis/expr/operations/strings.py b/ibis/expr/operations/strings.py index d17f22e412a1..9b40261b9d2a 100644 --- a/ibis/expr/operations/strings.py +++ b/ibis/expr/operations/strings.py @@ -4,7 +4,6 @@ from public import public -import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import attribute @@ -78,7 +77,7 @@ class Repeat(Value): arg: Value[dt.String] times: Value[dt.Integer] - shape = rlz.shape_like("arg") + shape = rlz.shape_like("args") dtype = dt.string @@ -156,7 +155,7 @@ class ArrayStringJoin(Value): @public class StartsWith(Value): arg: Value[dt.String] - start: Value[dt.String, ds.Scalar] + start: Value[dt.String] dtype = dt.boolean shape = rlz.shape_like("arg") @@ -165,7 +164,7 @@ class StartsWith(Value): @public class EndsWith(Value): arg: Value[dt.String] - end: Value[dt.String, ds.Scalar] + end: Value[dt.String] dtype = dt.boolean shape = rlz.shape_like("arg") diff --git a/ibis/expr/operations/temporal_windows.py b/ibis/expr/operations/temporal_windows.py index 415f0b026fd2..8eec01e25713 100644 --- a/ibis/expr/operations/temporal_windows.py +++ b/ibis/expr/operations/temporal_windows.py @@ -5,6 +5,7 @@ from public import public import ibis.expr.datatypes as dt +from ibis.common.annotations import attribute from ibis.expr.operations.core import Column, Scalar # noqa: TCH001 from ibis.expr.operations.relations import Relation from ibis.expr.schema import Schema @@ -14,9 +15,14 @@ class WindowingTVF(Relation): """Generic windowing table-valued function.""" + # TODO(kszucs): rename to `parent` table: Relation time_col: Column[dt.Timestamp] # enforce timestamp column type here + @attribute + def values(self): + return self.table.fields + @property def schema(self): names = list(self.table.schema.names) @@ -26,6 +32,8 @@ def schema(self): # of original relation as well as additional 3 columns named “window_start”, # “window_end”, “window_time” to indicate the assigned window + # TODO(kszucs): this looks like an implementation detail leaked from the + # flink backend names.extend(["window_start", "window_end", "window_time"]) # window_start, window_end, window_time have type TIMESTAMP(3) in Flink types.extend([dt.timestamp(scale=3)] * 3) diff --git a/ibis/expr/operations/tests/test_structs.py b/ibis/expr/operations/tests/test_structs.py index 7b7c36fae402..efded74516df 100644 --- a/ibis/expr/operations/tests/test_structs.py +++ b/ibis/expr/operations/tests/test_structs.py @@ -15,7 +15,7 @@ def test_struct_column_shape(): assert op.shape == ds.scalar - col = ops.TableColumn( + col = ops.Field( ops.UnboundTable(schema=ibis.schema(dict(a="int64")), name="t"), "a" ) op = ops.StructColumn(names=("a",), values=(col,)) diff --git a/ibis/expr/operations/window.py b/ibis/expr/operations/window.py index c724615708f9..b87686853032 100644 --- a/ibis/expr/operations/window.py +++ b/ibis/expr/operations/window.py @@ -129,12 +129,9 @@ class WindowFunction(Value): shape = ds.columnar def __init__(self, func, frame): - from ibis.expr.analysis import shares_all_roots - - if not shares_all_roots(func, frame): + if func.relations and frame.table not in func.relations: raise com.RelationError( - "Window function expressions doesn't fully originate from the " - "dependencies of the window expression." + "The reduction has different parent relation than the window" ) super().__init__(func=func, frame=frame) diff --git a/ibis/expr/rewrites.py b/ibis/expr/rewrites.py index 4ae694f0120c..74e2294ec3db 100644 --- a/ibis/expr/rewrites.py +++ b/ibis/expr/rewrites.py @@ -1,81 +1,161 @@ """Some common rewrite functions to be shared between backends.""" from __future__ import annotations -import functools -from collections.abc import Mapping +import toolz -import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.common.exceptions import UnsupportedOperationError -from ibis.common.patterns import pattern, replace +from ibis.common.deferred import Item, _, deferred, var +from ibis.common.exceptions import ExpressionError +from ibis.common.patterns import Check, pattern, replace from ibis.util import Namespace p = Namespace(pattern, module=ops) +d = Namespace(deferred, module=ops) -@replace(p.FillNa) -def rewrite_fillna(_): - """Rewrite FillNa expressions to use more common operations.""" - if isinstance(_.replacements, Mapping): - mapping = _.replacements - else: - mapping = { - name: _.replacements - for name, type in _.table.schema.items() - if type.nullable - } - - if not mapping: - return _.table - - selections = [] - for name in _.table.schema.names: - col = ops.TableColumn(_.table, name) - if (value := mapping.get(name)) is not None: - col = ops.Alias(ops.Coalesce((col, value)), name) - selections.append(col) - - return ops.Selection(_.table, selections, (), ()) - - -@replace(p.DropNa) -def rewrite_dropna(_): - """Rewrite DropNa expressions to use more common operations.""" - if _.subset is None: - columns = [ops.TableColumn(_.table, name) for name in _.table.schema.names] +y = var("y") +name = var("name") + + +@replace(ops.Analytic) +def project_wrap_analytic(_, rel): + # Wrap analytic functions in a window function + return ops.WindowFunction(_, ops.RowsWindowFrame(rel)) + + +@replace(ops.Reduction) +def project_wrap_reduction(_, rel): + # Query all the tables that the reduction depends on + if _.relations == {rel}: + # The reduction is fully originating from the `rel`, so turn + # it into a window function of `rel` + return ops.WindowFunction(_, ops.RowsWindowFrame(rel)) else: - columns = _.subset - - if columns: - preds = [ - functools.reduce( - ops.And if _.how == "any" else ops.Or, - [ops.NotNull(c) for c in columns], - ) - ] - elif _.how == "all": - preds = [ops.Literal(False, dtype=dt.bool)] + # 1. The reduction doesn't depend on any table, constructed from + # scalar values, so turn it into a scalar subquery. + # 2. The reduction is originating from `rel` and other tables, + # so this is a correlated scalar subquery. + # 3. The reduction is originating entirely from other tables, + # so this is an uncorrelated scalar subquery. + return ops.ScalarSubquery(_.to_expr().as_table()) + + +def rewrite_project_input(value, relation): + # we need to detect reductions which are either turned into window functions + # or scalar subqueries depending on whether they are originating from the + # relation + return value.replace( + project_wrap_analytic | project_wrap_reduction, + filter=p.Value & ~p.WindowFunction, + context={"rel": relation}, + ) + + +ReductionValue = p.Reduction | p.Field(p.Aggregate(groups={})) + + +@replace(ReductionValue) +def filter_wrap_reduction(_): + # Wrap reductions or fields referencing an aggregation without a group by - + # which are scalar fields - in a scalar subquery. In the latter case we + # use the reduction value from the aggregation. + if isinstance(_, ops.Field): + value = _.rel.values[_.name] else: - return _.table + value = _ + return ops.ScalarSubquery(value.to_expr().as_table()) + - return ops.Selection(_.table, (), preds, ()) +def rewrite_filter_input(value): + return value.replace(filter_wrap_reduction, filter=p.Value & ~p.WindowFunction) -@replace(p.Sample) -def rewrite_sample(_): - """Rewrite Sample as `t.filter(random() <= fraction)`. +@replace(p.Analytic | p.Reduction) +def window_wrap_reduction(_, frame): + # Wrap analytic and reduction functions in a window function. Used in the + # value.over() API. + return ops.WindowFunction(_, frame) - Errors as unsupported if a `seed` is specified. - """ - if _.seed is not None: - raise UnsupportedOperationError( - "`Table.sample` with a random seed is unsupported" +@replace(p.WindowFunction) +def window_merge_frames(_, frame): + # Merge window frames, used in the value.over() and groupby.select() APIs. + if _.frame.start and frame.start and _.frame.start != frame.start: + raise ExpressionError( + "Unable to merge windows with conflicting `start` boundary" ) + if _.frame.end and frame.end and _.frame.end != frame.end: + raise ExpressionError("Unable to merge windows with conflicting `end` boundary") + + start = _.frame.start or frame.start + end = _.frame.end or frame.end + group_by = tuple(toolz.unique(_.frame.group_by + frame.group_by)) + + order_by = {} + for sort_key in _.frame.order_by + frame.order_by: + order_by[sort_key.expr] = sort_key.ascending + order_by = tuple(ops.SortKey(k, v) for k, v in order_by.items()) + + frame = _.frame.copy(start=start, end=end, group_by=group_by, order_by=order_by) + return ops.WindowFunction(_.func, frame) - return ops.Selection( - _.table, - (), - (ops.LessEqual(ops.RandomScalar(), _.fraction),), - (), + +def rewrite_window_input(value, frame): + context = {"frame": frame} + # if self is a reduction or analytic function, wrap it in a window function + node = value.replace( + window_wrap_reduction, + filter=p.Value & ~p.WindowFunction, + context=context, ) + # if self is already a window function, merge the existing window frame + # with the requested window frame + return node.replace(window_merge_frames, filter=p.Value, context=context) + + +# TODO(kszucs): schema comparison should be updated to not distinguish between +# different column order +@replace(p.Project(y @ p.Relation) & Check(_.schema == y.schema)) +def complete_reprojection(_, y): + # TODO(kszucs): this could be moved to the pattern itself but not sure how + # to express it, especially in a shorter way then the following check + for name in _.schema: + if _.values[name] != ops.Field(y, name): + return _ + return y + + +@replace(p.Project(y @ p.Project)) +def subsequent_projects(_, y): + rule = p.Field(y, name) >> Item(y.values, name) + values = {k: v.replace(rule) for k, v in _.values.items()} + return ops.Project(y.parent, values) + + +@replace(p.Filter(y @ p.Filter)) +def subsequent_filters(_, y): + rule = p.Field(y, name) >> d.Field(y.parent, name) + preds = tuple(v.replace(rule) for v in _.predicates) + return ops.Filter(y.parent, y.predicates + preds) + + +@replace(p.Filter(y @ p.Project)) +def reorder_filter_project(_, y): + rule = p.Field(y, name) >> Item(y.values, name) + preds = tuple(v.replace(rule) for v in _.predicates) + + inner = ops.Filter(y.parent, preds) + rule = p.Field(y.parent, name) >> d.Field(inner, name) + projs = {k: v.replace(rule) for k, v in y.values.items()} + + return ops.Project(inner, projs) + + +def simplify(node): + # TODO(kszucs): add a utility to the graph module to do rewrites in multiple + # passes after each other + node = node.replace(reorder_filter_project) + node = node.replace(reorder_filter_project) + node = node.replace(subsequent_projects | subsequent_filters) + node = node.replace(complete_reprojection) + return node diff --git a/ibis/expr/sql.py b/ibis/expr/sql.py index 29509cd2629f..1e3a805e4b2f 100644 --- a/ibis/expr/sql.py +++ b/ibis/expr/sql.py @@ -125,18 +125,23 @@ def convert_join(join, catalog): left_name = join.name left_table = catalog[left_name] + for right_name, desc in join.joins.items(): right_table = catalog[right_name] join_kind = _join_types[desc["side"]] - predicate = None - for left_key, right_key in zip(desc["source_key"], desc["join_key"]): - left_key = convert(left_key, catalog=catalog) - right_key = convert(right_key, catalog=catalog) - if predicate is None: - predicate = left_key == right_key - else: - predicate &= left_key == right_key + if desc["join_key"]: + predicate = None + for left_key, right_key in zip(desc["source_key"], desc["join_key"]): + left_key = convert(left_key, catalog=catalog) + right_key = convert(right_key, catalog=catalog) + if predicate is None: + predicate = left_key == right_key + else: + predicate &= left_key == right_key + else: + condition = desc["condition"] + predicate = convert(condition, catalog=catalog) left_table = left_table.join(right_table, predicates=predicate, how=join_kind) @@ -179,6 +184,11 @@ def convert_literal(literal, catalog): return ibis.literal(value) +@convert.register(sge.Boolean) +def convert_boolean(boolean, catalog): + return ibis.literal(boolean.this) + + @convert.register(sge.Alias) def convert_alias(alias, catalog): this = convert(alias.this, catalog=catalog) @@ -367,6 +377,6 @@ def to_sql(expr: ir.Expr, dialect: str | None = None, **kwargs) -> SQLString: else: read = write = getattr(backend, "_sqlglot_dialect", dialect) - sql = backend._to_sql(expr, **kwargs) + sql = backend._to_sql(expr.unbind(), **kwargs) (pretty,) = sg.transpile(sql, read=read, write=write, pretty=True) return SQLString(pretty) diff --git a/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt b/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt index 44b15ca820f4..23cc70e5b6ac 100644 --- a/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt @@ -11,10 +11,10 @@ r0 := UnboundTable: alltypes j date k time -Aggregation[r0] +Aggregate[r0] + groups: + key1: r0.g + key2: Round(r0.f) metrics: c: Sum(r0.c) - d: Mean(r0.d) - by: - key1: r0.g - key2: Round(r0.f) \ No newline at end of file + d: Mean(r0.d) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt b/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt index aeaba71bcc34..263a594f7ef7 100644 --- a/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_asof_join/repr.txt @@ -1,20 +1,24 @@ -r0 := UnboundTable: right - time2 int32 - value2 float64 - -r1 := UnboundTable: left +r0 := UnboundTable: left time1 int32 value float64 -r2 := AsOfJoin[r1, r0] r1.time1 == r0.time2 +r1 := UnboundTable: right + time2 int32 + value2 float64 + +r2 := SelfReference[r1] -r3 := InnerJoin[r2, r0] r1.value == r0.value2 +r3 := SelfReference[r1] -Selection[r3] - selections: - time1: r2.time1 - value: r2.value +JoinChain[r0] + JoinLink[asof, r2] + r0.time1 == r2.time2 + JoinLink[inner, r3] + r0.value == r3.value2 + values: + time1: r0.time1 + value: r0.value time2: r2.time2 value2: r2.value2 - time2_right: r0.time2 - value2_right: r0.value2 \ No newline at end of file + time2_right: r3.time2 + value2_right: r3.value2 \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_complex_repr/repr.txt b/ibis/expr/tests/snapshots/test_format/test_complex_repr/repr.txt index 0f9d5621fa4b..9334f1c07925 100644 --- a/ibis/expr/tests/snapshots/test_format/test_complex_repr/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_complex_repr/repr.txt @@ -1,20 +1,18 @@ r0 := UnboundTable: t a int64 -r1 := Selection[r0] - predicates: - r0.a < 42 - r0.a >= 42 +r1 := Filter[r0] + r0.a < 42 + r0.a >= 42 -r2 := Selection[r1] - selections: - r1 - x: r1.a + 42 +r2 := Project[r1] + a: r1.a + x: r1.a + 42 -r3 := Aggregation[r2] +r3 := Aggregate[r2] + groups: + x: r2.x metrics: y: Sum(r2.a) - by: - x: r2.x Limit[r3, n=10] \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_destruct_selection/repr.txt b/ibis/expr/tests/snapshots/test_format/test_destruct_selection/repr.txt index 013871ecfb27..05087799363d 100644 --- a/ibis/expr/tests/snapshots/test_format/test_destruct_selection/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_destruct_selection/repr.txt @@ -1,7 +1,7 @@ r0 := UnboundTable: t col int64 -Aggregation[r0] +Aggregate[r0] metrics: sum: StructField(ReductionVectorizedUDF(func=multi_output_udf, func_args=[r0.col], input_type=[int64], return_type={'sum': int64, 'mean': float64}), field='sum') mean: StructField(ReductionVectorizedUDF(func=multi_output_udf, func_args=[r0.col], input_type=[int64], return_type={'sum': int64, 'mean': float64}), field='mean') \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_int_repr.txt b/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_int_repr.txt index d7aa4f2ee692..7ffb48f8a9f9 100644 --- a/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_int_repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_int_repr.txt @@ -2,9 +2,8 @@ r0 := UnboundTable: t a int64 b string -r1 := Selection[r0] - selections: - a: r0.a +r1 := Project[r0] + a: r0.a FillNa[r1] replacements: diff --git a/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_str_repr.txt b/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_str_repr.txt index 887edd9ee5b9..e23131448904 100644 --- a/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_str_repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_fillna/fillna_str_repr.txt @@ -2,9 +2,8 @@ r0 := UnboundTable: t a int64 b string -r1 := Selection[r0] - selections: - b: r0.b +r1 := Project[r0] + b: r0.b FillNa[r1] replacements: diff --git a/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt b/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt index 168803538ebe..0563c0ba6211 100644 --- a/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_format_dummy_table/repr.txt @@ -1,2 +1,2 @@ DummyTable - foo array \ No newline at end of file + foo: [1] \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_format_multiple_join_with_projection/repr.txt b/ibis/expr/tests/snapshots/test_format/test_format_multiple_join_with_projection/repr.txt index 057e2d8c8966..d1ed4735f67a 100644 --- a/ibis/expr/tests/snapshots/test_format/test_format_multiple_join_with_projection/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_format_multiple_join_with_projection/repr.txt @@ -1,36 +1,33 @@ -r0 := UnboundTable: three - bar_id string - value2 float64 - -r1 := UnboundTable: one +r0 := UnboundTable: one c int32 f float64 foo_id string bar_id string -r2 := UnboundTable: two +r1 := UnboundTable: two foo_id string value1 float64 -r3 := Selection[r1] - predicates: - r1.f > 0 +r2 := UnboundTable: three + bar_id string + value2 float64 -r4 := LeftJoin[r3, r2] r3.foo_id == r2.foo_id +r3 := SelfReference[r1] -r5 := Selection[r4] - selections: - c: r3.c - f: r3.f - foo_id: r3.foo_id - bar_id: r3.bar_id - foo_id_right: r2.foo_id - value1: r2.value1 +r4 := SelfReference[r2] -r6 := InnerJoin[r5, r0] r3.bar_id == r0.bar_id +r5 := Filter[r0] + r0.f > 0 -Selection[r6] - selections: - r3 - value1: r2.value1 - value2: r0.value2 \ No newline at end of file +JoinChain[r5] + JoinLink[left, r3] + r5.foo_id == r3.foo_id + JoinLink[inner, r4] + r5.bar_id == r4.bar_id + values: + c: r5.c + f: r5.f + foo_id: r5.foo_id + bar_id: r5.bar_id + value1: r3.value1 + value2: r4.value2 \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_format_new_relational_operation/repr.txt b/ibis/expr/tests/snapshots/test_format/test_format_new_relational_operation/repr.txt index f058f8b462d4..3169cef6f734 100644 --- a/ibis/expr/tests/snapshots/test_format/test_format_new_relational_operation/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_format_new_relational_operation/repr.txt @@ -11,20 +11,32 @@ r0 := UnboundTable: alltypes j date k time -r1 := MyRelation - a int8 - b int16 - c int32 - d int64 - e float32 - f float64 - g string - h boolean - i timestamp - j date - k time +r1 := MyRelation[r0] + kind: + foo + schema: + a int8 + b int16 + c int32 + d int64 + e float32 + f float64 + g string + h boolean + i timestamp + j date + k time -Selection[r1] - selections: - r1 - a2: r1.a \ No newline at end of file +Project[r1] + a: r1.a + b: r1.b + c: r1.c + d: r1.d + e: r1.e + f: r1.f + g: r1.g + h: r1.h + i: r1.i + j: r1.j + k: r1.k + a2: r1.a \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_format_projection/repr.txt b/ibis/expr/tests/snapshots/test_format/test_format_projection/repr.txt index c982128f1c0e..aff4c167e81f 100644 --- a/ibis/expr/tests/snapshots/test_format/test_format_projection/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_format_projection/repr.txt @@ -11,10 +11,9 @@ r0 := UnboundTable: alltypes j date k time -r1 := Selection[r0] - selections: - c: r0.c - a: r0.a - f: r0.f +r1 := Project[r0] + c: r0.c + a: r0.a + f: r0.f a: r1.a \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_table/repr.txt b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_table/repr.txt index cfd72d2fff7c..2e6f5c480f1a 100644 --- a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_table/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_table/repr.txt @@ -3,16 +3,16 @@ r0 := UnboundTable: airlines origin string arrdelay int32 -r1 := Aggregation[r0] +r1 := Filter[r0] + InValues(value=r0.dest, options=['ORD', 'JFK', 'SFO']) + +r2 := Aggregate[r1] + groups: + dest: r1.dest metrics: - Mean(arrdelay): Mean(r0.arrdelay) - by: - dest: r0.dest - predicates: - InValues(value=r0.dest, options=['ORD', 'JFK', 'SFO']) + Mean(arrdelay): Mean(r1.arrdelay) -r2 := Selection[r1] - sort_keys: - desc r1.Mean(arrdelay) +r3 := Sort[r2] + desc r2['Mean(arrdelay)'] -Limit[r2, n=10] \ No newline at end of file +Limit[r3, n=10] \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt index 69c7d1add031..95a08486a774 100644 --- a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt @@ -4,27 +4,26 @@ r0 := UnboundTable: purchases user int64 amount float64 -r1 := Aggregation[r0] - metrics: - total: Sum(r0.amount) - by: +r1 := Aggregate[r0] + groups: region: r0.region kind: r0.kind - predicates: - r0.kind == 'foo' - -r2 := Aggregation[r0] metrics: total: Sum(r0.amount) - by: - region: r0.region - kind: r0.kind - predicates: - r0.kind == 'bar' -r3 := InnerJoin[r1, r2] r1.region == r2.region +r2 := Filter[r1] + r1.kind == 'foo' + +r3 := Filter[r1] + r1.kind == 'bar' + +r4 := SelfReference[r3] -Selection[r3] - selections: - r1 - right_total: r2.total \ No newline at end of file +JoinChain[r2] + JoinLink[inner, r4] + r2.region == r4.region + values: + region: r2.region + kind: r2.kind + total: r2.total + right_total: r4.total \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_repr_exact/repr.txt b/ibis/expr/tests/snapshots/test_format/test_repr_exact/repr.txt index 38e341469f0c..ae0745d7299f 100644 --- a/ibis/expr/tests/snapshots/test_format/test_repr_exact/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_repr_exact/repr.txt @@ -3,7 +3,8 @@ r0 := UnboundTable: t col2 string col3 float64 -Selection[r0] - selections: - r0 - col4: StringLength(r0.col2) \ No newline at end of file +Project[r0] + col: r0.col + col2: r0.col2 + col3: r0.col3 + col4: StringLength(r0.col2) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_same_column_multiple_aliases/repr.txt b/ibis/expr/tests/snapshots/test_format/test_same_column_multiple_aliases/repr.txt index 1826aa9d8567..b3505df638d6 100644 --- a/ibis/expr/tests/snapshots/test_format/test_same_column_multiple_aliases/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_same_column_multiple_aliases/repr.txt @@ -1,7 +1,6 @@ r0 := UnboundTable: t col int64 -Selection[r0] - selections: - fakealias1: r0.col - fakealias2: r0.col \ No newline at end of file +Project[r0] + fakealias1: r0.col + fakealias2: r0.col \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_table_count_expr/cnt_repr.txt b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/cnt_repr.txt index a85e2bdb5dbb..fffb76933234 100644 --- a/ibis/expr/tests/snapshots/test_format/test_table_count_expr/cnt_repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/cnt_repr.txt @@ -2,4 +2,4 @@ r0 := UnboundTable: t1 a int64 b float64 -CountStar(t1): CountStar(r0) \ No newline at end of file +CountStar(): CountStar(r0) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_table_count_expr/join_repr.txt b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/join_repr.txt index e63b05c8c635..96aa59a58a31 100644 --- a/ibis/expr/tests/snapshots/test_format/test_table_count_expr/join_repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/join_repr.txt @@ -6,12 +6,15 @@ r1 := UnboundTable: t2 a int64 b float64 -r2 := InnerJoin[r0, r1] r0.a == r1.a +r2 := SelfReference[r1] -r3 := Selection[r2] - selections: +r3 := JoinChain[r0] + JoinLink[inner, r2] + r0.a == r2.a + values: a: r0.a b: r0.b - b_right: r1.b + a_right: r2.a + b_right: r2.b CountStar(): CountStar(r3) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_table_count_expr/union_repr.txt b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/union_repr.txt index caab7a357ba4..39d67ba6a7a6 100644 --- a/ibis/expr/tests/snapshots/test_format/test_table_count_expr/union_repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_table_count_expr/union_repr.txt @@ -8,9 +8,8 @@ r1 := UnboundTable: t2 r2 := Union[r0, r1, distinct=False] -r3 := Selection[r2] - selections: - a: r2.a - b: r2.b +r3 := Project[r2] + a: r2.a + b: r2.b CountStar(): CountStar(r3) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt b/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt index 959d15672b18..aa61982fec8f 100644 --- a/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_two_inner_joins/repr.txt @@ -1,25 +1,29 @@ -r0 := UnboundTable: right - time2 int32 - value2 float64 - b string - -r1 := UnboundTable: left +r0 := UnboundTable: left time1 int32 value float64 a string -r2 := InnerJoin[r1, r0] r1.a == r0.b +r1 := UnboundTable: right + time2 int32 + value2 float64 + b string + +r2 := SelfReference[r1] -r3 := InnerJoin[r2, r0] r1.value == r0.value2 +r3 := SelfReference[r1] -Selection[r3] - selections: - time1: r2.time1 - value: r2.value - a: r2.a +JoinChain[r0] + JoinLink[inner, r2] + r0.a == r2.b + JoinLink[inner, r3] + r0.value == r3.value2 + values: + time1: r0.time1 + value: r0.value + a: r0.a time2: r2.time2 value2: r2.value2 b: r2.b - time2_right: r0.time2 - value2_right: r0.value2 - b_right: r0.b \ No newline at end of file + time2_right: r3.time2 + value2_right: r3.value2 + b_right: r3.b \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_aggregation_with_multiple_joins/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_aggregation_with_multiple_joins/decompiled.py index 03fcc9f2791f..499385aab514 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_aggregation_with_multiple_joins/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_aggregation_with_multiple_joins/decompiled.py @@ -1,6 +1,10 @@ import ibis +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) call = ibis.table( name="call", schema={ @@ -14,28 +18,10 @@ call_outcome = ibis.table( name="call_outcome", schema={"outcome_text": "string", "id": "int64"} ) -employee = ibis.table( - name="employee", - schema={"first_name": "string", "last_name": "string", "id": "int64"}, +joinchain = employee.inner_join(call, employee.id == call.employee_id).inner_join( + call_outcome, call.call_outcome_id == call_outcome.id ) -innerjoin = employee.inner_join(call, employee.id == call.employee_id) -result = ( - innerjoin.inner_join(call_outcome, call.call_outcome_id == call_outcome.id) - .select( - [ - innerjoin.first_name, - innerjoin.last_name, - innerjoin.id, - innerjoin.start_time, - innerjoin.end_time, - innerjoin.employee_id, - innerjoin.call_outcome_id, - innerjoin.call_attempts, - call_outcome.outcome_text, - call_outcome.id.name("id_right"), - ] - ) - .group_by(call.employee_id) - .aggregate(call.call_attempts.mean().name("avg_attempts")) +result = joinchain.aggregate( + [joinchain.call_attempts.mean().name("avg_attempts")], by=[joinchain.employee_id] ) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation/decompiled.py index b5bd4842d48b..85221e0535fa 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation/decompiled.py @@ -12,6 +12,6 @@ }, ) -result = call.group_by(call.employee_id).aggregate( - call.call_attempts.sum().name("attempts") +result = call.aggregate( + [call.call_attempts.sum().name("attempts")], by=[call.employee_id] ) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation_with_join/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation_with_join/decompiled.py index 392f50271b0b..0b23d1687445 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation_with_join/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation_with_join/decompiled.py @@ -15,8 +15,8 @@ "call_attempts": "int64", }, ) -leftjoin = employee.left_join(call, employee.id == call.employee_id) +joinchain = employee.left_join(call, employee.id == call.employee_id) -result = leftjoin.group_by(leftjoin.id).aggregate( - call.call_attempts.sum().name("attempts") +result = joinchain.aggregate( + [joinchain.call_attempts.sum().name("attempts")], by=[joinchain.id] ) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/inner/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/inner/decompiled.py index 05f419d668db..8439fd762875 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/inner/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/inner/decompiled.py @@ -1,6 +1,10 @@ import ibis +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) call = ibis.table( name="call", schema={ @@ -11,24 +15,18 @@ "call_attempts": "int64", }, ) -employee = ibis.table( - name="employee", - schema={"first_name": "string", "last_name": "string", "id": "int64"}, -) -proj = employee.inner_join(call, employee.id == call.employee_id).filter( - employee.id < 5 -) +joinchain = employee.inner_join(call, employee.id == call.employee_id) +f = joinchain.filter(joinchain.id < 5) +s = f.order_by(f.id.desc()) -result = proj.select( - [ - proj.first_name, - proj.last_name, - proj.id, - call.start_time, - call.end_time, - call.employee_id, - call.call_outcome_id, - call.call_attempts, - proj.first_name.name("first"), - ] -).order_by(proj.id.desc()) +result = s.select( + s.first_name, + s.last_name, + s.id, + s.start_time, + s.end_time, + s.employee_id, + s.call_outcome_id, + s.call_attempts, + s.first_name.name("first"), +) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/left/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/left/decompiled.py index 2ed2c808d726..3e375cd052d2 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/left/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/left/decompiled.py @@ -1,6 +1,10 @@ import ibis +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) call = ibis.table( name="call", schema={ @@ -11,22 +15,18 @@ "call_attempts": "int64", }, ) -employee = ibis.table( - name="employee", - schema={"first_name": "string", "last_name": "string", "id": "int64"}, -) -proj = employee.left_join(call, employee.id == call.employee_id).filter(employee.id < 5) +joinchain = employee.left_join(call, employee.id == call.employee_id) +f = joinchain.filter(joinchain.id < 5) +s = f.order_by(f.id.desc()) -result = proj.select( - [ - proj.first_name, - proj.last_name, - proj.id, - call.start_time, - call.end_time, - call.employee_id, - call.call_outcome_id, - call.call_attempts, - proj.first_name.name("first"), - ] -).order_by(proj.id.desc()) +result = s.select( + s.first_name, + s.last_name, + s.id, + s.start_time, + s.end_time, + s.employee_id, + s.call_outcome_id, + s.call_attempts, + s.first_name.name("first"), +) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/right/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/right/decompiled.py index 0f31dffd1532..e9a8b2082dc1 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/right/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/right/decompiled.py @@ -1,6 +1,10 @@ import ibis +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) call = ibis.table( name="call", schema={ @@ -11,24 +15,18 @@ "call_attempts": "int64", }, ) -employee = ibis.table( - name="employee", - schema={"first_name": "string", "last_name": "string", "id": "int64"}, -) -proj = employee.right_join(call, employee.id == call.employee_id).filter( - employee.id < 5 -) +joinchain = employee.right_join(call, employee.id == call.employee_id) +f = joinchain.filter(joinchain.id < 5) +s = f.order_by(f.id.desc()) -result = proj.select( - [ - proj.first_name, - proj.last_name, - proj.id, - call.start_time, - call.end_time, - call.employee_id, - call.call_outcome_id, - call.call_attempts, - proj.first_name.name("first"), - ] -).order_by(proj.id.desc()) +result = s.select( + s.first_name, + s.last_name, + s.id, + s.start_time, + s.end_time, + s.employee_id, + s.call_outcome_id, + s.call_attempts, + s.first_name.name("first"), +) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_projection/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_projection/decompiled.py index b6e37f2ab518..404a75f95cfc 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_projection/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_projection/decompiled.py @@ -5,8 +5,7 @@ name="employee", schema={"first_name": "string", "last_name": "string", "id": "int64"}, ) -proj = employee.filter(employee.id < 5) +f = employee.filter(employee.id < 5) +s = f.order_by(f.id.desc()) -result = proj.select( - [proj.first_name, proj.last_name, proj.id, proj.first_name.name("first")] -).order_by(proj.id.desc()) +result = s.select(s.first_name, s.last_name, s.id, s.first_name.name("first")) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_in_clause/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_in_clause/decompiled.py index cc4993250d02..b29504c90709 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_in_clause/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_in_clause/decompiled.py @@ -5,8 +5,7 @@ name="employee", schema={"first_name": "string", "last_name": "string", "id": "int64"}, ) - -result = employee.select(employee.first_name).filter( +f = employee.filter( employee.first_name.isin( ( ibis.literal("Graham"), @@ -17,3 +16,5 @@ ) ) ) + +result = f.select(f.first_name) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_join_with_filter/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_join_with_filter/decompiled.py index 2ed2c808d726..3e375cd052d2 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_join_with_filter/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_join_with_filter/decompiled.py @@ -1,6 +1,10 @@ import ibis +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) call = ibis.table( name="call", schema={ @@ -11,22 +15,18 @@ "call_attempts": "int64", }, ) -employee = ibis.table( - name="employee", - schema={"first_name": "string", "last_name": "string", "id": "int64"}, -) -proj = employee.left_join(call, employee.id == call.employee_id).filter(employee.id < 5) +joinchain = employee.left_join(call, employee.id == call.employee_id) +f = joinchain.filter(joinchain.id < 5) +s = f.order_by(f.id.desc()) -result = proj.select( - [ - proj.first_name, - proj.last_name, - proj.id, - call.start_time, - call.end_time, - call.employee_id, - call.call_outcome_id, - call.call_attempts, - proj.first_name.name("first"), - ] -).order_by(proj.id.desc()) +result = s.select( + s.first_name, + s.last_name, + s.id, + s.start_time, + s.end_time, + s.employee_id, + s.call_outcome_id, + s.call_attempts, + s.first_name.name("first"), +) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_multiple_joins/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_multiple_joins/decompiled.py index ae6bfd9788f7..d6df17717b27 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_multiple_joins/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_multiple_joins/decompiled.py @@ -1,9 +1,6 @@ import ibis -call_outcome = ibis.table( - name="call_outcome", schema={"outcome_text": "string", "id": "int64"} -) employee = ibis.table( name="employee", schema={"first_name": "string", "last_name": "string", "id": "int64"}, @@ -18,21 +15,10 @@ "call_attempts": "int64", }, ) -innerjoin = employee.inner_join(call, employee.id == call.employee_id) +call_outcome = ibis.table( + name="call_outcome", schema={"outcome_text": "string", "id": "int64"} +) -result = innerjoin.inner_join( +result = employee.inner_join(call, employee.id == call.employee_id).inner_join( call_outcome, call.call_outcome_id == call_outcome.id -).select( - [ - innerjoin.first_name, - innerjoin.last_name, - innerjoin.id, - innerjoin.start_time, - innerjoin.end_time, - innerjoin.employee_id, - innerjoin.call_outcome_id, - innerjoin.call_attempts, - call_outcome.outcome_text, - call_outcome.id.name("id_right"), - ] ) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_scalar_subquery/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_scalar_subquery/decompiled.py index 81f627719b17..e651a29b1ad9 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_scalar_subquery/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_scalar_subquery/decompiled.py @@ -11,6 +11,6 @@ "call_attempts": "int64", }, ) -agg = call.aggregate(call.call_attempts.mean().name("mean")) +agg = call.aggregate([call.call_attempts.mean().name("mean")]) -result = call.inner_join(agg, []) +result = call.inner_join(agg, [agg.mean < call.call_attempts, ibis.literal(True)]) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_reduction/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_reduction/decompiled.py index 8fa935b7c5e4..3e4aaaf12b42 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_reduction/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_reduction/decompiled.py @@ -12,4 +12,4 @@ }, ) -result = call.aggregate(call.call_attempts.mean().name("mean")) +result = call.aggregate([call.call_attempts.mean().name("mean")]) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_select_count/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_select_count/decompiled.py index d993ac2ac040..8466d6aeb4ca 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_select_count/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_select_count/decompiled.py @@ -6,4 +6,4 @@ schema={"first_name": "string", "last_name": "string", "id": "int64"}, ) -result = employee.aggregate(employee.first_name.count().name("_col_0")) +result = employee.aggregate([employee.first_name.count().name("_col_0")]) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_table_alias/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_table_alias/decompiled.py index ec5df1972413..05aff9c5b4ee 100644 --- a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_table_alias/decompiled.py +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_table_alias/decompiled.py @@ -1,9 +1,7 @@ import ibis -employee = ibis.table( +result = ibis.table( name="employee", schema={"first_name": "string", "last_name": "string", "id": "int64"}, ) - -result = employee.select([employee.first_name, employee.last_name, employee.id]) diff --git a/ibis/expr/tests/test_format.py b/ibis/expr/tests/test_format.py index 6ee6dbc42514..87186e24728e 100644 --- a/ibis/expr/tests/test_format.py +++ b/ibis/expr/tests/test_format.py @@ -11,17 +11,11 @@ import ibis.expr.operations as ops import ibis.legacy.udf.vectorized as udf from ibis import util -from ibis.expr.operations.relations import Projection # easier to switch implementation if needed fmt = repr -@pytest.mark.parametrize("cls", set(ops.Relation.__subclasses__()) - {Projection}) -def test_tables_have_format_rules(cls): - assert cls in ibis.expr.format.fmt.registry - - @pytest.mark.parametrize("cls", [ops.PhysicalTable, ops.Relation]) def test_tables_have_format_value_rules(cls): assert cls in ibis.expr.format.fmt.registry @@ -62,7 +56,6 @@ def test_table_type_output(snapshot): expr = foo.dept_id == foo.view().dept_id result = fmt(expr) - assert "SelfReference[r0]" in result assert "UnboundTable: foo" in result snapshot.assert_match(result, "repr.txt") @@ -77,7 +70,7 @@ def test_aggregate_arg_names(alltypes, snapshot): expr = t.group_by(by_exprs).aggregate(metrics) result = fmt(expr) assert "metrics" in result - assert "by" in result + assert "groups" in result snapshot.assert_match(result, "repr.txt") @@ -125,8 +118,6 @@ def test_memoize_filtered_table(snapshot): delay_filter = t.dest.topk(10, by=t.arrdelay.mean()) result = fmt(delay_filter) - assert result.count("Selection") == 1 - snapshot.assert_match(result, "repr.txt") @@ -167,12 +158,6 @@ def test_memoize_filtered_tables_in_join(snapshot): joined = left.join(right, cond)[left, right.total.name("right_total")] result = fmt(joined) - - # one for each aggregation - # joins are shown without the word `predicates` above them - # since joins only have predicates as arguments - assert result.count("predicates") == 2 - snapshot.assert_match(result, "repr.txt") @@ -331,9 +316,6 @@ def test_asof_join(snapshot): ) result = fmt(joined) - assert result.count("InnerJoin") == 1 - assert result.count("AsOfJoin") == 1 - snapshot.assert_match(result, "repr.txt") @@ -349,8 +331,6 @@ def test_two_inner_joins(snapshot): ) result = fmt(joined) - assert result.count("InnerJoin") == 2 - snapshot.assert_match(result, "repr.txt") @@ -382,11 +362,13 @@ def test_format_literal(literal, typ, output): def test_format_dummy_table(snapshot): +<<<<<<< HEAD t = ops.DummyTable([ibis.array([1]).cast("array").name("foo")]).to_expr() +======= + t = ops.DummyTable({"foo": ibis.array([1], type="array")}).to_expr() +>>>>>>> 2189ab71b (refactor(ir): split the relational operations) result = fmt(t) - assert "DummyTable" in result - assert "foo array" in result snapshot.assert_match(result, "repr.txt") @@ -408,6 +390,10 @@ class MyRelation(ops.Relation): def schema(self): return self.parent.schema + @property + def values(self): + return {} + table = MyRelation(alltypes, kind="foo").to_expr() expr = table[table, table.a.name("a2")] result = fmt(expr) diff --git a/ibis/expr/tests/test_newrels.py b/ibis/expr/tests/test_newrels.py new file mode 100644 index 000000000000..1f0737c3897c --- /dev/null +++ b/ibis/expr/tests/test_newrels.py @@ -0,0 +1,1193 @@ +from __future__ import annotations + +import pytest + +import ibis +import ibis.expr.datashape as ds +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +import ibis.expr.types as ir +from ibis import _ +from ibis.common.annotations import ValidationError +from ibis.common.exceptions import IbisInputError, IntegrityError +from ibis.expr.operations import ( + Aggregate, + Field, + Filter, + JoinChain, + JoinLink, + Project, + UnboundTable, +) +from ibis.expr.schema import Schema + +t = ibis.table( + name="t", + schema={ + "bool_col": "boolean", + "int_col": "int64", + "float_col": "float64", + "string_col": "string", + }, +) + + +def test_field(): + f = Field(t, "bool_col") + assert f.rel == t.op() + assert f.name == "bool_col" + assert f.shape == ds.columnar + assert f.dtype == dt.boolean + assert f.to_expr().equals(t.bool_col) + assert f.relations == frozenset([t.op()]) + + +def test_relation_coercion(): + assert ops.Relation.__coerce__(t) == t.op() + assert ops.Relation.__coerce__(t.op()) == t.op() + with pytest.raises(TypeError): + assert ops.Relation.__coerce__("invalid") + + +def test_unbound_table(): + node = t.op() + assert isinstance(t, ir.TableExpr) + assert isinstance(node, UnboundTable) + assert node.name == "t" + assert node.schema == Schema( + { + "bool_col": dt.boolean, + "int_col": dt.int64, + "float_col": dt.float64, + "string_col": dt.string, + } + ) + assert node.fields == { + "bool_col": ops.Field(node, "bool_col"), + "int_col": ops.Field(node, "int_col"), + "float_col": ops.Field(node, "float_col"), + "string_col": ops.Field(node, "string_col"), + } + assert node.values == {} + + +def test_select_fields(): + proj = t.select("int_col") + expected = Project(parent=t, values={"int_col": t.int_col}) + assert proj.op() == expected + assert proj.op().schema == Schema({"int_col": dt.int64}) + + proj = t.select(myint=t.int_col) + expected = Project(parent=t, values={"myint": t.int_col}) + assert proj.op() == expected + assert proj.op().schema == Schema({"myint": dt.int64}) + + proj = t.select(t.int_col, myint=t.int_col) + expected = Project(parent=t, values={"int_col": t.int_col, "myint": t.int_col}) + assert proj.op() == expected + assert proj.op().schema == Schema({"int_col": dt.int64, "myint": dt.int64}) + + proj = t.select(_.int_col, myint=_.int_col) + expected = Project(parent=t, values={"int_col": t.int_col, "myint": t.int_col}) + assert proj.op() == expected + + +def test_select_values(): + proj = t.select((1 + t.int_col).name("incremented")) + expected = Project(parent=t, values={"incremented": (1 + t.int_col)}) + assert proj.op() == expected + assert proj.op().schema == Schema({"incremented": dt.int64}) + + proj = t.select(ibis.literal(1), "float_col", length=t.string_col.length()) + expected = Project( + parent=t, + values={"1": 1, "float_col": t.float_col, "length": t.string_col.length()}, + ) + assert proj.op() == expected + assert proj.op().schema == Schema( + {"1": dt.int8, "float_col": dt.float64, "length": dt.int32} + ) + + assert expected.fields == { + "1": ops.Field(proj, "1"), + "float_col": ops.Field(proj, "float_col"), + "length": ops.Field(proj, "length"), + } + assert expected.values == { + "1": ibis.literal(1).op(), + "float_col": t.float_col.op(), + "length": t.string_col.length().op(), + } + + +def test_select_windowing_local_reduction(): + t1 = t.select(res=t.int_col.sum()) + assert t1.op() == Project(parent=t, values={"res": t.int_col.sum().over()}) + + +def test_select_windowizing_analytic_function(): + t1 = t.select(res=t.int_col.lag()) + assert t1.op() == Project(parent=t, values={"res": t.int_col.lag().over()}) + + +def test_subquery_integrity_check(): + t = ibis.table(name="t", schema={"a": "int64", "b": "string"}) + + msg = "Subquery must have exactly one column, got 2" + with pytest.raises(IntegrityError, match=msg): + ops.ScalarSubquery(t) + + +def test_select_turns_scalar_reduction_into_subquery(): + arr = ibis.literal([1, 2, 3]) + res = arr.unnest().sum() + t1 = t.select(res) + subquery = ops.ScalarSubquery(res.as_table()) + expected = Project(parent=t, values={"Sum((1, 2, 3))": subquery}) + assert t1.op() == expected + + +def test_select_scalar_foreign_scalar_reduction_into_subquery(): + t1 = t.filter(t.bool_col) + t2 = t.select(summary=t1.int_col.sum()) + subquery = ops.ScalarSubquery(t1.int_col.sum().as_table()) + expected = Project(parent=t, values={"summary": subquery}) + assert t2.op() == expected + + +def test_select_turns_value_with_multiple_parents_into_subquery(): + v = ibis.table(name="v", schema={"a": "int64", "b": "string"}) + v_filt = v.filter(v.a == t.int_col) + + t1 = t.select(t.int_col, max=v_filt.a.max()) + subquery = ops.ScalarSubquery(v_filt.a.max().as_table()) + expected = Project(parent=t, values={"int_col": t.int_col, "max": subquery}) + assert t1.op() == expected + + +def test_mutate(): + proj = t.select(t, other=t.int_col + 1) + expected = Project( + parent=t, + values={ + "bool_col": t.bool_col, + "int_col": t.int_col, + "float_col": t.float_col, + "string_col": t.string_col, + "other": t.int_col + 1, + }, + ) + assert proj.op() == expected + + +def test_mutate_overwrites_existing_column(): + t = ibis.table(dict(a="string", b="string")) + + mut = t.mutate(a=42) + assert mut.op() == Project(parent=t, values={"a": ibis.literal(42), "b": t.b}) + + sel = mut.select("a") + assert sel.op() == Project(parent=mut, values={"a": mut.a}) + + +def test_select_full_reprojection(): + t1 = t.select(t) + assert t1.op() == Project( + t, + { + "bool_col": t.bool_col, + "int_col": t.int_col, + "float_col": t.float_col, + "string_col": t.string_col, + }, + ) + + +def test_subsequent_selections_with_field_names(): + t1 = t.select("bool_col", "int_col", "float_col") + assert t1.op() == Project( + parent=t, + values={ + "bool_col": t.bool_col, + "int_col": t.int_col, + "float_col": t.float_col, + }, + ) + t2 = t1.select("bool_col", "int_col") + assert t2.op() == Project( + parent=t1, + values={ + "bool_col": t1.bool_col, + "int_col": t1.int_col, + }, + ) + t3 = t2.select("bool_col") + assert t3.op() == Project( + parent=t2, + values={ + "bool_col": t2.bool_col, + }, + ) + + +def test_subsequent_selections_field_dereferencing(): + t1 = t.select(t.bool_col, t.int_col, t.float_col) + assert t1.op() == Project( + parent=t, + values={ + "bool_col": t.bool_col, + "int_col": t.int_col, + "float_col": t.float_col, + }, + ) + + t2 = t1.select(t1.bool_col, t1.int_col) + assert t1.select(t1.bool_col, t.int_col).equals(t2) + assert t1.select(t.bool_col, t.int_col).equals(t2) + assert t2.op() == Project( + parent=t1, + values={ + "bool_col": t1.bool_col, + "int_col": t1.int_col, + }, + ) + + t3 = t2.select(t2.bool_col) + assert t2.select(t1.bool_col).equals(t3) + assert t2.select(t.bool_col).equals(t3) + assert t3.op() == Project( + parent=t2, + values={ + "bool_col": t2.bool_col, + }, + ) + + u1 = t.select(t.bool_col, t.int_col, t.float_col) + assert u1.op() == Project( + parent=t, + values={ + "bool_col": t.bool_col, + "int_col": t.int_col, + "float_col": t.float_col, + }, + ) + + u2 = u1.select(u1.bool_col, u1.int_col, u1.float_col) + assert u1.select(t.bool_col, u1.int_col, u1.float_col).equals(u2) + assert u1.select(t.bool_col, t.int_col, t.float_col).equals(u2) + assert u2.op() == Project( + parent=u1, + values={ + "bool_col": u1.bool_col, + "int_col": u1.int_col, + "float_col": u1.float_col, + }, + ) + + u3 = u2.select(u2.bool_col, u2.int_col, u2.float_col) + assert u2.select(u2.bool_col, u1.int_col, u2.float_col).equals(u3) + assert u2.select(u2.bool_col, u1.int_col, t.float_col).equals(u3) + assert u3.op() == Project( + parent=u2, + values={ + "bool_col": u2.bool_col, + "int_col": u2.int_col, + "float_col": u2.float_col, + }, + ) + + +def test_subsequent_selections_value_dereferencing(): + t1 = t.select( + bool_col=~t.bool_col, int_col=t.int_col + 1, float_col=t.float_col * 3 + ) + assert t1.op() == Project( + parent=t, + values={ + "bool_col": ~t.bool_col, + "int_col": t.int_col + 1, + "float_col": t.float_col * 3, + }, + ) + + t2 = t1.select(t1.bool_col, t1.int_col, t1.float_col) + assert t2.op() == Project( + parent=t1, + values={ + "bool_col": t1.bool_col, + "int_col": t1.int_col, + "float_col": t1.float_col, + }, + ) + + t3 = t2.select( + t2.bool_col, + t2.int_col, + float_col=t2.float_col * 2, + another_col=t1.float_col - 1, + ) + assert t3.op() == Project( + parent=t2, + values={ + "bool_col": t2.bool_col, + "int_col": t2.int_col, + "float_col": t2.float_col * 2, + "another_col": t2.float_col - 1, + }, + ) + + +def test_where(): + filt = t.filter(t.bool_col) + expected = Filter(parent=t, predicates=[t.bool_col]) + assert filt.op() == expected + + filt = t.filter(t.bool_col, t.int_col > 0) + expected = Filter(parent=t, predicates=[t.bool_col, t.int_col > 0]) + assert filt.op() == expected + + filt = t.filter(_.bool_col) + expected = Filter(parent=t, predicates=[t.bool_col]) + assert filt.op() == expected + + assert expected.fields == { + "bool_col": ops.Field(expected, "bool_col"), + "int_col": ops.Field(expected, "int_col"), + "float_col": ops.Field(expected, "float_col"), + "string_col": ops.Field(expected, "string_col"), + } + assert expected.values == { + "bool_col": t.bool_col.op(), + "int_col": t.int_col.op(), + "float_col": t.float_col.op(), + "string_col": t.string_col.op(), + } + + +def test_where_raies_for_empty_predicate_list(): + t = ibis.table(dict(a="string")) + with pytest.raises(IbisInputError): + t.filter() + + +def test_where_after_select(): + t1 = t.select(t.bool_col) + t2 = t1.filter(t.bool_col) + expected = Filter(parent=t1, predicates=[t1.bool_col]) + assert t2.op() == expected + + t1 = t.select(int_col=t.bool_col) + t2 = t1.filter(t.bool_col) + expected = Filter(parent=t1, predicates=[t1.int_col]) + assert t2.op() == expected + + +def test_where_with_reduction(): + with pytest.raises(IntegrityError): + Filter(t, predicates=[t.int_col.sum() > 1]) + + t1 = t.filter(t.int_col.sum() > 0) + subquery = ops.ScalarSubquery(t.int_col.sum().as_table()) + expected = Filter(parent=t, predicates=[ops.Greater(subquery, 0)]) + assert t1.op() == expected + + +def test_where_flattens_predicates(): + t1 = t.filter(t.bool_col & ((t.int_col > 0) & (t.float_col < 0))) + expected = Filter( + parent=t, + predicates=[ + t.bool_col, + t.int_col > 0, + t.float_col < 0, + ], + ) + assert t1.op() == expected + + +def test_project_filter_sort(): + expr = t.select(t.bool_col, t.int_col).filter(t.bool_col).order_by(t.int_col) + expected = ops.Sort( + parent=( + filt := ops.Filter( + parent=( + proj := ops.Project( + parent=t, + values={ + "bool_col": t.bool_col, + "int_col": t.int_col, + }, + ) + ), + predicates=[ops.Field(proj, "bool_col")], + ) + ), + keys=[ops.SortKey(ops.Field(filt, "int_col"), ascending=True)], + ) + assert expr.op() == expected + + +def test_subsequent_filter(): + f1 = t.filter(t.bool_col) + f2 = f1.filter(t.int_col > 0) + expected = Filter(f1, predicates=[f1.int_col > 0]) + assert f2.op() == expected + + +def test_project_before_and_after_filter(): + t1 = t.select( + bool_col=~t.bool_col, int_col=t.int_col + 1, float_col=t.float_col * 3 + ) + assert t1.op() == Project( + parent=t, + values={ + "bool_col": ~t.bool_col, + "int_col": t.int_col + 1, + "float_col": t.float_col * 3, + }, + ) + + t2 = t1.filter(t1.bool_col) + assert t2.op() == Filter(parent=t1, predicates=[t1.bool_col]) + + t3 = t2.filter(t2.int_col > 0) + assert t3.op() == Filter(parent=t2, predicates=[t2.int_col > 0]) + + t3_ = t2.filter(t1.int_col > 0) + assert t3_.op() == Filter(parent=t2, predicates=[t2.int_col > 0]) + + t4 = t3.select(t3.bool_col, t3.int_col) + assert t4.op() == Project( + parent=t3, + values={ + "bool_col": t3.bool_col, + "int_col": t3.int_col, + }, + ) + + +# TODO(kszucs): add test for failing integrity checks +def test_join(): + t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"}) + t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"}) + + joined = t1.join(t2, [t1.a == t2.c]) + assert isinstance(joined, ir.JoinExpr) + assert isinstance(joined.op(), JoinChain) + assert isinstance(joined.op().to_expr(), ir.JoinExpr) + + result = joined._finish() + assert isinstance(joined, ir.TableExpr) + assert isinstance(joined.op(), JoinChain) + assert isinstance(joined.op().to_expr(), ir.JoinExpr) + + t2_ = joined.op().rest[0].table.to_expr() + assert result.op() == JoinChain( + first=t1, + rest=[ + JoinLink("inner", t2_, [t1.a == t2_.c]), + ], + values={ + "a": t1.a, + "b": t1.b, + "c": t2_.c, + "d": t2_.d, + }, + ) + + +def test_join_unambiguous_select(): + a = ibis.table(name="a", schema={"a_int": "int64", "a_str": "string"}) + b = ibis.table(name="b", schema={"b_int": "int64", "b_str": "string"}) + + join = a.join(b, a.a_int == b.b_int) + expr1 = join["a_int", "b_int"] + expr2 = join.select("a_int", "b_int") + assert expr1.equals(expr2) + + b_ = join.op().rest[0].table.to_expr() + assert expr1.op() == JoinChain( + first=a, + rest=[JoinLink("inner", b_, [a.a_int == b_.b_int])], + values={ + "a_int": a.a_int, + "b_int": b_.b_int, + }, + ) + + +def test_join_with_subsequent_projection(): + t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"}) + t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"}) + + # a single computed value is pulled to a subsequent projection + joined = t1.join(t2, [t1.a == t2.c]) + expr = joined.select(t1.a, t1.b, col=t2.c + 1) + t2_ = joined.op().rest[0].table.to_expr() + expected = JoinChain( + first=t1, + rest=[JoinLink("inner", t2_, [t1.a == t2_.c])], + values={"a": t1.a, "b": t1.b, "col": t2_.c + 1}, + ) + assert expr.op() == expected + + # multiple computed values + joined = t1.join(t2, [t1.a == t2.c]) + expr = joined.select( + t1.a, + t1.b, + foo=t2.c + 1, + bar=t2.c + 2, + baz=t2.d.name("bar") + "3", + baz2=(t2.c + t1.a).name("foo"), + ) + t2_ = joined.op().rest[0].table.to_expr() + expected = JoinChain( + first=t1, + rest=[JoinLink("inner", t2_, [t1.a == t2_.c])], + values={ + "a": t1.a, + "b": t1.b, + "foo": t2_.c + 1, + "bar": t2_.c + 2, + "baz": t2_.d.name("bar") + "3", + "baz2": t2_.c + t1.a, + }, + ) + assert expr.op() == expected + + +def test_join_with_subsequent_projection_colliding_names(): + t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"}) + t2 = ibis.table( + name="t2", schema={"a": "int64", "b": "string", "c": "float", "d": "string"} + ) + + joined = t1.join(t2, [t1.a == t2.a]) + expr = joined.select( + t1.a, + t1.b, + foo=t2.a + 1, + bar=t1.a + t2.a, + ) + t2_ = joined.op().rest[0].table.to_expr() + expected = JoinChain( + first=t1, + rest=[JoinLink("inner", t2_, [t1.a == t2_.a])], + values={ + "a": t1.a, + "b": t1.b, + "foo": t2_.a + 1, + "bar": t1.a + t2_.a, + }, + ) + assert expr.op() == expected + + +def test_chained_join(): + a = ibis.table(name="a", schema={"a": "int64", "b": "string"}) + b = ibis.table(name="b", schema={"c": "int64", "d": "string"}) + c = ibis.table(name="c", schema={"e": "int64", "f": "string"}) + + joined = a.join(b, [a.a == b.c]).join(c, [a.a == c.e]) + result = joined._finish() + + b_ = joined.op().rest[0].table.to_expr() + c_ = joined.op().rest[1].table.to_expr() + assert result.op() == JoinChain( + first=a, + rest=[ + JoinLink("inner", b_, [a.a == b_.c]), + JoinLink("inner", c_, [a.a == c_.e]), + ], + values={ + "a": a.a, + "b": a.b, + "c": b_.c, + "d": b_.d, + "e": c_.e, + "f": c_.f, + }, + ) + + joined = a.join(b, [a.a == b.c]).join(c, [b.c == c.e]) + result = joined.select(a.a, b.d, c.f) + + b_ = joined.op().rest[0].table.to_expr() + c_ = joined.op().rest[1].table.to_expr() + assert result.op() == JoinChain( + first=a, + rest=[ + JoinLink("inner", b_, [a.a == b_.c]), + JoinLink("inner", c_, [b_.c == c_.e]), + ], + values={ + "a": a.a, + "d": b_.d, + "f": c_.f, + }, + ) + + +def test_chained_join_referencing_intermediate_table(): + a = ibis.table(name="a", schema={"a": "int64", "b": "string"}) + b = ibis.table(name="b", schema={"c": "int64", "d": "string"}) + c = ibis.table(name="c", schema={"e": "int64", "f": "string"}) + + ab = a.join(b, [a.a == b.c]) + assert isinstance(ab, ir.JoinExpr) + + # assert ab.a.op() == Field(ab, "a") + abc = ab.join(c, [ab.a == c.e]) + assert isinstance(abc, ir.JoinExpr) + + result = abc._finish() + + b_ = abc.op().rest[0].table.to_expr() + c_ = abc.op().rest[1].table.to_expr() + assert result.op() == JoinChain( + first=a, + rest=[ + JoinLink("inner", b_, [a.a == b_.c]), + JoinLink("inner", c_, [a.a == c_.e]), + ], + values={"a": a.a, "b": a.b, "c": b_.c, "d": b_.d, "e": c_.e, "f": c_.f}, + ) + + +def test_join_predicate_dereferencing(): + # See #790, predicate pushdown in joins not supported + + # Star schema with fact table + table = ibis.table({"c": int, "f": float, "foo_id": str, "bar_id": str}) + table2 = ibis.table({"foo_id": str, "value1": float, "value3": float}) + table3 = ibis.table({"bar_id": str, "value2": float}) + + filtered = table[table["f"] > 0] + + # dereference table.foo_id to filtered.foo_id + j1 = filtered.left_join(table2, table["foo_id"] == table2["foo_id"]) + + table2_ = j1.op().rest[0].table.to_expr() + expected = ops.JoinChain( + first=filtered, + rest=[ + ops.JoinLink("left", table2_, [filtered.foo_id == table2_.foo_id]), + ], + values={ + "c": filtered.c, + "f": filtered.f, + "foo_id": filtered.foo_id, + "bar_id": filtered.bar_id, + "foo_id_right": table2_.foo_id, + "value1": table2_.value1, + "value3": table2_.value3, + }, + ) + assert j1.op() == expected + + j2 = j1.inner_join(table3, filtered["bar_id"] == table3["bar_id"]) + + table2_ = j2.op().rest[0].table.to_expr() + table3_ = j2.op().rest[1].table.to_expr() + expected = ops.JoinChain( + first=filtered, + rest=[ + ops.JoinLink("left", table2_, [filtered.foo_id == table2_.foo_id]), + ops.JoinLink("inner", table3_, [filtered.bar_id == table3_.bar_id]), + ], + values={ + "c": filtered.c, + "f": filtered.f, + "foo_id": filtered.foo_id, + "bar_id": filtered.bar_id, + "foo_id_right": table2_.foo_id, + "value1": table2_.value1, + "value3": table2_.value3, + "bar_id_right": table3_.bar_id, + "value2": table3_.value2, + }, + ) + assert j2.op() == expected + + # Project out the desired fields + view = j2[[filtered, table2["value1"], table3["value2"]]] + expected = ops.JoinChain( + first=filtered, + rest=[ + ops.JoinLink("left", table2_, [filtered.foo_id == table2_.foo_id]), + ops.JoinLink("inner", table3_, [filtered.bar_id == table3_.bar_id]), + ], + values={ + "c": filtered.c, + "f": filtered.f, + "foo_id": filtered.foo_id, + "bar_id": filtered.bar_id, + "value1": table2_.value1, + "value2": table3_.value2, + }, + ) + assert view.op() == expected + + +def test_aggregate(): + agg = t.aggregate(by=[t.bool_col], metrics=[t.int_col.sum()]) + expected = Aggregate( + parent=t, + groups={ + "bool_col": t.bool_col, + }, + metrics={ + "Sum(int_col)": t.int_col.sum(), + }, + ) + assert agg.op() == expected + + +def test_aggregate_having(): + table = ibis.table(name="table", schema={"g": "string", "f": "double"}) + + metrics = [table.f.sum().name("total")] + by = ["g"] + + expr = table.aggregate(metrics, by=by, having=(table.f.sum() > 0).name("cond")) + expected = table.aggregate(metrics, by=by).filter(_.total > 0) + assert expr.equals(expected) + + with pytest.raises(ValidationError): + # non boolean + table.aggregate(metrics, by=by, having=table.f.sum()) + + with pytest.raises(IntegrityError): + # non scalar + table.aggregate(metrics, by=by, having=table.f > 2) + + +def test_select_with_uncorrelated_scalar_subquery(): + t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"}) + t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"}) + + # Create a subquery + t2_filt = t2.filter(t2.d == "value") + + # Non-reduction won't be turned into a subquery + with pytest.raises(IntegrityError): + t1.select(t2_filt.c) + + # Construct the projection using the subquery + sub = t1.select(t1.a, summary=t2_filt.c.sum()) + expected = Project( + parent=t1, + values={ + "a": t1.a, + "summary": ops.ScalarSubquery(t2_filt.c.sum().as_table()), + }, + ) + assert sub.op() == expected + + +def test_select_with_reduction_turns_into_window_function(): + # Define your tables + employees = ibis.table( + name="employees", schema={"name": "string", "salary": "double"} + ) + + # Use the subquery in a select operation + expr = employees.select(employees.name, average_salary=employees.salary.mean()) + expected = Project( + parent=employees, + values={ + "name": employees.name, + "average_salary": employees.salary.mean().over(), + }, + ) + assert expr.op() == expected + + +def test_select_with_correlated_scalar_subquery(): + # Define your tables + t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"}) + t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"}) + + # Create a subquery + filt = t2.filter(t2.d == t1.b) + summary = filt.c.sum().name("summary") + + # Use the subquery in a select operation + expr = t1.select(t1.a, summary) + expected = Project( + parent=t1, + values={ + "a": t1.a, + "summary": ops.ScalarSubquery(filt.c.sum().as_table()), + }, + ) + assert expr.op() == expected + + +def test_aggregate_field_dereferencing(): + t = ibis.table( + { + "l_orderkey": "int32", + "l_partkey": "int32", + "l_suppkey": "int32", + "l_linenumber": "int32", + "l_quantity": "decimal(15, 2)", + "l_extendedprice": "decimal(15, 2)", + "l_discount": "decimal(15, 2)", + "l_tax": "decimal(15, 2)", + "l_returnflag": "string", + "l_linestatus": "string", + "l_shipdate": "date", + "l_commitdate": "date", + "l_receiptdate": "date", + "l_shipinstruct": "string", + "l_shipmode": "string", + "l_comment": "string", + } + ) + + f = t.filter(t.l_shipdate <= ibis.date("1998-09-01")) + assert f.op() == Filter( + parent=t, predicates=[t.l_shipdate <= ibis.date("1998-09-01")] + ) + + discount_price = t.l_extendedprice * (1 - t.l_discount) + charge = discount_price * (1 + t.l_tax) + a = f.group_by(["l_returnflag", "l_linestatus"]).aggregate( + sum_qty=t.l_quantity.sum(), + sum_base_price=t.l_extendedprice.sum(), + sum_disc_price=discount_price.sum(), + sum_charge=charge.sum(), + avg_qty=t.l_quantity.mean(), + avg_price=t.l_extendedprice.mean(), + avg_disc=t.l_discount.mean(), + count_order=f.count(), # note that this is f.count() not t.count() + ) + + discount_price_ = f.l_extendedprice * (1 - f.l_discount) + charge_ = discount_price_ * (1 + f.l_tax) + assert a.op() == Aggregate( + parent=f, + groups={ + "l_returnflag": f.l_returnflag, + "l_linestatus": f.l_linestatus, + }, + metrics={ + "sum_qty": f.l_quantity.sum(), + "sum_base_price": f.l_extendedprice.sum(), + "sum_disc_price": discount_price_.sum(), + "sum_charge": charge_.sum(), + "avg_qty": f.l_quantity.mean(), + "avg_price": f.l_extendedprice.mean(), + "avg_disc": f.l_discount.mean(), + "count_order": f.count(), + }, + ) + + s = a.order_by(["l_returnflag", "l_linestatus"]) + assert s.op() == ops.Sort( + parent=a, + keys=[a.l_returnflag, a.l_linestatus], + ) + + +def test_isin_subquery(): + t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"}) + t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"}) + + t2_filt = t2.filter(t2.d == "value") + + expr = t1.filter(t1.a.isin(t2_filt.c)) + subquery = Project(t2_filt, values={"c": t2_filt.c}) + expected = Filter(parent=t1, predicates=[ops.InSubquery(rel=subquery, needle=t1.a)]) + assert expr.op() == expected + + +def test_filter_condition_referencing_agg_without_groupby_turns_it_into_a_subquery(): + r1 = ibis.table( + name="r3", schema={"name": str, "key": str, "int_col": int, "float_col": float} + ) + r2 = r1.filter(r1.name == "GERMANY") + r3 = r2.aggregate(by=[r2.key], value=(r2.float_col * r2.int_col).sum()) + r4 = r2.aggregate(total=(r2.float_col * r2.int_col).sum()) + r5 = r3.filter(r3.value > r4.total * 0.0001) + + total = (r2.float_col * r2.int_col).sum() + subquery = ops.ScalarSubquery( + ops.Aggregate(r2, groups={}, metrics={total.get_name(): total}) + ).to_expr() + expected = Filter(parent=r3, predicates=[r3.value > subquery * 0.0001]) + + assert r5.op() == expected + + +def test_self_join(): + t0 = ibis.table(schema=ibis.schema(dict(key="int")), name="leaf") + t1 = t0.filter(ibis.literal(True)) + t2 = t1[["key"]] + + t3 = t2.join(t2, ["key"]) + t2_ = t3.op().rest[0].table.to_expr() + expected = ops.JoinChain( + first=t2, + rest=[ + ops.JoinLink("inner", t2_, [t2.key == t2_.key]), + ], + values={"key": t2.key, "key_right": t2_.key}, + ) + assert t3.op() == expected + + t4 = t3.join(t3, ["key"]) + t3_ = t4.op().rest[1].table.to_expr() + + expected = ops.JoinChain( + first=t2, + rest=[ + ops.JoinLink("inner", t2_, [t2.key == t2_.key]), + ops.JoinLink("inner", t3_, [t2.key == t3_.key]), + ], + values={ + "key": t2.key, + "key_right": t2_.key, + "key_right_right": t3_.key_right, + }, + ) + assert t4.op() == expected + + +def test_self_join_view(): + t = ibis.memtable({"x": [1, 2], "y": [2, 1], "z": ["a", "b"]}) + t_view = t.view() + expr = t.join(t_view, t.x == t_view.y).select("x", "y", "z", "z_right") + + t_view_ = expr.op().rest[0].table.to_expr() + expected = ops.JoinChain( + first=t, + rest=[ + ops.JoinLink("inner", t_view_, [t.x == t_view_.y]), + ], + values={"x": t.x, "y": t.y, "z": t.z, "z_right": t_view_.z}, + ) + assert expr.op() == expected + + +def test_self_join_with_view_projection(): + t1 = ibis.memtable({"x": [1, 2], "y": [2, 1], "z": ["a", "b"]}) + t2 = t1.view() + expr = t1.inner_join(t2, ["x"])[[t1]] + + t2_ = expr.op().rest[0].table.to_expr() + expected = ops.JoinChain( + first=t1, + rest=[ + ops.JoinLink("inner", t2_, [t1.x == t2_.x]), + ], + values={"x": t1.x, "y": t1.y, "z": t1.z}, + ) + assert expr.op() == expected + + +def test_joining_same_table_twice(): + left = ibis.table(name="left", schema={"time1": int, "value": float, "a": str}) + right = ibis.table(name="right", schema={"time2": int, "value2": float, "b": str}) + + joined = left.inner_join(right, left.a == right.b).inner_join( + right, left.value == right.value2 + ) + + right_ = joined.op().rest[0].table.to_expr() + right__ = joined.op().rest[1].table.to_expr() + expected = ops.JoinChain( + first=left, + rest=[ + ops.JoinLink("inner", right_, [left.a == right_.b]), + ops.JoinLink("inner", right__, [left.value == right__.value2]), + ], + values={ + "time1": left.time1, + "value": left.value, + "a": left.a, + "time2": right_.time2, + "value2": right_.value2, + "b": right_.b, + "time2_right": right__.time2, + "value2_right": right__.value2, + "b_right": right__.b, + }, + ) + assert joined.op() == expected + + +def test_join_chain_gets_reused_and_continued_after_a_select(): + a = ibis.table(name="a", schema={"a": "int64", "b": "string"}) + b = ibis.table(name="b", schema={"c": "int64", "d": "string"}) + c = ibis.table(name="c", schema={"e": "int64", "f": "string"}) + + ab = a.join(b, [a.a == b.c]) + abc = ab[a.b, b.d].join(c, [a.a == c.e]) + + b_ = abc.op().rest[0].table.to_expr() + c_ = abc.op().rest[1].table.to_expr() + expected = ops.JoinChain( + first=a, + rest=[ + ops.JoinLink("inner", b_, [a.a == b_.c]), + ops.JoinLink("inner", c_, [a.a == c_.e]), + ], + values={ + "b": a.b, + "d": b_.d, + "e": c_.e, + "f": c_.f, + }, + ) + assert abc.op() == expected + assert abc._finish().op() == expected + + +def test_self_join_extensive(): + a = ibis.table(name="a", schema={"a": "int64", "b": "string"}) + + aa = a.join(a, [a.a == a.a]) + aa_ = a.join(a, "a") + aa__ = a.join(a, [("a", "a")]) + for join in [aa, aa_, aa__]: + a1 = join.op().rest[0].table.to_expr() + expected = ops.JoinChain( + first=a, + rest=[ + ops.JoinLink("inner", a1, [a.a == a1.a]), + ], + values={ + "a": a.a, + "b": a.b, + "a_right": a1.a, + "b_right": a1.b, + }, + ) + assert join.op() == expected + + aaa = a.join(a, [a.a == a.a]).join(a, [a.a == a.a]) + a0 = a + a1 = aaa.op().rest[0].table.to_expr() + a2 = aaa.op().rest[1].table.to_expr() + expected = ops.JoinChain( + first=a0, + rest=[ + ops.JoinLink("inner", a1, [a0.a == a1.a]), + ops.JoinLink("inner", a2, [a0.a == a2.a]), + ], + values={ + "a": a0.a, + "b": a0.b, + "a_right": a1.a, + "b_right": a1.b, + }, + ) + + aaa = aa.join(a, [aa.a == a.a]) + aaa_ = aa.join(a, "a") + aaa__ = aa.join(a, [("a", "a")]) + for join in [aaa, aaa_, aaa__]: + a1 = join.op().rest[0].table.to_expr() + a2 = join.op().rest[1].table.to_expr() + expected = ops.JoinChain( + first=a, + rest=[ + ops.JoinLink("inner", a1, [a.a == a1.a]), + ops.JoinLink("inner", a2, [a.a == a2.a]), + ], + values={ + "a": a.a, + "b": a.b, + "a_right": a1.a, + "b_right": a1.b, + }, + ) + assert join.op() == expected + + +def test_self_join_with_intermediate_selection(): + a = ibis.table(name="a", schema={"a": "int64", "b": "string"}) + + join = a[["b", "a"]].join(a, [a.a == a.a]) + a0 = a[["b", "a"]] + a1 = join.op().rest[0].table.to_expr() + expected = ops.JoinChain( + first=a0, + rest=[ + ops.JoinLink("inner", a1, [a0.a == a1.a]), + ], + values={ + "b": a0.b, + "a": a0.a, + "a_right": a1.a, + "b_right": a1.b, + }, + ) + assert join.op() == expected + + aa_ = a.join(a, [a.a == a.a])["a", "b_right"] + aaa_ = aa_.join(a, [aa_.a == a.a]) + a0 = a + a1 = aaa_.op().rest[0].table.to_expr() + a2 = aaa_.op().rest[1].table.to_expr() + expected = ops.JoinChain( + first=a0, + rest=[ + ops.JoinLink("inner", a1, [a0.a == a1.a]), + ops.JoinLink("inner", a2, [a0.a == a2.a]), + ], + values={ + "a": a0.a, + "b_right": a1.b, + "a_right": a2.a, + "b": a2.b, + }, + ) + assert aaa_.op() == expected + + # TODO(kszucs): this use case could be supported if `_get_column` gets + # overridden to return underlying column reference, but that would mean + # that `aa.a` returns with `a.a` instead of `aa.a` which breaks other + # things + # aa = a.join(a, [a.a == a.a]) + # aaa = aa["a", "b_right"].join(a, [aa.a == a.a]) + # a0 = a + # a1 = aaa.op().rest[0].table.to_expr() + # a2 = aaa.op().rest[1].table.to_expr() + # expected = ops.JoinChain( + # first=a0, + # rest=[ + # ops.JoinLink("inner", a1, [a0.a == a1.a]), + # ops.JoinLink("inner", a2, [a0.a == a2.a]), + # ], + # values={ + # "a": a0.a, + # "b_right": a1.b, + # "a_right": a2.a, + # "b": a2.b, + # }, + # ) + # assert aaa.op() == expected + + +def test_name_collisions_raise(): + a = ibis.table(name="a", schema={"a": "int64", "b": "string"}) + b = ibis.table(name="b", schema={"a": "int64", "b": "string"}) + c = ibis.table(name="c", schema={"a": "int64", "b": "string"}) + + ab = a.join(b, [a.a == b.a]) + filt = ab.filter(ab.a < 1) + expected = ops.Filter( + parent=ab, + predicates=[ + ops.Less(ops.Field(ab, "a"), 1), + ], + ) + assert filt.op() == expected + + abc = a.join(b, [a.a == b.a]).join(c, [a.a == c.a]) + with pytest.raises(IntegrityError): + abc.filter(abc.a < 1) diff --git a/ibis/expr/tests/test_rewrites.py b/ibis/expr/tests/test_rewrites.py new file mode 100644 index 000000000000..ca54f2216006 --- /dev/null +++ b/ibis/expr/tests/test_rewrites.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import ibis +import ibis.expr.operations as ops +from ibis.expr.rewrites import simplify + +t = ibis.table( + name="t", + schema={ + "bool_col": "boolean", + "int_col": "int64", + "float_col": "float64", + "string_col": "string", + }, +) + + +def test_simplify_full_reprojection(): + t1 = t.select(t) + t1_opt = simplify(t1.op()) + assert t1_opt == t.op() + + +def test_simplify_subsequent_field_selections(): + t1 = t.select(t.bool_col, t.int_col, t.float_col) + assert t1.op() == ops.Project( + parent=t, + values={ + "bool_col": t.bool_col, + "int_col": t.int_col, + "float_col": t.float_col, + }, + ) + + t2 = t1.select(t1.bool_col, t1.int_col) + t2_opt = simplify(t2.op()) + assert t2_opt == ops.Project( + parent=t, + values={ + "bool_col": t.bool_col, + "int_col": t.int_col, + }, + ) + + t3 = t2.select(t2.bool_col) + t3_opt = simplify(t3.op()) + assert t3_opt == ops.Project(parent=t, values={"bool_col": t.bool_col}) + + +def test_simplify_subsequent_value_selections(): + t1 = t.select( + bool_col=~t.bool_col, int_col=t.int_col + 1, float_col=t.float_col * 3 + ) + t2 = t1.select(t1.bool_col, t1.int_col, t1.float_col) + t2_opt = simplify(t2.op()) + assert t2_opt == ops.Project( + parent=t, + values={ + "bool_col": ~t.bool_col, + "int_col": t.int_col + 1, + "float_col": t.float_col * 3, + }, + ) + + t3 = t2.select( + t2.bool_col, + t2.int_col, + float_col=t2.float_col * 2, + another_col=t1.float_col - 1, + ) + t3_opt = simplify(t3.op()) + assert t3_opt == ops.Project( + parent=t, + values={ + "bool_col": ~t.bool_col, + "int_col": t.int_col + 1, + "float_col": (t.float_col * 3) * 2, + "another_col": (t.float_col * 3) - 1, + }, + ) + + +def test_simplify_subsequent_filters(): + f1 = t.filter(t.bool_col) + f2 = f1.filter(t.int_col > 0) + f2_opt = simplify(f2.op()) + assert f2_opt == ops.Filter(t, predicates=[t.bool_col, t.int_col > 0]) + + +def test_simplify_project_filter_project(): + t1 = t.select( + bool_col=~t.bool_col, int_col=t.int_col + 1, float_col=t.float_col * 3 + ) + t2 = t1.filter(t1.bool_col) + t3 = t2.filter(t2.int_col > 0) + t4 = t3.select(t3.bool_col, t3.int_col) + + filt = ops.Filter(parent=t, predicates=[~t.bool_col, t.int_col + 1 > 0]).to_expr() + proj = ops.Project( + parent=filt, values={"bool_col": ~filt.bool_col, "int_col": filt.int_col + 1} + ).to_expr() + + t4_opt = simplify(t4.op()) + assert t4_opt == proj.op() diff --git a/ibis/expr/types/__init__.py b/ibis/expr/types/__init__.py index 610504d43989..99bd54d2f6e4 100644 --- a/ibis/expr/types/__init__.py +++ b/ibis/expr/types/__init__.py @@ -12,6 +12,7 @@ from ibis.expr.types.maps import * # noqa: F403 from ibis.expr.types.numeric import * # noqa: F403 from ibis.expr.types.relations import * # noqa: F403 +from ibis.expr.types.joins import * # noqa: F403 from ibis.expr.types.strings import * # noqa: F403 from ibis.expr.types.structs import * # noqa: F403 from ibis.expr.types.temporal import * # noqa: F403 diff --git a/ibis/expr/types/core.py b/ibis/expr/types/core.py index 043f8385474d..dd215962f68e 100644 --- a/ibis/expr/types/core.py +++ b/ibis/expr/types/core.py @@ -245,9 +245,10 @@ def _find_backends(self) -> tuple[list[BaseBackend], bool]: list[BaseBackend] A list of the backends found. """ + backends = set() has_unbound = False - node_types = (ops.DatabaseTable, ops.SQLQueryResult, ops.UnboundTable) + node_types = (ops.UnboundTable, ops.DatabaseTable, ops.SQLQueryResult) for table in self.op().find(node_types): if isinstance(table, ops.UnboundTable): has_unbound = True diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 03af02ef5ff4..dfda547a49a1 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -7,10 +7,12 @@ import ibis import ibis.common.exceptions as com +import ibis.expr.builders as bl import ibis.expr.datatypes as dt import ibis.expr.operations as ops -from ibis.common.deferred import Deferred +from ibis.common.deferred import Deferred, _, deferrable from ibis.common.grounds import Singleton +from ibis.expr.rewrites import rewrite_window_input from ibis.expr.types.core import Expr, _binop, _FixedTextJupyterMixin from ibis.util import deprecated @@ -18,7 +20,6 @@ import pandas as pd import pyarrow as pa - import ibis.expr.builders as bl import ibis.expr.types as ir from ibis.formats.pyarrow import PyArrowData @@ -571,7 +572,7 @@ def isin(self, values: Value | Sequence[Value]) -> ir.BooleanValue: if isinstance(values, ArrayValue): return ops.ArrayContains(values, self).to_expr() elif isinstance(values, Column): - return ops.InColumn(self, values).to_expr() + return ops.InSubquery(values.as_table(), needle=self).to_expr() else: return ops.InValues(self, values).to_expr() @@ -721,11 +722,7 @@ def over( A window function expression """ - import ibis.expr.analysis as an - import ibis.expr.builders as bl - from ibis import _ - from ibis.common.deferred import Call - + node = self.op() if window is None: window = ibis.window( rows=rows, @@ -733,23 +730,30 @@ def over( group_by=group_by, order_by=order_by, ) + elif not isinstance(window, bl.WindowBuilder): + raise com.IbisTypeError("Unexpected window type: {window!r}") + if len(node.relations) == 0: + table = None + elif len(node.relations) == 1: + (table,) = node.relations + else: + raise com.RelationError("Cannot use window with multiple tables") + + @deferrable def bind(table): frame = window.bind(table) - expr = an.windowize_function(self, frame) - if expr.equals(self): + winfunc = rewrite_window_input(node, frame) + if winfunc == node: raise com.IbisTypeError( "No reduction or analytic function found to construct a window expression" ) - return expr + return winfunc.to_expr() - if isinstance(window, bl.WindowBuilder): - if table := an.find_first_base_table(self.op()): - return bind(table) - else: - return Deferred(Call(bind, _)) - else: - raise com.IbisTypeError("Unexpected window type: {window!r}") + try: + return bind(table) + except com.IbisInputError: + return bind(_) def isnull(self) -> ir.BooleanValue: """Return whether this expression is NULL. @@ -1116,9 +1120,13 @@ def __hash__(self) -> int: return super().__hash__() def __eq__(self, other: Value) -> ir.BooleanValue: + if other is None: + return _binop(ops.IdenticalTo, self, other) return _binop(ops.Equals, self, other) def __ne__(self, other: Value) -> ir.BooleanValue: + if other is None: + return ~self.__eq__(other) return _binop(ops.NotEquals, self, other) def __ge__(self, other: Value) -> ir.BooleanValue: @@ -1158,22 +1166,20 @@ def as_table(self) -> ir.Table: >>> expr.equals(expected) True """ - from ibis.expr.analysis import find_immediate_parent_tables - - roots = find_immediate_parent_tables(self.op()) - if len(roots) > 1: + parents = self.op().relations + values = {self.get_name(): self} + + if len(parents) == 0: + return ops.DummyTable(values).to_expr() + elif len(parents) == 1: + (parent,) = parents + return parent.to_expr().select(self) + else: raise com.RelationError( - f"Cannot convert {type(self)} expression " - "involving multiple base table references " - "to a projection" + f"Cannot convert {type(self)} expression involving multiple " + "base table references to a projection" ) - if roots: - return roots[0].to_expr().select(self) - - # no child table to select from - return ops.DummyTable(values=(self,)).to_expr() - def to_pandas(self, **kwargs) -> pd.Series: """Convert a column expression to a pandas Series or scalar object. @@ -1254,20 +1260,19 @@ def as_table(self) -> ir.Table: >>> isinstance(lit, ir.Table) True """ - from ibis.expr.analysis import find_first_base_table + parents = self.op().relations - op = self.op() - table = find_first_base_table(op) - if table is not None: - return table.to_expr().aggregate(**{self.get_name(): self}) + if len(parents) == 0: + return ops.DummyTable({self.get_name(): self}).to_expr() + elif len(parents) == 1: + (parent,) = parents + return parent.to_expr().aggregate(self) else: - if isinstance(op, ops.Alias): - value = op - assert value.name == self.get_name() - else: - value = ops.Alias(op, self.get_name()) - - return ops.DummyTable(values=(value,)).to_expr() + raise com.RelationError( + f"The scalar expression {self} cannot be converted to a " + "table expression because it involves multiple base table " + "references" + ) def __deferred_repr__(self): return f"" @@ -1323,14 +1328,24 @@ def __pandas_result__(self, df: pd.DataFrame) -> pd.Series: return PandasData.convert_column(df.loc[:, column], self.type()) def _bind_reduction_filter(self, where): - import ibis.expr.analysis as an - - if where is None or not isinstance(where, Deferred): + node = self.op() + if isinstance(where, Deferred): + if len(node.relations) == 0: + raise com.IbisInputError( + "Unable to bind deferred expression to a table because " + "the expression doesn't depend on any tables" + ) + elif len(node.relations) == 1: + (table,) = node.relations + return where.resolve(table) + else: + raise com.RelationError( + "Cannot bind deferred expression to a table because the " + "expression depends on multiple tables" + ) + else: return where - table = an.find_first_base_table(self.op()).to_expr() - return where.resolve(table) - def __deferred_repr__(self): return f"" @@ -1831,16 +1846,9 @@ def value_counts(self) -> ir.Table: │ d │ 3 │ └────────┴─────────────┘ """ - from ibis.expr.analysis import find_first_base_table - name = self.get_name() - return ( - find_first_base_table(self.op()) - .to_expr() - .select(self) - .group_by(name) - .agg(**{f"{name}_count": lambda t: t.count()}) - ) + metric = _.count().name(f"{name}_count") + return self.as_table().group_by(name).aggregate(metric) def first(self, where: ir.BooleanValue | None = None) -> Value: """Return the first value of a column. @@ -1923,13 +1931,7 @@ def rank(self) -> ir.IntegerColumn: │ 3 │ 5 │ └────────┴───────┘ """ - import ibis.expr.analysis as an - - return ( - ibis.rank() - .over(order_by=self) - .resolve(an.find_first_base_table(self.op()).to_expr()) - ) + return ibis.rank().over(order_by=self) def dense_rank(self) -> ir.IntegerColumn: """Position of first element within each group of equal values. @@ -1962,33 +1964,15 @@ def dense_rank(self) -> ir.IntegerColumn: │ 3 │ 2 │ └────────┴───────┘ """ - import ibis.expr.analysis as an - - return ( - ibis.dense_rank() - .over(order_by=self) - .resolve(an.find_first_base_table(self.op()).to_expr()) - ) + return ibis.dense_rank().over(order_by=self) def percent_rank(self) -> Column: """Return the relative rank of the values in the column.""" - import ibis.expr.analysis as an - - return ( - ibis.percent_rank() - .over(order_by=self) - .resolve(an.find_first_base_table(self.op()).to_expr()) - ) + return ibis.percent_rank().over(order_by=self) def cume_dist(self) -> Column: """Return the cumulative distribution over a window.""" - import ibis.expr.analysis as an - - return ( - ibis.cume_dist() - .over(order_by=self) - .resolve(an.find_first_base_table(self.op()).to_expr()) - ) + return ibis.cume_dist().over(order_by=self) def ntile(self, buckets: int | ir.IntegerValue) -> ir.IntegerColumn: """Return the integer number of a partitioning of the column values. @@ -1998,13 +1982,7 @@ def ntile(self, buckets: int | ir.IntegerValue) -> ir.IntegerColumn: buckets Number of buckets to partition into """ - import ibis.expr.analysis as an - - return ( - ibis.ntile(buckets) - .over(order_by=self) - .resolve(an.find_first_base_table(self.op()).to_expr()) - ) + return ibis.ntile(buckets).over(order_by=self) def cummin(self, *, where=None, group_by=None, order_by=None) -> Column: """Return the cumulative min over a window.""" diff --git a/ibis/expr/types/geospatial.py b/ibis/expr/types/geospatial.py index 95372ffa9703..d4e2f25cbe16 100644 --- a/ibis/expr/types/geospatial.py +++ b/ibis/expr/types/geospatial.py @@ -1622,13 +1622,20 @@ class GeoSpatialScalar(NumericScalar, GeoSpatialValue): @public class GeoSpatialColumn(NumericColumn, GeoSpatialValue): - def unary_union(self) -> ir.GeoSpatialScalar: + def unary_union( + self, where: bool | ir.BooleanValue | None = None + ) -> ir.GeoSpatialScalar: """Aggregate a set of geometries into a union. This corresponds to the aggregate version of the union. We give it a different name (following the corresponding method in GeoPandas) to avoid name conflicts with the non-aggregate version. + Parameters + ---------- + where + Filter expression + Returns ------- GeoSpatialScalar @@ -1642,7 +1649,7 @@ def unary_union(self) -> ir.GeoSpatialScalar: >>> t.geom.unary_union() """ - return ops.GeoUnaryUnion(self).to_expr().name("union") + return ops.GeoUnaryUnion(self, where=where).to_expr() @public diff --git a/ibis/expr/types/groupby.py b/ibis/expr/types/groupby.py index 4f3835153f3e..8f3fba1bab60 100644 --- a/ibis/expr/types/groupby.py +++ b/ibis/expr/types/groupby.py @@ -16,101 +16,58 @@ from __future__ import annotations -import itertools -import types from typing import TYPE_CHECKING from public import public import ibis import ibis.common.exceptions as com -import ibis.expr.analysis as an +import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.types as ir -from ibis import util -from ibis.common.deferred import Deferred -from ibis.expr.types.relations import bind_expr -from ibis.selectors import Selector +from ibis.common.grounds import Concrete +from ibis.common.typing import VarTuple # noqa: TCH001 +from ibis.expr.rewrites import rewrite_window_input +from ibis.expr.types.relations import bind if TYPE_CHECKING: from collections.abc import Iterable, Sequence -_function_types = tuple( - filter( - None, - ( - types.BuiltinFunctionType, - types.BuiltinMethodType, - types.FunctionType, - types.LambdaType, - types.MethodType, - getattr(types, "UnboundMethodType", None), - ), - ) -) - - -def _get_group_by_key(table, value): - if isinstance(value, str): - yield table[value] - elif isinstance(value, _function_types): - yield value(table) - elif isinstance(value, Deferred): - yield value.resolve(table) - elif isinstance(value, Selector): - yield from value.expand(table) - elif isinstance(value, ir.Expr): - yield an.sub_immediate_parents(value.op(), table.op()).to_expr() - else: - yield value - - @public -class GroupedTable: +class GroupedTable(Concrete): """An intermediate table expression to hold grouping information.""" - def __init__(self, table, by, having=None, order_by=None, **expressions): - self.table = table - self.by = list( - itertools.chain( - itertools.chain.from_iterable( - _get_group_by_key(table, v) for v in util.promote_list(by) - ), - ( - expr.name(k) - for k, v in expressions.items() - for expr in _get_group_by_key(table, v) - ), - ) - ) - - if not self.by: - raise com.IbisInputError("The grouping keys list is empty") + table: ops.Relation + groupings: VarTuple[ops.Column] + orderings: VarTuple[ops.SortKey] = () + havings: VarTuple[ops.Value[dt.Boolean]] = () - self._order_by = order_by or [] - self._having = having or [] + def __init__(self, groupings, **kwargs): + if not groupings: + raise com.IbisInputError("No group keys provided") + super().__init__(groupings=groupings, **kwargs) def __getitem__(self, args): # Shortcut for projection with window functions return self.select(*args) def __getattr__(self, attr): - if hasattr(self.table, attr): - return self._column_wrapper(attr) + try: + field = getattr(self.table.to_expr(), attr) + except AttributeError as e: + raise AttributeError(f"GroupedTable has no attribute {attr}") from e - raise AttributeError("GroupBy has no attribute %r" % attr) - - def _column_wrapper(self, attr): - col = self.table[attr] - if isinstance(col, ir.NumericValue): - return GroupedNumbers(col, self) + if isinstance(field, ir.NumericValue): + return GroupedNumbers(field, self) else: - return GroupedArray(col, self) + return GroupedArray(field, self) - def aggregate(self, metrics=None, **kwds) -> ir.Table: + def aggregate(self, metrics=(), **kwds) -> ir.Table: """Compute aggregates over a group by.""" - return self.table.aggregate(metrics, by=self.by, having=self._having, **kwds) + return self.table.to_expr().aggregate( + metrics, by=self.groupings, having=self.havings, **kwds + ) agg = aggregate @@ -131,12 +88,9 @@ def having(self, expr: ir.BooleanScalar) -> GroupedTable: GroupedTable A grouped table expression """ - return self.__class__( - self.table, - self.by, - having=self._having + util.promote_list(expr), - order_by=self._order_by, - ) + table = self.table.to_expr() + havings = tuple(bind(table, expr)) + return self.copy(havings=self.havings + havings) def order_by(self, expr: ir.Value | Iterable[ir.Value]) -> GroupedTable: """Sort a grouped table expression by `expr`. @@ -155,12 +109,9 @@ def order_by(self, expr: ir.Value | Iterable[ir.Value]) -> GroupedTable: GroupedTable A sorted grouped GroupedTable """ - return self.__class__( - self.table, - self.by, - having=self._having, - order_by=self._order_by + util.promote_list(expr), - ) + table = self.table.to_expr() + orderings = tuple(bind(table, expr)) + return self.copy(orderings=self.orderings + orderings) def mutate( self, *exprs: ir.Value | Sequence[ir.Value], **kwexprs: ir.Value @@ -230,7 +181,7 @@ def mutate( A table expression with window functions applied """ exprs = self._selectables(*exprs, **kwexprs) - return self.table.mutate(exprs) + return self.table.to_expr().mutate(exprs) def select(self, *exprs, **kwexprs) -> ir.Table: """Project new columns out of the grouped table. @@ -240,7 +191,7 @@ def select(self, *exprs, **kwexprs) -> ir.Table: [`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate) """ exprs = self._selectables(*exprs, **kwexprs) - return self.table.select(exprs) + return self.table.to_expr().select(exprs) def _selectables(self, *exprs, **kwexprs): """Project new columns out of the grouped table. @@ -249,22 +200,14 @@ def _selectables(self, *exprs, **kwexprs): -------- [`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate) """ - table = self.table - default_frame = ops.RowsWindowFrame( + table = self.table.to_expr() + frame = ops.RowsWindowFrame( table=self.table, - group_by=bind_expr(self.table, self.by), - order_by=bind_expr(self.table, self._order_by), + group_by=self.groupings, + order_by=self.orderings, ) - return [ - an.windowize_function(e2, default_frame) - for expr in exprs - for e1 in util.promote_list(expr) - for e2 in util.promote_list(table._ensure_expr(e1)) - ] + [ - an.windowize_function(e, default_frame).name(k) - for k, expr in kwexprs.items() - for e in util.promote_list(table._ensure_expr(expr)) - ] + values = bind(table, (exprs, kwexprs)) + return [rewrite_window_input(expr.op(), frame).to_expr() for expr in values] projection = select @@ -321,8 +264,8 @@ def count(self) -> ir.Table: Table The aggregated table """ - metric = self.table.count() - return self.table.aggregate([metric], by=self.by, having=self._having) + table = self.table.to_expr() + return table.aggregate([table.count()], by=self.groupings, having=self.havings) size = count diff --git a/ibis/expr/types/joins.py b/ibis/expr/types/joins.py new file mode 100644 index 000000000000..bde607b1d3ee --- /dev/null +++ b/ibis/expr/types/joins.py @@ -0,0 +1,247 @@ +from ibis.expr.types.relations import ( + bind, + dereference_values, + unwrap_aliases, +) +from public import public +import ibis.expr.operations as ops +from ibis.expr.types import Table, ValueExpr +from typing import Any, Optional +from collections.abc import Iterator, Mapping +from ibis.common.deferred import Deferred +from ibis.expr.analysis import flatten_predicates +from ibis.expr.operations.relations import JoinKind +from ibis.common.exceptions import ExpressionError, IntegrityError +from ibis import util +import functools +from ibis.expr.types.relations import dereference_mapping +import ibis + + +def disambiguate_fields(how, left_fields, right_fields, lname, rname): + collisions = set() + + if how in ("semi", "anti"): + # discard the right fields per left semi and left anty join semantics + return left_fields, collisions + + lname = lname or "{name}" + rname = rname or "{name}" + overlap = left_fields.keys() & right_fields.keys() + + fields = {} + for name, field in left_fields.items(): + if name in overlap: + name = lname.format(name=name) + fields[name] = field + for name, field in right_fields.items(): + if name in overlap: + name = rname.format(name=name) + # only add if there is no collision + if name in fields: + collisions.add(name) + else: + fields[name] = field + + return fields, collisions + + +def dereference_targets(chain): + yield chain.first + for join in chain.rest: + if join.how not in ("semi", "anti"): + yield join.table + + +def dereference_mapping_left(chain): + rels = dereference_targets(chain) + subs = dereference_mapping(rels) + # join chain fields => link table fields + for k, v in chain.values.items(): + subs[ops.Field(chain, k)] = v + return subs + + +def dereference_mapping_right(right): + if isinstance(right, ops.SelfReference): + # no support for dereferencing, the user must use the right table + # directly in the predicates + return {}, right + + # wrap the right table in a self reference to ensure its uniqueness in the + # join chain which requires dereferencing the predicates from + # right => SelfReference(right) + right = ops.SelfReference(right) + subs = {v: ops.Field(right, k) for k, v in right.values.items()} + return subs, right + + +def dereference_sides(left, right, deref_left, deref_right): + left = left.replace(deref_left, filter=ops.Value) + right = right.replace(deref_right, filter=ops.Value) + return left, right + + +def dereference_binop(pred, deref_left, deref_right): + left, right = dereference_sides(pred.left, pred.right, deref_left, deref_right) + return pred.copy(left=left, right=right) + + +def dereference_value(pred, deref_left, deref_right): + deref_both = {**deref_left, **deref_right} + if isinstance(pred, ops.Binary) and pred.left == pred.right: + return dereference_binop(pred, deref_left, deref_right) + else: + return pred.replace(deref_both, filter=ops.Value) + + +def prepare_predicates(left, right, predicates, deref_left, deref_right, deref_both): + """Bind and dereference predicates to the left and right tables.""" + + for pred in util.promote_list(predicates): + if pred is True or pred is False: + yield ops.Literal(pred, dtype="bool") + elif isinstance(pred, ValueExpr): + node = pred.op() + yield dereference_value(node, deref_left, deref_right) + # yield node.replace(deref_both, filter=ops.Value) + elif isinstance(pred, Deferred): + # resolve deferred expressions on the left table + node = pred.resolve(left).op() + yield dereference_value(node, deref_left, deref_right) + # yield node.replace(deref_both, filter=ops.Value) + else: + if isinstance(pred, tuple): + if len(pred) != 2: + raise ExpressionError("Join key tuple must be length 2") + lk, rk = pred + else: + lk = rk = pred + + # bind the predicates to the join chain + (left_value,) = bind(left, lk) + (right_value,) = bind(right, rk) + + # dereference the left value to one of the relations in the join chain + left_value, right_value = dereference_sides( + left_value.op(), right_value.op(), deref_left, deref_right + ) + yield ops.Equals(left_value, right_value).to_expr() + + +def finished(method): + """Decorator to ensure the join chain is finished before calling a method.""" + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + return method(self._finish(), *args, **kwargs) + + return wrapper + + +@public +class JoinExpr(Table): + __slots__ = ("_collisions",) + + def __init__(self, arg, collisions=None): + super().__init__(arg) + object.__setattr__(self, "_collisions", collisions or set()) + + def _finish(self) -> Table: + """Construct a valid table expression from this join expression.""" + if self._collisions: + raise IntegrityError(f"Name collisions: {self._collisions}") + return Table(self.op()) + + def join( + self, + right, + predicates: Any, + how: JoinKind = "inner", + *, + lname: str = "", + rname: str = "{name}_right", + ): + """Join with another table.""" + import pyarrow as pa + import pandas as pd + + if isinstance(right, (pd.DataFrame, pa.Table)): + right = ibis.memtable(right) + elif not isinstance(right, Table): + raise TypeError( + f"right operand must be a Table, got {type(right).__name__}" + ) + + if how == "left_semi": + how = "semi" + + left = self.op() + right = right.op() + subs_left = dereference_mapping_left(left) + subs_right, right = dereference_mapping_right(right) + subs_both = {**subs_left, **subs_right} + + # bind and dereference the predicates + preds = prepare_predicates( + left.to_expr(), + right.to_expr(), + predicates, + deref_left=subs_left, + deref_right=subs_right, + deref_both=subs_both, + ) + preds = flatten_predicates(list(preds)) + + # calculate the fields based in lname and rname, this should be a best + # effort to avoid collisions, but does not raise if there are any + # if no disambiaution happens using a final .select() call, then + # the finish() method will raise due to the name collisions + values, collisions = disambiguate_fields( + how, left.values, right.fields, lname, rname + ) + + # construct a new join link and add it to the join chain + link = ops.JoinLink(how, table=right, predicates=preds) + left = left.copy(rest=left.rest + (link,), values=values) + + # return with a new JoinExpr wrapping the new join chain + return self.__class__(left, collisions=collisions) + + def select(self, *args, **kwargs): + """Select expressions.""" + chain = self.op() + values = bind(self, (args, kwargs)) + values = unwrap_aliases(values) + + # if there are values referencing fields from the join chain constructed + # so far, we need to replace them the fields from one of the join links + subs = dereference_mapping_left(chain) + values = {k: v.replace(subs, filter=ops.Value) for k, v in values.items()} + + node = chain.copy(values=values) + return Table(node) + + aggregate = finished(Table.aggregate) + alias = finished(Table.alias) + cast = finished(Table.cast) + compile = finished(Table.compile) + count = finished(Table.count) + difference = finished(Table.difference) + distinct = finished(Table.distinct) + drop = finished(Table.drop) + dropna = finished(Table.dropna) + execute = finished(Table.execute) + fillna = finished(Table.fillna) + filter = finished(Table.filter) + group_by = finished(Table.group_by) + intersect = finished(Table.intersect) + limit = finished(Table.limit) + mutate = finished(Table.mutate) + nunique = finished(Table.nunique) + order_by = finished(Table.order_by) + sample = finished(Table.sample) + sql = finished(Table.sql) + unbind = finished(Table.unbind) + union = finished(Table.union) + view = finished(Table.view) diff --git a/ibis/expr/types/logical.py b/ibis/expr/types/logical.py index ab0574546a6a..09927223f2ac 100644 --- a/ibis/expr/types/logical.py +++ b/ibis/expr/types/logical.py @@ -244,6 +244,10 @@ class BooleanColumn(NumericColumn, BooleanValue): def any(self, where: BooleanValue | None = None) -> BooleanValue: """Return whether at least one element is `True`. + If the expression does not reference any foreign tables, the result + will be a scalar reduction, otherwise it will be a deferred expression + constructing an exists subquery when passed to a table method. + Parameters ---------- where @@ -254,6 +258,41 @@ def any(self, where: BooleanValue | None = None) -> BooleanValue: BooleanValue Whether at least one element is `True`. + Notes + ----- + Consider the following ibis expressions + + ```python + import ibis + + t = ibis.table(dict(a="string")) + s = ibis.table(dict(a="string")) + + cond = (t.a == s.a).any() + ``` + + Without knowing the table to use as the outer query there are two ways to + turn this expression into a SQL `EXISTS` predicate, depending on which of + `t` or `s` is filtered on. + + Filtering from `t`: + + ```sql + SELECT * + FROM t + WHERE EXISTS (SELECT 1 FROM s WHERE t.a = s.a) + ``` + + Filtering from `s`: + + ```sql + SELECT * + FROM s + WHERE EXISTS (SELECT 1 FROM t WHERE t.a = s.a) + ``` + + Notably the correlated subquery cannot stand on its own. + Examples -------- >>> import ibis @@ -267,17 +306,25 @@ def any(self, where: BooleanValue | None = None) -> BooleanValue: >>> (t.arr == None).any(where=t.arr != None) False """ - import ibis.expr.analysis as an + from ibis.common.deferred import Call, _, Deferred - tables = an.find_immediate_parent_tables(self.op()) + parents = self.op().relations - if len(tables) > 1: - op = ops.UnresolvedExistsSubquery( - tables=[t.to_expr() for t in tables], - predicates=an.find_predicates(self.op(), flatten=True), - ) - else: + def resolve_exists_subquery(outer): + """An exists subquery whose outer leaf table is unknown.""" + (inner,) = (t for t in parents if t != outer.op()) + relation = ops.Project(ops.Filter(inner, [self]), {"1": 1}) + return ops.ExistsSubquery(relation).to_expr() + + if len(parents) == 2: + return Deferred(Call(resolve_exists_subquery, _)) + elif len(parents) == 1: op = ops.Any(self, where=self._bind_reduction_filter(where)) + else: + raise NotImplementedError( + f'Cannot compute "any" for expression of type {type(self)} ' + f"with multiple foreign tables" + ) return op.to_expr() diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index c7d2b47bed65..7617d927c399 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -1,14 +1,16 @@ from __future__ import annotations -import collections -import contextlib -import functools import itertools import operator import re -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Iterable, Iterator, Mapping, Sequence from keyword import iskeyword -from typing import TYPE_CHECKING, Callable, Literal +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, +) import toolz from public import public @@ -19,9 +21,10 @@ import ibis.expr.operations as ops import ibis.expr.schema as sch from ibis import util -from ibis.common.deferred import Deferred, Resolver +from ibis.common.deferred import Deferred from ibis.expr.types.core import Expr, _FixedTextJupyterMixin -from ibis.expr.types.generic import literal +from ibis.expr.types.generic import TableExpr, ValueExpr, literal +from ibis.selectors import Selector if TYPE_CHECKING: import pandas as pd @@ -30,31 +33,16 @@ import ibis.expr.types as ir import ibis.selectors as s from ibis.common.typing import SupportsSchema + from ibis.expr.operations.relations import JoinKind + from ibis.expr.types import Table from ibis.expr.types.groupby import GroupedTable from ibis.expr.types.tvf import WindowedTable from ibis.formats.pyarrow import PyArrowData - from ibis.selectors import IfAnyAll, Selector + from ibis.selectors import IfAnyAll _ALIASES = (f"_ibis_view_{n:d}" for n in itertools.count()) -def _ensure_expr(table, expr): - from ibis.selectors import Selector - - # This is different than self._ensure_expr, since we don't want to - # treat `str` or `int` values as column indices - if isinstance(expr, Expr): - return expr - elif util.is_function(expr): - return expr(table) - elif isinstance(expr, Deferred): - return expr.resolve(table) - elif isinstance(expr, Selector): - return expr.expand(table) - else: - return literal(expr) - - def _regular_join_method( name: str, how: Literal[ @@ -69,8 +57,8 @@ def _regular_join_method( ], ): def f( # noqa: D417 - self: Table, - right: Table, + self: ir.Table, + right: ir.Table, predicates: str | Sequence[ str | tuple[str | ir.Column, str | ir.Column] | ir.BooleanValue @@ -78,7 +66,7 @@ def f( # noqa: D417 *, lname: str = "", rname: str = "{name}_right", - ) -> Table: + ) -> ir.Table: """Perform a join between two tables. Parameters @@ -105,6 +93,118 @@ def f( # noqa: D417 return f +# TODO(kszucs): should use (table, *args, **kwargs) instead to avoid interpreting +# nested inputs +def bind(table: TableExpr, value: Any, prefer_column=True) -> Iterator[ir.Value]: + """Bind a value to a table expression.""" + if prefer_column and isinstance(value, (str, int)): + yield table._get_column(value) + elif isinstance(value, ValueExpr): + yield value + elif isinstance(value, TableExpr): + for name in value.columns: + yield value._get_column(name) + elif isinstance(value, Deferred): + yield value.resolve(table) + elif isinstance(value, Selector): + yield from value.expand(table) + elif isinstance(value, Mapping): + for k, v in value.items(): + for val in bind(table, v, prefer_column=prefer_column): + yield val.name(k) + elif util.is_iterable(value): + for v in value: + yield from bind(table, v, prefer_column=prefer_column) + elif isinstance(value, ops.Value): + # TODO(kszucs): from certain builders, like ir.GroupedTable we pass + # operation nodes instead of expressions to table methods, it would + # be better to convert them to expressions before passing them to + # this function + yield value.to_expr() + elif callable(value): + yield value(table) + else: + yield literal(value) + + +def unwrap_aliases(values: Iterator[ir.Value]) -> Mapping[str, ir.Value]: + """Unwrap aliases into a mapping of {name: expression}.""" + result = {} + for value in values: + node = value.op() + if node.name in result: + raise com.IntegrityError( + f"Duplicate column name {node.name!r} in result set" + ) + if isinstance(node, ops.Alias): + result[node.name] = node.arg + else: + result[node.name] = node + return result + + +def dereference_mapping(parents): + mapping = {} + parents = util.promote_list(parents) + for parent in parents: + for k, v in parent.values.items(): + if isinstance(v, ops.Field): + # track down the field in the hierarchy until no modification + # is made so only follow ops.Field nodes not arbitrary values; + # also stop tracking if the field belongs to a parent which + # we want to dereference to, see the docstring of + # `dereference_values()` for more details + while isinstance(v, ops.Field) and v.rel not in parents: + mapping[v] = ops.Field(parent, k) + v = v.rel.values.get(v.name) + elif v.relations: + # do not dereference literal expressions + mapping[v] = ops.Field(parent, k) + return mapping + + +def dereference_values( + parents: Iterable[ops.Parents], values: Mapping[str, ops.Value] +) -> Mapping[str, ops.Value]: + """Trace and replace fields from earlier relations in the hierarchy. + + In order to provide a nice user experience, we need to allow expressions + from earlier relations in the hierarchy. Consider the following example: + + t = ibis.table([('a', 'int64'), ('b', 'string')], name='t') + t1 = t.select([t.a, t.b]) + t2 = t1.filter(t.a > 0) # note that not t1.a is referenced here + t3 = t2.select(t.a) # note that not t2.a is referenced here + + However the relational operations in the IR are strictly enforcing that + the expressions are referencing the immediate parent only. So we need to + track fields upwards the hierarchy to replace `t.a` with `t1.a` and `t2.a` + in the example above. This is called dereferencing. + + Whether we can treat or not a field of a relation semantically equivalent + with a field of an earlier relation in the hierarchy depends on the + `.values` mapping of the relation. Leaf relations, like `t` in the example + above, have an empty `.values` mapping, so we cannot dereference fields + from them. On the other hand a projection, like `t1` in the example above, + has a `.values` mapping like `{'a': t.a, 'b': t.b}`, so we can deduce that + `t1.a` is semantically equivalent with `t.a` and so on. + + Parameters + ---------- + parents + The relations we want the values to point to. + values + The values to dereference. + + Returns + ------- + The same mapping as `values` but with all the dereferenceable fields + replaced with the fields from the parents. + """ + subs = dereference_mapping(parents) + return {k: v.replace(subs, filter=ops.Value) for k, v in values.items()} + + @public class Table(Expr, _FixedTextJupyterMixin): """An immutable and lazy dataframe. @@ -387,6 +487,13 @@ def __interactive_rich_console__(self, console, options): raise e return console.render(table, options=options) + # TODO(kszucs): expose this method in the public API + def _get_column(self, name: str | int) -> ir.Column: + """Get a column from the table.""" + if isinstance(name, int): + name = self.schema().name_at_position(name) + return ops.Field(self, name).to_expr() + def __getitem__(self, what): """Select items from a table expression. @@ -631,33 +738,24 @@ def __getitem__(self, what): │ 36.7 │ 19.3 │ 193 │ 3450 │ └────────────────┴───────────────┴───────────────────┴─────────────┘ """ - from ibis.expr.types.generic import Column from ibis.expr.types.logical import BooleanValue if isinstance(what, (str, int)): - return ops.TableColumn(self, what).to_expr() - - if isinstance(what, slice): + return self._get_column(what) + elif isinstance(what, slice): limit, offset = util.slice_to_limit_offset(what, self.count()) return self.limit(limit, offset=offset) - - what = bind_expr(self, what) - - if isinstance(what, (list, tuple, Table)): + elif isinstance(what, (list, tuple, Table)): # Projection case return self.select(what) - elif isinstance(what, BooleanValue): - # Boolean predicate + + (what,) = bind(self, what) + if isinstance(what, BooleanValue): + # TODO(kszucs): this branch should be removed, .filter should be + # used instead return self.filter([what]) - elif isinstance(what, Column): - # Projection convenience - return self.select(what) else: - raise NotImplementedError( - "Selection rows or columns with {} objects is not supported".format( - type(what).__name__ - ) - ) + return self.select(what) def __len__(self): raise com.ExpressionError("Use .count() instead") @@ -699,8 +797,10 @@ def __getattr__(self, key: str) -> ir.Column: │ … │ └───────────┘ """ - with contextlib.suppress(com.IbisTypeError): - return ops.TableColumn(self, key).to_expr() + try: + return self._get_column(key) + except com.IbisTypeError: + pass # A mapping of common attribute typos, mapping them to the proper name common_typos = { @@ -715,6 +815,7 @@ def __getattr__(self, key: str) -> ir.Column: raise AttributeError( f"{type(self).__name__} object has no attribute {key!r}, did you mean {hint!r}" ) + raise AttributeError(f"'Table' object has no attribute {key!r}") def __dir__(self) -> list[str]: @@ -725,28 +826,6 @@ def __dir__(self) -> list[str]: def _ipython_key_completions_(self) -> list[str]: return self.columns - def _ensure_expr(self, expr): - import numpy as np - - from ibis.selectors import Selector - - if isinstance(expr, str): - # treat strings as column names - return self[expr] - elif isinstance(expr, (int, np.integer)): - # treat Python integers as a column index - return self[self.schema().name_at_position(expr)] - elif isinstance(expr, Deferred): - return expr.resolve(self) - elif isinstance(expr, Resolver): - return expr.resolve({"_": self}) - elif isinstance(expr, Selector): - return expr.expand(self) - elif callable(expr): - return expr(self) - else: - return expr - @property def columns(self) -> list[str]: """The list of column names in this table. @@ -797,7 +876,7 @@ def schema(self) -> sch.Schema: def group_by( self, - by: str | ir.Value | Iterable[str] | Iterable[ir.Value] | None = None, + by: str | ir.Value | Iterable[str] | Iterable[ir.Value] | None = (), **key_exprs: str | ir.Value | Iterable[str] | Iterable[ir.Value], ) -> GroupedTable: """Create a grouped table expression. @@ -853,8 +932,13 @@ def group_by( """ from ibis.expr.types.groupby import GroupedTable - return GroupedTable(self, by, **key_exprs) + if by is None: + by = () + groups = bind(self, (by, key_exprs)) + return GroupedTable(self, groups) + + # TODO(kszucs): shouldn't this be ibis.rowid() instead not bound to a specific table? def rowid(self) -> ir.IntegerValue: """A unique integer per row. @@ -890,7 +974,10 @@ def view(self) -> Table: Table Table expression """ - return ops.SelfReference(self).to_expr() + if isinstance(self.op(), ops.SelfReference): + return self + else: + return ops.SelfReference(self).to_expr() def difference(self, table: Table, *rest: Table, distinct: bool = True) -> Table: """Compute the set difference of multiple table expressions. @@ -955,9 +1042,9 @@ def difference(self, table: Table, *rest: Table, distinct: bool = True) -> Table def aggregate( self, - metrics: Sequence[ir.Scalar] | None = None, - by: Sequence[ir.Value] | None = None, - having: Sequence[ir.BooleanValue] | None = None, + metrics: Sequence[ir.Scalar] | None = (), + by: Sequence[ir.Value] | None = (), + having: Sequence[ir.BooleanValue] | None = (), **kwargs: ir.Value, ) -> Table: """Aggregate a table with a given set of reductions grouping by `by`. @@ -1022,33 +1109,46 @@ def aggregate( │ orange │ 0.33 │ 0.33 │ └────────┴────────────┴──────────┘ """ - import ibis.expr.analysis as an - - metrics = itertools.chain( - itertools.chain.from_iterable( - ( - (_ensure_expr(self, m) for m in metric) - if isinstance(metric, (list, tuple)) - else util.promote_list(_ensure_expr(self, metric)) - ) - for metric in util.promote_list(metrics) - ), - ( - e.name(name) - for name, expr in kwargs.items() - for e in util.promote_list(_ensure_expr(self, expr)) - ), - ) + from ibis.common.patterns import Contains, In + from ibis.expr.rewrites import p + + node = self.op() + + groups = bind(self, by) + metrics = bind(self, (metrics, kwargs)) + having = bind(self, having) + + groups = unwrap_aliases(groups) + metrics = unwrap_aliases(metrics) + having = unwrap_aliases(having) + + groups = dereference_values(self.op(), groups) + metrics = dereference_values(self.op(), metrics) + having = dereference_values(self.op(), having) + + # the user doesn't need to specify the metrics used in the having clause + # explicitly, we implicitly add them to the metrics list by looking for + # any metrics depending on self which are not specified explicitly + pattern = p.Reduction(relations=Contains(node)) & ~In(set(metrics.values())) + original_metrics = metrics.copy() + for pred in having.values(): + for metric in pred.find_topmost(pattern): + if metric.name in metrics: + metrics[util.get_name("metric")] = metric + else: + metrics[metric.name] = metric - agg = ops.Aggregation( - self, - metrics=list(metrics), - by=bind_expr(self, util.promote_list(by)), - having=bind_expr(self, util.promote_list(having)), - ) - agg = an.simplify_aggregation(agg) + # construct the aggregate node + agg = ops.Aggregate(node, groups, metrics).to_expr() + + if having: + # apply the having clause + agg = agg.filter(*having.values()) + # remove any metrics that were only used in the having clause + if metrics != original_metrics: + agg = agg.select(*groups.keys(), *original_metrics.keys()) - return agg.to_expr() + return agg agg = aggregate @@ -1555,22 +1655,14 @@ def order_by( │ 2 │ B │ 6 │ └───────┴────────┴───────┘ """ - import ibis.selectors as s - - sort_keys = [] - for item in util.promote_list(by): - if isinstance(item, tuple): - if len(item) != 2: - raise ValueError(f"Tuple must be of length 2, got {len(item):d}") - sort_keys.append(bind_expr(self, item[0]), item[1]) - elif isinstance(item, s.Selector): - sort_keys.extend(item.expand(self)) - else: - sort_keys.append(bind_expr(self, item)) - - if not sort_keys: + keys = bind(self, by) + keys = unwrap_aliases(keys) + keys = dereference_values(self.op(), keys) + if not keys: raise com.IbisError("At least one sort key must be provided") - return self.op().order_by(sort_keys).to_expr() + + node = ops.Sort(self, keys.values()) + return node.to_expr() def union(self, table: Table, *rest: Table, distinct: bool = False) -> Table: """Compute the set union of multiple table expressions. @@ -1707,25 +1799,7 @@ def intersect(self, table: Table, *rest: Table, distinct: bool = True) -> Table: node = ops.Intersection(node, table, distinct=distinct) return node.to_expr().select(self.columns) - def to_array(self) -> ir.Column: - """View a single column table as an array. - - Returns - ------- - Value - A single column view of a table - """ - schema = self.schema() - if len(schema) != 1: - raise com.ExpressionError( - "Table must have exactly one column when viewed as array" - ) - - return ops.TableArrayView(self).to_expr() - - def mutate( - self, exprs: Sequence[ir.Expr] | None = None, **mutations: ir.Value - ) -> Table: + def mutate(self, *exprs: Sequence[ir.Expr] | None, **mutations: ir.Value) -> Table: """Add columns to a table expression. Parameters @@ -1811,28 +1885,14 @@ def mutate( │ Adelie │ 2007 │ -7.22193 │ └─────────┴───────┴────────────────┘ """ - import ibis.expr.analysis as an - - exprs = [] if exprs is None else util.promote_list(exprs) - - new_exprs = [] - - for expr in exprs: - if isinstance(expr, Mapping): - new_exprs.extend( - _ensure_expr(self, val).name(name) for name, val in expr.items() - ) - else: - new_exprs.extend(util.promote_list(_ensure_expr(self, expr))) - - new_exprs.extend( - e.name(name) - for name, expr in mutations.items() - for e in util.promote_list(_ensure_expr(self, expr)) - ) - - mutation_exprs = an.get_mutation_exprs(new_exprs, self) - return self.select(mutation_exprs) + # string and integer inputs are going to be coerced to literals instead + # of interpreted as column references like in select + node = self.op() + values = bind(self, (exprs, mutations), prefer_column=False) + values = unwrap_aliases(values) + # allow overriding of fields, hence the mutation behavior + values = {**node.fields, **values} + return self.select(**values) def select( self, @@ -2011,39 +2071,22 @@ def select( │ 43.92193 │ 17.15117 │ 200.915205 │ 4201.754386 │ └────────────────┴───────────────┴───────────────────┴─────────────┘ """ - import ibis.expr.analysis as an - from ibis.selectors import Selector + from ibis.expr.rewrites import rewrite_project_input - new_exprs = [] - - for expr in exprs: - if isinstance(expr, Selector): - new_exprs.extend(expr.expand(self)) - elif isinstance(expr, Mapping): - new_exprs.extend( - self._ensure_expr(value).name(name) for name, value in expr.items() - ) - else: - new_exprs.extend(map(self._ensure_expr, util.promote_list(expr))) - - new_exprs.extend( - self._ensure_expr(expr).name(name) for name, expr in named_exprs.items() - ) - - if not new_exprs: + values = bind(self, (exprs, named_exprs)) + values = unwrap_aliases(values) + values = dereference_values(self.op(), values) + if not values: raise com.IbisTypeError( "You must select at least one column for a valid projection" ) - for ex in new_exprs: - if not isinstance(ex, Expr): - raise com.IbisTypeError( - "All arguments to `.select` must be coerceable to " - f"expressions - got {type(ex)!r}" - ) - op = an.Projector(self, new_exprs).get_result() - - return op.to_expr() + # we need to detect reductions which are either turned into window functions + # or scalar subqueries depending on whether they are originating from self + values = { + k: rewrite_project_input(v, relation=self.op()) for k, v in values.items() + } + return ops.Project(self, values).to_expr() projection = select @@ -2355,7 +2398,7 @@ def drop(self, *fields: str | Selector) -> Table: def filter( self, - predicates: ir.BooleanValue | Sequence[ir.BooleanValue] | IfAnyAll, + *predicates: ir.BooleanValue | Sequence[ir.BooleanValue] | IfAnyAll, ) -> Table: """Select rows from `table` based on `predicates`. @@ -2404,11 +2447,17 @@ def filter( │ male │ 68 │ └────────┴───────────┘ """ - import ibis.expr.analysis as an - - resolved_predicates = _resolve_predicates(self, predicates) - relation = an.pushdown_selection_filters(self.op(), resolved_predicates) - return relation.to_expr() + from ibis.expr.analysis import flatten_predicates + from ibis.expr.rewrites import rewrite_filter_input + + preds = bind(self, predicates) + preds = unwrap_aliases(preds) + preds = dereference_values(self.op(), preds) + preds = flatten_predicates(list(preds.values())) + preds = list(map(rewrite_filter_input, preds)) + if not preds: + raise com.IbisInputError("You must pass at least one predicate to filter") + return ops.Filter(self, preds).to_expr() def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: """Compute the number of unique rows in the table. @@ -2537,7 +2586,7 @@ def dropna( 344 """ if subset is not None: - subset = bind_expr(self, util.promote_list(subset)) + subset = bind(self, subset) return ops.DropNa(self, how, subset).to_expr() def fillna( @@ -2609,7 +2658,7 @@ def fillna( """ schema = self.schema() - if isinstance(replacements, collections.abc.Mapping): + if isinstance(replacements, Mapping): for col, val in replacements.items(): if col not in schema: columns_formatted = ", ".join(map(repr, schema.names)) @@ -2766,17 +2815,7 @@ def join( str | ir.Column | ir.Deferred, ] ] = (), - how: Literal[ - "inner", - "left", - "outer", - "right", - "semi", - "anti", - "any_inner", - "any_left", - "left_semi", - ] = "inner", + how: JoinKind = "inner", *, lname: str = "", rname: str = "{name}_right", @@ -2935,29 +2974,26 @@ def join( │ 106782 │ Leonardo DiCaprio │ 5989 │ Leonardo DiCaprio │ └─────────┴───────────────────┴───────────────┴───────────────────┘ """ + from ibis.expr.types.joins import JoinExpr + + # the first participant of the join can be any Relation, but the rest + # must be wrapped in SelfReferences so that we can join the same table + # with itself multiple times and to enable optimization passes later on + left = left.op() + if isinstance(left, ops.JoinChain): + # if the left side is already a join chain, we can reuse it, for + # example in the `a.join(b)[fields].join(c)` expression the first + # join followed by a projection `a.join(b)[...]` constructs a + # `ir.Table(ops.JoinChain())` expression, which we can reuse here + expr = left.to_expr() + else: + if isinstance(left, ops.SelfReference): + left = left.parent + # construct an empty join chain and wrap it with a JoinExpr, the + # projected fields are the fields of the starting table + expr = ops.JoinChain(left, rest=(), values=left.fields).to_expr() - _join_classes = { - "inner": ops.InnerJoin, - "left": ops.LeftJoin, - "any_inner": ops.AnyInnerJoin, - "any_left": ops.AnyLeftJoin, - "outer": ops.OuterJoin, - "right": ops.RightJoin, - "left_semi": ops.LeftSemiJoin, - "semi": ops.LeftSemiJoin, - "anti": ops.LeftAntiJoin, - "cross": ops.CrossJoin, - } - - klass = _join_classes[how.lower()] - expr = klass(left, right, predicates).to_expr() - - # semi/anti join only give access to the left table's fields, so - # there's never overlap - if how in ("left_semi", "semi", "anti"): - return expr - - return ops.relations._dedup_join_columns(expr, lname=lname, rname=rname) + return expr.join(right, predicates, how=how, lname=lname, rname=rname) def asof_join( left: Table, @@ -3000,14 +3036,22 @@ def asof_join( Table Table expression """ - op = ops.AsOfJoin( - left=left, - right=right, - predicates=predicates, - by=by, - tolerance=tolerance, - ) - return ops.relations._dedup_join_columns(op.to_expr(), lname=lname, rname=rname) + if by: + # `by` is an argument that comes from pandas, which for pandas was + # a convenient and fast way to perform a standard join before the + # asof join, so we implement the equivalent behavior here for + # consistency across backends. + left = left.join(right, by, lname=lname, rname=rname) + + if tolerance is not None: + if not isinstance(predicates, str): + raise TypeError( + "tolerance can only be specified when predicates is a string" + ) + left_key, right_key = left[predicates], right[predicates] + predicates = [left_key == right_key, left_key - right_key <= tolerance] + + return left.join(right, predicates, how="asof", lname=lname, rname=rname) def cross_join( left: Table, @@ -3083,12 +3127,12 @@ def cross_join( >>> expr.count() 344 """ - op = ops.CrossJoin( - left, - functools.reduce(Table.cross_join, rest, right), - [], - ) - return ops.relations._dedup_join_columns(op.to_expr(), lname=lname, rname=rname) + left = left.join(right, how="cross", predicates=(), lname=lname, rname=rname) + for right in rest: + left = left.join( + right, how="cross", predicates=(), lname=lname, rname=rname + ) + return left inner_join = _regular_join_method("inner_join", "inner") left_join = _regular_join_method("left_join", "left") @@ -4344,43 +4388,4 @@ def release(self): return current_backend._release_cached(self) -# TODO(kszucs): used at a single place along with an.apply_filter(), should be -# consolidated into a single function -def _resolve_predicates( - table: Table, predicates -) -> tuple[list[ir.BooleanValue], list[tuple[ir.BooleanValue, ir.Table]]]: - import ibis.expr.types as ir - from ibis.expr.analysis import flatten_predicate, p - - # TODO(kszucs): clean this up, too much flattening and resolving happens here - predicates = [ - pred.op() - for preds in map( - functools.partial(ir.relations.bind_expr, table), - util.promote_list(predicates), - ) - for pred in util.promote_list(preds) - ] - predicates = flatten_predicate(predicates) - - rules = ( - # turn reductions into table array views so that they can be used as - # WHERE t1.`a` = (SELECT max(t1.`a`) AS `Max(a)` - p.Reduction >> (lambda _: ops.TableArrayView(_.to_expr().as_table())) - | - # resolve unresolved exists subqueries to IN subqueries - p.UnresolvedExistsSubquery >> (lambda _: _.resolve(table.op())) - ) - # do not apply the rules below the following nodes - until = p.Value & ~p.WindowFunction & ~p.TableArrayView & ~p.ExistsSubquery - return [pred.replace(rules, filter=until) for pred in predicates] - - -def bind_expr(table, expr): - if util.is_iterable(expr): - return [bind_expr(table, x) for x in expr] - - return table._ensure_expr(expr) - - -public(TableExpr=Table) +public(TableExpr=Table, CachedTableExpr=CachedTable) diff --git a/ibis/expr/types/temporal_windows.py b/ibis/expr/types/temporal_windows.py index d357c127b15a..865a9922e6a2 100644 --- a/ibis/expr/types/temporal_windows.py +++ b/ibis/expr/types/temporal_windows.py @@ -10,36 +10,19 @@ import ibis.expr.types as ir from ibis.common.deferred import Deferred from ibis.selectors import Selector +from ibis.expr.types.relations import bind if TYPE_CHECKING: from ibis.expr.types import Table -def _get_window_by_key(table, value): - if isinstance(value, str): - return table[value] - elif isinstance(value, Deferred): - return value.resolve(table) - elif isinstance(value, Selector): - matches = value.expand(table) - if len(matches) != 1: - raise com.IbisInputError( - "Multiple columns match the selector; only 1 is expected" - ) - return next(iter(matches)) - elif isinstance(value, ir.Expr): - return an.sub_immediate_parents(value.op(), table.op()).to_expr() - else: - return value - - @public class WindowedTable: """An intermediate table expression to hold windowing information.""" def __init__(self, table: ir.Table, time_col: ir.Value): self.table = table - self.time_col = _get_window_by_key(table, time_col) + self.time_col = next(bind(table, time_col)) if self.time_col is None: raise com.IbisInputError( @@ -68,9 +51,10 @@ def tumble( Table Table expression after applying tumbling table-valued function. """ + time_col = next(bind(self.table, self.time_col)) return ops.TumbleWindowingTVF( table=self.table, - time_col=_get_window_by_key(self.table, self.time_col), + time_col=time_col, window_size=window_size, offset=offset, ).to_expr() @@ -106,9 +90,10 @@ def hop( Table Table expression after applying hopping table-valued function. """ + time_col = next(bind(self.table, self.time_col)) return ops.HopWindowingTVF( table=self.table, - time_col=_get_window_by_key(self.table, self.time_col), + time_col=time_col, window_size=window_size, window_slide=window_slide, offset=offset, @@ -143,9 +128,10 @@ def cumulate( Table Table expression after applying cumulate table-valued function. """ + time_col = next(bind(self.table, self.time_col)) return ops.CumulateWindowingTVF( table=self.table, - time_col=_get_window_by_key(self.table, self.time_col), + time_col=time_col, window_size=window_size, window_step=window_step, offset=offset, diff --git a/ibis/expr/visualize.py b/ibis/expr/visualize.py index 0af8f3118336..ef16463251ce 100644 --- a/ibis/expr/visualize.py +++ b/ibis/expr/visualize.py @@ -56,7 +56,7 @@ def get_label(node): node, ( ops.Literal, - ops.TableColumn, + ops.Field, ops.Alias, ops.PhysicalTable, ops.window.RangeWindowFrame, @@ -70,14 +70,14 @@ def get_label(node): label_fmt = "<{}>" label = label_fmt.format(escape(name)) else: - if isinstance(node, ops.TableNode): + if isinstance(node, ops.Relation): label_fmt = "<{}: {}{}>" else: label_fmt = '<{}: {}
:: {}>' # typename is already escaped label = label_fmt.format(escape(nodename), escape(name), typename) else: - if isinstance(node, ops.TableNode): + if isinstance(node, ops.Relation): label_fmt = "<{}{}>" else: label_fmt = '<{}
:: {}>' diff --git a/ibis/selectors.py b/ibis/selectors.py index 9bc5f1a9e654..b094b74839df 100644 --- a/ibis/selectors.py +++ b/ibis/selectors.py @@ -393,7 +393,7 @@ def c(*names: str | ir.Column) -> Predicate: names = frozenset(col if isinstance(col, str) else col.get_name() for col in names) def func(col: ir.Value) -> bool: - schema = col.op().table.schema + schema = col.op().rel.schema if extra_cols := (names - schema.keys()): raise exc.IbisInputError( f"Columns {extra_cols} are not present in {schema.names}" diff --git a/ibis/tests/expr/snapshots/test_format_sql_operations/test_format_sql_query_result/repr.txt b/ibis/tests/expr/snapshots/test_format_sql_operations/test_format_sql_query_result/repr.txt index 1cd75a4812d8..9589a01e618b 100644 --- a/ibis/tests/expr/snapshots/test_format_sql_operations/test_format_sql_query_result/repr.txt +++ b/ibis/tests/expr/snapshots/test_format_sql_operations/test_format_sql_query_result/repr.txt @@ -36,8 +36,7 @@ r1 := SQLStringView[r0]: foo carrier string avg_arrdelay float64 -Selection[r1] - selections: - carrier: r1.carrier - avg_arrdelay: Round(r1.avg_arrdelay, digits=1) - island: Lowercase(r1.carrier) \ No newline at end of file +Project[r1] + carrier: r1.carrier + avg_arrdelay: Round(r1.avg_arrdelay, digits=1) + island: Lowercase(r1.carrier) \ No newline at end of file diff --git a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt index 6266bb50b1cc..b67141c7beda 100644 --- a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt +++ b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt @@ -7,16 +7,25 @@ r1 := DatabaseTable: test1 f float64 g string -r2 := Selection[r1] - predicates: - r1.f > 0 +r2 := Filter[r1] + r1.f > 0 -r3 := InnerJoin[r0, r2] r2.g == r0.key +r3 := SelfReference[r2] -Aggregation[r3] +r4 := JoinChain[r0] + JoinLink[inner, r3] + r3.g == r0.key + values: + key: r0.key + value: r0.value + c: r3.c + f: r3.f + g: r3.g + +Aggregate[r4] + groups: + g: r4.g + key: r4.key metrics: - foo: Mean(r2.f - r0.value) - bar: Sum(r2.f) - by: - g: r2.g - key: r0.key \ No newline at end of file + foo: Mean(r4.f - r4.value) + bar: Sum(r4.f) \ No newline at end of file diff --git a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt index d5e678285698..3514fa501a73 100644 --- a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt +++ b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_insert_sort_key/repr.txt @@ -29,23 +29,20 @@ r0 := DatabaseTable: airlines security_delay int32 late_aircraft_delay int32 -r1 := Selection[r0] - selections: - arrdelay: r0.arrdelay - dest: r0.dest +r1 := Project[r0] + arrdelay: r0.arrdelay + dest: r0.dest -r2 := Selection[r1] - selections: - r1 - dest_avg: WindowFunction(func=Mean(r1.arrdelay), frame=RowsWindowFrame(table=r1, group_by=[r1.dest])) - dev: r1.arrdelay - WindowFunction(func=Mean(r1.arrdelay), frame=RowsWindowFrame(table=r1, group_by=[r1.dest])) +r2 := Project[r1] + arrdelay: r1.arrdelay + dest: r1.dest + dest_avg: WindowFunction(func=Mean(r1.arrdelay), frame=RowsWindowFrame(table=r1, group_by=[r1.dest])) + dev: r1.arrdelay - WindowFunction(func=Mean(r1.arrdelay), frame=RowsWindowFrame(table=r1, group_by=[r1.dest])) -r3 := Selection[r2] - predicates: - NotNull(r2.dev) +r3 := Filter[r2] + NotNull(r2.dev) -r4 := Selection[r3] - sort_keys: - desc r3.dev +r4 := Sort[r3] + desc r3.dev Limit[r4, n=10] \ No newline at end of file diff --git a/ibis/tests/expr/test_analysis.py b/ibis/tests/expr/test_analysis.py index fa430a6d7fbe..a0b52b84cd6f 100644 --- a/ibis/tests/expr/test_analysis.py +++ b/ibis/tests/expr/test_analysis.py @@ -5,16 +5,12 @@ import ibis import ibis.common.exceptions as com import ibis.expr.operations as ops -from ibis.tests.util import assert_equal +from ibis.expr.rewrites import simplify # Place to collect esoteric expression analysis bugs and tests -# TODO(kszucs): not directly using an analysis function anymore, move to a -# more appropriate test module def test_rewrite_join_projection_without_other_ops(con): - # See #790, predicate pushdown in joins not supported - # Star schema with fact table table = con.table("star1") table2 = con.table("star2") @@ -32,10 +28,32 @@ def test_rewrite_join_projection_without_other_ops(con): view = j2[[filtered, table2["value1"], table3["value2"]]] # Construct the thing we expect to obtain - ex_pred2 = table["bar_id"] == table3["bar_id"] - ex_expr = table.left_join(table2, [pred1]).inner_join(table3, [ex_pred2]) - - assert view.op().table != ex_expr.op() + table2_ref = j2.op().rest[0].table.to_expr() + table3_ref = j2.op().rest[1].table.to_expr() + expected = ops.JoinChain( + first=filtered, + rest=[ + ops.JoinLink( + how="left", + table=table2_ref, + predicates=[filtered["foo_id"] == table2_ref["foo_id"]], + ), + ops.JoinLink( + how="inner", + table=table3_ref, + predicates=[filtered["bar_id"] == table3_ref["bar_id"]], + ), + ], + values={ + "c": filtered.c, + "f": filtered.f, + "foo_id": filtered.foo_id, + "bar_id": filtered.bar_id, + "value1": table2_ref.value1, + "value2": table3_ref.value2, + }, + ) + assert view.op() == expected def test_multiple_join_deeper_reference(): @@ -86,8 +104,8 @@ def test_filter_on_projected_field(con): # Now then! Predicate pushdown here is inappropriate, so we check that # it didn't occur. - assert isinstance(result.op(), ops.Selection) - assert result.op().table == tpch.op() + assert isinstance(result.op(), ops.Filter) + assert result.op().parent == tpch.op() def test_join_predicate_from_derived_raises(): @@ -101,18 +119,18 @@ def test_join_predicate_from_derived_raises(): filter_pred = table["f"] > 0 table3 = table[filter_pred] - with pytest.raises(com.ExpressionError): + with pytest.raises(com.IntegrityError, match="they belong to another relation"): + # TODO(kszucs): could be smarter actually and rewrite the predicate + # to contain the conditions from the filter table.inner_join(table2, [table3["g"] == table2["key"]]) def test_bad_join_predicate_raises(): table = ibis.table([("c", "int32"), ("f", "double"), ("g", "string")], "foo_table") - table2 = ibis.table([("key", "string"), ("value", "double")], "bar_table") - table3 = ibis.table([("key", "string"), ("value", "double")], "baz_table") - with pytest.raises(com.ExpressionError): + with pytest.raises(com.IntegrityError): table.inner_join(table2, [table["g"] == table3["key"]]) @@ -130,9 +148,22 @@ def test_filter_self_join(): metric = purchases.amount.sum().name("total") agged = purchases.group_by(["region", "kind"]).aggregate(metric) + assert agged.op() == ops.Aggregate( + parent=purchases, + groups={"region": purchases.region, "kind": purchases.kind}, + metrics={"total": purchases.amount.sum()}, + ) left = agged[agged.kind == "foo"] right = agged[agged.kind == "bar"] + assert left.op() == ops.Filter( + parent=agged, + predicates=[agged.kind == "foo"], + ) + assert right.op() == ops.Filter( + parent=agged, + predicates=[agged.kind == "bar"], + ) cond = left.region == right.region joined = left.join(right, cond) @@ -141,11 +172,18 @@ def test_filter_self_join(): what = [left.region, metric] projected = joined.select(what) - proj_exprs = projected.op().selections - - # proj exprs unaffected by analysis - assert_equal(proj_exprs[0], left.region.op()) - assert_equal(proj_exprs[1], metric.op()) + right_ = joined.op().rest[0].table.to_expr() + join = ops.JoinChain( + first=left, + rest=[ + ops.JoinLink("inner", right_, [left.region == right_.region]), + ], + values={ + "region": left.region, + "diff": left.total - right_.total, + }, + ) + assert projected.op() == join def test_is_ancestor_analytic(): @@ -169,20 +207,17 @@ def test_mutation_fusion_no_overwrite(): result = result.mutate(col1=t["col"] + 1) result = result.mutate(col2=t["col"] + 2) result = result.mutate(col3=t["col"] + 3) - result = result.op() - - first_selection = result - - assert len(result.selections) == 4 - - col1 = (t["col"] + 1).name("col1") - assert first_selection.selections[1] == col1.op() - col2 = (t["col"] + 2).name("col2") - assert first_selection.selections[2] == col2.op() - - col3 = (t["col"] + 3).name("col3") - assert first_selection.selections[3] == col3.op() + simplified = simplify(result.op()) + assert simplified == ops.Project( + parent=t, + values={ + "col": t["col"], + "col1": t["col"] + 1, + "col2": t["col"] + 2, + "col3": t["col"] + 3, + }, + ) # Pr 2635 @@ -196,39 +231,21 @@ def test_mutation_fusion_overwrite(): result = result.mutate(col2=t["col"] + 2) result = result.mutate(col3=t["col"] + 3) result = result.mutate(col=t["col"] - 1) - result = result.mutate(col4=t["col"] + 4) - - second_selection = result.op() - first_selection = second_selection.table - - assert len(first_selection.selections) == 4 - col1 = (t["col"] + 1).name("col1").op() - assert first_selection.selections[1] == col1 - - col2 = (t["col"] + 2).name("col2").op() - assert first_selection.selections[2] == col2 - - col3 = (t["col"] + 3).name("col3").op() - assert first_selection.selections[3] == col3 - - # Since the second selection overwrites existing columns, it will - # not have the Table as the first selection - assert len(second_selection.selections) == 5 - - col = (t["col"] - 1).name("col").op() - assert second_selection.selections[0] == col - col1 = first_selection.to_expr()["col1"].op() - assert second_selection.selections[1] == col1 - - col2 = first_selection.to_expr()["col2"].op() - assert second_selection.selections[2] == col2 - - col3 = first_selection.to_expr()["col3"].op() - assert second_selection.selections[3] == col3 - - col4 = (t["col"] + 4).name("col4").op() - assert second_selection.selections[4] == col4 + with pytest.raises(com.IntegrityError): + # unable to dereference the column since result doesn't contain it anymore + result.mutate(col4=t["col"] + 4) + + simplified = simplify(result.op()) + assert simplified == ops.Project( + parent=t, + values={ + "col": t["col"] - 1, + "col1": t["col"] + 1, + "col2": t["col"] + 2, + "col3": t["col"] + 3, + }, + ) # Pr 2635 @@ -237,41 +254,21 @@ def test_select_filter_mutate_fusion(): t = ibis.table(ibis.schema([("col", "float32")]), "t") - result = t[["col"]] - result = result[result["col"].isnan()] - result = result.mutate(col=result["col"].cast("int32")) - - second_selection = result.op() - first_selection = second_selection.table - assert len(second_selection.selections) == 1 - - col = first_selection.to_expr()["col"].cast("int32").name("col").op() - assert second_selection.selections[0] == col - - # we don't look past the projection when a filter is encountered, so the - # number of selections in the first projection (`first_selection`) is 0 - # - # previously we did, but this was buggy when executing against the pandas - # backend - # - # eventually we will bring this back, but we're trading off the ability - # to remove materialize for some performance in the short term - assert len(first_selection.selections) == 1 - assert len(first_selection.predicates) == 1 + t1 = t[["col"]] + assert t1.op() == ops.Project(parent=t, values={"col": t.col}) + t2 = t1[t1["col"].isnan()] + assert t2.op() == ops.Filter(parent=t1, predicates=[t1.col.isnan()]) -def test_no_filter_means_no_selection(): - t = ibis.table(dict(a="string")) - proj = t.filter([]) - assert proj.equals(t) + t3 = t2.mutate(col=t2["col"].cast("int32")) + assert t3.op() == ops.Project(parent=t2, values={"col": t2.col.cast("int32")}) + # create the expected expression + filt = ops.Filter(parent=t, predicates=[t.col.isnan()]).to_expr() + proj = ops.Project(parent=filt, values={"col": filt.col.cast("int32")}).to_expr() -def test_mutate_overwrites_existing_column(): - t = ibis.table(dict(a="string")) - mut = t.mutate(a=42).select(["a"]) - sel = mut.op().selections[0].table.selections[0].arg - assert isinstance(sel, ops.Literal) - assert sel.value == 42 + t3_opt = simplify(t3.op()).to_expr() + assert t3_opt.equals(proj) def test_agg_selection_does_not_share_roots(): @@ -280,5 +277,5 @@ def test_agg_selection_does_not_share_roots(): gb = t.group_by("a") n = s.count() - with pytest.raises(com.RelationError, match="Selection expressions"): + with pytest.raises(com.IntegrityError, match=" they belong to another relation"): gb.aggregate(n=n) diff --git a/ibis/tests/expr/test_selectors.py b/ibis/tests/expr/test_selectors.py index 656e02e38324..a39b773fcebf 100644 --- a/ibis/tests/expr/test_selectors.py +++ b/ibis/tests/expr/test_selectors.py @@ -479,14 +479,14 @@ def test_c_error_on_misspelled_column(penguins): def test_order_by_with_selectors(penguins): expr = penguins.order_by(s.of_type("string")) - assert tuple(key.name for key in expr.op().sort_keys) == ( + assert tuple(key.name for key in expr.op().keys) == ( "species", "island", "sex", ) expr = penguins.order_by(s.all()) - assert tuple(key.name for key in expr.op().sort_keys) == tuple(expr.columns) + assert tuple(key.name for key in expr.op().keys) == tuple(expr.columns) with pytest.raises(exc.IbisError): penguins.order_by(~s.all()) diff --git a/ibis/tests/expr/test_set_operations.py b/ibis/tests/expr/test_set_operations.py index 5299852cfd60..b872520ea403 100644 --- a/ibis/tests/expr/test_set_operations.py +++ b/ibis/tests/expr/test_set_operations.py @@ -51,13 +51,13 @@ def test_operation_supports_schemas_with_different_field_order(method): assert u1.schema() == a.schema() - u1 = u1.op().table + u1 = u1.op().parent assert u1.left == a.op() assert u1.right == b.op() # a selection is added to ensure that the field order of the right table # matches the field order of the left table - u2 = u2.op().table + u2 = u2.op().parent assert u2.schema == a.schema() assert u2.left == a.op() diff --git a/ibis/tests/expr/test_struct.py b/ibis/tests/expr/test_struct.py index 6911f0d0765f..c960fefbd126 100644 --- a/ibis/tests/expr/test_struct.py +++ b/ibis/tests/expr/test_struct.py @@ -71,8 +71,16 @@ def test_unpack_from_table(t): def test_lift_join(t, s): join = t.join(s, t.d == s.a.g) result = join.a_right.lift() - expected = join[_.a_right.f, _.a_right.g] - assert result.equals(expected) + + s_ = join.op().rest[0].table.to_expr() + join = ops.JoinChain( + first=t, + rest=[ + ops.JoinLink("inner", s_, [t.d == s_.a.g]), + ], + values={"f": s_.a.f, "g": s_.a.g}, + ) + assert result.op() == join def test_unpack_join_from_table(t, s): diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index dd873f95bf99..85bffdfc5668 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -11,17 +11,17 @@ import ibis import ibis.common.exceptions as com -import ibis.expr.analysis as an import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir import ibis.selectors as s from ibis import _ -from ibis import literal as L from ibis.common.annotations import ValidationError -from ibis.common.exceptions import RelationError +from ibis.common.deferred import Deferred +from ibis.common.exceptions import ExpressionError, IntegrityError, RelationError from ibis.expr import api +from ibis.expr.rewrites import simplify from ibis.expr.types import Column, Table from ibis.tests.util import assert_equal, assert_pickle_roundtrip @@ -75,11 +75,19 @@ def test_view_new_relation(table): # # This thing is not exactly a projection, since it has no semantic # meaning when it comes to execution - tview = table.view() + tview1 = table.view() + tview2 = table.view() + tview2_ = tview2.view() - roots = an.find_immediate_parent_tables(tview.op()) - assert len(roots) == 1 - assert roots[0] is tview.op() + node1 = tview1.op() + node2 = tview2.op() + node2_ = tview2_.op() + + assert isinstance(node1, ops.SelfReference) + assert isinstance(node2, ops.SelfReference) + assert node1.parent is node2.parent + assert node1 != node2 + assert node2_ is node2 def test_getitem_column_select(table): @@ -136,7 +144,7 @@ def test_projection(table): proj = table[cols] assert isinstance(proj, Table) - assert isinstance(proj.op(), ops.Selection) + assert isinstance(proj.op(), ops.Project) assert proj.schema().names == tuple(cols) for c in cols: @@ -181,7 +189,7 @@ def test_projection_invalid_root(table): right = api.table(schema1, name="bar") exprs = [right["foo"], right["bar"]] - with pytest.raises(RelationError): + with pytest.raises(IntegrityError): left.select(exprs) @@ -199,7 +207,7 @@ def test_projection_with_star_expr(table): # cannot pass an invalid table expression t2 = t.aggregate([t["a"].sum().name("sum(a)")], by=["g"]) - with pytest.raises(RelationError): + with pytest.raises(IntegrityError): t[[t2]] # TODO: there may be some ways this can be invalid @@ -242,14 +250,16 @@ def test_projection_no_expr(table, empty): table.select(empty) -def test_projection_invalid_nested_list(table): - errmsg = "must be coerceable to expressions" - with pytest.raises(com.IbisTypeError, match=errmsg): - table.select(["a", ["b"]]) - with pytest.raises(com.IbisTypeError, match=errmsg): - table[["a", ["b"]]] - with pytest.raises(com.IbisTypeError, match=errmsg): - table["a", ["b"]] +# FIXME(kszucs): currently bind() flattens the list of expressions, so arbitrary +# nesting is allowed, need to revisit +# def test_projection_invalid_nested_list(table): +# errmsg = "must be coerceable to expressions" +# with pytest.raises(com.IbisTypeError, match=errmsg): +# table.select(["a", ["b"]]) +# with pytest.raises(com.IbisTypeError, match=errmsg): +# table[["a", ["b"]]] +# with pytest.raises(com.IbisTypeError, match=errmsg): +# table["a", ["b"]] def test_mutate(table): @@ -331,14 +341,14 @@ def test_filter_no_list(table): def test_add_predicate(table): pred = table["a"] > 5 result = table[pred] - assert isinstance(result.op(), ops.Selection) + assert isinstance(result.op(), ops.Filter) def test_invalid_predicate(table, schema): # a lookalike table2 = api.table(schema, name="bar") predicate = table2.a > 5 - with pytest.raises(RelationError): + with pytest.raises(IntegrityError): table.filter(predicate) @@ -349,13 +359,13 @@ def test_add_predicate_coalesce(table): pred1 = table["a"] > 5 pred2 = table["b"] > 0 - result = table[pred1][pred2] + result = simplify(table[pred1][pred2].op()).to_expr() expected = table.filter([pred1, pred2]) assert_equal(result, expected) # 59, if we are not careful, we can obtain broken refs subset = table[pred1] - result = subset.filter([subset["b"] > 0]) + result = simplify(subset.filter([subset["b"] > 0]).op()).to_expr() assert_equal(result, expected) @@ -496,7 +506,7 @@ def test_limit(table): def test_order_by(table): result = table.order_by(["f"]).op() - sort_key = result.sort_keys[0] + sort_key = result.keys[0] assert_equal(sort_key.expr, table.f.op()) assert sort_key.ascending @@ -505,7 +515,7 @@ def test_order_by(table): result2 = table.order_by("f").op() assert_equal(result, result2) - key2 = result2.sort_keys[0] + key2 = result2.keys[0] assert key2.descending is False @@ -534,24 +544,24 @@ def test_order_by_asc_deferred_sort_key(table): [ param(ibis.NA, ibis.NA.op(), id="na"), param(ibis.random(), ibis.random().op(), id="random"), - param(1.0, L(1.0).op(), id="float"), - param(L("a"), L("a").op(), id="string"), - param(L([1, 2, 3]), L([1, 2, 3]).op(), id="array"), + param(1.0, ibis.literal(1.0).op(), id="float"), + param(ibis.literal("a"), ibis.literal("a").op(), id="string"), + param(ibis.literal([1, 2, 3]), ibis.literal([1, 2, 3]).op(), id="array"), ], ) def test_order_by_scalar(table, key, expected): result = table.order_by(key) - assert result.op().sort_keys == (ops.SortKey(expected),) + assert result.op().keys == (ops.SortKey(expected),) @pytest.mark.parametrize( ("key", "exc_type"), [ ("bogus", com.IbisTypeError), - (("bogus", False), com.IbisTypeError), + # (("bogus", False), com.IbisTypeError), (ibis.desc("bogus"), com.IbisTypeError), (1000, IndexError), - ((1000, False), IndexError), + # ((1000, False), IndexError), (_.bogus, AttributeError), (_.bogus.desc(), AttributeError), ], @@ -652,15 +662,51 @@ def test_aggregate_keys_basic(table): repr(result) -def test_aggregate_non_list_inputs(table): - # per #150 +def test_aggregate_having_implicit_metric(table): metric = table.f.sum().name("total") by = "g" having = table.c.sum() > 10 - result = table.aggregate(metric, by=by, having=having) - expected = table.aggregate([metric], by=[by], having=[having]) - assert_equal(result, expected) + implicit_having_metric = table.aggregate(metric, by=by, having=having) + expected_aggregate = ops.Aggregate( + parent=table, + groups={"g": table.g}, + metrics={"total": table.f.sum(), table.c.sum().get_name(): table.c.sum()}, + ) + expected_filter = ops.Filter( + parent=expected_aggregate, + predicates=[ + ops.Greater(ops.Field(expected_aggregate, table.c.sum().get_name()), 10) + ], + ) + expected_project = ops.Project( + parent=expected_filter, + values={ + "g": ops.Field(expected_filter, "g"), + "total": ops.Field(expected_filter, "total"), + }, + ) + assert implicit_having_metric.op() == expected_project + + +def test_agg_having_explicit_metric(table): + metric = table.f.sum().name("total") + by = "g" + having = table.c.sum() > 10 + + explicit_having_metric = table.aggregate( + [metric, table.c.sum().name("sum")], by=by, having=having + ) + expected_aggregate = ops.Aggregate( + parent=table, + groups={"g": table.g}, + metrics={"total": table.f.sum(), "sum": table.c.sum()}, + ) + expected_filter = ops.Filter( + parent=expected_aggregate, + predicates=[ops.Greater(ops.Field(expected_aggregate, "sum"), 10)], + ) + assert explicit_having_metric.op() == expected_filter def test_aggregate_keywords(table): @@ -674,56 +720,32 @@ def test_aggregate_keywords(table): assert_equal(expr2, expected) -def test_filter_aggregate_pushdown_predicate(table): - # In the case where we want to add a predicate to an aggregate - # expression after the fact, rather than having to backpedal and add it - # before calling aggregate. - # - # TODO (design decision): This could happen automatically when adding a - # predicate originating from the same root table; if an expression is - # created from field references from the aggregated table then it - # becomes a filter predicate applied on top of a view - - pred = table.f > 0 - metrics = [table.a.sum().name("total")] - agged = table.aggregate(metrics, by=["g"]) - filtered = agged.filter([pred]) - expected = table[pred].aggregate(metrics, by=["g"]) - assert_equal(filtered, expected) - - def test_filter_on_literal_then_aggregate(table): # Mostly just a smoketest, this used to error on construction expr = table.filter(ibis.literal(True)).agg(lambda t: t.a.sum().name("total")) assert expr.columns == ["total"] -@pytest.mark.parametrize( - "case_fn", - [ - param(lambda t: t.f.sum(), id="non_boolean"), - param(lambda t: t.f > 2, id="non_scalar"), - ], -) -def test_aggregate_post_predicate(table, case_fn): - # Test invalid having clause - metrics = [table.f.sum().name("total")] - by = ["g"] - having = [case_fn(table)] - - with pytest.raises(ValidationError): - table.aggregate(metrics, by=by, having=having) - - def test_group_by_having_api(table): # #154, add a HAVING post-predicate in a composable way metric = table.f.sum().name("foo") postp = table.d.mean() > 1 - expr = table.group_by("g").having(postp).aggregate(metric) - expected = table.aggregate(metric, by="g", having=postp) - assert_equal(expr, expected) + agg = ops.Aggregate( + parent=table, + groups={"g": table.g}, + metrics={"foo": table.f.sum(), "Mean(d)": table.d.mean()}, + ).to_expr() + filt = ops.Filter( + parent=agg, + predicates=[agg["Mean(d)"] > 1], + ).to_expr() + proj = ops.Project( + parent=filt, + values={"g": filt.g, "foo": filt.foo}, + ) + assert expr.op() == proj def test_group_by_kwargs(table): @@ -756,6 +778,12 @@ def test_groupby_convenience(table): assert_equal(expr, expected) +@pytest.mark.parametrize("group", [[], (), None]) +def test_group_by_nothing(table, group): + with pytest.raises(com.IbisInputError): + table.group_by(group) + + def test_group_by_count_size(table): # #148, convenience for interactive use, and so forth result1 = table.group_by("g").size() @@ -820,16 +848,56 @@ def test_join_no_predicate_list(con): pred = region.r_regionkey == nation.n_regionkey joined = region.inner_join(nation, pred) - expected = region.inner_join(nation, [pred]) - assert_equal(joined, expected) + + nation_ = joined.op().rest[0].table.to_expr() + expected = ops.JoinChain( + first=region, + rest=[ + ops.JoinLink("inner", nation_, [region.r_regionkey == nation_.n_regionkey]) + ], + values={ + "r_regionkey": region.r_regionkey, + "r_name": region.r_name, + "r_comment": region.r_comment, + "n_nationkey": nation_.n_nationkey, + "n_name": nation_.n_name, + "n_regionkey": nation_.n_regionkey, + "n_comment": nation_.n_comment, + }, + ) + assert joined.op() == expected def test_join_deferred(con): region = con.table("tpch_region") nation = con.table("tpch_nation") res = region.join(nation, _.r_regionkey == nation.n_regionkey) - exp = region.join(nation, region.r_regionkey == nation.n_regionkey) - assert_equal(res, exp) + + nation_ = res.op().rest[0].table.to_expr() + expected = ops.JoinChain( + first=region, + rest=[ + ops.JoinLink("inner", nation_, [region.r_regionkey == nation_.n_regionkey]) + ], + values={ + "r_regionkey": region.r_regionkey, + "r_name": region.r_name, + "r_comment": region.r_comment, + "n_nationkey": nation_.n_nationkey, + "n_name": nation_.n_name, + "n_regionkey": nation_.n_regionkey, + "n_comment": nation_.n_comment, + }, + ) + assert res.op() == expected + + +def test_join_invalid_predicate(con): + region = con.table("tpch_region") + nation = con.table("tpch_nation") + + with pytest.raises(com.InputTypeError): + region.inner_join(nation, object()) def test_asof_join(): @@ -843,24 +911,51 @@ def test_asof_join(): "time_right", "value2", ] - pred = joined.op().table.predicates[0] + pred = joined.op().rest[0].predicates[0] assert pred.left.name == pred.right.name == "time" +# TODO(kszucs): ensure the correctness of the pd.merge_asof(by=...) argument emulation def test_asof_join_with_by(): left = ibis.table([("time", "int32"), ("key", "int32"), ("value", "double")]) right = ibis.table([("time", "int32"), ("key", "int32"), ("value2", "double")]) - joined = api.asof_join(left, right, "time", by="key") - assert joined.columns == [ - "time", - "key", - "value", - "time_right", - "key_right", - "value2", - ] - by = joined.op().table.by[0] - assert by.left.name == by.right.name == "key" + + join_without_by = api.asof_join(left, right, "time") + right_ = join_without_by.op().rest[0].table.to_expr() + expected = ops.JoinChain( + first=left, + rest=[ops.JoinLink("asof", right_, [left.time == right_.time])], + values={ + "time": left.time, + "key": left.key, + "value": left.value, + "time_right": right_.time, + "key_right": right_.key, + "value2": right_.value2, + }, + ) + assert join_without_by.op() == expected + + join_with_by = api.asof_join(left, right, "time", by="key") + right_ = join_with_by.op().rest[0].table.to_expr() + right__ = join_with_by.op().rest[1].table.to_expr() + expected = ops.JoinChain( + first=left, + rest=[ + ops.JoinLink("inner", right_, [left.key == right_.key]), + ops.JoinLink("asof", right__, [left.time == right__.time]), + ], + values={ + "time": left.time, + "key": left.key, + "value": left.value, + "time_right": right_.time, + "key_right": right_.key, + "value2": right_.value2, + "value2_right": right__.value2, + }, + ) + assert join_with_by.op() == expected @pytest.mark.parametrize( @@ -885,14 +980,28 @@ def test_asof_join_with_tolerance(ibis_interval, timedelta_interval): left = ibis.table([("time", "int32"), ("key", "int32"), ("value", "double")]) right = ibis.table([("time", "int32"), ("key", "int32"), ("value2", "double")]) - joined = api.asof_join(left, right, "time", tolerance=ibis_interval).op() - tolerance = joined.table.tolerance - assert_equal(tolerance, ibis_interval.op()) - - joined = api.asof_join(left, right, "time", tolerance=timedelta_interval).op() - tolerance = joined.table.tolerance - assert isinstance(tolerance.to_expr(), ir.IntervalScalar) - assert isinstance(tolerance, ops.Literal) + for interval in [ibis_interval, timedelta_interval]: + joined = api.asof_join(left, right, "time", tolerance=interval) + right_ = joined.op().rest[0].table.to_expr() + expected = ops.JoinChain( + first=left, + rest=[ + ops.JoinLink( + "asof", + right_, + [left.time == right_.time, (left.time - right_.time) <= interval], + ) + ], + values={ + "time": left.time, + "key": left.key, + "value": left.value, + "time_right": right_.time, + "key_right": right_.key, + "value2": right_.value2, + }, + ) + assert joined.op() == expected def test_equijoin_schema_merge(): @@ -976,7 +1085,9 @@ def test_self_join_no_view_convenience(table): result = table.join(table, [("g", "g")]) expected_cols = list(table.columns) - expected_cols.extend(f"{c}_right" for c in table.columns if c != "g") + # TODO(kszucs): the inner join convenience to don't duplicate the + # equivalent columns from the right table is not implemented yet + expected_cols.extend(f"{c}_right" for c in table.columns) # if c != "g") assert result.columns == expected_cols @@ -1050,8 +1161,26 @@ def test_cross_join_multiple(table): c = table["f", "h"] joined = ibis.cross_join(a, b, c) - expected = a.cross_join(b.cross_join(c)) - assert joined.equals(expected) + b_ = joined.op().rest[0].table.to_expr() + c_ = joined.op().rest[1].table.to_expr() + assert joined.op() == ops.JoinChain( + first=a, + rest=[ + ops.JoinLink("cross", b_, []), + ops.JoinLink("cross", c_, []), + ], + values={ + "a": a.a, + "b": a.b, + "c": a.c, + "d": b_.d, + "e": b_.e, + "f": c_.f, + "h": c_.h, + }, + ) + # TODO(kszucs): it must be simplified first using an appropriate rewrite rule + assert not joined.equals(a.cross_join(b.cross_join(c))) def test_filter_join(): @@ -1064,41 +1193,43 @@ def test_filter_join(): repr(filtered) -def test_inner_join_overlapping_column_names(): - t1 = ibis.table([("foo", "string"), ("bar", "string"), ("value1", "double")]) - t2 = ibis.table([("foo", "string"), ("bar", "string"), ("value2", "double")]) - - joined = t1.join(t2, "foo") - expected = t1.join(t2, t1.foo == t2.foo) - assert_equal(joined, expected) - assert joined.columns == ["foo", "bar", "value1", "bar_right", "value2"] - - joined = t1.join(t2, ["foo", "bar"]) - expected = t1.join(t2, [t1.foo == t2.foo, t1.bar == t2.bar]) - assert_equal(joined, expected) - assert joined.columns == ["foo", "bar", "value1", "value2"] - - # Equality predicates don't have same name, need to rename - joined = t1.join(t2, t1.foo == t2.bar) - assert joined.columns == [ - "foo", - "bar", - "value1", - "foo_right", - "bar_right", - "value2", - ] - - # Not all predicates are equality, still need to rename - joined = t1.join(t2, ["foo", t1.value1 < t2.value2]) - assert joined.columns == [ - "foo", - "bar", - "value1", - "foo_right", - "bar_right", - "value2", - ] +# TODO(kszucs): the inner join convenience to don't duplicate the equivalent +# columns from the right table is not implemented yet +# def test_inner_join_overlapping_column_names(): +# t1 = ibis.table([("foo", "string"), ("bar", "string"), ("value1", "double")]) +# t2 = ibis.table([("foo", "string"), ("bar", "string"), ("value2", "double")]) + +# joined = t1.join(t2, "foo") +# expected = t1.join(t2, t1.foo == t2.foo) +# assert_equal(joined, expected) +# assert joined.columns == ["foo", "bar", "value1", "bar_right", "value2"] + +# joined = t1.join(t2, ["foo", "bar"]) +# expected = t1.join(t2, [t1.foo == t2.foo, t1.bar == t2.bar]) +# assert_equal(joined, expected) +# assert joined.columns == ["foo", "bar", "value1", "value2"] + +# # Equality predicates don't have same name, need to rename +# joined = t1.join(t2, t1.foo == t2.bar) +# assert joined.columns == [ +# "foo", +# "bar", +# "value1", +# "foo_right", +# "bar_right", +# "value2", +# ] + +# # Not all predicates are equality, still need to rename +# joined = t1.join(t2, ["foo", t1.value1 < t2.value2]) +# assert joined.columns == [ +# "foo", +# "bar", +# "value1", +# "foo_right", +# "bar_right", +# "value2", +# ] @pytest.mark.parametrize( @@ -1116,24 +1247,38 @@ def test_inner_join_overlapping_column_names(): def test_join_key_alternatives(con, key_maker): t1 = con.table("star1") t2 = con.table("star2") - expected = t1.inner_join(t2, [t1.foo_id == t2.foo_id]) key = key_maker(t1, t2) + joined = t1.inner_join(t2, key) - assert_equal(joined, expected) + t2_ = joined.op().rest[0].table.to_expr() + expected = ops.JoinChain( + first=t1, + rest=[ + ops.JoinLink("inner", t2_, [t1.foo_id == t2_.foo_id]), + ], + values={ + "c": t1.c, + "f": t1.f, + "foo_id": t1.foo_id, + "bar_id": t1.bar_id, + "foo_id_right": t2_.foo_id, + "value1": t2_.value1, + "value3": t2_.value3, + }, + ) + assert joined.op() == expected -@pytest.mark.parametrize( - "key,error", - [ - ([("foo_id", "foo_id", "foo_id")], com.ExpressionError), - ([(s.c("foo_id"), s.c("foo_id"))], ValueError), - ], -) -def test_join_key_invalid(con, key, error): + +def test_join_key_invalid(con): t1 = con.table("star1") t2 = con.table("star2") - with pytest.raises(error): - t1.inner_join(t2, key) + + with pytest.raises(ExpressionError): + t1.inner_join(t2, [("foo_id", "foo_id", "foo_id")]) + + # it is working now + t1.inner_join(t2, [(s.c("foo_id"), s.c("foo_id"))]) def test_join_invalid_refs(con): @@ -1142,7 +1287,7 @@ def test_join_invalid_refs(con): t3 = con.table("star3") predicate = t1.bar_id == t3.bar_id - with pytest.raises(com.RelationError): + with pytest.raises(com.IntegrityError): t1.inner_join(t2, [predicate]) @@ -1151,7 +1296,7 @@ def test_join_invalid_expr_type(con): invalid_right = left.foo_id join_key = ["bar_id"] - with pytest.raises(ValidationError): + with pytest.raises(TypeError): left.inner_join(invalid_right, join_key) @@ -1161,7 +1306,7 @@ def test_join_non_boolean_expr(con): # oops predicate = t1.f * t2.value1 - with pytest.raises(com.ExpressionError): + with pytest.raises(ValidationError): t1.inner_join(t2, [predicate]) @@ -1191,8 +1336,28 @@ def test_unravel_compound_equijoin(table): p3 = t1.key3 == t2.key3 joined = t1.inner_join(t2, [p1 & p2 & p3]) - expected = t1.inner_join(t2, [p1, p2, p3]) - assert_equal(joined, expected) + t2_ = joined.op().rest[0].table.to_expr() + expected = ops.JoinChain( + first=t1, + rest=[ + ops.JoinLink( + "inner", + t2_, + [t1.key1 == t2_.key1, t1.key2 == t2_.key2, t1.key3 == t2_.key3], + ) + ], + values={ + "key1": t1.key1, + "key2": t1.key2, + "key3": t1.key3, + "value1": t1.value1, + "key1_right": t2_.key1, + "key2_right": t2_.key2, + "key3_right": t2_.key3, + "value2": t2_.value2, + }, + ) + assert joined.op() == expected def test_union( @@ -1202,11 +1367,11 @@ def test_union( setops_relation_error_message, ): result = setops_table_foo.union(setops_table_bar) - assert isinstance(result.op().table, ops.Union) - assert not result.op().table.distinct + assert isinstance(result.op().parent, ops.Union) + assert not result.op().parent.distinct result = setops_table_foo.union(setops_table_bar, distinct=True) - assert result.op().table.distinct + assert result.op().parent.distinct with pytest.raises(RelationError, match=setops_relation_error_message): setops_table_foo.union(setops_table_baz) @@ -1219,7 +1384,7 @@ def test_intersection( setops_relation_error_message, ): result = setops_table_foo.intersect(setops_table_bar) - assert isinstance(result.op().table, ops.Intersection) + assert isinstance(result.op().parent, ops.Intersection) with pytest.raises(RelationError, match=setops_relation_error_message): setops_table_foo.intersect(setops_table_baz) @@ -1232,7 +1397,7 @@ def test_difference( setops_relation_error_message, ): result = setops_table_foo.difference(setops_table_bar) - assert isinstance(result.op().table, ops.Difference) + assert isinstance(result.op().parent, ops.Difference) with pytest.raises(RelationError, match=setops_relation_error_message): setops_table_foo.difference(setops_table_baz) @@ -1274,14 +1439,23 @@ def t2(): def test_unresolved_existence_predicate(t1, t2): expr = (t1.key1 == t2.key1).any() - assert isinstance(expr, ir.BooleanColumn) - assert isinstance(expr.op(), ops.UnresolvedExistsSubquery) + assert isinstance(expr, Deferred) + + filtered = t2.filter(t1.key1 == t2.key1).select(ibis.literal(1)) + subquery = ops.ExistsSubquery(filtered) + expected = ops.Filter(parent=t1, predicates=[subquery]) + assert t1[expr].op() == expected + + filtered = t1.filter(t1.key1 == t2.key1).select(ibis.literal(1)) + subquery = ops.ExistsSubquery(filtered) + expected = ops.Filter(parent=t2, predicates=[subquery]) + assert t2[expr].op() == expected def test_resolve_existence_predicate(t1, t2): expr = t1[(t1.key1 == t2.key1).any()] op = expr.op() - assert isinstance(op, ops.Selection) + assert isinstance(op, ops.Filter) pred = op.predicates[0].to_expr() assert isinstance(pred.op(), ops.ExistsSubquery) @@ -1317,11 +1491,23 @@ def test_group_by_keys(table): def test_having(table): m = table.mutate(foo=table.f * 2, bar=table.e / 2) - expr = m.group_by("foo").having(lambda x: x.foo.sum() > 10).size() - expected = m.group_by("foo").having(m.foo.sum() > 10).size() - assert_equal(expr, expected) + agg = ops.Aggregate( + parent=m, + groups={"foo": m.foo}, + metrics={"CountStar()": ops.CountStar(m), "Sum(foo)": ops.Sum(m.foo)}, + ).to_expr() + filt = ops.Filter( + parent=agg, + predicates=[agg["Sum(foo)"] > 10], + ).to_expr() + proj = ops.Project( + parent=filt, + values={"foo": filt.foo, "CountStar()": filt["CountStar()"]}, + ).to_expr() + + assert expr.equals(proj) def test_filter(table): @@ -1494,16 +1680,20 @@ def test_mutate_chain(): one = ibis.table([("a", "string"), ("b", "string")], name="t") two = one.mutate(b=lambda t: t.b.fillna("Short Term")) three = two.mutate(a=lambda t: t.a.fillna("Short Term")) - a, b = three.op().selections - # we can't fuse these correctly yet - assert isinstance(a, ops.Alias) - assert isinstance(a.arg, ops.Coalesce) - assert isinstance(b, ops.TableColumn) - - expr = b.table.selections[1] - assert isinstance(expr, ops.Alias) - assert isinstance(expr.arg, ops.Coalesce) + values = three.op().values + assert isinstance(values["a"], ops.Coalesce) + assert isinstance(values["b"], ops.Field) + assert values["b"].rel == two.op() + + three_opt = simplify(three.op()) + assert three_opt == ops.Project( + parent=one, + values={ + "a": one.a.fillna("Short Term"), + "b": one.b.fillna("Short Term"), + }, + ) # TODO(kszucs): move this test case to ibis/tests/sql since it requires the @@ -1613,11 +1803,11 @@ def test_join_lname_rname_still_collide(): t2 = ibis.table({"id": "int64", "col1": "int64", "col2": "int64"}) t3 = ibis.table({"id": "int64", "col1": "int64", "col2": "int64"}) - with pytest.raises(com.IntegrityError) as rec: - t1.left_join(t2, "id").left_join(t3, "id") + with pytest.raises(com.IntegrityError): + t1.left_join(t2, "id").left_join(t3, "id")._finish() - assert "`['col1_right', 'col2_right', 'id_right']`" in str(rec.value) - assert "`lname='', rname='{name}_right'`" in str(rec.value) + # assert "`['col1_right', 'col2_right', 'id_right']`" in str(rec.value) + # assert "`lname='', rname='{name}_right'`" in str(rec.value) def test_drop(): @@ -1690,22 +1880,15 @@ def test_array_string_compare(): @pytest.mark.parametrize("value", [True, False]) -@pytest.mark.parametrize( - "api", - [ - param(lambda t, value: t[value], id="getitem"), - param(lambda t, value: t.filter(value), id="filter"), - ], -) -def test_filter_with_literal(value, api): +def test_filter_with_literal(value): t = ibis.table(dict(a="string")) - filt = api(t, ibis.literal(value)) - assert filt is not None + filt = t.filter(ibis.literal(value)) + assert filt.op() == ops.Filter(parent=t, predicates=[ibis.literal(value)]) # ints are invalid predicates int_val = ibis.literal(int(value)) - with pytest.raises((NotImplementedError, ValidationError, com.IbisTypeError)): - api(t, int_val) + with pytest.raises(ValidationError): + t.filter(int_val) def test_cast(): diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index 4e597f435834..bb9e3e8ba8f6 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -308,7 +308,7 @@ def test_distinct_table(functional_alltypes): expr = functional_alltypes.distinct() assert isinstance(expr.op(), ops.Distinct) assert isinstance(expr, ir.Table) - assert expr.op().table == functional_alltypes.op() + assert expr.op().parent == functional_alltypes.op() def test_nunique(functional_alltypes): @@ -1465,10 +1465,9 @@ def test_deferred_r_ops(op_name, expected_left, expected_right): op = getattr(operator, op_name) expr = t[op(left, right).name("b")] - - op = expr.op().selections[0].arg - assert op.left.equals(expected_left(t).op()) - assert op.right.equals(expected_right(t).op()) + node = expr.op().values["b"] + assert node.left.equals(expected_left(t).op()) + assert node.right.equals(expected_right(t).op()) @pytest.mark.parametrize( @@ -1671,9 +1670,9 @@ def test_quantile_shape(): projs = [b1] expr = t.select(projs) - (b1,) = expr.op().selections + b1 = expr.br2 - assert b1.shape.is_columnar() + assert b1.op().shape.is_columnar() def test_sample(): diff --git a/ibis/tests/expr/test_window_frames.py b/ibis/tests/expr/test_window_frames.py index 1b0d8cd7268f..31585fc28ef5 100644 --- a/ibis/tests/expr/test_window_frames.py +++ b/ibis/tests/expr/test_window_frames.py @@ -502,8 +502,7 @@ def test_window_analysis_combine_preserves_existing_window(): ) w = ibis.cumulative_window(order_by=t.one) mut = t.group_by(t.three).mutate(four=t.two.sum().over(w)) - - assert mut.op().selections[1].arg.frame.start is None + assert mut.op().values["four"].frame.start is None def test_window_analysis_auto_windowize_bug(): @@ -552,19 +551,18 @@ def test_group_by_with_window_function_preserves_range(alltypes): w = ibis.cumulative_window(order_by=t.one) expr = t.group_by(t.three).mutate(four=t.two.sum().over(w)) - expected = ops.Selection( - t, - [ - t, - ops.Alias( - ops.WindowFunction( - func=ops.Sum(t.two), - frame=ops.RowsWindowFrame( - table=t, end=0, group_by=[t.three], order_by=[t.one] - ), + expected = ops.Project( + parent=t, + values={ + "one": t.one, + "two": t.two, + "three": t.three, + "four": ops.WindowFunction( + func=ops.Sum(t.two), + frame=ops.RowsWindowFrame( + table=t, end=0, group_by=[t.three], order_by=[t.one] ), - name="four", ), - ], + }, ) assert expr.op() == expected diff --git a/ibis/tests/expr/test_window_functions.py b/ibis/tests/expr/test_window_functions.py index 1c2fd6110468..fdc7be22f385 100644 --- a/ibis/tests/expr/test_window_functions.py +++ b/ibis/tests/expr/test_window_functions.py @@ -40,9 +40,10 @@ def test_mutate_with_analytic_functions(alltypes): exprs = [expr.name("e%d" % i) for i, expr in enumerate(exprs)] proj = g.mutate(exprs) - for field in proj.op().selections[1:]: - assert isinstance(field, ops.Alias) - assert isinstance(field.arg, ops.WindowFunction) + + values = list(proj.op().values.values()) + for field in values[len(t.schema()) :]: + assert isinstance(field, ops.WindowFunction) def test_value_over_api(alltypes): @@ -70,5 +71,5 @@ def test_conflicting_window_boundaries(alltypes): def test_rank_followed_by_over_call_merge_frames(alltypes): t = alltypes expr1 = t.f.percent_rank().over(ibis.window(group_by=t.f.notnull())) - expr2 = ibis.percent_rank().over(group_by=t.f.notnull(), order_by=t.f).resolve(t) + expr2 = ibis.percent_rank().over(group_by=t.f.notnull(), order_by=t.f) assert expr1.equals(expr2)