Skip to content

Commit

Permalink
class flattening and many smaller fixes and features
Browse files Browse the repository at this point in the history
  • Loading branch information
acalotoiu authored and acalotoiu committed Nov 11, 2024
2 parents 412385e + 4f8fd47 commit d0025e3
Show file tree
Hide file tree
Showing 29 changed files with 502 additions and 158 deletions.
79 changes: 37 additions & 42 deletions dace/frontend/fortran/ast_components.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
from fparser.two import Fortran2008 as f08
from typing import Any, List, Optional, Type, TypeVar, Union, overload, TYPE_CHECKING

from fparser.two import Fortran2003 as f03
from fparser.two import Fortran2008 as f08
from fparser.two import symbol_table

import copy
from dace.frontend.fortran import ast_internal_classes
from dace.frontend.fortran.ast_internal_classes import FNode, Name_Node
from typing import Any, List, Optional, Tuple, Type, TypeVar, Union, overload, TYPE_CHECKING

from dace.frontend.fortran.ast_internal_classes import Name_Node

if TYPE_CHECKING:
from dace.frontend.fortran.intrinsics import FortranIntrinsics
Expand Down Expand Up @@ -206,7 +205,6 @@ def __init__(self, ast: FASTNode, tables: symbol_table.SymbolTable):
"Assignment_Stmt": self.assignment_stmt,
"Pointer_Assignment_Stmt": self.pointer_assignment_stmt,
"Where_Stmt": self.where_stmt,
"Forall_Stmt": self.forall_stmt,
"Where_Construct": self.where_construct,
"Where_Construct_Stmt": self.where_construct_stmt,
"Masked_Elsewhere_Stmt": self.masked_elsewhere_stmt,
Expand Down Expand Up @@ -239,7 +237,6 @@ def __init__(self, ast: FASTNode, tables: symbol_table.SymbolTable):
"End_Interface_Stmt": self.end_interface_stmt,
"Generic_Spec": self.generic_spec,
"Name": self.name,
"Rename": self.rename,
"Type_Name": self.type_name,
"Specification_Part": self.specification_part,
"Intrinsic_Type_Spec": self.intrinsic_type_spec,
Expand Down Expand Up @@ -317,10 +314,10 @@ def data_pointer_object(self, node: FASTNode):
return ast_internal_classes.Data_Ref_Node(parent_ref=children[0], part_ref=children[2],type="VOID")
else:
raise NotImplementedError("Data pointer object not supported yet")


def add_name_list_for_module(self, module: str, name_list: List[str]):
self.name_list[module] = name_list
self.name_list[module] = name_list

