Skip to content

Commit

Permalink
move rewrites to ibis/expr/rewrites.py from newrels.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Nov 21, 2023
1 parent 03c136c commit e646905
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 115 deletions.
4 changes: 3 additions & 1 deletion ibis/common/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def match(self, value: AnyType, context: dict[str, AnyType]) -> AnyType:
"""
...

def is_match(self, value: AnyType, context: dict[str, AnyType] = None) -> bool:
def is_match(
self, value: AnyType, context: Optional[dict[str, AnyType]] = None
) -> bool:
"""Check that a value matches the pattern.
Parameters
Expand Down
2 changes: 0 additions & 2 deletions ibis/expr/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from ibis.expr.operations.json import * # noqa: F403
from ibis.expr.operations.logical import * # noqa: F403
from ibis.expr.operations.maps import * # noqa: F403

# from ibis.expr.operations.relations import * # noqa: F403
from ibis.expr.operations.newrels import * # noqa: F403
from ibis.expr.operations.numeric import * # noqa: F403
from ibis.expr.operations.reductions import * # noqa: F403
Expand Down
3 changes: 1 addition & 2 deletions ibis/expr/operations/generic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

import itertools
from typing import Annotated, Any, Optional, Union
from typing import Annotated, Any, Optional
from typing import Literal as LiteralType

from public import public
from typing_extensions import TypeVar

import ibis.common.exceptions as com
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
Expand Down
115 changes: 8 additions & 107 deletions ibis/expr/operations/newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import itertools
import typing
from abc import abstractmethod
from typing import Annotated, Any, Literal, Optional
from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional

from public import public

Expand All @@ -15,16 +15,20 @@
from ibis.common.annotations import attribute
from ibis.common.bases import Immutable
from ibis.common.collections import FrozenDict
from ibis.common.deferred import Item, deferred, var
from ibis.common.deferred import deferred
from ibis.common.exceptions import IbisTypeError, IntegrityError, RelationError
from ibis.common.grounds import Concrete
from ibis.common.patterns import Between, Check, In, InstanceOf, _, pattern, replace
from ibis.common.patterns import Between, In, InstanceOf, pattern
from ibis.common.typing import Coercible, VarTuple
from ibis.expr.operations.core import Alias, Column, Node, Scalar, Value
from ibis.expr.operations.sortkeys import SortKey
from ibis.expr.operations.sortkeys import SortKey # noqa: TCH001
from ibis.expr.schema import Schema
from ibis.util import Namespace, gen_name, indent

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa

p = Namespace(pattern, module=__name__)
d = Namespace(deferred, module=__name__)

Expand Down Expand Up @@ -154,19 +158,6 @@ def dtype(self):
return self.rel.schema[self.name]


# TODO(kszucs): we may not need this, the pattern can define whether it is foreign or not
# @public
# class ForeignField(Value):
# rel: Relation
# name: str

# shape = ds.columnar

# @attribute
# def dtype(self):
# return self.rel.schema[self.name]


# TODO(kszucs): rename it to ForeignScalar
@public
class ForeignField(Value):
Expand All @@ -188,9 +179,6 @@ def _check_integrity(values, allowed_parents):
raise IntegrityError(
f"Cannot add {disallowed!r} to projection, they belong to another relation"
)
# add a flag about allowing foreign values
# add a flag to enfore scalar foreign values (e.g. for Project and Aggregate)
# egyebkent csak scalar lehet (e.g. scalar subquery or a value based on literals)


@public
Expand Down Expand Up @@ -533,96 +521,9 @@ def schema(self):
return self.parent.schema


# class Subquery(Relation):
# rel: Relation

# @property
# def schema(self):
# return self.rel.schema

# @property
# def fields(self):
# return self.rel.fields


################################ TYPES ################################


# class TableExpr(Expr):
# def schema(self):
# return self.op().schema

# # def __getattr__(self, key):
# # return next(bind(self, key))


@public
def table(name, schema):
return UnboundTable(name, schema).to_expr()


# TODO(kszucs): cover it with tests


################################ REWRITES ################################
from ibis.common.patterns import Each

name = var("name")

y = var("y")
values = var("values")


@replace(p.Project(y @ p.Relation) & Check(_.schema == y.schema))
def complete_reprojection(_, y):
# TODO(kszucs): this could be moved to the pattern itself but not sure how
# to express it, especially in a shorter way then the following check
for name in _.schema:
if _.values[name] != Field(y, name):
return _
return y


