Skip to content

Commit

Permalink
feat[next][dace]: GTIR-to-SDFG lowering of let-lambdas (GridTools#1589)
Browse files Browse the repository at this point in the history
This PR adds lowering of let-lambdas to DaCe SDFG.
  • Loading branch information
edopao authored Jul 29, 2024
1 parent 20fae77 commit abed597
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 106 deletions.
29 changes: 29 additions & 0 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,32 @@ def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call:
)
)
)


def op_as_fieldop(
op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None
) -> Callable[..., itir.FunCall]:
"""
Promotes a function `op` to a field_operator.
Args:
op: a function from values to value.
domain: the domain of the returned field.
Returns:
A function from Fields to Field.
Examples:
>>> str(op_as_fieldop("op")("a", "b"))
'(⇑(λ(__arg0, __arg1) → op(·__arg0, ·__arg1)))(a, b)'
"""
if isinstance(op, (str, itir.SymRef, itir.Lambda)):
op = call(op)

def _impl(*its: itir.Expr) -> itir.FunCall:
args = [
f"__arg{i}" for i in range(len(its))
] # TODO: `op` must not contain `SymRef(id="__argX")`
return as_fieldop(lambda_(*args)(op(*[deref(arg) for arg in args])), domain)(*its)

return _impl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@


IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes
LetSymbol: TypeAlias = tuple[str, ts.FieldType | ts.ScalarType]
TemporaryData: TypeAlias = tuple[dace.nodes.Node, ts.FieldType | ts.ScalarType]


Expand All @@ -49,6 +50,7 @@ def __call__(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
let_symbols: dict[str, LetSymbol],
) -> list[TemporaryData]:
"""Creates the dataflow subgraph representing a GTIR primitive function.
Expand All @@ -60,6 +62,7 @@ def __call__(
sdfg: The SDFG where the primitive subgraph should be instantiated
state: The SDFG state where the result of the primitive function should be made available
sdfg_builder: The object responsible for visiting child nodes of the primitive node.
let_symbols: Mapping of symbols (i.e. lambda parameters) to known temporary fields.
Returns:
A list of data access nodes and the associated GT4Py data type, which provide
Expand All @@ -77,8 +80,14 @@ def _parse_arg_expr(
domain: list[
tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]
],
let_symbols: dict[str, LetSymbol],
) -> gtir_to_tasklet.IteratorExpr | gtir_to_tasklet.MemletExpr:
fields: list[TemporaryData] = sdfg_builder.visit(node, sdfg=sdfg, head_state=state)
fields: list[TemporaryData] = sdfg_builder.visit(
node,
sdfg=sdfg,
head_state=state,
let_symbols=let_symbols,
)

assert len(fields) == 1
data_node, arg_type = fields[0]
Expand Down Expand Up @@ -155,8 +164,9 @@ def translate_as_field_op(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
let_symbols: dict[str, LetSymbol],
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for the `as_field_op` builtin function."""
"""Generates the dataflow subgraph for the `as_fieldop` builtin function."""
assert isinstance(node, gtir.FunCall)
assert cpm.is_call_to(node.fun, "as_fieldop")

Expand All @@ -173,7 +183,9 @@ def translate_as_field_op(
assert isinstance(node.type, ts.FieldType)

# first visit the list of arguments and build a symbol map
stencil_args = [_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain) for arg in node.args]
stencil_args = [
_parse_arg_expr(arg, sdfg, state, sdfg_builder, domain, let_symbols) for arg in node.args
]

# represent the field operator as a mapped tasklet graph, which will range over the field domain
taskgen = gtir_to_tasklet.LambdaToTasklet(sdfg, state, sdfg_builder)
Expand Down Expand Up @@ -236,6 +248,7 @@ def translate_cond(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
let_symbols: dict[str, LetSymbol],
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for the `cond` builtin function."""
assert cpm.is_call_to(node, "cond")
Expand Down Expand Up @@ -273,8 +286,18 @@ def translate_cond(
sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not bool({cond})")))
sdfg.add_edge(false_state, state, dace.InterstateEdge())

true_br_args = sdfg_builder.visit(true_expr, sdfg=sdfg, head_state=true_state)
false_br_args = sdfg_builder.visit(false_expr, sdfg=sdfg, head_state=false_state)
true_br_args = sdfg_builder.visit(
true_expr,
sdfg=sdfg,
head_state=true_state,
let_symbols=let_symbols,
)
false_br_args = sdfg_builder.visit(
false_expr,
sdfg=sdfg,
head_state=false_state,
let_symbols=let_symbols,
)

output_nodes = []
for true_br, false_br in zip(true_br_args, false_br_args, strict=True):
Expand Down Expand Up @@ -309,6 +332,7 @@ def translate_symbol_ref(
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
let_symbols: dict[str, LetSymbol],
) -> list[TemporaryData]:
"""Generates the dataflow subgraph for a `ir.SymRef` node."""
assert isinstance(node, (gtir.Literal, gtir.SymRef))
Expand All @@ -320,7 +344,16 @@ def translate_symbol_ref(
temp_name = "literal"
else:
sym_value = str(node.id)
data_type = sdfg_builder.get_symbol_type(sym_value)
if sym_value in let_symbols:
# The `let_symbols` dictionary maps a `gtir.SymRef` string to a temporary
# data container. These symbols are visited and initialized in a state
# that preceeds the current state, therefore a new access node is created
# everytime they are accessed. It is therefore possible that multiple access
# nodes are created in one state for the same data container. We rely
# on the simplify to remove duplicated access nodes.
sym_value, data_type = let_symbols[sym_value]
else:
data_type = sdfg_builder.get_symbol_type(sym_value)
temp_name = sym_value

if isinstance(data_type, ts.FieldType):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder):
"""

offset_provider: dict[str, gtx_common.Connectivity | gtx_common.Dimension]
symbol_types: dict[str, ts.FieldType | ts.ScalarType] = dataclasses.field(
global_symbols: dict[str, ts.FieldType | ts.ScalarType] = dataclasses.field(
default_factory=lambda: {}
)
map_uids: eve.utils.UIDGenerator = dataclasses.field(
Expand All @@ -119,12 +119,10 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder):
)

def get_offset_provider(self, offset: str) -> gtx_common.Connectivity | gtx_common.Dimension:
assert offset in self.offset_provider
return self.offset_provider[offset]

def get_symbol_type(self, symbol_name: str) -> ts.FieldType | ts.ScalarType:
assert symbol_name in self.symbol_types
return self.symbol_types[symbol_name]
return self.global_symbols[symbol_name]

def unique_map_name(self, name: str) -> str:
return f"{self.map_uids.sequential_id()}_{name}"
Expand Down Expand Up @@ -184,7 +182,7 @@ def _add_storage(

# TODO: unclear why mypy complains about incompatible types
assert isinstance(symbol_type, (ts.FieldType, ts.ScalarType))
self.symbol_types[name] = symbol_type
self.global_symbols[name] = symbol_type

def _add_storage_for_temporary(self, temp_decl: gtir.Temporary) -> dict[str, str]:
"""
Expand All @@ -210,7 +208,7 @@ def _visit_expression(
to have the same memory layout as the target array.
"""
results: list[gtir_builtin_translators.TemporaryData] = self.visit(
node, sdfg=sdfg, head_state=head_state
node, sdfg=sdfg, head_state=head_state, let_symbols={}
)

field_nodes = []
Expand Down Expand Up @@ -303,7 +301,7 @@ def visit_SetAt(self, stmt: gtir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState)
for expr_node, target_node in zip(expr_nodes, target_nodes, strict=True):
target_array = sdfg.arrays[target_node.data]
assert not target_array.transient
target_symbol_type = self.symbol_types[target_node.data]
target_symbol_type = self.global_symbols[target_node.data]

if isinstance(target_symbol_type, ts.FieldType):
subset = ",".join(
Expand All @@ -324,38 +322,102 @@ def visit_FunCall(
node: gtir.FunCall,
sdfg: dace.SDFG,
head_state: dace.SDFGState,
let_symbols: dict[str, gtir_builtin_translators.LetSymbol],
) -> list[gtir_builtin_translators.TemporaryData]:
# use specialized dataflow builder classes for each builtin function
if cpm.is_call_to(node, "cond"):
return gtir_builtin_translators.translate_cond(node, sdfg, head_state, self)
return gtir_builtin_translators.translate_cond(
node, sdfg, head_state, self, let_symbols
)
elif cpm.is_call_to(node.fun, "as_fieldop"):
return gtir_builtin_translators.translate_as_field_op(node, sdfg, head_state, self)
return gtir_builtin_translators.translate_as_field_op(
node, sdfg, head_state, self, let_symbols
)
elif isinstance(node.fun, gtir.Lambda):
# We use a separate state to ensure that the lambda arguments are evaluated
# before the computation starts. This is required in case the let-symbols
# are used in conditional branch execution, which happens in different states.
lambda_state = sdfg.add_state_before(head_state, f"{head_state.label}_symbols")

node_args = []
for arg in node.args:
node_args.extend(
self.visit(
arg,
sdfg=sdfg,
head_state=lambda_state,
let_symbols=let_symbols,
)
)

# some cleanup: remove isolated nodes for program arguments in lambda state
isolated_node_args = [node for node, _ in node_args if lambda_state.degree(node) == 0]
assert all(
isinstance(node, dace.nodes.AccessNode) and node.data in self.global_symbols
for node in isolated_node_args
)
lambda_state.remove_nodes_from(isolated_node_args)

return self.visit(
node.fun,
sdfg=sdfg,
head_state=head_state,
let_symbols=let_symbols,
args=node_args,
)
else:
raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).")

def visit_Lambda(self, node: gtir.Lambda) -> Any:
def visit_Lambda(
self,
node: gtir.Lambda,
sdfg: dace.SDFG,
head_state: dace.SDFGState,
let_symbols: dict[str, gtir_builtin_translators.LetSymbol],
args: list[gtir_builtin_translators.TemporaryData],
) -> list[gtir_builtin_translators.TemporaryData]:
"""
This visitor class should never encounter `itir.Lambda` expressions
because a lambda represents a stencil, which operates from iterator to values.
In fieldview, lambdas should only be arguments to field operators (`as_field_op`).
Translates a `Lambda` node to a tasklet subgraph in the current SDFG state.
All arguments to lambda functions are fields (i.e. `as_fieldop`, field or scalar `gtir.SymRef`,
nested let-lambdas thereof). The dictionary called `let_symbols` maps the lambda parameters
to symbols, e.g. temporary fields or program arguments. If the lambda has a parameter whose name
is already present in `let_symbols`, i.e. a paramater with the same name as a previously defined
symbol, the parameter will shadow the previous symbol during traversal of the lambda expression.
"""
raise RuntimeError("Unexpected 'itir.Lambda' node encountered in GTIR.")
lambda_symbols = let_symbols | {
str(p.id): (temp_node.data, type_)
for p, (temp_node, type_) in zip(node.params, args, strict=True)
}

return self.visit(
node.expr,
sdfg=sdfg,
head_state=head_state,
let_symbols=lambda_symbols,
)

def visit_Literal(
self,
node: gtir.Literal,
sdfg: dace.SDFG,
head_state: dace.SDFGState,
let_symbols: dict[str, gtir_builtin_translators.LetSymbol],
) -> list[gtir_builtin_translators.TemporaryData]:
return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self)
return gtir_builtin_translators.translate_symbol_ref(
node, sdfg, head_state, self, let_symbols={}
)

def visit_SymRef(
self,
node: gtir.SymRef,
sdfg: dace.SDFG,
head_state: dace.SDFGState,
let_symbols: dict[str, gtir_builtin_translators.LetSymbol],
) -> list[gtir_builtin_translators.TemporaryData]:
return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self)
return gtir_builtin_translators.translate_symbol_ref(
node, sdfg, head_state, self, let_symbols
)


def build_sdfg_from_gtir(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class IteratorExpr:
class LambdaToTasklet(eve.NodeVisitor):
"""Translates an `ir.Lambda` expression to a dataflow graph.
Lambda functions should only be encountered as argument to the `as_field_op`
Lambda functions should only be encountered as argument to the `as_fieldop`
builtin function, therefore the dataflow graph generated here typically
represents the stencil function of a field operator.
"""
Expand Down
Loading

0 comments on commit abed597

Please sign in to comment.