def create_children(self, node: FASTNode):
return [self.create_ast(child)
Expand Down Expand Up @@ -353,7 +350,7 @@ def create_ast(self, node=None):
if self.unsupported_fortran_syntax.get(self.current_ast) is None:
self.unsupported_fortran_syntax[self.current_ast] = []
if type(node).__name__ not in self.unsupported_fortran_syntax[self.current_ast]:
if type(node).__name__ not in self.unsupported_fortran_syntax[self.current_ast]:
if type(node).__name__ not in self.unsupported_fortran_syntax[self.current_ast]:
self.unsupported_fortran_syntax[self.current_ast].append(type(node).__name__)
for i in node.children:
self.create_ast(i)
Expand All @@ -369,7 +366,7 @@ def suffix(self, node: FASTNode):
children = self.create_children(node)
name = children[0]
return ast_internal_classes.Suffix_Node(name=name)

def data_ref(self, node: FASTNode):
children = self.create_children(node)
idx=len(children)-1
Expand All @@ -379,16 +376,16 @@ def data_ref(self, node: FASTNode):
#parent.isStructMember=True
idx=idx-1
current=ast_internal_classes.Data_Ref_Node(parent_ref=parent, part_ref=part_ref,type="VOID")

while idx>0:
parent = children[idx-1]
current = ast_internal_classes.Data_Ref_Node(parent_ref=parent, part_ref=current,type="VOID")
current = ast_internal_classes.Data_Ref_Node(parent_ref=parent, part_ref=current,type="VOID")
idx=idx-1
return current

def end_type_stmt(self, node: FASTNode):
return None

def access_stmt(self, node: FASTNode):
return None

Expand Down Expand Up @@ -423,7 +420,7 @@ def derived_type_def(self, node: FASTNode):
new_placeholder[k.replace("__f2dace_A","__f2dace_SA")]=self.placeholders[k]
else:
new_placeholder[k]=self.placeholders[k]
self.placeholders=new_placeholder
self.placeholders=new_placeholder
for k,v in self.placeholders_offsets.items():
if "__f2dace_OA" in k:
new_placeholder_offsets[k.replace("__f2dace_OA","__f2dace_SOA")]=self.placeholders_offsets[k]
Expand Down Expand Up @@ -498,8 +495,8 @@ def program_stmt(self, node: FASTNode):
return ast_internal_classes.Program_Stmt_Node(name=name, line_number=node.item.span)

def subroutine_subprogram(self, node: FASTNode):


children = self.create_children(node)

name = get_child(children, ast_internal_classes.Subroutine_Stmt_Node)
Expand Down Expand Up @@ -548,7 +545,7 @@ def function_subprogram(self, node: FASTNode):
name = get_child(children, ast_internal_classes.Function_Stmt_Node)
specification_part = get_child(children, ast_internal_classes.Specification_Part_Node)
execution_part = get_child(children, ast_internal_classes.Execution_Part_Node)

return_type = name.return_type
return ast_internal_classes.Function_Subprogram_Node(
name=name.name,
Expand All @@ -561,7 +558,7 @@ def function_subprogram(self, node: FASTNode):
elemental=name.elemental,
)

def function_stmt(self, node: FASTNode):
def function_stmt(self, node: FASTNode):
children = self.create_children(node)
name = get_child(children, ast_internal_classes.Name_Node)
args = get_child(children, ast_internal_classes.Arg_List_Node)
Expand All @@ -574,7 +571,7 @@ def function_stmt(self, node: FASTNode):
if args==None:
ret_args = []
else:
ret_args = args.args
ret_args = args.args
return ast_internal_classes.Function_Stmt_Node(name=name, args=ret_args,return_type=ret, line_number=node.item.span,ret=ret,elemental=elemental)

def subroutine_stmt(self, node: FASTNode):
Expand Down Expand Up @@ -643,16 +640,16 @@ def structure_constructor(self, node: FASTNode):
if args==None:
ret_args = []
else:
ret_args = args.args
ret_args = args.args
return ast_internal_classes.Structure_Constructor_Node(name=name, args=ret_args, type=None,line_number=line)


def intrinsic_function_reference(self, node: FASTNode):
children = self.create_children(node)
line = get_line(node)
name = get_child(children, ast_internal_classes.Name_Node)
args = get_child(children, ast_internal_classes.Arg_List_Node)

if name is None:
return Name_Node(name="Error! "+node.children[0].string,type='VOID')
node = self.intrinsic_handler.replace_function_reference(name, args, line,self.symbols)
Expand Down Expand Up @@ -701,7 +698,7 @@ def module(self, node: FASTNode):
specification_part = get_child(children, ast_internal_classes.Specification_Part_Node)

function_definitions = [i for i in children if isinstance(i, ast_internal_classes.Function_Subprogram_Node)]

subroutine_definitions = [i for i in children if isinstance(i, ast_internal_classes.Subroutine_Subprogram_Node)]

interface_blocks = {}
Expand Down Expand Up @@ -917,7 +914,7 @@ def type_declaration_stmt(self, node: FASTNode):
basetype = "INTEGER"
else:
raise TypeError("Derived type not supported")

#if derived_type:
# raise TypeError("Derived type not supported")
if not derived_type:
Expand Down Expand Up @@ -981,7 +978,7 @@ def type_declaration_stmt(self, node: FASTNode):
attr_offset = [attr_offset] * len(names)
else:
attr_size, assumed_vardecls,attr_offset = self.assumed_array_shape(dimension_spec[0], names, node.item.span)

if attr_size is None:
raise RuntimeError("Couldn't parse the dimension attribute specification!")

Expand Down Expand Up @@ -1051,8 +1048,8 @@ def type_declaration_stmt(self, node: FASTNode):
raw_init = comp_init[0].children[1]
init = self.create_ast(raw_init)
#if size_later:
# size.append(len(init))
if testtype!="INTEGER": symbol=False
# size.append(len(init))
if testtype!="INTEGER": symbol=False
if symbol == False:

if attr_size is None:
Expand Down Expand Up @@ -1227,10 +1224,10 @@ def level_2_expr(self, node: FASTNode):
children = self.create_children(node)
line = get_line(node)
if children[1]=="==": type="LOGICAL"
else:
else:
type="VOID"
if hasattr(children[0],"type"):
type=children[0].type
type=children[0].type
if len(children) == 3:
return ast_internal_classes.BinOp_Node(lval=children[0], op=children[1], rval=children[2], line_number=line, type=type)
else:
Expand Down Expand Up @@ -1330,7 +1327,7 @@ def if_construct(self, node: FASTNode):
if else_mode:
body_else.append(i)
else:

body.append(i)
currentIf.body = ast_internal_classes.Execution_Part_Node(execution=body)
currentIf.body_else = ast_internal_classes.Execution_Part_Node(execution=body_else)
Expand Down Expand Up @@ -1527,7 +1524,7 @@ def call_stmt(self, node: FASTNode):
#if node.item is None:
# line_number = 42
#else:
# line_number = node.item.span
# line_number = node.item.span
return ast_internal_classes.Call_Expr_Node(name=name, args=ret_args, type="VOID", line_number=line_number)

def return_stmt(self, node: FASTNode):
Expand Down Expand Up @@ -1607,8 +1604,6 @@ def block_nonlabel_do_construct(self, node: FASTNode):
body=ast_internal_classes.Execution_Part_Node(execution=body),
line_number=do.line_number)



