Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Improve fieldop fusion #1764

Draft
wants to merge 191 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
191 commits
Select commit Hold shift + click to select a range
b0d688a
Add support for IfStmt in ITIR
tehrengruber Sep 25, 2024
65a22fe
Initial version of FuseAsFieldOp
tehrengruber Sep 28, 2024
8a27d16
Merge branch 'main' into gtir_fuse_as_fieldop
tehrengruber Sep 28, 2024
3edc130
Move symbolic domain utilities out of global tmp pass into seperate m…
tehrengruber Sep 28, 2024
031e459
Merge branch 'gtir_if_stmt' into gtir_temporaries_pass
tehrengruber Sep 28, 2024
dd63122
feat[next]: gtir lowering of broadcasted scalars
havogt Oct 1, 2024
2d840cb
Merge origin/main
tehrengruber Oct 1, 2024
7172442
Second draft of temporary pass
tehrengruber Oct 1, 2024
3313191
Merge remote-tracking branch 'origin/main' into gtir_if_stmt
tehrengruber Oct 1, 2024
4a20973
Merge branch 'gtir_temporaries_pass' into gtir_temporaries2
tehrengruber Oct 2, 2024
f3cbc18
Merge branch 'gtir_fuse_as_fieldop' into gtir_temporaries2
tehrengruber Oct 2, 2024
d4f066a
Cleanup
tehrengruber Oct 2, 2024
9a69dea
Merge branch 'gtir_temporaries_pass' into gtir_temporaries2
tehrengruber Oct 2, 2024
c1038fa
Address review comments
tehrengruber Oct 3, 2024
31cd94e
Merge branch 'main' into gtir_if_stmt
tehrengruber Oct 3, 2024
160aeaf
Add type inference test for IfStmt
tehrengruber Oct 3, 2024
67682ed
Merge remote-tracking branch 'origin_tehrengruber/gtir_if_stmt' into …
tehrengruber Oct 3, 2024
881babf
Merge origin/main
tehrengruber Oct 3, 2024
2bf9231
Merge remote-tracking branch 'origin/main' into gtir_fuse_as_fieldop
tehrengruber Oct 4, 2024
19a2c5e
Cleanup
tehrengruber Oct 9, 2024
a6a010c
Merge gtir_fuse_as_fieldop
tehrengruber Oct 9, 2024
a57cb99
Merge branch 'main' into gtir_fuse_as_fieldop
tehrengruber Oct 9, 2024
b50b85a
Cleanup
tehrengruber Oct 9, 2024
e2e3fa1
Merge remote-tracking branch 'origin_tehrengruber/gtir_fuse_as_fieldo…
tehrengruber Oct 9, 2024
bb2f2b1
Cleanup
tehrengruber Oct 9, 2024
729968e
Cleanup
tehrengruber Oct 9, 2024
cdcaac0
Cleanup
tehrengruber Oct 9, 2024
9472502
Cleanup
tehrengruber Oct 9, 2024
01dd86e
Cleanup
tehrengruber Oct 9, 2024
e078b36
Cleanup
tehrengruber Oct 10, 2024
f0331cb
Cleanup
tehrengruber Oct 10, 2024
21d2fb5
Merge origin_tehrengruber/gtir_if_stmt
tehrengruber Oct 10, 2024
3bd3ad6
Merge remote-tracking branch 'origin/main' into gtir_temporaries_pass
tehrengruber Oct 10, 2024
61e9665
Cleanup
tehrengruber Oct 10, 2024
f9dff50
Cleanup
tehrengruber Oct 10, 2024
71f512f
Cleanup
tehrengruber Oct 10, 2024
6fb424f
Cleanup
tehrengruber Oct 10, 2024
7d809fb
Cleanup
tehrengruber Oct 10, 2024
c7b79c0
Cleanup
tehrengruber Oct 10, 2024
f16a961
Cleanup
tehrengruber Oct 10, 2024
3e3f9a1
Cleanup
tehrengruber Oct 10, 2024
edffd97
Cleanup
tehrengruber Oct 10, 2024
3196a11
Cleanup
tehrengruber Oct 10, 2024
6a4d227
Merge remote-tracking branch 'origin/main' into gtir_temporaries_pass
tehrengruber Oct 10, 2024
183139b
Merge branch 'gtir_temporaries_pass' into gtir_temporaries2
tehrengruber Oct 10, 2024
04f59dd
Inline lambda pass: ensure opcount preserving option works whether `i…
tehrengruber Oct 10, 2024
feab647
Retrigger CI
tehrengruber Oct 10, 2024
5869769
Merge branch 'fix_inline_lambda_opcount_preserving' into gtir_tempora…
tehrengruber Oct 10, 2024
376153f
Address review comments
tehrengruber Oct 11, 2024
b61be89
Merge branch 'gtir_fuse_as_fieldop' into gtir_temporaries2
tehrengruber Oct 11, 2024
52a1b90
Allow type inference without domain argument to `as_fieldop`
tehrengruber Oct 14, 2024
a1b4448
Cleanup
tehrengruber Oct 14, 2024
f93c8b4
Merge branch 'type_inf_before_domain_inf' into gtir_temporaries2
tehrengruber Oct 14, 2024
9684507
Merge origin/main
tehrengruber Oct 15, 2024
5b86f19
Remove comment
tehrengruber Oct 15, 2024
7784896
Merge branch 'gtir_temporaries_pass' into gtir_temporaries2
tehrengruber Oct 15, 2024
0e50214
Use GTIR in embedded and gtfn
tehrengruber Oct 15, 2024
068ff06
Cleanup
tehrengruber Oct 16, 2024
aaba729
Cleanup
tehrengruber Oct 16, 2024
6044d76
Cleanup
tehrengruber Oct 16, 2024
1dc9ebb
Cleanup
tehrengruber Oct 16, 2024
685bedb
Cleanup
tehrengruber Oct 16, 2024
378b3b3
Cleanup
tehrengruber Oct 16, 2024
83e5ce2
Cleanup
tehrengruber Oct 16, 2024
b3ae17b
Cleanup
tehrengruber Oct 16, 2024
4637b7a
Merge origin_tehrengruber/lower_broadcasted_scalars_to_fieldops
tehrengruber Oct 16, 2024
dbd71a9
Small fix
tehrengruber Oct 16, 2024
0990858
Merge branch 'main' into gtir_temporaries
tehrengruber Oct 16, 2024
da4a63c
Merge remote-tracking branch 'origin/main' into gtir_temporaries
tehrengruber Oct 16, 2024
c4b1ed8
Merge remote-tracking branch 'origin_tehrengruber/gtir_temporaries' i…
tehrengruber Oct 16, 2024
f545984
Remove superfluous test backend
tehrengruber Oct 16, 2024
0904d88
Cleanup
tehrengruber Oct 16, 2024
6f6c65b
Cleanup
tehrengruber Oct 16, 2024
606f662
Cleanup
tehrengruber Oct 16, 2024
b917011
Cleanup
tehrengruber Oct 16, 2024
6c9e8ab
Cleanup
tehrengruber Oct 16, 2024
320c7f8
Cleanup
tehrengruber Oct 17, 2024
b589346
Cleanup
tehrengruber Oct 17, 2024
309d58f
Merge branch 'gtir_temporaries_pass' into gtir_temporaries
tehrengruber Oct 17, 2024
4d2b3da
Fix format
tehrengruber Oct 17, 2024
cfb59d7
Merge branch 'main' into gtir_temporaries_pass
tehrengruber Oct 17, 2024
1fe44e0
Merge branch 'gtir_temporaries_pass' into gtir_temporaries
tehrengruber Oct 17, 2024
ab5a6a2
Merge remote-tracking branch 'origin_tehrengruber/gtir_temporaries_pa…
tehrengruber Oct 18, 2024
6072d53
Merge commit 'cb77ccb85c1d5baa83720493c626a66c0c116451' into gtir_tem…
tehrengruber Oct 18, 2024
821af59
Merge origin/main
tehrengruber Oct 18, 2024
70c0dff
Fix scalar let in gtfn
tehrengruber Oct 19, 2024
108af05
Fix some tests
tehrengruber Oct 20, 2024
7380d6e
Fix
tehrengruber Oct 20, 2024
c268ee1
Fix some tests
tehrengruber Oct 20, 2024
c5b0171
Address review comments
tehrengruber Oct 21, 2024
bba6aa4
Allow partial type inference in ITIR
tehrengruber Oct 24, 2024
840c004
Merge remote-tracking branch 'origin/main' into allow_partial_type_in…
tehrengruber Oct 24, 2024
f9fc5c5
Small fix
tehrengruber Oct 24, 2024
c241bc4
Add tests
tehrengruber Oct 28, 2024
e120849
Merge origin/main
tehrengruber Oct 28, 2024
45f41be
Address review comments
tehrengruber Oct 28, 2024
fccd43b
Fix tests
tehrengruber Oct 28, 2024
1874eba
Fix tests
tehrengruber Nov 5, 2024
8685731
Merge origin/main
tehrengruber Nov 5, 2024
dd5bfa7
Fix tests
tehrengruber Nov 5, 2024
8f1e84a
Fix tests
tehrengruber Nov 5, 2024
d53d3bb
Fix tests
tehrengruber Nov 5, 2024
4bfef54
Fix tests
tehrengruber Nov 5, 2024
78b7a98
Fix tests
tehrengruber Nov 5, 2024
88a7660
Fix tests
tehrengruber Nov 5, 2024
1b79b3a
Fix tests
tehrengruber Nov 5, 2024
b6b603e
Bump CI version to 22.04
tehrengruber Nov 5, 2024
5beccf0
Bump CI version to 22.04
tehrengruber Nov 5, 2024
deca907
Bump CI version to 22.04
tehrengruber Nov 5, 2024
af9d776
Bump CI version to 22.04
tehrengruber Nov 5, 2024
8cb36da
Merge origin/main
tehrengruber Nov 5, 2024
ee0b94a
Fix failing dace tests
tehrengruber Nov 5, 2024
270e173
Address review comments
tehrengruber Nov 5, 2024
8e2ba0c
Revert CI changes & skip failing test
tehrengruber Nov 8, 2024
229d3ac
Merge origin/main
tehrengruber Nov 8, 2024
580cb79
Fix format
tehrengruber Nov 8, 2024
c67d355
Ugly fix
tehrengruber Nov 8, 2024
79ab838
Format
tehrengruber Nov 8, 2024
d53fda9
Small fix
tehrengruber Nov 8, 2024
24c3e87
Bump gridtools-cpp to 2.3.7 in preperation of #1648
tehrengruber Nov 11, 2024
04f110f
Merge branch 'bump_gridtools' into gtir_temporaries
tehrengruber Nov 11, 2024
c399f65
Merge remote-tracking branch 'origin/main' into gtir_temporaries
tehrengruber Nov 11, 2024
0faa7ef
Test tuple fix in gridtools
tehrengruber Nov 11, 2024
8e8a0a1
Fix typo
tehrengruber Nov 11, 2024
e03dd38
Fix ITIR program hash stability
tehrengruber Nov 12, 2024
7a4c692
Merge branch 'fix_itir_hash_stability' into gtir_temporaries
tehrengruber Nov 12, 2024
405cbb0
Revert "Fix typo"
tehrengruber Nov 13, 2024
4c27279
Revert "Test tuple fix in gridtools"
tehrengruber Nov 13, 2024
a666eef
Test 2.3.8
tehrengruber Nov 13, 2024
8402e4f
Fix type preservation in CSE
tehrengruber Nov 13, 2024
4547295
Fix type preservation in CSE
tehrengruber Nov 13, 2024
2d6464f
Fix test skip matrix
tehrengruber Nov 14, 2024
4b67a99
Fix format
tehrengruber Nov 14, 2024
af2ed5f
Address review comments
tehrengruber Nov 14, 2024
1188a55
Merge origin/inline_dynamic_offsets
tehrengruber Nov 14, 2024
bff17b2
Inline dynamic shifts
tehrengruber Nov 14, 2024
52df315
Merge remote-tracking branch 'origin/main' into gtir_temporaries
tehrengruber Nov 14, 2024
262ffdd
Merge origin/main
tehrengruber Nov 14, 2024
79026c7
dace-related changes
edopao Nov 15, 2024
16f143f
Address review comments
tehrengruber Nov 15, 2024
75695d9
Address review comments
tehrengruber Nov 15, 2024
f3b1c6c
Merge origin/main
tehrengruber Nov 15, 2024
1ab17ca
Merge branch 'gtir_temporaries' into inline_dynamic_offsets
tehrengruber Nov 15, 2024
9df5e4f
Merge remote-tracking branch 'origin/main' into inline_dynamic_offsets
tehrengruber Nov 15, 2024
f268ae1
Merge remote-tracking branch 'origin_tehrengruber/inline_dynamic_offs…
tehrengruber Nov 15, 2024
6e40499
Small cleanup
tehrengruber Nov 15, 2024
b0bd658
Address review comments
tehrengruber Nov 27, 2024
7b67ecb
Small cleanup
tehrengruber Nov 27, 2024
438eb6b
Small cleanup
tehrengruber Nov 27, 2024
306a82a
Merge
tehrengruber Nov 27, 2024
752e08a
Small fix
tehrengruber Nov 27, 2024
395e6ee
Small fix
tehrengruber Nov 27, 2024
ff00b55
Small fix
tehrengruber Nov 27, 2024
34d6040
Non-tree-size-increasing collapse tuple on ifs
tehrengruber Dec 1, 2024
48abc08
Fix typos
tehrengruber Dec 1, 2024
3790944
Fix typos
tehrengruber Dec 1, 2024
a8a63bf
Disable PROPAGATE_TO_IF_ON_TUPLES by default in pass manager
tehrengruber Dec 1, 2024
2c44ffc
Improve doc
tehrengruber Dec 1, 2024
42b5817
Improve typo
tehrengruber Dec 1, 2024
9da19a2
Improve typo
tehrengruber Dec 1, 2024
bcd9e48
Improve typo
tehrengruber Dec 1, 2024
0a212bd
Improve typo
tehrengruber Dec 1, 2024
70562fe
Fix typo
tehrengruber Dec 1, 2024
9cee650
Improve doc
tehrengruber Dec 1, 2024
43f5741
Format
tehrengruber Dec 1, 2024
7b37f1c
Fix type synthesizer for partially typed arithmetic ops
tehrengruber Dec 1, 2024
a0341a6
Fix test
tehrengruber Dec 1, 2024
5a892f3
Fix test
tehrengruber Dec 1, 2024
cd04fb3
Merge branch 'inline_dynamic_offsets' into improve_fieldop_fusion
tehrengruber Dec 1, 2024
bd02cc7
Merge remote-tracking branch 'origin/main' into inline_dynamic_offsets
tehrengruber Dec 1, 2024
31ebf0c
Merge branch 'inline_dynamic_offsets' into improve_fieldop_fusion
tehrengruber Dec 1, 2024
ae8a36f
Merge origin_tehrengruber/ct_cps (contains partial type inference)
tehrengruber Dec 1, 2024
93688a8
Improve field operator fusion
tehrengruber Dec 2, 2024
7c271ca
Small fix
tehrengruber Dec 2, 2024
22dbf5e
Fix inlining of multiple-use, dynamically-calculated neighbor field
tehrengruber Dec 2, 2024
bfacda6
Fix broken iterator tests containing lifts
tehrengruber Dec 2, 2024
92b7533
Support inlining into scan
tehrengruber Dec 3, 2024
1eeca75
Small fix
tehrengruber Dec 3, 2024
e3306fc
Address review comments.
tehrengruber Dec 5, 2024
2928503
Address review comments.
tehrengruber Dec 5, 2024
447b467
Address review comments.
tehrengruber Dec 5, 2024
54cd6b1
Fix type annotation
tehrengruber Dec 5, 2024
e043d0a
Add type annotation in extract_tmp pass.
tehrengruber Dec 6, 2024
20fb59e
Update src/gt4py/next/iterator/transforms/infer_domain.py
tehrengruber Dec 6, 2024
d8d3913
Address review comments
tehrengruber Dec 6, 2024
e9ea52b
Merge remote-tracking branch 'origin_tehrengruber/inline_dynamic_offs…
tehrengruber Dec 6, 2024
5832371
Address review comments
tehrengruber Dec 6, 2024
3aa6421
Merge branch 'main' into inline_dynamic_offsets
tehrengruber Dec 6, 2024
668e3ec
Fix format
tehrengruber Dec 6, 2024
da63e23
Merge origin/main
tehrengruber Dec 8, 2024
15c1212
Merge origin/main
tehrengruber Dec 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 139 additions & 23 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
from gt4py.next.type_system import type_info, type_specifications as ts


def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr):
def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr | str):
"""Given a itir.FunCall return a new call with one of its argument replaced."""
return ir.FunCall(
fun=node.fun, args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)]
fun=node.fun,
args=[arg if i != arg_idx else im.ensure_expr(new_arg) for i, arg in enumerate(node.args)],
)


Expand All @@ -47,6 +48,32 @@ def _is_trivial_make_tuple_call(node: ir.Expr):
return True


def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool:
"""
Return `true` if the expr is a trivial expression or tuple thereof.

>>> _is_trivial_or_tuple_thereof_expr(im.make_tuple("a", "b"))
True
>>> _is_trivial_or_tuple_thereof_expr(im.tuple_get(1, "a"))
True
>>> _is_trivial_or_tuple_thereof_expr(
... im.let("t", im.make_tuple("a", "b"))(im.tuple_get(1, "t"))
... )
True
"""
if cpm.is_call_to(node, "make_tuple"):
return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args)
if cpm.is_call_to(node, "tuple_get"):
return _is_trivial_or_tuple_thereof_expr(node.args[1])
if isinstance(node, (ir.SymRef, ir.Literal)):
return True
if cpm.is_let(node):
return _is_trivial_or_tuple_thereof_expr(node.fun.expr) and all( # type: ignore[attr-defined] # ensured by is_let
_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args
)
return False


# TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first,
# transform each node until no transformations apply anymore, whenever a node is to be transformed
# go through all available transformation and apply them. However the final result here still
Expand Down Expand Up @@ -76,28 +103,42 @@ class Flag(enum.Flag):
#: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))`
#: -> `foo({trivial_expr1, trivial_expr2})`
INLINE_TRIVIAL_MAKE_TUPLE = enum.auto()
#: Similar as `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, i.e.
#: into the tree, allowing removal of tuple expressions across `if_` calls without
#: increasing the size of the tree. This is particularly important for `if` statements
#: in the frontend, where outwards propagation can have devastating effects on the tree
#: size, without any gained optimization potential. For example
#: ```
#: complex_lambda(if cond1
#: if cond2
#: {...}
#: else:
#: {...}
#: else
#: {...})
#: ```
#: is problematic, since `PROPAGATE_TO_IF_ON_TUPLES` would propagate and hence duplicate
#: `complex_lambda` three times, while we only want to get rid of the tuple expressions
#: inside of the `if_`s.
#: Note that this transformation is not mutually exclusive to `PROPAGATE_TO_IF_ON_TUPLES`.
PROPAGATE_TO_IF_ON_TUPLES_CPS = enum.auto()
#: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]`
PROPAGATE_TO_IF_ON_TUPLES = enum.auto()
#: `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))`
PROPAGATE_NESTED_LET = enum.auto()
#: `let(a, 1)(a)` -> `1`
#: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)`
INLINE_TRIVIAL_LET = enum.auto()

