From 29f5c2244fbe0ebb7e2c6ebda2591929211a3f89 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 6 Dec 2024 09:42:13 +0100 Subject: [PATCH] Import changes made in #1808 Author: @romanc --- dace/frontend/python/newast.py | 32 ++++++++++++++++++++------------ dace/frontend/python/parser.py | 22 ++++++++++++---------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 1cbb8e67c9..d2813371c9 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1319,7 +1319,7 @@ def _views_to_data(state: SDFGState, nodes: List[dace.nodes.AccessNode]) -> List self.sdfg.replace_dict(repl_dict) propagate_states(self.sdfg) - for state, memlet, inner_indices in itertools.chain(self.inputs.values(), self.outputs.values()): + for state, memlet, _inner_indices in itertools.chain(self.inputs.values(), self.outputs.values()): if state is not None and state.dynamic_executions: memlet.dynamic = True @@ -2366,8 +2366,11 @@ def visit_For(self, node: ast.For): init_expr='%s = %s' % (indices[0], astutils.unparse(ast_ranges[0][0])), update_expr=incr[indices[0]], inverted=False) - _, first_subblock, _, _ = self._recursive_visit(node.body, f'for_{node.lineno}', node.lineno, - extra_symbols=extra_syms, parent=loop_region, + _, first_subblock, _, _ = self._recursive_visit(node.body, + f'for_{node.lineno}', + node.lineno, + extra_symbols=extra_syms, + parent=loop_region, unconnected_last_block=False) loop_region.start_block = loop_region.node_id(first_subblock) self._connect_break_blocks(loop_region) @@ -2449,7 +2452,10 @@ def visit_While(self, node: ast.While): loop_region = self._add_loop_region(loop_cond, label=f'while_{node.lineno}', inverted=False) # Parse body - self._recursive_visit(node.body, f'while_{node.lineno}', node.lineno, parent=loop_region, + self._recursive_visit(node.body, + f'while_{node.lineno}', + node.lineno, + parent=loop_region, unconnected_last_block=False) if test_region is not None: @@ -2540,7 +2546,6 @@ def _has_loop_ancestor(self, node: ControlFlowBlock) -> bool: node = node.parent_graph return False - def visit_Break(self, node: ast.Break): if not self._has_loop_ancestor(self.cfg_target): raise DaceSyntaxError(self, node, "Break block outside loop region") @@ -2572,8 +2577,7 @@ def visit_If(self, node: ast.If): # Process 'else'/'elif' statements if len(node.orelse) > 0: - else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', - sdfg=self.sdfg) + else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', sdfg=self.sdfg) cond_block.add_branch(None, else_body) # Visit recursively self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False) @@ -2934,7 +2938,6 @@ def _add_aug_assignment(self, wsqueezed = [i for i in range(len(wtarget_subset)) if i not in wsqz] rsqueezed = [i for i in range(len(rtarget_subset)) if i not in rsqz] - if (boolarr or indirect_indices or (sqz_wsub.size() == sqz_osub.size() and sqz_wsub.size() == sqz_rsub.size())): map_range = {i: rng for i, rng in all_idx_tuples} @@ -3358,8 +3361,11 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): new_data, rng = None, None dtype_keys = tuple(dtypes.dtype_to_typeclass().keys()) - if not (result in self.sdfg.symbols or symbolic.issymbolic(result) or isinstance(result, dtype_keys) or - (isinstance(result, str) and any(result in x for x in [self.sdfg.arrays, self.sdfg._pgrids, self.sdfg._subarrays, self.sdfg._rdistrarrays]))): + if not ( + result in self.sdfg.symbols or symbolic.issymbolic(result) or isinstance(result, dtype_keys) or + (isinstance(result, str) and any( + result in x + for x in [self.sdfg.arrays, self.sdfg._pgrids, self.sdfg._subarrays, self.sdfg._rdistrarrays]))): raise DaceSyntaxError( self, node, "In assignments, the rhs may only be " "data, numerical/boolean constants " @@ -3467,7 +3473,9 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): cname = self.sdfg.find_new_constant(f'__ind{i}_{true_name}') self.sdfg.add_constant(cname, carr) # Add constant to descriptor repository - self.sdfg.add_array(cname, carr.shape, dtypes.dtype_to_typeclass(carr.dtype.type), + self.sdfg.add_array(cname, + carr.shape, + dtypes.dtype_to_typeclass(carr.dtype.type), transient=True) if numpy.array(arr).dtype == numpy.bool_: boolarr = cname @@ -4769,7 +4777,7 @@ def visit_With(self, node: ast.With, is_async=False): evald = astutils.evalnode(node.items[0].context_expr, self.globals) if hasattr(evald, "name"): named_region_name: str = evald.name - else: + else: named_region_name = f"Named Region {node.lineno}" named_region = NamedRegion(named_region_name, debuginfo=self.current_lineinfo) self.cfg_target.add_node(named_region) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 20018effd0..b65e7c227d 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -59,9 +59,10 @@ def _get_locals_and_globals(f): result.update(f.__globals__) # grab the free variables (i.e. locals) if f.__closure__ is not None: - result.update( - {k: v - for k, v in zip(f.__code__.co_freevars, [_get_cell_contents_or_none(x) for x in f.__closure__])}) + result.update({ + k: v + for k, v in zip(f.__code__.co_freevars, [_get_cell_contents_or_none(x) for x in f.__closure__]) + }) return result @@ -142,6 +143,7 @@ def infer_symbols_from_datadescriptor(sdfg: SDFG, class DaceProgram(pycommon.SDFGConvertible): """ A data-centric program object, obtained by decorating a function with ``@dace.program``. """ + def __init__(self, f, args, @@ -405,9 +407,10 @@ def _create_sdfg_args(self, sdfg: SDFG, args: Tuple[Any], kwargs: Dict[str, Any] # Update arguments with symbols in data shapes result.update( - infer_symbols_from_datadescriptor( - sdfg, {k: create_datadescriptor(v) - for k, v in result.items() if k not in self.constant_args})) + infer_symbols_from_datadescriptor(sdfg, { + k: create_datadescriptor(v) + for k, v in result.items() if k not in self.constant_args + })) return result def __call__(self, *args, **kwargs): @@ -487,9 +490,6 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF :param validate: If True, validates the resulting SDFG after creation. :return: The generated SDFG object. """ - # Avoid import loop - from dace.transformation.passes import scalar_to_symbol as scal2sym - from dace.transformation import helpers as xfh # Obtain DaCe program as SDFG sdfg, cached = self._generate_pdp(args, kwargs, simplify=simplify) @@ -812,7 +812,9 @@ def get_program_hash(self, *args, **kwargs) -> cached_program.ProgramCacheKey: _, key = self._load_sdfg(None, *args, **kwargs) return key - def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any], + def _generate_pdp(self, + args: Tuple[Any], + kwargs: Dict[str, Any], simplify: Optional[bool] = None) -> Tuple[SDFG, bool]: """ Generates the parsed AST representation of a DaCe program.