Skip to content

Commit

Permalink
Merge origin/main
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Nov 15, 2024
2 parents 75695d9 + c51bdd1 commit f3b1c6c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
12 changes: 6 additions & 6 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import operator
from typing import Callable, Iterable, TypeVar, Union, cast

import gt4py.next.iterator.ir_utils.ir_makers as im
from gt4py.eve import (
NodeTranslator,
NodeVisitor,
Expand Down Expand Up @@ -256,7 +257,6 @@ def extract_subexpression(
Examples:
Default case for `(x+y) + ((x+y)+z)`:
>>> import gt4py.next.iterator.ir_utils.ir_makers as im
>>> from gt4py.eve.utils import UIDGenerator
>>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z"))
>>> predicate = lambda subexpr, num_occurences: num_occurences > 1
Expand Down Expand Up @@ -448,7 +448,9 @@ def predicate(subexpr: itir.Expr, num_occurences: int):
if num_occurences > 1:
if within_stencil:
return True
else:
# condition is only necessary since typing on lambdas is not preserved during
# the transformation
elif not isinstance(subexpr, itir.Lambda):
# only extract fields outside of `as_fieldop`
# `as_fieldop(...)(field_expr, field_expr)`
# -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)`
Expand All @@ -468,10 +470,8 @@ def predicate(subexpr: itir.Expr, num_occurences: int):
return self.generic_visit(node, **kwargs)

# apply remapping
result = itir.FunCall(
fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr),
args=list(extracted.values()),
)
result = im.let(*extracted.items())(new_expr)
itir_type_inference.copy_type(from_=node, to=result, allow_untyped=True)

# if the node id is ignored (because its parent is eliminated), but it occurs
# multiple times then we want to visit the final result once more.
Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/iterator/transforms/inline_lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift
from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols
from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs
from gt4py.next.iterator.type_system import inference as itir_inference


# TODO(tehrengruber): Reduce complexity of the function by removing the different options here
Expand Down Expand Up @@ -113,6 +114,7 @@ def new_name(name):
for attr in ("type", "recorded_shifts", "domain"):
if hasattr(node.annex, attr):
setattr(new_expr.annex, attr, getattr(node.annex, attr))
itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True)
return new_expr


Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None:
node.type = type_


def copy_type(from_: itir.Node, to: itir.Node) -> None:
def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None:
"""
Copy type from one node to another.
This function mainly exists for readability reasons.
"""
assert isinstance(from_.type, ts.TypeSpec)
_set_node_type(to, from_.type)
assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec)
_set_node_type(to, from_.type) # type: ignore[arg-type]


def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None:
Expand Down

0 comments on commit f3b1c6c

Please sign in to comment.