def subscript_triplet(self, node: FASTNode):
if node.string == ":":
return ast_internal_classes.ParDecl_Node(type="ALL")
Expand All @@ -1620,7 +1615,7 @@ def section_subscript_list(self, node: FASTNode):
return ast_internal_classes.Section_Subscript_List_Node(list=children)

def specification_part(self, node: FASTNode):

#TODO this can be refactored to consider more fortran declaration options. Currently limited to what is encountered in code.
others = [self.create_ast(i) for i in node.children if not isinstance(i, f08.Type_Declaration_Stmt)]

Expand Down Expand Up @@ -1654,7 +1649,7 @@ def specification_part(self, node: FASTNode):
for i in decls:
names_filtered.extend(ii.name for ii in i.vardecl if j.name == ii.name)
decl_filtered = []

for i in decls:
if i is None:
continue
Expand All @@ -1681,7 +1676,7 @@ def int_literal_constant(self, node: FASTNode):
x=value.split("_")
value=x[0]
return ast_internal_classes.Int_Literal_Node(value=value,type="INTEGER")

def hex_constant(self, node: FASTNode):
return ast_internal_classes.Int_Literal_Node(value=str(int(node.string[2:-1],16)),type="INTEGER")

Expand All @@ -1703,8 +1698,8 @@ def real_literal_constant(self, node: FASTNode):
if (x[1]=="wp"):
return ast_internal_classes.Double_Literal_Node(value=value,type="DOUBLE")
return ast_internal_classes.Real_Literal_Node(value=value,type="REAL")


def char_literal_constant(self, node: FASTNode):
return ast_internal_classes.Char_Literal_Node(value=node.string,type="CHAR")

Expand All @@ -1713,7 +1708,7 @@ def actual_arg_spec(self, node: FASTNode):
if len(children) != 2:
raise ValueError("Actual arg spec must have two children")
return ast_internal_classes.Actual_Arg_Spec_Node(arg_name=children[0], arg=children[1],type="VOID")

def actual_arg_spec_list(self, node: FASTNode):
children = self.create_children(node)
return ast_internal_classes.Arg_List_Node(args=children)
Expand All @@ -1723,8 +1718,8 @@ def initialization(self, node: FASTNode):

def name(self, node: FASTNode):
return ast_internal_classes.Name_Node(name=node.string.lower(),type="VOID")


def rename(self, node: FASTNode):
return ast_internal_classes.Rename_Node(oldname=node.children[2].string.lower(),newname=node.children[1].string.lower())

Expand Down
Loading

0 comments on commit d0025e3

Please sign in to comment.