diff --git a/dace/frontend/fortran/ast_components.py b/dace/frontend/fortran/ast_components.py index 1537f14e47..a0d7925475 100644 --- a/dace/frontend/fortran/ast_components.py +++ b/dace/frontend/fortran/ast_components.py @@ -1,13 +1,12 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -from fparser.two import Fortran2008 as f08 +from typing import Any, List, Optional, Type, TypeVar, Union, overload, TYPE_CHECKING + from fparser.two import Fortran2003 as f03 +from fparser.two import Fortran2008 as f08 from fparser.two import symbol_table -import copy from dace.frontend.fortran import ast_internal_classes -from dace.frontend.fortran.ast_internal_classes import FNode, Name_Node -from typing import Any, List, Optional, Tuple, Type, TypeVar, Union, overload, TYPE_CHECKING - +from dace.frontend.fortran.ast_internal_classes import Name_Node if TYPE_CHECKING: from dace.frontend.fortran.intrinsics import FortranIntrinsics @@ -206,7 +205,6 @@ def __init__(self, ast: FASTNode, tables: symbol_table.SymbolTable): "Assignment_Stmt": self.assignment_stmt, "Pointer_Assignment_Stmt": self.pointer_assignment_stmt, "Where_Stmt": self.where_stmt, - "Forall_Stmt": self.forall_stmt, "Where_Construct": self.where_construct, "Where_Construct_Stmt": self.where_construct_stmt, "Masked_Elsewhere_Stmt": self.masked_elsewhere_stmt, @@ -239,7 +237,6 @@ def __init__(self, ast: FASTNode, tables: symbol_table.SymbolTable): "End_Interface_Stmt": self.end_interface_stmt, "Generic_Spec": self.generic_spec, "Name": self.name, - "Rename": self.rename, "Type_Name": self.type_name, "Specification_Part": self.specification_part, "Intrinsic_Type_Spec": self.intrinsic_type_spec, @@ -317,10 +314,10 @@ def data_pointer_object(self, node: FASTNode): return ast_internal_classes.Data_Ref_Node(parent_ref=children[0], part_ref=children[2],type="VOID") else: raise NotImplementedError("Data pointer object not supported yet") - + def add_name_list_for_module(self, module: str, name_list: List[str]): - self.name_list[module] = name_list + self.name_list[module] = name_list def create_children(self, node: FASTNode): return [self.create_ast(child) @@ -353,7 +350,7 @@ def create_ast(self, node=None): if self.unsupported_fortran_syntax.get(self.current_ast) is None: self.unsupported_fortran_syntax[self.current_ast] = [] if type(node).__name__ not in self.unsupported_fortran_syntax[self.current_ast]: - if type(node).__name__ not in self.unsupported_fortran_syntax[self.current_ast]: + if type(node).__name__ not in self.unsupported_fortran_syntax[self.current_ast]: self.unsupported_fortran_syntax[self.current_ast].append(type(node).__name__) for i in node.children: self.create_ast(i) @@ -369,7 +366,7 @@ def suffix(self, node: FASTNode): children = self.create_children(node) name = children[0] return ast_internal_classes.Suffix_Node(name=name) - + def data_ref(self, node: FASTNode): children = self.create_children(node) idx=len(children)-1 @@ -379,16 +376,16 @@ def data_ref(self, node: FASTNode): #parent.isStructMember=True idx=idx-1 current=ast_internal_classes.Data_Ref_Node(parent_ref=parent, part_ref=part_ref,type="VOID") - + while idx>0: parent = children[idx-1] - current = ast_internal_classes.Data_Ref_Node(parent_ref=parent, part_ref=current,type="VOID") + current = ast_internal_classes.Data_Ref_Node(parent_ref=parent, part_ref=current,type="VOID") idx=idx-1 return current def end_type_stmt(self, node: FASTNode): return None - + def access_stmt(self, node: FASTNode): return None @@ -423,7 +420,7 @@ def derived_type_def(self, node: FASTNode): new_placeholder[k.replace("__f2dace_A","__f2dace_SA")]=self.placeholders[k] else: new_placeholder[k]=self.placeholders[k] - self.placeholders=new_placeholder + self.placeholders=new_placeholder for k,v in self.placeholders_offsets.items(): if "__f2dace_OA" in k: new_placeholder_offsets[k.replace("__f2dace_OA","__f2dace_SOA")]=self.placeholders_offsets[k] @@ -498,8 +495,8 @@ def program_stmt(self, node: FASTNode): return ast_internal_classes.Program_Stmt_Node(name=name, line_number=node.item.span) def subroutine_subprogram(self, node: FASTNode): - - + + children = self.create_children(node) name = get_child(children, ast_internal_classes.Subroutine_Stmt_Node) @@ -548,7 +545,7 @@ def function_subprogram(self, node: FASTNode): name = get_child(children, ast_internal_classes.Function_Stmt_Node) specification_part = get_child(children, ast_internal_classes.Specification_Part_Node) execution_part = get_child(children, ast_internal_classes.Execution_Part_Node) - + return_type = name.return_type return ast_internal_classes.Function_Subprogram_Node( name=name.name, @@ -561,7 +558,7 @@ def function_subprogram(self, node: FASTNode): elemental=name.elemental, ) - def function_stmt(self, node: FASTNode): + def function_stmt(self, node: FASTNode): children = self.create_children(node) name = get_child(children, ast_internal_classes.Name_Node) args = get_child(children, ast_internal_classes.Arg_List_Node) @@ -574,7 +571,7 @@ def function_stmt(self, node: FASTNode): if args==None: ret_args = [] else: - ret_args = args.args + ret_args = args.args return ast_internal_classes.Function_Stmt_Node(name=name, args=ret_args,return_type=ret, line_number=node.item.span,ret=ret,elemental=elemental) def subroutine_stmt(self, node: FASTNode): @@ -643,16 +640,16 @@ def structure_constructor(self, node: FASTNode): if args==None: ret_args = [] else: - ret_args = args.args + ret_args = args.args return ast_internal_classes.Structure_Constructor_Node(name=name, args=ret_args, type=None,line_number=line) - + def intrinsic_function_reference(self, node: FASTNode): children = self.create_children(node) line = get_line(node) name = get_child(children, ast_internal_classes.Name_Node) args = get_child(children, ast_internal_classes.Arg_List_Node) - + if name is None: return Name_Node(name="Error! "+node.children[0].string,type='VOID') node = self.intrinsic_handler.replace_function_reference(name, args, line,self.symbols) @@ -701,7 +698,7 @@ def module(self, node: FASTNode): specification_part = get_child(children, ast_internal_classes.Specification_Part_Node) function_definitions = [i for i in children if isinstance(i, ast_internal_classes.Function_Subprogram_Node)] - + subroutine_definitions = [i for i in children if isinstance(i, ast_internal_classes.Subroutine_Subprogram_Node)] interface_blocks = {} @@ -917,7 +914,7 @@ def type_declaration_stmt(self, node: FASTNode): basetype = "INTEGER" else: raise TypeError("Derived type not supported") - + #if derived_type: # raise TypeError("Derived type not supported") if not derived_type: @@ -981,7 +978,7 @@ def type_declaration_stmt(self, node: FASTNode): attr_offset = [attr_offset] * len(names) else: attr_size, assumed_vardecls,attr_offset = self.assumed_array_shape(dimension_spec[0], names, node.item.span) - + if attr_size is None: raise RuntimeError("Couldn't parse the dimension attribute specification!") @@ -1051,8 +1048,8 @@ def type_declaration_stmt(self, node: FASTNode): raw_init = comp_init[0].children[1] init = self.create_ast(raw_init) #if size_later: - # size.append(len(init)) - if testtype!="INTEGER": symbol=False + # size.append(len(init)) + if testtype!="INTEGER": symbol=False if symbol == False: if attr_size is None: @@ -1227,10 +1224,10 @@ def level_2_expr(self, node: FASTNode): children = self.create_children(node) line = get_line(node) if children[1]=="==": type="LOGICAL" - else: + else: type="VOID" if hasattr(children[0],"type"): - type=children[0].type + type=children[0].type if len(children) == 3: return ast_internal_classes.BinOp_Node(lval=children[0], op=children[1], rval=children[2], line_number=line, type=type) else: @@ -1330,7 +1327,7 @@ def if_construct(self, node: FASTNode): if else_mode: body_else.append(i) else: - + body.append(i) currentIf.body = ast_internal_classes.Execution_Part_Node(execution=body) currentIf.body_else = ast_internal_classes.Execution_Part_Node(execution=body_else) @@ -1527,7 +1524,7 @@ def call_stmt(self, node: FASTNode): #if node.item is None: # line_number = 42 #else: - # line_number = node.item.span + # line_number = node.item.span return ast_internal_classes.Call_Expr_Node(name=name, args=ret_args, type="VOID", line_number=line_number) def return_stmt(self, node: FASTNode): @@ -1607,8 +1604,6 @@ def block_nonlabel_do_construct(self, node: FASTNode): body=ast_internal_classes.Execution_Part_Node(execution=body), line_number=do.line_number) - - def subscript_triplet(self, node: FASTNode): if node.string == ":": return ast_internal_classes.ParDecl_Node(type="ALL") @@ -1620,7 +1615,7 @@ def section_subscript_list(self, node: FASTNode): return ast_internal_classes.Section_Subscript_List_Node(list=children) def specification_part(self, node: FASTNode): - + #TODO this can be refactored to consider more fortran declaration options. Currently limited to what is encountered in code. others = [self.create_ast(i) for i in node.children if not isinstance(i, f08.Type_Declaration_Stmt)] @@ -1654,7 +1649,7 @@ def specification_part(self, node: FASTNode): for i in decls: names_filtered.extend(ii.name for ii in i.vardecl if j.name == ii.name) decl_filtered = [] - + for i in decls: if i is None: continue @@ -1681,7 +1676,7 @@ def int_literal_constant(self, node: FASTNode): x=value.split("_") value=x[0] return ast_internal_classes.Int_Literal_Node(value=value,type="INTEGER") - + def hex_constant(self, node: FASTNode): return ast_internal_classes.Int_Literal_Node(value=str(int(node.string[2:-1],16)),type="INTEGER") @@ -1703,8 +1698,8 @@ def real_literal_constant(self, node: FASTNode): if (x[1]=="wp"): return ast_internal_classes.Double_Literal_Node(value=value,type="DOUBLE") return ast_internal_classes.Real_Literal_Node(value=value,type="REAL") - - + + def char_literal_constant(self, node: FASTNode): return ast_internal_classes.Char_Literal_Node(value=node.string,type="CHAR") @@ -1713,7 +1708,7 @@ def actual_arg_spec(self, node: FASTNode): if len(children) != 2: raise ValueError("Actual arg spec must have two children") return ast_internal_classes.Actual_Arg_Spec_Node(arg_name=children[0], arg=children[1],type="VOID") - + def actual_arg_spec_list(self, node: FASTNode): children = self.create_children(node) return ast_internal_classes.Arg_List_Node(args=children) @@ -1723,8 +1718,8 @@ def initialization(self, node: FASTNode): def name(self, node: FASTNode): return ast_internal_classes.Name_Node(name=node.string.lower(),type="VOID") - - + + def rename(self, node: FASTNode): return ast_internal_classes.Rename_Node(oldname=node.children[2].string.lower(),newname=node.children[1].string.lower()) diff --git a/dace/frontend/fortran/ast_transforms.py b/dace/frontend/fortran/ast_transforms.py index 317b18d481..21fcd2ead2 100644 --- a/dace/frontend/fortran/ast_transforms.py +++ b/dace/frontend/fortran/ast_transforms.py @@ -91,6 +91,18 @@ def iter_fields(node: ast_internal_classes.FNode): except AttributeError: pass +def iter_attributes(node: ast_internal_classes.FNode): + """ + Yield a tuple of ``(fieldname, value)`` for each field in ``node._attributes`` + that is present on *node*. + """ + if not hasattr(node, "_attributes"): + a = 1 + for field in node._attributes: + try: + yield field, getattr(node, field) + except AttributeError: + pass def iter_child_nodes(node: ast_internal_classes.FNode): """ @@ -2461,3 +2473,138 @@ def visit_Specification_Part_Node(self, node: ast_internal_classes.Specification typedecls=node.typedecls, uses=node.uses ) + +class ArgumentPruner(NodeVisitor): + + def __init__(self, funcs): + + self.funcs = funcs + + self.parsed_funcs: Dict[str, List[int]] = {} + + self.used_names = set() + self.declaration_names = set() + + self.used_in_all_functions: Set[str] = set() + + def visit_Name_Node(self, node: ast_internal_classes.Name_Node): + if node.name not in self.used_names: + print(f"Used name {node.name}") + self.used_names.add(node.name) + + def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node): + self.declaration_names.add(node.name) + + # visit also sizes and offsets + self.generic_visit(node) + + def generic_visit(self, node: ast_internal_classes.FNode): + """Called if no explicit visitor function exists for a node.""" + for field, value in iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast_internal_classes.FNode): + self.visit(item) + elif isinstance(value, ast_internal_classes.FNode): + self.visit(value) + + for field, value in iter_attributes(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast_internal_classes.FNode): + self.visit(item) + elif isinstance(value, ast_internal_classes.FNode): + self.visit(value) + + def _visit_function(self, node: ast_internal_classes.FNode): + + old_used_names = self.used_names + self.used_names = set() + self.declaration_names = set() + + self.visit(node.specification_part) + + self.visit(node.execution_part) + + new_args = [] + removed_args = [] + for idx, arg in enumerate(node.args): + + if not isinstance(arg, ast_internal_classes.Name_Node): + raise NotImplementedError() + + if arg.name not in self.used_names: + print(f"Pruning argument {arg.name} of function {node.name.name}") + removed_args.append(idx) + else: + print(f"Leaving used argument {arg.name} of function {node.name.name}") + new_args.append(arg) + self.parsed_funcs[node.name.name] = removed_args + + declarations_to_remove = set() + for x in self.declaration_names: + if x not in self.used_names: + print(f"Marking removal variable {x}") + declarations_to_remove.add(x) + else: + print(f"Keeping used variable {x}") + + for decl_stmt_node in node.specification_part.specifications: + + newdecl = [] + for decl in decl_stmt_node.vardecl: + + if not isinstance(decl, ast_internal_classes.Var_Decl_Node): + raise NotImplementedError() + + if decl.name not in declarations_to_remove: + print(f"Readding declared variable {decl.name}") + newdecl.append(decl) + else: + print(f"Pruning unused but declared variable {decl.name}") + decl_stmt_node.vardecl = newdecl + + self.used_in_all_functions.update(self.used_names) + self.used_names = old_used_names + + def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine_Subprogram_Node): + + if node.name.name not in self.parsed_funcs: + self._visit_function(node) + + to_remove = self.parsed_funcs[node.name.name] + for idx in reversed(to_remove): + print(f"Prune argument {node.args[idx].name} in {node.name.name}") + del node.args[idx] + + def visit_Function_Subprogram_Node(self, node: ast_internal_classes.Function_Subprogram_Node): + + if node.name.name not in self.parsed_funcs: + self._visit_function(node) + + to_remove = self.parsed_funcs[node.name.name] + for idx in reversed(to_remove): + del node.args[idx] + + def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): + + if node.name.name not in self.parsed_funcs: + + if node.name.name in self.funcs: + self._visit_function(self.funcs[node.name.name]) + else: + + # now add actual arguments to the list of used names + for arg in node.args: + self.visit(arg) + + return + + to_remove = self.parsed_funcs[node.name.name] + for idx in reversed(to_remove): + del node.args[idx] + + # now add actual arguments to the list of used names + for arg in node.args: + self.visit(arg) + diff --git a/dace/frontend/fortran/ast_utils.py b/dace/frontend/fortran/ast_utils.py index 5e6b89d151..69dae8a6cb 100644 --- a/dace/frontend/fortran/ast_utils.py +++ b/dace/frontend/fortran/ast_utils.py @@ -1,31 +1,22 @@ # Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. -from fparser.api import parse -import os -import sys -from fparser.common.readfortran import FortranStringReader, FortranFileReader +from typing import List, Set -from dace.frontend.fortran import ast_components +import networkx as nx +from numpy import finfo as finf +from numpy import float64 as fl -#dace imports -from dace import subsets -from dace.data import Scalar -from dace.sdfg import SDFG, SDFGState, InterstateEdge +from dace import DebugInfo as di +from dace import Language as lang from dace import Memlet -from dace.sdfg.nodes import Tasklet -from dace import dtypes from dace import data as dat +from dace import dtypes +# dace imports +from dace import subsets from dace import symbolic as sym -from dace import DebugInfo as di -from dace import Language as lang -from dace.properties import CodeBlock -from numpy import finfo as finf -from numpy import float64 as fl - from dace.frontend.fortran import ast_internal_classes -from typing import List, Set -import networkx as nx -from dace.frontend.fortran import ast_transforms +from dace.sdfg import SDFG, SDFGState, InterstateEdge +from dace.sdfg.nodes import Tasklet fortrantypes2dacetypes = { "DOUBLE": dtypes.float64, @@ -515,9 +506,32 @@ def intlit2string(self, node: ast_internal_classes.Int_Literal_Node): return "".join(map(str, node.value)) def floatlit2string(self, node: ast_internal_classes.Real_Literal_Node): + # Typecheck and crash early if unexpected. + assert hasattr(node, 'value') + lit = node.value + assert isinstance(lit, str) + + # Fortran "real literals" may have an additional suffix at the end. + # Examples: + # valid: 1.0 => 1 + # valid: 1. => 1 + # valid: 1.e5 => 1e5 + # valid: 1.d5 => 1e5 + # valid: 1._kinder => 1 (precondition: somewhere earlier, `integer, parameter :: kinder=8`) + # valid: 1.e5_kinder => 1e5 + # not valid: 1.d5_kinder => 1e5 + # TODO: Is there a complete spec of the structure of real literals? + if '_' in lit: + # First, deal with kind specification and remove it altogether, since we know the type anyway. + parts = lit.split('_') + assert 1 <= len(parts) <= 2, f"{lit} is not a valid fortran literal." + lit = parts[0] + assert 'd' not in lit, f"{lit} is not a valid fortran literal." + if 'd' in lit: + # Again, since we know the type anyway, here we just make the s/d/e/ replacement. + lit = lit.replace('d', 'e') + return f"{float(lit)}" - return "".join(map(str, node.value)) - def doublelit2string(self, node: ast_internal_classes.Double_Literal_Node): return "".join(map(str, node.value)) @@ -676,32 +690,6 @@ def namerange2string(self, node: ast_internal_classes.Name_Range_Node): -class UseModuleLister: - def __init__(self): - self.list_of_modules = [] - self.objects_in_use={} - - def get_used_modules(self, node): - if node is None: - return - if not hasattr(node, "children"): - return - for i in node.children: - if i.__class__.__name__ == "Use_Stmt": - if i.children[0] is not None: - if i.children[0].string.lower()=="intrinsic": - continue - for j in i.children: - if j.__class__.__name__ == "Name": - self.list_of_modules.append(j.string) - for k in i.children: - if k.__class__.__name__ == "Only_List": - self.objects_in_use[j.string] = k - - else: - self.get_used_modules(i) - - class Context: def __init__(self, name): self.name = name diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index f92b1858d6..2a92118e50 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -231,6 +231,7 @@ def __init__( startpoint=None, sdfg_path=None, toplevel_subroutine: Optional[str] = None, + subroutine_used_names: Optional[Set[str]] = None, normalize_offsets = False ): """ @@ -275,6 +276,7 @@ def __init__( #self.iblocks=ast.iblocks self.replace_names = {} self.toplevel_subroutine = toplevel_subroutine + self.subroutine_used_names = subroutine_used_names self.normalize_offsets = normalize_offsets self.ast_elements = { ast_internal_classes.If_Stmt_Node: self.ifstmt2sdfg, @@ -409,6 +411,10 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): for k in j.vardecl: self.module_vars.append((k.name, i.name)) if i.specification_part is not None: + + # this works with CloudSC + # unsure about ICON + self.transient_mode=False for j in i.specification_part.symbols: self.translate(j, sdfg) if isinstance(j, ast_internal_classes.Symbol_Array_Decl_Node): @@ -421,6 +427,9 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): self.translate(j, sdfg) for k in j.vardecl: self.module_vars.append((k.name, i.name)) + # this works with CloudSC + # unsure about ICON + self.transient_mode=True ast_utils.add_simple_state_to_sdfg(self, sdfg, "GlobalDefEnd") if node.main_program is not None: @@ -446,7 +455,9 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): else: if self.startpoint.specification_part is not None: - self.transient_mode=True + # this works with CloudSC + # unsure about ICON + self.transient_mode=False for i in self.startpoint.specification_part.typedecls: self.translate(i, sdfg) @@ -464,7 +475,16 @@ def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): continue add_deferred_shape_assigns_for_structs(self.structures,decl, sdfg, assign_state, decl.name,decl.name,self.placeholders,self.placeholders_offsets,sdfg.arrays[self.name_mapping[sdfg][decl.name]],self.replace_names,self.actual_offsets_per_sdfg[sdfg]) - self.transient_mode=True + # this works with CloudSC + # unsure about ICON + arg_names = [ast_utils.get_name(i) for i in self.startpoint.args] + for arr_name, arr in sdfg.arrays.items(): + + if arr.transient and arr_name in arg_names: + print(f"Changing the transient status to false of {arr_name} because it's a function argument") + arr.transient = False + + self.transient_mode=True self.translate(self.startpoint.execution_part.execution, sdfg) def pointerassignment2sdfg(self, node: ast_internal_classes.Pointer_Assignment_Stmt_Node, sdfg: SDFG): @@ -777,6 +797,8 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, if node.execution_part is None: return + print("TRANSLATE SUBROUTINE", node.name.name) + # First get the list of read and written variables inputnodefinder = ast_transforms.FindInputs() inputnodefinder.visit(node) @@ -805,6 +827,7 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, if not ((len(variables_in_call) == len(parameters)) or (len(variables_in_call) == len(parameters) + 1 and not isinstance(node.result_type, ast_internal_classes.Void))): + print("Subroutine", node.name.name) print('Variables in call', len(variables_in_call)) print('Parameters', len(parameters)) #for i in variables_in_call: @@ -1844,11 +1867,9 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node, pass old_mode=self.transient_mode - print("For ",sdfg_name," old mode is ",old_mode) + #print("For ",sdfg_name," old mode is ",old_mode) self.transient_mode=True for j in node.specification_part.specifications: - - self.declstmt2sdfg(j, new_sdfg) self.transient_mode=old_mode @@ -2095,6 +2116,23 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG): if name in self.local_not_transient_because_assign[sdfg.name]: is_arg=False break + + # if this is a variable declared in the module, + # then we will not add it unless it is used by the functions. + # It would be sufficient to check the main entry function, + # since it must pass this variable through call + # to other functions. + # However, I am not completely sure how to determine which function is the main one. + # + # we ignore the variable that is not used at all in all functions + # this is a module variaable that can be removed + if not is_arg: + if self.subroutine_used_names is not None: + + if node.name not in self.subroutine_used_names: + print(f"Ignoring module variable {node.name} because it is not used in the the top level subroutine") + return + if is_arg: transient=False else: @@ -2135,7 +2173,7 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG): sizes.append(sym.pystr_to_symbolic(text)) actual_offset_value=node.offsets[node.sizes.index(i)] if isinstance(actual_offset_value,ast_internal_classes.Array_Subscript_Node): - print(node.name,actual_offset_value.name.name) + #print(node.name,actual_offset_value.name.name) raise NotImplementedError("Array subscript in offset not implemented") if isinstance(actual_offset_value,int): actual_offset_value=ast_internal_classes.Int_Literal_Node(value=str(actual_offset_value)) @@ -2168,9 +2206,9 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG): #here we must replace local placeholder sizes that have already made it to tasklets via size and ubound calls if sizes is not None: actual_sizes=sdfg.arrays[self.name_mapping[sdfg][node.name]].shape - print(node.name,sdfg.name,self.names_of_object_in_parent_sdfg.get(sdfg).get(node.name)) - print(sdfg.parent_sdfg.name,self.actual_offsets_per_sdfg[sdfg.parent_sdfg].get(self.names_of_object_in_parent_sdfg[sdfg][node.name])) - print(sdfg.parent_sdfg.arrays.get(self.name_mapping[sdfg.parent_sdfg].get(self.names_of_object_in_parent_sdfg.get(sdfg).get(node.name)))) + #print(node.name,sdfg.name,self.names_of_object_in_parent_sdfg.get(sdfg).get(node.name)) + #print(sdfg.parent_sdfg.name,self.actual_offsets_per_sdfg[sdfg.parent_sdfg].get(self.names_of_object_in_parent_sdfg[sdfg][node.name])) + #print(sdfg.parent_sdfg.arrays.get(self.name_mapping[sdfg.parent_sdfg].get(self.names_of_object_in_parent_sdfg.get(sdfg).get(node.name)))) if self.actual_offsets_per_sdfg[sdfg.parent_sdfg].get(self.names_of_object_in_parent_sdfg[sdfg][node.name]) is not None: actual_offsets=self.actual_offsets_per_sdfg[sdfg.parent_sdfg][self.names_of_object_in_parent_sdfg[sdfg][node.name]] else: @@ -2691,6 +2729,10 @@ def create_sdfg_from_string( raise NameError("Structs have cyclic dependencies") own_ast.tables = own_ast.symbols + #program = + print(dir(functions_and_subroutines_builder)) + ast_transforms.ArgumentPruner(functions_and_subroutines_builder.nodes).visit(program) + ast2sdfg = AST_translator(own_ast, __file__, multiple_sdfgs=multiple_sdfgs, toplevel_subroutine=sdfg_name, normalize_offsets=normalize_offsets) sdfg = SDFG(sdfg_name) ast2sdfg.actual_offsets_per_sdfg[sdfg]={} @@ -3345,6 +3387,10 @@ def create_sdfg_from_fortran_file_with_options(source_string: str, source_list, program.placeholders_offsets=partial_ast.placeholders_offsets program.functions_and_subroutines=partial_ast.functions_and_subroutines unordered_modules=program.modules + + arg_pruner = ast_transforms.ArgumentPruner(functions_and_subroutines_builder.nodes) + arg_pruner.visit(program) + program.modules=[] for i in parse_order: for j in unordered_modules: @@ -3417,15 +3463,20 @@ def create_sdfg_from_fortran_file_with_options(source_string: str, source_list, continue print(f"Building SDFG {j.name.name}") startpoint = j - ast2sdfg = AST_translator(program, __file__,multiple_sdfgs=False,startpoint=startpoint,sdfg_path=icon_sdfgs_dir, normalize_offsets=normalize_offsets) + ast2sdfg = AST_translator( + program, __file__, + multiple_sdfgs=False, + startpoint=startpoint, + sdfg_path=icon_sdfgs_dir, + #toplevel_subroutine_arg_names=arg_pruner.visited_funcs[toplevel_subroutine], + subroutine_used_names=arg_pruner.used_in_all_functions, + normalize_offsets=normalize_offsets + ) sdfg = SDFG(j.name.name) ast2sdfg.actual_offsets_per_sdfg[sdfg]={} ast2sdfg.top_level = program ast2sdfg.globalsdfg = sdfg - ast2sdfg.translate(program, sdfg) - - sdfg.save(os.path.join(icon_sdfgs_dir, sdfg.name + "_raw_before_intrinsics_full.sdfgz"),compress=True) diff --git a/tests/fortran/advanced_optional_args.py b/tests/fortran/advanced_optional_args_test.py similarity index 100% rename from tests/fortran/advanced_optional_args.py rename to tests/fortran/advanced_optional_args_test.py diff --git a/tests/fortran/array_to_loop_offset.py b/tests/fortran/array_to_loop_offset_test.py similarity index 100% rename from tests/fortran/array_to_loop_offset.py rename to tests/fortran/array_to_loop_offset_test.py diff --git a/tests/fortran/ast_utils_test.py b/tests/fortran/ast_utils_test.py new file mode 100644 index 0000000000..4ab7b87f35 --- /dev/null +++ b/tests/fortran/ast_utils_test.py @@ -0,0 +1,28 @@ +import pytest + +from dace.frontend.fortran.ast_internal_classes import Real_Literal_Node + +from dace.frontend.fortran.ast_utils import TaskletWriter + + +def test_floatlit2string(): + def parse(fl: str) -> float: + t = TaskletWriter([], []) # The parameters won't matter. + return t.floatlit2string(Real_Literal_Node(value=fl)) + + assert parse('1.0') == '1.0' + assert parse('1.') == '1.0' + assert parse('1.e5') == '100000.0' + assert parse('1.d5') == '100000.0' + assert parse('1._kinder') == '1.0' + assert parse('1.e5_kinder') == '100000.0' + with pytest.raises(AssertionError): + parse('1.d5_kinder') + with pytest.raises(AssertionError): + parse('1._kinder_kinder') + with pytest.raises(ValueError, match="could not convert string to float"): + parse('1.2.0') + with pytest.raises(ValueError, match="could not convert string to float"): + parse('1.d0d0') + with pytest.raises(ValueError, match="could not convert string to float"): + parse('foo') diff --git a/tests/fortran/call_extract.py b/tests/fortran/call_extract_test.py similarity index 100% rename from tests/fortran/call_extract.py rename to tests/fortran/call_extract_test.py diff --git a/tests/fortran/fortran_language_test.py b/tests/fortran/fortran_language_test.py index 32ab23714b..9034ba20f3 100644 --- a/tests/fortran/fortran_language_test.py +++ b/tests/fortran/fortran_language_test.py @@ -1,22 +1,8 @@ # Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. -from fparser.common.readfortran import FortranStringReader -from fparser.common.readfortran import FortranFileReader -from fparser.two.parser import ParserFactory -import sys, os import numpy as np -import pytest - -from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic from dace.frontend.fortran import fortran_parser -from fparser.two.symbol_table import SymbolTable -from dace.sdfg import utils as sdutil - -import dace.frontend.fortran.ast_components as ast_components -import dace.frontend.fortran.ast_transforms as ast_transforms -import dace.frontend.fortran.ast_utils as ast_utils -import dace.frontend.fortran.ast_internal_classes as ast_internal_classes def test_fortran_frontend_real_kind_selector(): @@ -24,23 +10,25 @@ def test_fortran_frontend_real_kind_selector(): Tests that the size intrinsics are correctly parsed and translated to DaCe. """ test_string = """ - PROGRAM real_kind_selector_test - implicit none - INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300) - INTEGER, PARAMETER :: JPIM = SELECTED_INT_KIND(9) - REAL(KIND=JPRB) d(4) - CALL real_kind_selector_test_function(d) - end - - SUBROUTINE real_kind_selector_test_function(d) - REAL(KIND=JPRB) d(4) - INTEGER(KIND=JPIM) i - - i=7 - d(2)=5.5+i - - END SUBROUTINE real_kind_selector_test_function - """ +program real_kind_selector_test + implicit none + integer, parameter :: JPRB = selected_real_kind(13, 300) + real(KIND=JPRB) d(4) + call real_kind_selector_test_function(d) +end + +subroutine real_kind_selector_test_function(d) + implicit none + integer, parameter :: JPRB = selected_real_kind(13, 300) + integer, parameter :: JPIM = selected_int_kind(9) + real(KIND=JPRB) d(4) + integer(KIND=JPIM) i + + i = 7 + d(2) = 5.5 + i + +end subroutine real_kind_selector_test_function +""" sdfg = fortran_parser.create_sdfg_from_string(test_string, "real_kind_selector_test") sdfg.simplify(verbose=True) a = np.full([4], 42, order="F", dtype=np.float64) @@ -129,30 +117,30 @@ def test_fortran_frontend_function_statement1(): """ Tests that the function statement are correctly removed recursively. """ - - test_string = """ - PROGRAM function_statement1_test - implicit none - double precision d(3,4,5) - CALL function_statement1_test_function(d) - end - SUBROUTINE function_statement1_test_function(d) - double precision d(3,4,5) - double precision :: PTARE,RTT(2),FOEDELTA,FOELDCP - double precision :: RALVDCP(2),RALSDCP(2),RES - - FOEDELTA (PTARE) = MAX (0.0,SIGN(1.0,PTARE-RTT(1))) - FOELDCP ( PTARE ) = FOEDELTA(PTARE)*RALVDCP(1) + (1.0-FOEDELTA(PTARE))*RALSDCP(1) - - RTT(1)=4.5 - RALVDCP(1)=4.9 - RALSDCP(1)=5.1 - d(1,1,1)=FOELDCP(3.0) - RES=FOELDCP(3.0) - d(1,1,2)=RES - END SUBROUTINE function_statement1_test_function - """ + test_string = """ +program function_statement1_test + implicit none + double precision d(3, 4, 5) + call function_statement1_test_function(d) +end + +subroutine function_statement1_test_function(d) + double precision d(3, 4, 5) + double precision :: PTARE, RTT(2), FOEDELTA, FOELDCP + double precision :: RALVDCP(2), RALSDCP(2), RES + + FOEDELTA(PTARE) = max(0.0, sign(1.d0, PTARE - RTT(1))) + FOELDCP(PTARE) = FOEDELTA(PTARE)*RALVDCP(1) + (1.0 - FOEDELTA(PTARE))*RALSDCP(1) + + RTT(1) = 4.5 + RALVDCP(1) = 4.9 + RALSDCP(1) = 5.1 + d(1, 1, 1) = FOELDCP(3.d0) + RES = FOELDCP(3.d0) + d(1, 1, 2) = RES +end subroutine function_statement1_test_function +""" sdfg = fortran_parser.create_sdfg_from_string(test_string, "function_statement1_test") sdfg.simplify(verbose=True) d = np.full([3, 4, 5], 42, order="F", dtype=np.float64) diff --git a/tests/fortran/future/fortran_class.py b/tests/fortran/future/fortran_class_test.py similarity index 100% rename from tests/fortran/future/fortran_class.py rename to tests/fortran/future/fortran_class_test.py diff --git a/tests/fortran/intrinsic_all.py b/tests/fortran/intrinsic_all_test.py similarity index 100% rename from tests/fortran/intrinsic_all.py rename to tests/fortran/intrinsic_all_test.py diff --git a/tests/fortran/intrinsic_any.py b/tests/fortran/intrinsic_any_test.py similarity index 100% rename from tests/fortran/intrinsic_any.py rename to tests/fortran/intrinsic_any_test.py diff --git a/tests/fortran/intrinsic_basic.py b/tests/fortran/intrinsic_basic_test.py similarity index 100% rename from tests/fortran/intrinsic_basic.py rename to tests/fortran/intrinsic_basic_test.py diff --git a/tests/fortran/intrinsic_blas.py b/tests/fortran/intrinsic_blas_test.py similarity index 100% rename from tests/fortran/intrinsic_blas.py rename to tests/fortran/intrinsic_blas_test.py diff --git a/tests/fortran/intrinsic_count.py b/tests/fortran/intrinsic_count_test.py similarity index 100% rename from tests/fortran/intrinsic_count.py rename to tests/fortran/intrinsic_count_test.py diff --git a/tests/fortran/intrinsic_math.py b/tests/fortran/intrinsic_math_test.py similarity index 100% rename from tests/fortran/intrinsic_math.py rename to tests/fortran/intrinsic_math_test.py diff --git a/tests/fortran/intrinsic_merge.py b/tests/fortran/intrinsic_merge_test.py similarity index 100% rename from tests/fortran/intrinsic_merge.py rename to tests/fortran/intrinsic_merge_test.py diff --git a/tests/fortran/intrinsic_minmaxval.py b/tests/fortran/intrinsic_minmaxval_test.py similarity index 100% rename from tests/fortran/intrinsic_minmaxval.py rename to tests/fortran/intrinsic_minmaxval_test.py diff --git a/tests/fortran/intrinsic_product.py b/tests/fortran/intrinsic_product_test.py similarity index 100% rename from tests/fortran/intrinsic_product.py rename to tests/fortran/intrinsic_product_test.py diff --git a/tests/fortran/intrinsic_sum.py b/tests/fortran/intrinsic_sum_test.py similarity index 100% rename from tests/fortran/intrinsic_sum.py rename to tests/fortran/intrinsic_sum_test.py diff --git a/tests/fortran/long_tasklet.py b/tests/fortran/long_tasklet_test.py similarity index 100% rename from tests/fortran/long_tasklet.py rename to tests/fortran/long_tasklet_test.py diff --git a/tests/fortran/missing_func.py b/tests/fortran/missing_func_test.py similarity index 100% rename from tests/fortran/missing_func.py rename to tests/fortran/missing_func_test.py diff --git a/tests/fortran/non-interactive/fortran_int_init.py b/tests/fortran/non-interactive/fortran_int_init_test.py similarity index 100% rename from tests/fortran/non-interactive/fortran_int_init.py rename to tests/fortran/non-interactive/fortran_int_init_test.py diff --git a/tests/fortran/offset_normalizer.py b/tests/fortran/offset_normalizer_test.py similarity index 100% rename from tests/fortran/offset_normalizer.py rename to tests/fortran/offset_normalizer_test.py diff --git a/tests/fortran/optional_args.py b/tests/fortran/optional_args_test.py similarity index 100% rename from tests/fortran/optional_args.py rename to tests/fortran/optional_args_test.py diff --git a/tests/fortran/pointer_removal.py b/tests/fortran/pointer_removal_test.py similarity index 100% rename from tests/fortran/pointer_removal.py rename to tests/fortran/pointer_removal_test.py diff --git a/tests/fortran/prune_test.py b/tests/fortran/prune_test.py new file mode 100644 index 0000000000..a1e43e411b --- /dev/null +++ b/tests/fortran/prune_test.py @@ -0,0 +1,147 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +from fparser.common.readfortran import FortranStringReader +from fparser.common.readfortran import FortranFileReader +from fparser.two.parser import ParserFactory +import sys, os +import numpy as np +import pytest + +from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace.frontend.fortran import fortran_parser +from fparser.two.symbol_table import SymbolTable +from dace.sdfg import utils as sdutil +from dace.sdfg.nodes import AccessNode + +import dace.frontend.fortran.ast_components as ast_components +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_utils as ast_utils +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + +def test_fortran_frontend_prune_simple(): + test_string = """ + PROGRAM init_test + implicit none + double precision d(4) + double precision dx(4) + CALL test_function(d, dx) + end + + SUBROUTINE test_function(d, dx) + + double precision dx(4) + double precision d(4) + + d(2) = d(1) + 3.14 + + END SUBROUTINE test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "init_test", False) + print('a', flush=True) + sdfg.simplify(verbose=True) + a = np.full([4], 42, order="F", dtype=np.float64) + print(a) + sdfg(d=a,outside_init=0) + print(a) + assert (a[0] == 42) + assert (a[1] == 42 + 3.14) + assert (a[2] == 42) + + +def test_fortran_frontend_prune_complex(): + # Test we can detect recursively unused arguments + # Test we can change names and it does not affect pruning + # Test we can use two different ignored args in the same function + test_string = """ + PROGRAM init_test + implicit none + double precision d(4) + double precision dx(1) + double precision dy(1) + CALL test_function(dy, d, dx) + end + + SUBROUTINE test_function(dy, d, dx) + + double precision dx(4) + double precision d(1) + double precision dy(1) + + d(2) = d(1) + 3.14 + + CALL test_function_another(d, dx) + CALL test_function_another(d, dy) + + END SUBROUTINE test_function + + SUBROUTINE test_function_another(dx, dz) + + double precision dx(4) + double precision dz(1) + + dx(3) = dx(3) - 1 + + END SUBROUTINE test_function_another + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "init_test", False) + print('a', flush=True) + sdfg.simplify(verbose=True) + a = np.full([4], 42, order="F", dtype=np.float64) + print(a) + sdfg(d=a,outside_init=0) + print(a) + assert (a[0] == 42) + assert (a[1] == 42 + 3.14) + assert (a[2] == 40) + +def test_fortran_frontend_prune_actual_param(): + # Test we do not remove a variable that is passed along + # but not used in the function. + test_string = """ + PROGRAM init_test + implicit none + double precision d(4) + double precision dx(1) + double precision dy(1) + CALL test_function(dy, d, dx) + end + + SUBROUTINE test_function(dy, d, dx) + + double precision d(4) + double precision dx(1) + double precision dy(1) + + CALL test_function_another(d, dx) + CALL test_function_another(d, dy) + + END SUBROUTINE test_function + + SUBROUTINE test_function_another(dx, dz) + + double precision dx(4) + double precision dz(1) + + dx(3) = dx(3) - 1 + + END SUBROUTINE test_function_another + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "init_test", False) + print('a', flush=True) + sdfg.simplify(verbose=True) + a = np.full([4], 42, order="F", dtype=np.float64) + print(a) + sdfg(d=a,outside_init=0) + print(a) + assert (a[0] == 42) + assert (a[1] == 42) + assert (a[2] == 40) + +if __name__ == "__main__": + + test_fortran_frontend_prune_simple() + test_fortran_frontend_prune_complex() + test_fortran_frontend_prune_actual_param() diff --git a/tests/fortran/scope_arrays.py b/tests/fortran/scope_arrays_test.py similarity index 96% rename from tests/fortran/scope_arrays.py rename to tests/fortran/scope_arrays_test.py index 0eb0cf44b2..5dd5b806a8 100644 --- a/tests/fortran/scope_arrays.py +++ b/tests/fortran/scope_arrays_test.py @@ -30,7 +30,7 @@ def test_fortran_frontend_parent(): ast, functions = fortran_parser.create_ast_from_string(test_string, "array_access_test") ast_transforms.ParentScopeAssigner().visit(ast) - visitor = ast_transforms.ScopeVarsDeclarations() + visitor = ast_transforms.ScopeVarsDeclarations(ast) visitor.visit(ast) for var in ['d', 'arr', 'arr3']: diff --git a/tests/fortran/sum_to_loop_offset.py b/tests/fortran/sum_to_loop_offset_test.py similarity index 100% rename from tests/fortran/sum_to_loop_offset.py rename to tests/fortran/sum_to_loop_offset_test.py