@replace(p.Project(y @ p.Project))
def subsequent_projects(_, y):
rule = p.Field(y, name) >> Item(y.values, name)
values = {k: v.replace(rule) for k, v in _.values.items()}
return Project(y.parent, values)


@replace(p.Filter(y @ p.Filter))
def subsequent_filters(_, y):
rule = p.Field(y, name) >> d.Field(y.parent, name)
preds = tuple(v.replace(rule) for v in _.predicates)
return Filter(y.parent, y.predicates + preds)


@replace(p.Filter(y @ p.Project))
def reorder_filter_project(_, y):
rule = p.Field(y, name) >> Item(y.values, name)
preds = tuple(v.replace(rule) for v in _.predicates)

inner = Filter(y.parent, preds)
rule = p.Field(y.parent, name) >> d.Field(inner, name)
projs = {k: v.replace(rule) for k, v in y.values.items()}

return Project(inner, projs)


# TODO(kszucs): add a rewrite rule for nestes JoinChain objects where the
# JoinLink depends on another JoinChain, in this case the JoinLink should be
# merged into the JoinChain


# TODO(kszucs): this may work if the sort keys are not overlapping, need to revisit
# @replace(p.Sort(y @ p.Sort))
# def subsequent_sorts(_, y):
# return Sort(y.parent, y.keys + _.keys)


# TODO(kszucs): support t.select(*t) syntax by implementing TableExpr.__iter__()


# subqueries:
# 1. reduction passed to .filter() should be turned into a subquery
# 2. reduction passed to .select() with a foreign table should be turned into a subquery
54 changes: 53 additions & 1 deletion ibis/expr/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.deferred import Item, _, deferred, var
from ibis.common.exceptions import UnsupportedOperationError
from ibis.common.patterns import pattern, replace
from ibis.common.patterns import Check, pattern, replace
from ibis.util import Namespace

p = Namespace(pattern, module=ops)
d = Namespace(deferred, module=ops)

y = var("y")
name = var("name")


@replace(p.FillNa)
Expand Down Expand Up @@ -79,3 +84,50 @@ def rewrite_sample(_):
(ops.LessEqual(ops.RandomScalar(), _.fraction),),
(),
)


@replace(p.Project(y @ p.Relation) & Check(_.schema == y.schema))
def complete_reprojection(_, y):
# TODO(kszucs): this could be moved to the pattern itself but not sure how
# to express it, especially in a shorter way then the following check
for name in _.schema:
if _.values[name] != ops.Field(y, name):
return _
return y


@replace(p.Project(y @ p.Project))
def subsequent_projects(_, y):
rule = p.Field(y, name) >> Item(y.values, name)
values = {k: v.replace(rule) for k, v in _.values.items()}
return ops.Project(y.parent, values)


@replace(p.Filter(y @ p.Filter))
def subsequent_filters(_, y):
rule = p.Field(y, name) >> d.Field(y.parent, name)
preds = tuple(v.replace(rule) for v in _.predicates)
return ops.Filter(y.parent, y.predicates + preds)


@replace(p.Filter(y @ p.Project))
def reorder_filter_project(_, y):
rule = p.Field(y, name) >> Item(y.values, name)
preds = tuple(v.replace(rule) for v in _.predicates)

inner = ops.Filter(y.parent, preds)
rule = p.Field(y.parent, name) >> d.Field(inner, name)
projs = {k: v.replace(rule) for k, v in y.values.items()}

return ops.Project(inner, projs)


# TODO(kszucs): add a rewrite rule for nestes JoinChain objects where the
# JoinLink depends on another JoinChain, in this case the JoinLink should be
# merged into the JoinChain


# TODO(kszucs): this may work if the sort keys are not overlapping, need to revisit
# @replace(p.Sort(y @ p.Sort))
# def subsequent_sorts(_, y):
# return Sort(y.parent, y.keys + _.keys)
2 changes: 1 addition & 1 deletion ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4326,7 +4326,7 @@ def window_by(self, time_col: ir.Value) -> WindowedTable:
return WindowedTable(self, time_col)

def optimize(self, enable_reordering=True):
from ibis.expr.operations.newrels import (
from ibis.expr.rewrites import (
complete_reprojection,
subsequent_filters,
subsequent_projects,
Expand Down
1 change: 0 additions & 1 deletion ibis/tests/expr/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import ibis
import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.tests.util import assert_equal

# Place to collect esoteric expression analysis bugs and tests

Expand Down

0 comments on commit e646905

Please sign in to comment.