diff --git a/dace/frontend/fortran/ast_components.py b/dace/frontend/fortran/ast_components.py index 8de6024599..1537f14e47 100644 --- a/dace/frontend/fortran/ast_components.py +++ b/dace/frontend/fortran/ast_components.py @@ -291,11 +291,19 @@ def __init__(self, ast: FASTNode, tables: symbol_table.SymbolTable): "End_Type_Stmt": self.end_type_stmt, "Data_Ref": self.data_ref, "Cycle_Stmt": self.cycle_stmt, + "Deferred_Shape_Spec": self.deferred_shape_spec, + "Deferred_Shape_Spec_List": self.deferred_shape_spec_list, + "Component_Initialization": self.component_initialization, + "Case_Selector": self.case_selector, + "Case_Value_Range_List": self.case_value_range_list, + "Procedure_Designator": self.procedure_designator, + "Specific_Binding": self.specific_binding, #"Component_Decl_List": self.component_decl_list, #"Component_Decl": self.component_decl, } self.type_arbitrary_array_variable_count = 0 + def fortran_intrinsics(self) -> "FortranIntrinsics": return self.intrinsic_handler @@ -323,6 +331,7 @@ def create_children(self, node: FASTNode): def cycle_stmt(self, node: FASTNode): line = get_line(node) return ast_internal_classes.Continue_Node( line_number=line) + def create_ast(self, node=None): """ @@ -382,11 +391,27 @@ def end_type_stmt(self, node: FASTNode): def access_stmt(self, node: FASTNode): return None + + def deferred_shape_spec(self, node: FASTNode): + return ast_internal_classes.Defer_Shape_Node() + + def deferred_shape_spec_list(self, node: FASTNode): + children = self.create_children(node) + return children + + def component_initialization(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Component_Initialization_Node(init=children[1]) + + def procedure_designator(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Procedure_Separator_Node(parent_ref=children[0],part_ref=children[2]) def derived_type_def(self, node: FASTNode): children = self.create_children(node) name = children[0].name component_part = get_child(children, ast_internal_classes.Component_Part_Node) + procedure_part = get_child(children, ast_internal_classes.Bound_Procedures_Node) from dace.frontend.fortran.ast_transforms import PartialRenameVar if component_part is not None: component_part=PartialRenameVar(oldname="__f2dace_A", newname="__f2dace_SA").visit(component_part) @@ -405,7 +430,7 @@ def derived_type_def(self, node: FASTNode): else: new_placeholder_offsets[k]=self.placeholders_offsets[k] self.placeholders_offsets=new_placeholder_offsets - return ast_internal_classes.Derived_Type_Def_Node(name=name, component_part=component_part) + return ast_internal_classes.Derived_Type_Def_Node(name=name, component_part=component_part, procedure_part=procedure_part) def derived_type_stmt(self, node: FASTNode): children = self.create_children(node) @@ -432,13 +457,13 @@ def component_decl(self, node: FASTNode): return ast_internal_classes.Component_Decl_Node(name=name) def write_stmt(self, node: FASTNode): - children=[] - if node.children[0] is not None: - children = self.create_children(node.children[0]) - if node.children[1] is not None: - children = self.create_children(node.children[1]) + #children=[] + #if node.children[0] is not None: + # children = self.create_children(node.children[0]) + #if node.children[1] is not None: + # children = self.create_children(node.children[1]) line = get_line(node) - return ast_internal_classes.Write_Stmt_Node(args=children, line_number=line) + return ast_internal_classes.Write_Stmt_Node(args=node.string, line_number=line) def program(self, node: FASTNode): children = self.create_children(node) @@ -859,8 +884,8 @@ def type_declaration_stmt(self, node: FASTNode): elif isinstance(type_of_node, f03.Declaration_Type_Spec): if type_of_node.items[0].lower() == "class": basetype = "CLASS" - - derived_type = False + basetype = type_of_node.items[1].string + derived_type = True else: derived_type = True basetype = type_of_node.items[1].string @@ -886,6 +911,8 @@ def type_declaration_stmt(self, node: FASTNode): if self.symbols[kind].value == "8": basetype = "REAL8" elif basetype == "INTEGER": + while hasattr(self.symbols[kind], "name"): + kind = self.symbols[kind].name if self.symbols[kind].value == "4": basetype = "INTEGER" else: @@ -1018,6 +1045,11 @@ def type_declaration_stmt(self, node: FASTNode): raise ValueError("Initialization must have an expression") raw_init = initialization.children[1] init = self.create_ast(raw_init) + else: + comp_init = get_children(var, "Component_Initialization") + if len(comp_init) == 1: + raw_init = comp_init[0].children[1] + init = self.create_ast(raw_init) #if size_later: # size.append(len(init)) if testtype!="INTEGER": symbol=False @@ -1323,13 +1355,88 @@ def end_if_stmt(self, node: FASTNode): return node def case_construct(self, node: FASTNode): - return Name_Node(name="Error!") + children = self.create_children(node) + cond_start = children[0] + cond_end= children[1] + body = [] + body_else = [] + else_mode = False + line = get_line(node) + if line is None: + line = "Unknown:TODO" + cond=ast_internal_classes.BinOp_Node(op=cond_end.op[0],lval=cond_start,rval=cond_end.cond[0],line_number=line) + for j in range(1,len(cond_end.op)): + cond_add=ast_internal_classes.BinOp_Node(op=cond_end.op[j],lval=cond_start,rval=cond_end.cond[j],line_number=line) + cond=ast_internal_classes.BinOp_Node(op=".AND.",lval=cond,rval=cond_add,line_number=line) + + toplevelIf = ast_internal_classes.If_Stmt_Node(cond=cond, line_number=line) + currentIf = toplevelIf + for i in children[2:-1]: + if i is None: + continue + if isinstance(i, ast_internal_classes.Case_Cond_Node): + cond=ast_internal_classes.BinOp_Node(op=i.op[0],lval=cond_start,rval=i.cond[0],line_number=line) + for j in range(1,len(i.op)): + cond_add=ast_internal_classes.BinOp_Node(op=i.op[j],lval=cond_start,rval=i.cond[j],line_number=line) + cond=ast_internal_classes.BinOp_Node(op=".AND.",lval=cond,rval=cond_add,line_number=line) + + newif = ast_internal_classes.If_Stmt_Node(cond=cond, line_number=line) + currentIf.body = ast_internal_classes.Execution_Part_Node(execution=body) + currentIf.body_else = ast_internal_classes.Execution_Part_Node(execution=[newif]) + currentIf = newif + body = [] + continue + if isinstance(i, str) and i=="__default__": + else_mode = True + continue + if else_mode: + body_else.append(i) + else: + + body.append(i) + currentIf.body = ast_internal_classes.Execution_Part_Node(execution=body) + currentIf.body_else = ast_internal_classes.Execution_Part_Node(execution=body_else) + return toplevelIf def select_case_stmt(self, node: FASTNode): - return Name_Node(name="Error!") + children = self.create_children(node) + if len(children)!=1: + raise ValueError("CASE should have only 1 child") + return children[0] def case_stmt(self, node: FASTNode): - return Name_Node(name="Error!") + children = self.create_children(node) + children=[i for i in children if i is not None] + if len(children)==1: + return children[0] + elif len(children)==0: + return "__default__" + else: + raise ValueError("Can't parse case statement") + + def case_selector(self, node:FASTNode): + children = self.create_children(node) + if len(children)==1: + if children[0] is None: + return None + returns=ast_internal_classes.Case_Cond_Node(op=[],cond=[]) + + for i in children[0]: + returns.op.append(i[0]) + returns.cond.append(i[1]) + return returns + else: + raise ValueError("Can't parse case selector") + + def case_value_range_list(self, node:FASTNode): + children = self.create_children(node) + if len(children)==1: + return [[".EQ.", children[0]]] + if len(children)==2: + return [[".EQ.", children[0]],[".EQ.", children[1]]] + + else: + raise ValueError("Can't parse case range list") def end_select_stmt(self, node: FASTNode): return node @@ -1387,9 +1494,15 @@ def generic_spec(self, node: FASTNode): def procedure_declaration_stmt(self, node: FASTNode): return node + + def specific_binding(self, node: FASTNode): + children = self.create_children(node) + return ast_internal_classes.Specific_Binding_Node(name=children[3], args=children[0:2]+[children[4]]) + def type_bound_procedure_part(self, node: FASTNode): - return node + children = self.create_children(node) + return ast_internal_classes.Bound_Procedures_Node(procedures= children[1:]) def contains_stmt(self, node: FASTNode): return node @@ -1397,11 +1510,19 @@ def contains_stmt(self, node: FASTNode): def call_stmt(self, node: FASTNode): children = self.create_children(node) name = get_child(children, ast_internal_classes.Name_Node) + arg_addition = None + if name is None: + proc_ref = get_child(children, ast_internal_classes.Procedure_Separator_Node) + name = proc_ref.part_ref + arg_addition = proc_ref.parent_ref + args = get_child(children, ast_internal_classes.Arg_List_Node) if args==None: ret_args = [] else: ret_args = args.args + if arg_addition is not None: + ret_args.insert(0,arg_addition) line_number = get_line(node) #if node.item is None: # line_number = 42 diff --git a/dace/frontend/fortran/ast_internal_classes.py b/dace/frontend/fortran/ast_internal_classes.py index 8c8345969e..1f00cfefda 100644 --- a/dace/frontend/fortran/ast_internal_classes.py +++ b/dace/frontend/fortran/ast_internal_classes.py @@ -330,7 +330,7 @@ class Derived_Type_Stmt_Node(FNode): class Derived_Type_Def_Node(FNode): _attributes = ('name', ) - _fields = ('component_part', ) + _fields = ('component_part','procedure_part' ) class Component_Part_Node(FNode): @@ -395,11 +395,33 @@ class If_Stmt_Node(FNode): 'body_else', ) +class Defer_Shape_Node(FNode): + _attributes = () + _fields = () + +class Component_Initialization_Node(FNode): + _attributes = () + _fields = ('init') + +class Case_Cond_Node(FNode): + _fields = ('cond', 'op') + _attributes = () class Else_Separator_Node(FNode): _attributes = () _fields = () +class Procedure_Separator_Node(FNode): + _attributes = () + _fields = ('parent_ref', 'part_ref') + +class Bound_Procedures_Node(FNode): + _attributes = () + _fields = ('procedures') + +class Specific_Binding_Node(FNode): + _attributes = () + _fields = ('name', 'args') class Parenthesis_Expr_Node(FNode): _attributes = () diff --git a/dace/frontend/fortran/ast_transforms.py b/dace/frontend/fortran/ast_transforms.py index 83c75d596e..317b18d481 100644 --- a/dace/frontend/fortran/ast_transforms.py +++ b/dace/frontend/fortran/ast_transforms.py @@ -168,6 +168,63 @@ def generic_visit(self, node: ast_internal_classes.FNode): return node +class Flatten_Classes(NodeTransformer): + + def __init__(self, classes: List[ast_internal_classes.Derived_Type_Def_Node]): + self.classes=classes + self.current_class=None + + def visit_Derived_Type_Def_Node(self, node: ast_internal_classes.Derived_Type_Def_Node): + self.current_class=node + return_node=self.generic_visit(node) + #self.current_class=None + return return_node + + def visit_Module_Node(self, node: ast_internal_classes.Module_Node): + self.current_class=None + return self.generic_visit(node) + + def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine_Subprogram_Node): + new_node=self.generic_visit(node) + if self.current_class is not None: + for i in self.classes: + if i.is_class is True: + if i.name.name==self.current_class.name.name: + for j in i.procedure_part.procedures: + if j.name.name==node.name.name: + return ast_internal_classes.Subroutine_Subprogram_Node( + name=ast_internal_classes.Name_Node(name=i.name.name+"_"+node.name.name,type=node.type), + args=new_node.args, + specification_part=new_node.specification_part, + execution_part=new_node.execution_part, + mandatory_args_count=new_node.mandatory_args_count, + optional_args_count=new_node.optional_args_count, + elemental=new_node.elemental, + line_number=new_node.line_number) + elif j.args[2] is not None: + if j.args[2].name==node.name.name: + + return ast_internal_classes.Subroutine_Subprogram_Node( + name=ast_internal_classes.Name_Node(name=i.name.name+"_"+j.name.name,type=node.type), + args=new_node.args, + specification_part=new_node.specification_part, + execution_part=new_node.execution_part, + mandatory_args_count=new_node.mandatory_args_count, + optional_args_count=new_node.optional_args_count, + elemental=new_node.elemental, + line_number=new_node.line_number) + return new_node + + def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): + if self.current_class is not None: + for i in self.classes: + if i.is_class is True: + if i.name.name==self.current_class.name.name: + for j in i.procedure_part.procedures: + if j.name.name==node.name.name: + return ast_internal_classes.Call_Expr_Node(name=ast_internal_classes.Name_Node(name=i.name.name+"_"+node.name.name,type=node.type,args=node.args,line_number=node.line_number),args=node.args,type=node.type,line_number=node.line_number) + return self.generic_visit(node) + class FindFunctionAndSubroutines(NodeVisitor): """ Finds all function and subroutine names in the AST @@ -388,8 +445,15 @@ def __init__(self): self.names= [] def visit_Derived_Type_Def_Node(self, node: ast_internal_classes.Derived_Type_Def_Node): - self.structs.append(node) self.names.append(node.name.name) + if node.procedure_part is not None: + if len(node.procedure_part.procedures)>0: + node.is_class=True + self.structs.append(node) + return + node.is_class=False + self.structs.append(node) + class StructDependencyLister(NodeVisitor): def __init__(self, names=None): @@ -610,7 +674,7 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): raise ValueError("Call_Expr_Node name is None") return ast_internal_classes.Char_Literal_Node(value="Error!", type="CHARACTER") - if node.name.name in self.excepted_funcs or node.name in self.funcs.names or node.name.name in self.funcs.iblocks: + if node.name.name in self.excepted_funcs or node.name.name in [i.name for i in self.funcs.names] or node.name.name in self.funcs.iblocks: processed_args = [] for i in node.args: arg = CallToArray(self.funcs).visit(i) diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index f08af80e7e..f92b1858d6 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -312,6 +312,8 @@ def get_dace_type(self, type): if type=="VOID": return ast_utils.fortrantypes2dacetypes["DOUBLE"] raise ValueError("Unknown type " + type) + else: + raise ValueError("Unknown type " + type) def get_name_mapping_in_context(self, sdfg: SDFG): """ @@ -625,7 +627,8 @@ def allocate2sdfg(self, node: ast_internal_classes.Allocate_Stmt_Node, sdfg: SDF def write2sdfg(self, node: ast_internal_classes.Write_Stmt_Node, sdfg: SDFG): #TODO implement - raise NotImplementedError("Fortran write statements are not implemented yet") + print("Uh oh") + #raise NotImplementedError("Fortran write statements are not implemented yet") def ifstmt2sdfg(self, node: ast_internal_classes.If_Stmt_Node, sdfg: SDFG): """ @@ -2785,10 +2788,10 @@ def recursive_ast_improver(ast, if i not in exclude_list: exclude_list.append(i) #if i not in dep_graph.nodes: - dep_graph.add_node(i,info_list=fandsl) + dep_graph.add_node(i.lower(),info_list=fandsl) for i in used_modules: if i not in dep_graph.nodes: - dep_graph.add_node(i) + dep_graph.add_node(i.lower()) weight = None if i in objects_in_modules: weight=[] @@ -2796,7 +2799,7 @@ def recursive_ast_improver(ast, for j in objects_in_modules[i].children: weight.append(j) - dep_graph.add_edge(parent_module, i, obj_list=weight) + dep_graph.add_edge(parent_module.lower(), i.lower(), obj_list=weight) #print("It's turtles all the way down: ", len(exclude_list)) modules_to_parse = [] @@ -2807,15 +2810,16 @@ def recursive_ast_improver(ast, added_modules = [] for i in modules_to_parse: found = False - name=i + name=i.lower() if i=="mo_restart_nml_and_att": name="mo_restart_nmls_and_atts" - if i=="yomhook": + if name=="yomhook": name="yomhook_dummy" for j in source_list: if name in j: fname = j.split("/") fname = fname[len(fname) - 1] + fname = fname.lower() if fname == name + ".f90" or fname == name + ".F90": found = True next_file = j @@ -2829,10 +2833,12 @@ def recursive_ast_improver(ast, continue if isinstance(source_list,dict): reader = fsr(source_list[next_file]) + next_ast = parser(reader) else: next_reader = ffr(file_candidate=next_file, include_dirs=include_list, source_only=source_list) + next_ast = parser(next_reader) next_ast = recursive_ast_improver(next_ast, @@ -2864,6 +2870,7 @@ def create_sdfg_from_fortran_file_with_options(source_string: str, source_list, """ parser = pf().create(std="f2008") reader = ffr(file_candidate=source_string, include_dirs=include_list, source_only=source_list) + ast = parser(reader) exclude_list = [] missing_modules = [] @@ -2952,42 +2959,44 @@ def create_sdfg_from_fortran_file_with_options(source_string: str, source_list, if deps is None: continue for k in deps: - if k.string not in parse_list[i]: - parse_list[i].append(k.string) + if k.string.lower() not in parse_list[i]: + parse_list[i].append(k.string.lower()) if res is not None: for jj in parse_list[i]: - if jj in res.list_of_functions: - if jj not in fands_list: - fands_list.append(jj) - if jj in res.list_of_subroutines: - if jj not in fands_list: - fands_list.append(jj) - if jj in res.list_of_types: - if jj not in type_list: - type_list.append(jj) + if jj.lower() in res.list_of_functions: + if jj.lower() not in fands_list: + fands_list.append(jj.lower()) + if jj.lower() in res.list_of_subroutines: + if jj.lower() not in fands_list: + fands_list.append(jj.lower()) + if jj.lower() in res.list_of_types: + if jj.lower() not in type_list: + type_list.append(jj.lower()) print("Module " + i + " used names: " + str(parse_list[i])) if len(fands_list)>0: print("Module " + i + " used fands: " + str(fands_list)) print("ACtually used: "+str(actually_used_in_module[i])) for j in actually_used_in_module[i]: if res is not None: - if j in res.list_of_functions: + if j.lower() in res.list_of_functions: - if j not in fands_list: - fands_list.append(j) + if j.lower() not in fands_list: + fands_list.append(j.lower()) - if j in res.list_of_subroutines: - if j not in fands_list: - fands_list.append(j) - if j in res.list_of_types: - if j not in type_list: - type_list.append(j) + if j.lower() in res.list_of_subroutines: + if j.lower() not in fands_list: + fands_list.append(j.lower()) + if j.lower() in res.list_of_types: + if j.lower() not in type_list: + type_list.append(j.lower()) what_to_parse_list[i]=fands_list type_to_parse_list[i]=type_list + if len(parse_order)==0: + raise ValueError("No top-level function found") top_level_ast = parse_order.pop() changes=True new_children=[] @@ -3015,9 +3024,9 @@ def create_sdfg_from_fortran_file_with_options(source_string: str, source_list, for i in ast.children: - if i.children[0].children[1].string not in parse_order and i.children[0].children[1].string!=top_level_ast: + if i.children[0].children[1].string.lower() not in parse_order and i.children[0].children[1].string.lower()!=top_level_ast: print("Module " + i.children[0].children[1].string + " not needing parsing") - elif i.children[0].children[1].string==top_level_ast: + elif i.children[0].children[1].string.lower()==top_level_ast: new_children.append(i) else: types=[] @@ -3031,9 +3040,9 @@ def create_sdfg_from_fortran_file_with_options(source_string: str, source_list, entity_decls=[] for k in j.children[2].children: if k.__class__.__name__=="Entity_Decl": - if k.children[0].string in actually_used_in_module[i.children[0].children[1].string]: + if k.children[0].string in actually_used_in_module[i.children[0].children[1].string.lower()]: entity_decls.append(k) - elif rename_dict[i.children[0].children[1].string].get(k.children[0].string) in actually_used_in_module[i.children[0].children[1].string]: + elif rename_dict[i.children[0].children[1].string.lower()].get(k.children[0].string) in actually_used_in_module[i.children[0].children[1].string.lower()]: entity_decls.append(k) if entity_decls==[]: continue @@ -3046,7 +3055,7 @@ def create_sdfg_from_fortran_file_with_options(source_string: str, source_list, j.children[2].children.append(k) new_spec_children.append(j) elif j.__class__.__name__=="Derived_Type_Def": - if j.children[0].children[1].string in type_to_parse_list[i.children[0].children[1].string]: + if j.children[0].children[1].string.lower() in type_to_parse_list[i.children[0].children[1].string.lower()]: new_spec_children.append(j) else: new_spec_children.append(j) @@ -3056,11 +3065,11 @@ def create_sdfg_from_fortran_file_with_options(source_string: str, source_list, if i.children[2].__class__.__name__=="End_Module_Stmt": new_children.append(i) continue - if i.children[0].children[1].string!=top_level_ast: + if i.children[0].children[1].string.lower()!=top_level_ast: for j in i.children[2].children: if j.__class__.__name__!="Contains_Stmt": - if j.children[0].children[1].string in what_to_parse_list[i.children[0].children[1].string]: + if j.children[0].children[1].string.lower() in what_to_parse_list[i.children[0].children[1].string.lower()]: subroutinesandfunctions.append(j) i.children[2].children.clear() for j in subroutinesandfunctions: @@ -3192,8 +3201,9 @@ def create_sdfg_from_fortran_file_with_options(source_string: str, source_list, struct_dep_graph.add_node(j) struct_dep_graph.add_edge(name,j,pointing=pointing,point_name=point_name) + program = ast_transforms.Flatten_Classes(structs_lister.structs).visit(program) program.structures = ast_transforms.Structures(structs_lister.structs) - + functions_and_subroutines_builder = ast_transforms.FindFunctionAndSubroutines() functions_and_subroutines_builder.visit(program) listnames=[i.name for i in functions_and_subroutines_builder.names] @@ -3344,6 +3354,49 @@ def create_sdfg_from_fortran_file_with_options(source_string: str, source_list, if j.name.name==top_level_ast: program.modules.append(j) + for j in program.subroutine_definitions: + #if j.name.name!="cloudscouter": + if j.name.name!="tspectralplanck_calc": + #if j.name.name!="rot_vertex_ri" and j.name.name!="cells2verts_scalar_ri" and j.name.name!="get_indices_c" and j.name.name!="get_indices_v" and j.name.name!="get_indices_e" and j.name.name!="velocity_tendencies": + #if j.name.name!="rot_vertex_ri": + #if j.name.name!="velocity_tendencies": + #if j.name.name!="cells2verts_scalar_ri": + #if j.name.name!="get_indices_c": + continue + if j.execution_part is None: + continue + print(f"Building SDFG {j.name.name}") + startpoint = j + ast2sdfg = AST_translator(program, __file__,multiple_sdfgs=False,startpoint=startpoint,sdfg_path=icon_sdfgs_dir, normalize_offsets=normalize_offsets) + sdfg = SDFG(j.name.name) + ast2sdfg.actual_offsets_per_sdfg[sdfg]={} + ast2sdfg.top_level = program + ast2sdfg.globalsdfg = sdfg + + ast2sdfg.translate(program, sdfg) + + + + sdfg.save(os.path.join(icon_sdfgs_dir, sdfg.name + "_raw_before_intrinsics_full.sdfgz"),compress=True) + + sdfg.apply_transformations(IntrinsicSDFGTransformation) + + try: + sdfg.expand_library_nodes() + except: + print("Expansion failed for ", sdfg.name) + continue + + sdfg.validate() + sdfg.save(os.path.join(icon_sdfgs_dir, sdfg.name + "_validated_f.sdfgz"),compress=True) + + sdfg.simplify(verbose=True) + print(f'Saving SDFG {os.path.join(icon_sdfgs_dir, sdfg.name + "_simplified_tr.sdfgz")}') + sdfg.save(os.path.join(icon_sdfgs_dir, sdfg.name + "_simplified_f.sdfgz"),compress=True) + + print(f'Compiling SDFG {os.path.join(icon_sdfgs_dir, sdfg.name + "_simplifiedf.sdfgz")}') + sdfg.compile() + for i in program.modules: for path in source_list: @@ -3353,7 +3406,7 @@ def create_sdfg_from_fortran_file_with_options(source_string: str, source_list, #copyfile(mypath, os.path.join(icon_sources_dir, i.name.name.lower()+".f90")) for j in i.subroutine_definitions: #if j.name.name!="cloudscouter": - if j.name.name!="solve_nh": + if j.name.name!="tspectralplanck_init": #if j.name.name!="rot_vertex_ri" and j.name.name!="cells2verts_scalar_ri" and j.name.name!="get_indices_c" and j.name.name!="get_indices_v" and j.name.name!="get_indices_e" and j.name.name!="velocity_tendencies": #if j.name.name!="rot_vertex_ri": #if j.name.name!="velocity_tendencies": @@ -3375,37 +3428,9 @@ def create_sdfg_from_fortran_file_with_options(source_string: str, source_list, sdfg.save(os.path.join(icon_sdfgs_dir, sdfg.name + "_raw_before_intrinsics_full.sdfgz"),compress=True) - # for sd in sdfg.all_sdfgs_recursive(): - # free_symbols = sd.free_symbols - # for i in ['__f2dace_OA_iblk_d_0_s_4140', '__f2dace_OA_iidx_d_1_s_4138', '__f2dace_OA_iidx_d_2_s_4139', '__f2dace_OA_iblk_d_2_s_4142', '__f2dace_OA_iblk_d_1_s_4141', '__f2dace_OA_iidx_d_0_s_4137','__f2dace_A_iidx_d_1_s_5395', '__f2dace_A_iblk_d_0_s_5397', '__f2dace_A_iidx_d_2_s_5396', '__f2dace_A_iblk_d_1_s_5398', '__f2dace_A_iblk_d_2_s_5399', '__f2dace_A_iidx_d_0_s_5394','__f2dace_A_iidx_d_0_s_4137', '__f2dace_A_iblk_d_2_s_4142', '__f2dace_A_iblk_d_0_s_4140', '__f2dace_A_iidx_d_1_s_4138', '__f2dace_A_iidx_d_2_s_4139', '__f2dace_A_iblk_d_1_s_4141','__f2dace_OA_opt_out2_d_1_s_8111', '__f2dace_A_opt_out2_d_0_s_8110', '__f2dace_OA_opt_out2_d_2_s_8112', '__f2dace_OA_opt_out2_d_0_s_8110', '__f2dace_A_opt_out2_d_2_s_8112', '__f2dace_A_opt_out2_d_1_s_8111','__f2dace_A_opt_out2_d_1_s_8111', '__f2dace_A_ieidx_d_2_s_8121', '__f2dace_OA_opt_out2_d_1_s_8111', '__f2dace_A_ieblk_d_1_s_8123', '__f2dace_A_inidx_d_1_s_8114', '__f2dace_A_inblk_d_2_s_8118', '__f2dace_A_ieidx_d_0_s_8119', '__f2dace_A_opt_out2_d_2_s_8112', '__f2dace_A_ieidx_d_1_s_8120', '__f2dace_OA_opt_out2_d_0_s_8110', '__f2dace_A_inblk_d_0_s_8116', '__f2dace_OA_opt_out2_d_2_s_8112', '__f2dace_A_ieblk_d_2_s_8124', '__f2dace_A_inblk_d_1_s_8117', '__f2dace_A_ieblk_d_0_s_8122', '__f2dace_A_opt_out2_d_0_s_8110', '__f2dace_A_inidx_d_2_s_8115', '__f2dace_A_inidx_d_0_s_8113','__f2dace_A_iidx_d_1_s_8159', '__f2dace_A_iidx_d_0_s_8158', '__f2dace_A_iblk_d_0_s_8161', '__f2dace_A_iblk_d_1_s_8162', '__f2dace_A_iblk_d_2_s_8163', '__f2dace_A_iidx_d_2_s_8160','__f2dace_A_iblk_d_0_s_6698', '__f2dace_A_iblk_d_1_s_6699', '__f2dace_A_iidx_d_2_s_6697', '__f2dace_A_iidx_d_1_s_6696', '__f2dace_A_iblk_d_2_s_6700', '__f2dace_A_iidx_d_0_s_6695','__f2dace_A_incidx_d_2_s_8055', '__f2dace_A_iqblk_d_0_s_8044', '__f2dace_A_incblk_d_1_s_8057', '__f2dace_A_iqblk_d_1_s_8045', '__f2dace_A_iqidx_d_2_s_8043', '__f2dace_A_icidx_d_2_s_8031', '__f2dace_A_ivblk_d_2_s_8052', '__f2dace_A_incidx_d_0_s_8053', '__f2dace_A_icidx_d_0_s_8029', '__f2dace_A_incblk_d_2_s_8058', '__f2dace_A_incblk_d_0_s_8056', '__f2dace_A_ividx_d_2_s_8049', '__f2dace_A_icblk_d_0_s_8032', '__f2dace_A_ieidx_d_0_s_8035', '__f2dace_A_ivblk_d_0_s_8050', '__f2dace_A_ieidx_d_1_s_8036', '__f2dace_A_icidx_d_1_s_8030', '__f2dace_A_incidx_d_1_s_8054', '__f2dace_A_iqidx_d_0_s_8041', '__f2dace_A_icblk_d_2_s_8034', '__f2dace_A_ieblk_d_0_s_8038', '__f2dace_A_ividx_d_0_s_8047', '__f2dace_A_ieblk_d_2_s_8040', '__f2dace_A_ividx_d_1_s_8048', '__f2dace_A_icblk_d_1_s_8033', '__f2dace_A_iqidx_d_1_s_8042', '__f2dace_A_ieblk_d_1_s_8039', '__f2dace_A_iqblk_d_2_s_8046', '__f2dace_A_ieidx_d_2_s_8037', '__f2dace_A_ivblk_d_1_s_8051']: - # #print("I want to remove:", i) - # if(i in sd.symbols): - # sd.symbols.pop(i) - # print("Removed from symbols ",i) - # if sd.parent_nsdfg_node is not None: - # if i in sd.parent_nsdfg_node.symbol_mapping: - # print("Removed from symbol mapping ",i) - # sd.parent_nsdfg_node.symbol_mapping.pop(i) + sdfg.apply_transformations(IntrinsicSDFGTransformation) - # for sd in sdfg.all_sdfgs_recursive(): - # free_symbols = sd.free_symbols - # for i in free_symbols: - # #print("I want to remove:", i) - # if(i in sd.symbols): - # sd.symbols.pop(i) - # #print("Removed from symbols ",i) - # if sd.parent_nsdfg_node is not None: - # if i in sd.parent_nsdfg_node.symbol_mapping: - # #print("Removed from symbol mapping ",i) - # sd.parent_nsdfg_node.symbol_mapping.pop(i) - #try: - - # sdfg.save(os.path.join(icon_sdfgs_dir, sdfg.name + "_raw.sdfg")) - #except: - # print("Intrinsics failed for ", sdfg.name) - # continue - try: sdfg.expand_library_nodes() except: @@ -3414,21 +3439,13 @@ def create_sdfg_from_fortran_file_with_options(source_string: str, source_list, sdfg.validate() sdfg.save(os.path.join(icon_sdfgs_dir, sdfg.name + "_validated_f.sdfgz"),compress=True) - #try: + sdfg.simplify(verbose=True) print(f'Saving SDFG {os.path.join(icon_sdfgs_dir, sdfg.name + "_simplified_tr.sdfgz")}') sdfg.save(os.path.join(icon_sdfgs_dir, sdfg.name + "_simplified_f.sdfgz"),compress=True) - #except Exception as e: - # print("Simplification failed for ", sdfg.name) - # print(e) - # continue - #sdfg.save(os.path.join(icon_sdfgs_dir, sdfg.name + "_simplified.sdfg")) - #try: + print(f'Compiling SDFG {os.path.join(icon_sdfgs_dir, sdfg.name + "_simplifiedf.sdfgz")}') sdfg.compile() - #except Exception as e: - # print("Compilation failed for ", sdfg.name) - # print(e) - # continue + #return sdfg