Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[next]: Fix type preservation in CSE #1736

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import operator
from typing import Callable, Iterable, TypeVar, Union, cast

import gt4py.next.iterator.ir_utils.ir_makers as im
from gt4py.eve import (
NodeTranslator,
NodeVisitor,
Expand Down Expand Up @@ -241,7 +242,6 @@ def extract_subexpression(
Examples:
Default case for `(x+y) + ((x+y)+z)`:

>>> import gt4py.next.iterator.ir_utils.ir_makers as im
>>> from gt4py.eve.utils import UIDGenerator
>>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z"))
>>> predicate = lambda subexpr, num_occurences: num_occurences > 1
Expand Down Expand Up @@ -433,7 +433,9 @@ def predicate(subexpr: itir.Expr, num_occurences: int):
if num_occurences > 1:
if is_local_view:
return True
else:
# condition is only necessary since typing on lambdas is not preserved during
# the transformation
elif not isinstance(subexpr, itir.Lambda):
# only extract fields outside of `as_fieldop`
# `as_fieldop(...)(field_expr, field_expr)`
# -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)`
Expand All @@ -451,10 +453,8 @@ def predicate(subexpr: itir.Expr, num_occurences: int):
return self.generic_visit(node, **kwargs)

# apply remapping
result = itir.FunCall(
fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr),
args=list(extracted.values()),
)
result = im.let(*extracted.items())(new_expr)
itir_type_inference.copy_type(from_=node, to=result, allow_untyped=True)

# if the node id is ignored (because its parent is eliminated), but it occurs
# multiple times then we want to visit the final result once more.
Expand Down
5 changes: 4 additions & 1 deletion src/gt4py/next/iterator/transforms/inline_lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift
from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols
from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs
from gt4py.next.iterator.type_system import inference as itir_inference


# TODO(tehrengruber): Reduce complexity of the function by removing the different options here
Expand Down Expand Up @@ -98,7 +99,7 @@ def new_name(name):
new_expr.location = node.location
return new_expr
else:
return ir.FunCall(
new_expr = ir.FunCall(
fun=ir.Lambda(
params=[
param
Expand All @@ -110,6 +111,8 @@ def new_name(name):
args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible],
location=node.location,
)
itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True)
return new_expr


@dataclasses.dataclass
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None:
node.type = type_


def copy_type(from_: itir.Node, to: itir.Node) -> None:
def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None:
"""
Copy type from one node to another.

This function mainly exists for readability reasons.
"""
assert isinstance(from_.type, ts.TypeSpec)
_set_node_type(to, from_.type)
assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec)
_set_node_type(to, from_.type) # type: ignore[arg-type]


def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None:
Expand Down
Loading