From e64690525adf443fd4f24e1b88f5b360a2c8c7a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Tue, 21 Nov 2023 20:56:10 +0100 Subject: [PATCH] move rewrites to ibis/expr/rewrites.py from newrels.py --- ibis/common/patterns.py | 4 +- ibis/expr/operations/__init__.py | 2 - ibis/expr/operations/generic.py | 3 +- ibis/expr/operations/newrels.py | 115 +++---------------------------- ibis/expr/rewrites.py | 54 ++++++++++++++- ibis/expr/types/relations.py | 2 +- ibis/tests/expr/test_analysis.py | 1 - 7 files changed, 66 insertions(+), 115 deletions(-) diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index 4e58118bdb0f0..8ffd3ca9ed273 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -207,7 +207,9 @@ def match(self, value: AnyType, context: dict[str, AnyType]) -> AnyType: """ ... - def is_match(self, value: AnyType, context: dict[str, AnyType] = None) -> bool: + def is_match( + self, value: AnyType, context: Optional[dict[str, AnyType]] = None + ) -> bool: """Check that a value matches the pattern. Parameters diff --git a/ibis/expr/operations/__init__.py b/ibis/expr/operations/__init__.py index f0644214f13ad..c525eba2198ab 100644 --- a/ibis/expr/operations/__init__.py +++ b/ibis/expr/operations/__init__.py @@ -9,8 +9,6 @@ from ibis.expr.operations.json import * # noqa: F403 from ibis.expr.operations.logical import * # noqa: F403 from ibis.expr.operations.maps import * # noqa: F403 - -# from ibis.expr.operations.relations import * # noqa: F403 from ibis.expr.operations.newrels import * # noqa: F403 from ibis.expr.operations.numeric import * # noqa: F403 from ibis.expr.operations.reductions import * # noqa: F403 diff --git a/ibis/expr/operations/generic.py b/ibis/expr/operations/generic.py index c1057bc261768..613310c81cb87 100644 --- a/ibis/expr/operations/generic.py +++ b/ibis/expr/operations/generic.py @@ -1,13 +1,12 @@ from __future__ import annotations import itertools -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any, Optional from typing import Literal as LiteralType from public import public from typing_extensions import TypeVar -import ibis.common.exceptions as com import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz diff --git a/ibis/expr/operations/newrels.py b/ibis/expr/operations/newrels.py index 23b74109b4c31..bdddd19266ec7 100644 --- a/ibis/expr/operations/newrels.py +++ b/ibis/expr/operations/newrels.py @@ -6,7 +6,7 @@ import itertools import typing from abc import abstractmethod -from typing import Annotated, Any, Literal, Optional +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional from public import public @@ -15,16 +15,20 @@ from ibis.common.annotations import attribute from ibis.common.bases import Immutable from ibis.common.collections import FrozenDict -from ibis.common.deferred import Item, deferred, var +from ibis.common.deferred import deferred from ibis.common.exceptions import IbisTypeError, IntegrityError, RelationError from ibis.common.grounds import Concrete -from ibis.common.patterns import Between, Check, In, InstanceOf, _, pattern, replace +from ibis.common.patterns import Between, In, InstanceOf, pattern from ibis.common.typing import Coercible, VarTuple from ibis.expr.operations.core import Alias, Column, Node, Scalar, Value -from ibis.expr.operations.sortkeys import SortKey +from ibis.expr.operations.sortkeys import SortKey # noqa: TCH001 from ibis.expr.schema import Schema from ibis.util import Namespace, gen_name, indent +if TYPE_CHECKING: + import pandas as pd + import pyarrow as pa + p = Namespace(pattern, module=__name__) d = Namespace(deferred, module=__name__) @@ -154,19 +158,6 @@ def dtype(self): return self.rel.schema[self.name] -# TODO(kszucs): we may not need this, the pattern can define whether it is foreign or not -# @public -# class ForeignField(Value): -# rel: Relation -# name: str - -# shape = ds.columnar - -# @attribute -# def dtype(self): -# return self.rel.schema[self.name] - - # TODO(kszucs): rename it to ForeignScalar @public class ForeignField(Value): @@ -188,9 +179,6 @@ def _check_integrity(values, allowed_parents): raise IntegrityError( f"Cannot add {disallowed!r} to projection, they belong to another relation" ) - # add a flag about allowing foreign values - # add a flag to enfore scalar foreign values (e.g. for Project and Aggregate) - # egyebkent csak scalar lehet (e.g. scalar subquery or a value based on literals) @public @@ -533,96 +521,9 @@ def schema(self): return self.parent.schema -# class Subquery(Relation): -# rel: Relation - -# @property -# def schema(self): -# return self.rel.schema - -# @property -# def fields(self): -# return self.rel.fields - - -################################ TYPES ################################ - - -# class TableExpr(Expr): -# def schema(self): -# return self.op().schema - -# # def __getattr__(self, key): -# # return next(bind(self, key)) - - @public def table(name, schema): return UnboundTable(name, schema).to_expr() -# TODO(kszucs): cover it with tests - - -################################ REWRITES ################################ -from ibis.common.patterns import Each - -name = var("name") - -y = var("y") -values = var("values") - - -@replace(p.Project(y @ p.Relation) & Check(_.schema == y.schema)) -def complete_reprojection(_, y): - # TODO(kszucs): this could be moved to the pattern itself but not sure how - # to express it, especially in a shorter way then the following check - for name in _.schema: - if _.values[name] != Field(y, name): - return _ - return y - - -@replace(p.Project(y @ p.Project)) -def subsequent_projects(_, y): - rule = p.Field(y, name) >> Item(y.values, name) - values = {k: v.replace(rule) for k, v in _.values.items()} - return Project(y.parent, values) - - -@replace(p.Filter(y @ p.Filter)) -def subsequent_filters(_, y): - rule = p.Field(y, name) >> d.Field(y.parent, name) - preds = tuple(v.replace(rule) for v in _.predicates) - return Filter(y.parent, y.predicates + preds) - - -@replace(p.Filter(y @ p.Project)) -def reorder_filter_project(_, y): - rule = p.Field(y, name) >> Item(y.values, name) - preds = tuple(v.replace(rule) for v in _.predicates) - - inner = Filter(y.parent, preds) - rule = p.Field(y.parent, name) >> d.Field(inner, name) - projs = {k: v.replace(rule) for k, v in y.values.items()} - - return Project(inner, projs) - - -# TODO(kszucs): add a rewrite rule for nestes JoinChain objects where the -# JoinLink depends on another JoinChain, in this case the JoinLink should be -# merged into the JoinChain - - -# TODO(kszucs): this may work if the sort keys are not overlapping, need to revisit -# @replace(p.Sort(y @ p.Sort)) -# def subsequent_sorts(_, y): -# return Sort(y.parent, y.keys + _.keys) - - # TODO(kszucs): support t.select(*t) syntax by implementing TableExpr.__iter__() - - -# subqueries: -# 1. reduction passed to .filter() should be turned into a subquery -# 2. reduction passed to .select() with a foreign table should be turned into a subquery diff --git a/ibis/expr/rewrites.py b/ibis/expr/rewrites.py index 4ae694f0120c2..629cf278d0326 100644 --- a/ibis/expr/rewrites.py +++ b/ibis/expr/rewrites.py @@ -6,11 +6,16 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis.common.deferred import Item, _, deferred, var from ibis.common.exceptions import UnsupportedOperationError -from ibis.common.patterns import pattern, replace +from ibis.common.patterns import Check, pattern, replace from ibis.util import Namespace p = Namespace(pattern, module=ops) +d = Namespace(deferred, module=ops) + +y = var("y") +name = var("name") @replace(p.FillNa) @@ -79,3 +84,50 @@ def rewrite_sample(_): (ops.LessEqual(ops.RandomScalar(), _.fraction),), (), ) + + +@replace(p.Project(y @ p.Relation) & Check(_.schema == y.schema)) +def complete_reprojection(_, y): + # TODO(kszucs): this could be moved to the pattern itself but not sure how + # to express it, especially in a shorter way then the following check + for name in _.schema: + if _.values[name] != ops.Field(y, name): + return _ + return y + + +@replace(p.Project(y @ p.Project)) +def subsequent_projects(_, y): + rule = p.Field(y, name) >> Item(y.values, name) + values = {k: v.replace(rule) for k, v in _.values.items()} + return ops.Project(y.parent, values) + + +@replace(p.Filter(y @ p.Filter)) +def subsequent_filters(_, y): + rule = p.Field(y, name) >> d.Field(y.parent, name) + preds = tuple(v.replace(rule) for v in _.predicates) + return ops.Filter(y.parent, y.predicates + preds) + + +@replace(p.Filter(y @ p.Project)) +def reorder_filter_project(_, y): + rule = p.Field(y, name) >> Item(y.values, name) + preds = tuple(v.replace(rule) for v in _.predicates) + + inner = ops.Filter(y.parent, preds) + rule = p.Field(y.parent, name) >> d.Field(inner, name) + projs = {k: v.replace(rule) for k, v in y.values.items()} + + return ops.Project(inner, projs) + + +# TODO(kszucs): add a rewrite rule for nestes JoinChain objects where the +# JoinLink depends on another JoinChain, in this case the JoinLink should be +# merged into the JoinChain + + +# TODO(kszucs): this may work if the sort keys are not overlapping, need to revisit +# @replace(p.Sort(y @ p.Sort)) +# def subsequent_sorts(_, y): +# return Sort(y.parent, y.keys + _.keys) diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 4bc015c58187b..01438228ef4f9 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -4326,7 +4326,7 @@ def window_by(self, time_col: ir.Value) -> WindowedTable: return WindowedTable(self, time_col) def optimize(self, enable_reordering=True): - from ibis.expr.operations.newrels import ( + from ibis.expr.rewrites import ( complete_reprojection, subsequent_filters, subsequent_projects, diff --git a/ibis/tests/expr/test_analysis.py b/ibis/tests/expr/test_analysis.py index f22799944261d..b446a60254954 100644 --- a/ibis/tests/expr/test_analysis.py +++ b/ibis/tests/expr/test_analysis.py @@ -5,7 +5,6 @@ import ibis import ibis.common.exceptions as com import ibis.expr.operations as ops -from ibis.tests.util import assert_equal # Place to collect esoteric expression analysis bugs and tests