From c51bdd1b6e515b2cebff466876d51b1bc0874096 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 15 Nov 2024 10:35:42 +0100 Subject: [PATCH] fix[next]: Fix type preservation in CSE (#1736) The common subexpression elimination uses typing information to decide what expressions can be extracted. However, while extracting it creates new nodes and uses the inline lambda pass, which did not preserve the types. This was observed in PMAP and is fixed in this PR on a best effort basis. Creating a minimal reproducible example is hard and since multiple of us are considering making typing information an integral part of the IR, e.g. by attaching the computation to the node instead of having a separate pass, which would solve the problem automatically no tests have been written. --- src/gt4py/next/iterator/transforms/cse.py | 12 ++++++------ src/gt4py/next/iterator/transforms/inline_lambdas.py | 5 ++++- src/gt4py/next/iterator/type_system/inference.py | 6 +++--- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index ccc1d2195f..4932d376ad 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -14,6 +14,7 @@ import operator from typing import Callable, Iterable, TypeVar, Union, cast +import gt4py.next.iterator.ir_utils.ir_makers as im from gt4py.eve import ( NodeTranslator, NodeVisitor, @@ -241,7 +242,6 @@ def extract_subexpression( Examples: Default case for `(x+y) + ((x+y)+z)`: - >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> from gt4py.eve.utils import UIDGenerator >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> predicate = lambda subexpr, num_occurences: num_occurences > 1 @@ -433,7 +433,9 @@ def predicate(subexpr: itir.Expr, num_occurences: int): if num_occurences > 1: if is_local_view: return True - else: + # condition is only necessary since typing on lambdas is not preserved during + # the transformation + elif not isinstance(subexpr, itir.Lambda): # only extract fields outside of `as_fieldop` # `as_fieldop(...)(field_expr, field_expr)` # -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)` @@ -451,10 +453,8 @@ def predicate(subexpr: itir.Expr, num_occurences: int): return self.generic_visit(node, **kwargs) # apply remapping - result = itir.FunCall( - fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr), - args=list(extracted.values()), - ) + result = im.let(*extracted.items())(new_expr) + itir_type_inference.copy_type(from_=node, to=result, allow_untyped=True) # if the node id is ignored (because its parent is eliminated), but it occurs # multiple times then we want to visit the final result once more. diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 920d628166..399a7a3dc6 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -14,6 +14,7 @@ from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs +from gt4py.next.iterator.type_system import inference as itir_inference # TODO(tehrengruber): Reduce complexity of the function by removing the different options here @@ -98,7 +99,7 @@ def new_name(name): new_expr.location = node.location return new_expr else: - return ir.FunCall( + new_expr = ir.FunCall( fun=ir.Lambda( params=[ param @@ -110,6 +111,8 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) + itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) + return new_expr @dataclasses.dataclass diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index edcb9b540c..66d8345b94 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -95,14 +95,14 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ -def copy_type(from_: itir.Node, to: itir.Node) -> None: +def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None: """ Copy type from one node to another. This function mainly exists for readability reasons. """ - assert isinstance(from_.type, ts.TypeSpec) - _set_node_type(to, from_.type) + assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec) + _set_node_type(to, from_.type) # type: ignore[arg-type] def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: