From 8402e4ffe68c1bdd1d443b79c5a91a6c4ca108bd Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 13 Nov 2024 20:41:04 +0100 Subject: [PATCH 1/2] Fix type preservation in CSE --- src/gt4py/next/iterator/transforms/cse.py | 14 ++++++++------ .../next/iterator/transforms/inline_lambdas.py | 5 ++++- src/gt4py/next/iterator/type_system/inference.py | 6 +++--- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index ccc1d2195f..5e3b77b062 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -14,6 +14,7 @@ import operator from typing import Callable, Iterable, TypeVar, Union, cast +import gt4py.next.iterator.ir_utils.ir_makers as im from gt4py.eve import ( NodeTranslator, NodeVisitor, @@ -241,7 +242,6 @@ def extract_subexpression( Examples: Default case for `(x+y) + ((x+y)+z)`: - >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> from gt4py.eve.utils import UIDGenerator >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> predicate = lambda subexpr, num_occurences: num_occurences > 1 @@ -433,10 +433,14 @@ def predicate(subexpr: itir.Expr, num_occurences: int): if num_occurences > 1: if is_local_view: return True - else: + # condition is only necessary since typing on lambdas is not preserved during + # the pass + elif not isinstance(subexpr, itir.Lambda): # only extract fields outside of `as_fieldop` # `as_fieldop(...)(field_expr, field_expr)` # -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)` + # only extract if subexpression is not a trivial tuple expressions, e.g., + # `make_tuple(a, b)`, as this would result in a more costly temporary. assert isinstance(subexpr.type, ts.TypeSpec) if all( isinstance(stype, ts.FieldType) @@ -451,10 +455,8 @@ def predicate(subexpr: itir.Expr, num_occurences: int): return self.generic_visit(node, **kwargs) # apply remapping - result = itir.FunCall( - fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr), - args=list(extracted.values()), - ) + result = im.let(*extracted.items())(new_expr) + itir_type_inference.copy_type(from_=node, to=result, allow_untyped=True) # if the node id is ignored (because its parent is eliminated), but it occurs # multiple times then we want to visit the final result once more. diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 920d628166..399a7a3dc6 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -14,6 +14,7 @@ from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs +from gt4py.next.iterator.type_system import inference as itir_inference # TODO(tehrengruber): Reduce complexity of the function by removing the different options here @@ -98,7 +99,7 @@ def new_name(name): new_expr.location = node.location return new_expr else: - return ir.FunCall( + new_expr = ir.FunCall( fun=ir.Lambda( params=[ param @@ -110,6 +111,8 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) + itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) + return new_expr @dataclasses.dataclass diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index edcb9b540c..66d8345b94 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -95,14 +95,14 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ -def copy_type(from_: itir.Node, to: itir.Node) -> None: +def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None: """ Copy type from one node to another. This function mainly exists for readability reasons. """ - assert isinstance(from_.type, ts.TypeSpec) - _set_node_type(to, from_.type) + assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec) + _set_node_type(to, from_.type) # type: ignore[arg-type] def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: From 4547295a99b71efbe42695c627e48b879a51b103 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 13 Nov 2024 20:42:28 +0100 Subject: [PATCH 2/2] Fix type preservation in CSE --- src/gt4py/next/iterator/transforms/cse.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 5e3b77b062..4932d376ad 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -434,13 +434,11 @@ def predicate(subexpr: itir.Expr, num_occurences: int): if is_local_view: return True # condition is only necessary since typing on lambdas is not preserved during - # the pass + # the transformation elif not isinstance(subexpr, itir.Lambda): # only extract fields outside of `as_fieldop` # `as_fieldop(...)(field_expr, field_expr)` # -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)` - # only extract if subexpression is not a trivial tuple expressions, e.g., - # `make_tuple(a, b)`, as this would result in a more costly temporary. assert isinstance(subexpr.type, ts.TypeSpec) if all( isinstance(stype, ts.FieldType)