Skip to content

Commit

Permalink
class flattening and many smaller fixes and features
Browse files Browse the repository at this point in the history
  • Loading branch information
acalotoiu authored and acalotoiu committed Nov 11, 2024
1 parent bcb9399 commit 412385e
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 91 deletions.
147 changes: 134 additions & 13 deletions dace/frontend/fortran/ast_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,19 @@ def __init__(self, ast: FASTNode, tables: symbol_table.SymbolTable):
"End_Type_Stmt": self.end_type_stmt,
"Data_Ref": self.data_ref,
"Cycle_Stmt": self.cycle_stmt,
"Deferred_Shape_Spec": self.deferred_shape_spec,
"Deferred_Shape_Spec_List": self.deferred_shape_spec_list,
"Component_Initialization": self.component_initialization,
"Case_Selector": self.case_selector,
"Case_Value_Range_List": self.case_value_range_list,
"Procedure_Designator": self.procedure_designator,
"Specific_Binding": self.specific_binding,
#"Component_Decl_List": self.component_decl_list,
#"Component_Decl": self.component_decl,
}
self.type_arbitrary_array_variable_count = 0


def fortran_intrinsics(self) -> "FortranIntrinsics":
return self.intrinsic_handler

Expand Down Expand Up @@ -323,6 +331,7 @@ def create_children(self, node: FASTNode):
def cycle_stmt(self, node: FASTNode):
line = get_line(node)
return ast_internal_classes.Continue_Node( line_number=line)


