Skip to content

Commit

Permalink
fix[next]: Fix type preservation in CSE (#1736)
Browse files Browse the repository at this point in the history
The common subexpression elimination uses typing information to decide
what expressions can be extracted. However, while extracting it creates
new nodes and uses the inline lambda pass, which did not preserve the
types. This was observed in PMAP and is fixed in this PR on a best
effort basis. Creating a minimal reproducible example is hard and since
multiple of us are considering making typing information an integral
part of the IR, e.g. by attaching the computation to the node instead of
having a separate pass, which would solve the problem automatically no
tests have been written.
  • Loading branch information
tehrengruber authored Nov 15, 2024
1 parent b60ffff commit c51bdd1
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
12 changes: 6 additions & 6 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import operator
from typing import Callable, Iterable, TypeVar, Union, cast

import gt4py.next.iterator.ir_utils.ir_makers as im
from gt4py.eve import (
NodeTranslator,
NodeVisitor,
Expand Down Expand Up @@ -241,7 +242,6 @@ def extract_subexpression(
Examples:
Default case for `(x+y) + ((x+y)+z)`:
>>> import gt4py.next.iterator.ir_utils.ir_makers as im
>>> from gt4py.eve.utils import UIDGenerator
>>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z"))
>>> predicate = lambda subexpr, num_occurences: num_occurences > 1
Expand Down Expand Up @@ -433,7 +433,9 @@ def predicate(subexpr: itir.Expr, num_occurences: int):
if num_occurences > 1:
if is_local_view:
return True
else:
# condition is only necessary since typing on lambdas is not preserved during
# the transformation
elif not isinstance(subexpr, itir.Lambda):
# only extract fields outside of `as_fieldop`
# `as_fieldop(...)(field_expr, field_expr)`
# -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)`
Expand All @@ -451,10 +453,8 @@ def predicate(subexpr: itir.Expr, num_occurences: int):
return self.generic_visit(node, **kwargs)

# apply remapping
result = itir.FunCall(
fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr),
args=list(extracted.values()),
)
result = im.let(*extracted.items())(new_expr)
itir_type_inference.copy_type(from_=node, to=result, allow_untyped=True)

# if the node id is ignored (because its parent is eliminated), but it occurs
# multiple times then we want to visit the final result once more.
Expand Down
5 changes: 4 additions & 1 deletion src/gt4py/next/iterator/transforms/inline_lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift
from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols
from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs
from gt4py.next.iterator.type_system import inference as itir_inference


# TODO(tehrengruber): Reduce complexity of the function by removing the different options here
Expand Down Expand Up @@ -98,7 +99,7 @@ def new_name(name):
new_expr.location = node.location
return new_expr
else:
return ir.FunCall(
new_expr = ir.FunCall(
fun=ir.Lambda(
params=[
param
Expand All @@ -110,6 +111,8 @@ def new_name(name):
args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible],
location=node.location,
)
itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True)
return new_expr


@dataclasses.dataclass
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None:
node.type = type_


def copy_type(from_: itir.Node, to: itir.Node) -> None:
def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None:
"""
Copy type from one node to another.
This function mainly exists for readability reasons.
"""
assert isinstance(from_.type, ts.TypeSpec)
_set_node_type(to, from_.type)
assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec)
_set_node_type(to, from_.type) # type: ignore[arg-type]


def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None:
Expand Down

0 comments on commit c51bdd1

Please sign in to comment.