@classmethod
def all(self) -> CollapseTuple.Flag:
return functools.reduce(operator.or_, self.__members__.values())

uids: eve_utils.UIDGenerator
ignore_tuple_size: bool
flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument]

PRESERVED_ANNEX_ATTRS = ("type",)

# we use one UID generator per instance such that the generated ids are
# stable across multiple runs (required for caching to properly work)
_letify_make_tuple_uids: eve_utils.UIDGenerator = dataclasses.field(
init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_tuple_el")
)

@classmethod
def apply(
cls,
Expand All @@ -111,6 +152,7 @@ def apply(
flags: Optional[Flag] = None,
# allow sym references without a symbol declaration, mostly for testing
allow_undeclared_symbols: bool = False,
uids: Optional[eve_utils.UIDGenerator] = None,
) -> ir.Node:
"""
Simplifies `make_tuple`, `tuple_get` calls.
Expand All @@ -127,6 +169,7 @@ def apply(
"""
flags = flags or cls.flags
offset_provider_type = offset_provider_type or {}
uids = uids or eve_utils.UIDGenerator()

if isinstance(node, ir.Program):
within_stencil = False
Expand All @@ -142,10 +185,9 @@ def apply(
allow_undeclared_symbols=allow_undeclared_symbols,
)

new_node = cls(
ignore_tuple_size=ignore_tuple_size,
flags=flags,
).visit(node, within_stencil=within_stencil)
new_node = cls(ignore_tuple_size=ignore_tuple_size, flags=flags, uids=uids).visit(
node, within_stencil=within_stencil
)

# inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important
# as otherwise two equal expressions containing a tuple will not be equal anymore
Expand Down Expand Up @@ -185,6 +227,8 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
method = getattr(self, f"transform_{transformation.name.lower()}")
result = method(node, **kwargs)
if result is not None:
assert result is not node
itir_type_inference.reinfer(result)
return result
return None

Expand All @@ -206,6 +250,7 @@ def transform_collapse_make_tuple_tuple_get(
# tuple argument differs, just continue with the rest of the tree
return None

itir_type_inference.reinfer(first_expr) # type is needed so reinfer on-demand
assert self.ignore_tuple_size or isinstance(
first_expr.type, (ts.TupleType, ts.DeferredType)
)
Expand All @@ -226,7 +271,7 @@ def transform_collapse_tuple_get_make_tuple(
and isinstance(node.args[0], ir.Literal)
):
# `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i`
assert type_info.is_integer(node.args[0].type)
assert not node.args[0].type or type_info.is_integer(node.args[0].type)
make_tuple_call = node.args[1]
idx = int(node.args[0].value)
assert idx < len(
Expand Down Expand Up @@ -263,13 +308,13 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Op
if node.fun == ir.SymRef(id="make_tuple"):
# `make_tuple(expr1, expr1)`
# -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))`
bound_vars: dict[str, ir.Expr] = {}
bound_vars: dict[ir.Sym, ir.Expr] = {}
new_args: list[ir.Expr] = []
for arg in node.args:
if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node):
el_name = self._letify_make_tuple_uids.sequential_id()
new_args.append(im.ref(el_name))
bound_vars[el_name] = arg
el_name = self.uids.sequential_id(prefix="__ct_el")
new_args.append(im.ref(el_name, arg.type))
bound_vars[im.sym(el_name, arg.type)] = arg
else:
new_args.append(arg)

Expand Down Expand Up @@ -312,6 +357,73 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt
return im.if_(cond, new_true_branch, new_false_branch)
return None

def transform_propagate_to_if_on_tuples_cps(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
if not cpm.is_call_to(node, "if_"):
for i, arg in enumerate(node.args):
if cpm.is_call_to(arg, "if_"):
itir_type_inference.reinfer(arg)
if not any(isinstance(branch.type, ts.TupleType) for branch in arg.args[1:]):
continue

cond, true_branch, false_branch = arg.args
tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above
tuple_len = len(tuple_type.types)
itir_type_inference.reinfer(node)
assert node.type

# transform function into continuation-passing-style
f_type = ts.FunctionType(
pos_only_args=tuple_type.types,
pos_or_kw_args={},
kw_only_args={},
returns=node.type,
)
f_params = [
im.sym(self.uids.sequential_id(prefix="__ct_el_cps"), type_)
for type_ in tuple_type.types
]
f_args = [im.ref(param.id, param.type) for param in f_params]
f_body = _with_altered_arg(node, i, im.make_tuple(*f_args))
# simplify, e.g., inline trivial make_tuple args
new_f_body = self.fp_transform(f_body, **kwargs)
# if the function did not simplify there is nothing to gain. Skip
# transformation.
if new_f_body is f_body:
continue
# if the function is not trivial the transformation would still work, but
# inlining would result in a larger tree again and we didn't didn't gain
# anything compared to regular `propagate_to_if_on_tuples`. Not inling also
# works, but we don't want bound lambda functions in our tree (at least right
# now).
if not _is_trivial_or_tuple_thereof_expr(new_f_body):
continue
f = im.lambda_(*f_params)(new_f_body)

tuple_var = self.uids.sequential_id(prefix="__ct_tuple_cps")
f_var = self.uids.sequential_id(prefix="__ct_cont")
new_branches = []
for branch in arg.args[1:]:
new_branch = im.let(tuple_var, branch)(
im.call(im.ref(f_var, f_type))(
*(
im.tuple_get(i, im.ref(tuple_var, branch.type))
for i in range(tuple_len)
)
)
)
new_branches.append(self.fp_transform(new_branch, **kwargs))

new_node = im.let(f_var, f)(im.if_(cond, *new_branches))
new_node = inline_lambda(new_node, eligible_params=[True])
assert cpm.is_call_to(new_node, "if_")
new_node = im.if_(
cond, *(self.fp_transform(branch, **kwargs) for branch in new_node.args[1:])
)
return new_node
return None

def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if cpm.is_let(node):
# `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))`
Expand Down Expand Up @@ -339,9 +451,13 @@ def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional
return None

def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let
# `let(a, 1)(a)` -> `1`
for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let
if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let
return arg
if cpm.is_let(node):
if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let
# `let(a, 1)(a)` -> `1`
for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let
if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let
return arg
if any(trivial_args := [isinstance(arg, (ir.SymRef, ir.Literal)) for arg in node.args]):
return inline_lambda(node, eligible_params=trivial_args)

return None
Loading
Loading