def create_ast(self, node=None):
"""
Expand Down Expand Up @@ -382,11 +391,27 @@ def end_type_stmt(self, node: FASTNode):

def access_stmt(self, node: FASTNode):
return None

def deferred_shape_spec(self, node: FASTNode):
return ast_internal_classes.Defer_Shape_Node()

def deferred_shape_spec_list(self, node: FASTNode):
children = self.create_children(node)
return children

def component_initialization(self, node: FASTNode):
children = self.create_children(node)
return ast_internal_classes.Component_Initialization_Node(init=children[1])

def procedure_designator(self, node: FASTNode):
children = self.create_children(node)
return ast_internal_classes.Procedure_Separator_Node(parent_ref=children[0],part_ref=children[2])

def derived_type_def(self, node: FASTNode):
children = self.create_children(node)
name = children[0].name
component_part = get_child(children, ast_internal_classes.Component_Part_Node)
procedure_part = get_child(children, ast_internal_classes.Bound_Procedures_Node)
from dace.frontend.fortran.ast_transforms import PartialRenameVar
if component_part is not None:
component_part=PartialRenameVar(oldname="__f2dace_A", newname="__f2dace_SA").visit(component_part)
Expand All @@ -405,7 +430,7 @@ def derived_type_def(self, node: FASTNode):
else:
new_placeholder_offsets[k]=self.placeholders_offsets[k]
self.placeholders_offsets=new_placeholder_offsets
return ast_internal_classes.Derived_Type_Def_Node(name=name, component_part=component_part)
return ast_internal_classes.Derived_Type_Def_Node(name=name, component_part=component_part, procedure_part=procedure_part)

def derived_type_stmt(self, node: FASTNode):
children = self.create_children(node)
Expand All @@ -432,13 +457,13 @@ def component_decl(self, node: FASTNode):
return ast_internal_classes.Component_Decl_Node(name=name)

def write_stmt(self, node: FASTNode):
children=[]
if node.children[0] is not None:
children = self.create_children(node.children[0])
if node.children[1] is not None:
children = self.create_children(node.children[1])
#children=[]
#if node.children[0] is not None:
# children = self.create_children(node.children[0])
#if node.children[1] is not None:
# children = self.create_children(node.children[1])
line = get_line(node)
return ast_internal_classes.Write_Stmt_Node(args=children, line_number=line)
return ast_internal_classes.Write_Stmt_Node(args=node.string, line_number=line)

def program(self, node: FASTNode):
children = self.create_children(node)
Expand Down Expand Up @@ -859,8 +884,8 @@ def type_declaration_stmt(self, node: FASTNode):
elif isinstance(type_of_node, f03.Declaration_Type_Spec):
if type_of_node.items[0].lower() == "class":
basetype = "CLASS"

derived_type = False
basetype = type_of_node.items[1].string
derived_type = True
else:
derived_type = True
basetype = type_of_node.items[1].string
Expand All @@ -886,6 +911,8 @@ def type_declaration_stmt(self, node: FASTNode):
if self.symbols[kind].value == "8":
basetype = "REAL8"
elif basetype == "INTEGER":
while hasattr(self.symbols[kind], "name"):
kind = self.symbols[kind].name
if self.symbols[kind].value == "4":
basetype = "INTEGER"
else:
Expand Down Expand Up @@ -1018,6 +1045,11 @@ def type_declaration_stmt(self, node: FASTNode):
raise ValueError("Initialization must have an expression")
raw_init = initialization.children[1]
init = self.create_ast(raw_init)
else:
comp_init = get_children(var, "Component_Initialization")
if len(comp_init) == 1:
raw_init = comp_init[0].children[1]
init = self.create_ast(raw_init)
#if size_later:
# size.append(len(init))
if testtype!="INTEGER": symbol=False
Expand Down Expand Up @@ -1323,13 +1355,88 @@ def end_if_stmt(self, node: FASTNode):
return node

def case_construct(self, node: FASTNode):
return Name_Node(name="Error!")
children = self.create_children(node)
cond_start = children[0]
cond_end= children[1]
body = []
body_else = []
else_mode = False
line = get_line(node)
if line is None:
line = "Unknown:TODO"
cond=ast_internal_classes.BinOp_Node(op=cond_end.op[0],lval=cond_start,rval=cond_end.cond[0],line_number=line)
for j in range(1,len(cond_end.op)):
cond_add=ast_internal_classes.BinOp_Node(op=cond_end.op[j],lval=cond_start,rval=cond_end.cond[j],line_number=line)
cond=ast_internal_classes.BinOp_Node(op=".AND.",lval=cond,rval=cond_add,line_number=line)

toplevelIf = ast_internal_classes.If_Stmt_Node(cond=cond, line_number=line)
currentIf = toplevelIf
for i in children[2:-1]:
if i is None:
continue
if isinstance(i, ast_internal_classes.Case_Cond_Node):
cond=ast_internal_classes.BinOp_Node(op=i.op[0],lval=cond_start,rval=i.cond[0],line_number=line)
for j in range(1,len(i.op)):
cond_add=ast_internal_classes.BinOp_Node(op=i.op[j],lval=cond_start,rval=i.cond[j],line_number=line)
cond=ast_internal_classes.BinOp_Node(op=".AND.",lval=cond,rval=cond_add,line_number=line)

newif = ast_internal_classes.If_Stmt_Node(cond=cond, line_number=line)
currentIf.body = ast_internal_classes.Execution_Part_Node(execution=body)
currentIf.body_else = ast_internal_classes.Execution_Part_Node(execution=[newif])
currentIf = newif
body = []
continue
if isinstance(i, str) and i=="__default__":
else_mode = True
continue
if else_mode:
body_else.append(i)
else:

body.append(i)
currentIf.body = ast_internal_classes.Execution_Part_Node(execution=body)
currentIf.body_else = ast_internal_classes.Execution_Part_Node(execution=body_else)
return toplevelIf

def select_case_stmt(self, node: FASTNode):
return Name_Node(name="Error!")
children = self.create_children(node)
if len(children)!=1:
raise ValueError("CASE should have only 1 child")
return children[0]

def case_stmt(self, node: FASTNode):
return Name_Node(name="Error!")
children = self.create_children(node)
children=[i for i in children if i is not None]
if len(children)==1:
return children[0]
elif len(children)==0:
return "__default__"
else:
raise ValueError("Can't parse case statement")

def case_selector(self, node:FASTNode):
children = self.create_children(node)
if len(children)==1:
if children[0] is None:
return None
returns=ast_internal_classes.Case_Cond_Node(op=[],cond=[])

for i in children[0]:
returns.op.append(i[0])
returns.cond.append(i[1])
return returns
else:
raise ValueError("Can't parse case selector")

def case_value_range_list(self, node:FASTNode):
children = self.create_children(node)
if len(children)==1:
return [[".EQ.", children[0]]]
if len(children)==2:
return [[".EQ.", children[0]],[".EQ.", children[1]]]

else:
raise ValueError("Can't parse case range list")

def end_select_stmt(self, node: FASTNode):
return node
Expand Down Expand Up @@ -1387,21 +1494,35 @@ def generic_spec(self, node: FASTNode):

def procedure_declaration_stmt(self, node: FASTNode):
return node

def specific_binding(self, node: FASTNode):
children = self.create_children(node)
return ast_internal_classes.Specific_Binding_Node(name=children[3], args=children[0:2]+[children[4]])


def type_bound_procedure_part(self, node: FASTNode):
return node
children = self.create_children(node)
return ast_internal_classes.Bound_Procedures_Node(procedures= children[1:])

def contains_stmt(self, node: FASTNode):
return node

def call_stmt(self, node: FASTNode):
children = self.create_children(node)
name = get_child(children, ast_internal_classes.Name_Node)
arg_addition = None
if name is None:
proc_ref = get_child(children, ast_internal_classes.Procedure_Separator_Node)
name = proc_ref.part_ref
arg_addition = proc_ref.parent_ref

args = get_child(children, ast_internal_classes.Arg_List_Node)
if args==None:
ret_args = []
else:
ret_args = args.args
if arg_addition is not None:
ret_args.insert(0,arg_addition)
line_number = get_line(node)
#if node.item is None:
# line_number = 42
Expand Down
24 changes: 23 additions & 1 deletion dace/frontend/fortran/ast_internal_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class Derived_Type_Stmt_Node(FNode):

class Derived_Type_Def_Node(FNode):
_attributes = ('name', )
_fields = ('component_part', )
_fields = ('component_part','procedure_part' )


class Component_Part_Node(FNode):
Expand Down Expand Up @@ -395,11 +395,33 @@ class If_Stmt_Node(FNode):
'body_else',
)

class Defer_Shape_Node(FNode):
_attributes = ()
_fields = ()

class Component_Initialization_Node(FNode):
_attributes = ()
_fields = ('init')

class Case_Cond_Node(FNode):
_fields = ('cond', 'op')
_attributes = ()

class Else_Separator_Node(FNode):
_attributes = ()
_fields = ()

class Procedure_Separator_Node(FNode):
_attributes = ()
_fields = ('parent_ref', 'part_ref')

class Bound_Procedures_Node(FNode):
_attributes = ()
_fields = ('procedures')

class Specific_Binding_Node(FNode):
_attributes = ()
_fields = ('name', 'args')

class Parenthesis_Expr_Node(FNode):
_attributes = ()
Expand Down
68 changes: 66 additions & 2 deletions dace/frontend/fortran/ast_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,63 @@ def generic_visit(self, node: ast_internal_classes.FNode):
return node


class Flatten_Classes(NodeTransformer):

def __init__(self, classes: List[ast_internal_classes.Derived_Type_Def_Node]):
self.classes=classes
self.current_class=None

def visit_Derived_Type_Def_Node(self, node: ast_internal_classes.Derived_Type_Def_Node):
self.current_class=node
return_node=self.generic_visit(node)
#self.current_class=None
return return_node

def visit_Module_Node(self, node: ast_internal_classes.Module_Node):
self.current_class=None
return self.generic_visit(node)

def visit_Subroutine_Subprogram_Node(self, node: ast_internal_classes.Subroutine_Subprogram_Node):
new_node=self.generic_visit(node)
if self.current_class is not None:
for i in self.classes:
if i.is_class is True:
if i.name.name==self.current_class.name.name:
for j in i.procedure_part.procedures:
if j.name.name==node.name.name:
return ast_internal_classes.Subroutine_Subprogram_Node(
name=ast_internal_classes.Name_Node(name=i.name.name+"_"+node.name.name,type=node.type),
args=new_node.args,
specification_part=new_node.specification_part,
execution_part=new_node.execution_part,
mandatory_args_count=new_node.mandatory_args_count,
optional_args_count=new_node.optional_args_count,
elemental=new_node.elemental,
line_number=new_node.line_number)
elif j.args[2] is not None:
if j.args[2].name==node.name.name:

return ast_internal_classes.Subroutine_Subprogram_Node(
name=ast_internal_classes.Name_Node(name=i.name.name+"_"+j.name.name,type=node.type),
args=new_node.args,
specification_part=new_node.specification_part,
execution_part=new_node.execution_part,
mandatory_args_count=new_node.mandatory_args_count,
optional_args_count=new_node.optional_args_count,
elemental=new_node.elemental,
line_number=new_node.line_number)
return new_node

def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
if self.current_class is not None:
for i in self.classes:
if i.is_class is True:
if i.name.name==self.current_class.name.name:
for j in i.procedure_part.procedures:
if j.name.name==node.name.name:
return ast_internal_classes.Call_Expr_Node(name=ast_internal_classes.Name_Node(name=i.name.name+"_"+node.name.name,type=node.type,args=node.args,line_number=node.line_number),args=node.args,type=node.type,line_number=node.line_number)
return self.generic_visit(node)

class FindFunctionAndSubroutines(NodeVisitor):
"""
Finds all function and subroutine names in the AST
Expand Down Expand Up @@ -388,8 +445,15 @@ def __init__(self):
self.names= []

def visit_Derived_Type_Def_Node(self, node: ast_internal_classes.Derived_Type_Def_Node):
self.structs.append(node)
self.names.append(node.name.name)
if node.procedure_part is not None:
if len(node.procedure_part.procedures)>0:
node.is_class=True
self.structs.append(node)
return
node.is_class=False
self.structs.append(node)


class StructDependencyLister(NodeVisitor):
def __init__(self, names=None):
Expand Down Expand Up @@ -610,7 +674,7 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
raise ValueError("Call_Expr_Node name is None")
return ast_internal_classes.Char_Literal_Node(value="Error!", type="CHARACTER")

if node.name.name in self.excepted_funcs or node.name in self.funcs.names or node.name.name in self.funcs.iblocks:
if node.name.name in self.excepted_funcs or node.name.name in [i.name for i in self.funcs.names] or node.name.name in self.funcs.iblocks:
processed_args = []
for i in node.args:
arg = CallToArray(self.funcs).visit(i)
Expand Down
Loading

0 comments on commit 412385e

Please sign in to comment.