Skip to content

Commit

Permalink
Fused diffusion stencils (#250)
Browse files Browse the repository at this point in the history
Introduce stencil fusion to Liskov.

- Adds START FUSED, END FUSED, START DELETE, END DELETE directives. 
- Adds a --fused and --unfused mode to icon_liskov in integration mode.

Co-authored-by: samkellerhals <[email protected]>
  • Loading branch information
huppd and samkellerhals authored Aug 10, 2023
1 parent b4e3df3 commit bca6876
Show file tree
Hide file tree
Showing 17 changed files with 770 additions and 106 deletions.
12 changes: 9 additions & 3 deletions tools/src/icon4pytools/liskov/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from icon4pytools.common.logger import setup_logger
from icon4pytools.liskov.external.exceptions import MissingCommandError
from icon4pytools.liskov.pipeline.collection import (
load_gt4py_stencils,
parse_fortran_file,
process_stencils,
run_code_generation,
)

Expand Down Expand Up @@ -50,6 +50,12 @@ def main(ctx):
is_flag=True,
help="Add metadata header with information about program.",
)
@click.option(
"--fused/--unfused",
"-f/-u",
default=True,
help="Adds fused or unfused stencils.",
)
@click.argument(
"input_path",
type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path),
Expand All @@ -58,10 +64,10 @@ def main(ctx):
"output_path",
type=click.Path(dir_okay=False, resolve_path=True, path_type=pathlib.Path),
)
def integrate(input_path, output_path, profile, metadatagen):
def integrate(input_path, output_path, fused, profile, metadatagen):
mode = "integration"
iface = parse_fortran_file(input_path, output_path, mode)
iface_gt4py = load_gt4py_stencils(iface)
iface_gt4py = process_stencils(iface, fused)
run_code_generation(
input_path,
output_path,
Expand Down
95 changes: 78 additions & 17 deletions tools/src/icon4pytools/liskov/codegen/integration/deserialise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
BoundsData,
DeclareData,
EndCreateData,
EndDeleteData,
EndFusedStencilData,
EndIfData,
EndProfileData,
EndStencilData,
Expand All @@ -28,6 +30,8 @@
InsertData,
IntegrationCodeInterface,
StartCreateData,
StartDeleteData,
StartFusedStencilData,
StartProfileData,
StartStencilData,
UnusedDirective,
Expand Down Expand Up @@ -134,6 +138,16 @@ class EndProfileDataFactory(OptionalMultiUseDataFactory):
dtype: Type[EndProfileData] = EndProfileData


class EndDeleteDataFactory(OptionalMultiUseDataFactory):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.EndDelete
dtype: Type[EndDeleteData] = EndDeleteData


class StartDeleteDataFactory(OptionalMultiUseDataFactory):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartDelete
dtype: Type[StartDeleteData] = StartDeleteData


class StartCreateDataFactory(DataFactoryBase):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartCreate
dtype: Type[StartCreateData] = StartCreateData
Expand Down Expand Up @@ -227,47 +241,80 @@ def __call__(self, parsed: ts.ParsedDict) -> list[EndStencilData]:
return deserialised


class StartStencilDataFactory(DataFactoryBase):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartStencil
dtype: Type[StartStencilData] = StartStencilData
class EndFusedStencilDataFactory(DataFactoryBase):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.EndFusedStencil
dtype: Type[EndFusedStencilData] = EndFusedStencilData

def __call__(self, parsed: ts.ParsedDict) -> list[StartStencilData]:
"""Create and return a list of StartStencilData objects from the parsed directives.
def __call__(self, parsed: ts.ParsedDict) -> list[EndFusedStencilData]:
deserialised = []
extracted = extract_directive(parsed["directives"], self.directive_cls)
for i, directive in enumerate(extracted):
named_args = parsed["content"]["EndFusedStencil"][i]
stencil_name = _extract_stencil_name(named_args, directive)
deserialised.append(
self.dtype(
name=stencil_name,
startln=directive.startln,
)
)
return deserialised

Args:
parsed (ParsedDict): Dictionary of parsed directives and their associated content.

Returns:
List[StartStencilData]: List of StartStencilData objects created from the parsed directives.
"""
deserialised = []
class StartStencilDataFactoryBase(DataFactoryBase):
directive_cls: Type[ts.ParsedDirective] = None
dtype: Type[StartFusedStencilData] = None

def __call__(self, parsed: ts.ParsedDict) -> list[StartStencilData]:
field_dimensions = flatten_list_of_dicts(
[DeclareDataFactory.get_field_dimensions(dim) for dim in parsed["content"]["Declare"]]
)
directives = extract_directive(parsed["directives"], self.directive_cls)
return self.create_stencil_data(
parsed, field_dimensions, directives, self.directive_cls, self.dtype
)

def create_stencil_data(
self,
parsed: ts.ParsedDict,
field_dimensions: list[dict[str, Any]],
directives: list[ts.ParsedDirective],
directive_cls: Type[ts.ParsedDirective],
dtype: Type[StartStencilData | StartFusedStencilData],
) -> list[StartStencilData | StartFusedStencilData]:
"""Create and return a list of StartStencilData or StartFusedStencilData objects from parsed directives."""
deserialised = []
for i, directive in enumerate(directives):
named_args = parsed["content"]["StartStencil"][i]
named_args = parsed["content"][directive_cls.__name__][i]
additional_attrs = self._pop_additional_attributes(dtype, named_args)
acc_present = string_to_bool(pop_item_from_dict(named_args, "accpresent", "true"))
mergecopy = string_to_bool(pop_item_from_dict(named_args, "mergecopy", "false"))
copies = string_to_bool(pop_item_from_dict(named_args, "copies", "true"))
stencil_name = _extract_stencil_name(named_args, directive)
bounds = self._make_bounds(named_args)
fields = self._make_fields(named_args, field_dimensions)
fields_w_tolerance = self._update_tolerances(named_args, fields)

deserialised.append(
self.dtype(
dtype(
name=stencil_name,
fields=fields_w_tolerance,
bounds=bounds,
startln=directive.startln,
acc_present=acc_present,
mergecopy=mergecopy,
copies=copies,
**additional_attrs,
)
)
return deserialised

def _pop_additional_attributes(
self, dtype: Type[StartStencilData | StartFusedStencilData], named_args: dict[str, Any]
) -> dict:
"""Pop and return additional attributes specific to StartStencilData."""
additional_attrs = {}
if dtype == StartStencilData:
mergecopy = string_to_bool(pop_item_from_dict(named_args, "mergecopy", "false"))
copies = string_to_bool(pop_item_from_dict(named_args, "copies", "true"))
additional_attrs = {"mergecopy": mergecopy, "copies": copies}
return additional_attrs

@staticmethod
def _make_bounds(named_args: dict) -> BoundsData:
"""Extract stencil bounds from directive arguments."""
Expand Down Expand Up @@ -355,6 +402,16 @@ def _update_tolerances(
return fields


class StartFusedStencilDataFactory(StartStencilDataFactoryBase):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartFusedStencil
dtype: Type[StartFusedStencilData] = StartFusedStencilData


class StartStencilDataFactory(StartStencilDataFactoryBase):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartStencil
dtype: Type[StartStencilData] = StartStencilData


class InsertDataFactory(DataFactoryBase):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.Insert
dtype: Type[InsertData] = InsertData
Expand All @@ -378,6 +435,10 @@ class IntegrationCodeDeserialiser(Deserialiser):
"Declare": DeclareDataFactory(),
"StartStencil": StartStencilDataFactory(),
"EndStencil": EndStencilDataFactory(),
"StartFusedStencil": StartFusedStencilDataFactory(),
"EndFusedStencil": EndFusedStencilDataFactory(),
"StartDelete": StartDeleteDataFactory(),
"EndDelete": EndDeleteDataFactory(),
"EndIf": EndIfDataFactory(),
"StartProfile": StartProfileDataFactory(),
"EndProfile": EndProfileDataFactory(),
Expand Down
55 changes: 54 additions & 1 deletion tools/src/icon4pytools/liskov/codegen/integration/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
DeclareStatementGenerator,
EndCreateStatement,
EndCreateStatementGenerator,
EndDeleteStatement,
EndDeleteStatementGenerator,
EndFusedStencilStatement,
EndFusedStencilStatementGenerator,
EndIfStatement,
EndIfStatementGenerator,
EndProfileStatement,
Expand All @@ -38,6 +42,10 @@
MetadataStatementGenerator,
StartCreateStatement,
StartCreateStatementGenerator,
StartDeleteStatement,
StartDeleteStatementGenerator,
StartFusedStencilStatement,
StartFusedStencilStatementGenerator,
StartProfileStatement,
StartProfileStatementGenerator,
StartStencilStatement,
Expand Down Expand Up @@ -75,6 +83,9 @@ def __call__(self, data: Any = None) -> list[GeneratedCode]:
self._generate_declare()
self._generate_start_stencil()
self._generate_end_stencil()
self._generate_start_fused_stencil()
self._generate_end_fused_stencil()
self._generate_delete()
self._generate_endif()
self._generate_profile()
self._generate_insert()
Expand Down Expand Up @@ -167,14 +178,56 @@ def _generate_end_stencil(self) -> None:
noaccenddata=self.interface.EndStencil[i].noaccenddata,
)

def _generate_start_fused_stencil(self) -> None:
"""Generate f90 integration code surrounding a fused stencil."""
if self.interface.StartFusedStencil != UnusedDirective:
for stencil in self.interface.StartFusedStencil:
logger.info(f"Generating START FUSED statement for {stencil.name}")
self._generate(
StartFusedStencilStatement,
StartFusedStencilStatementGenerator,
stencil.startln,
stencil_data=stencil,
)

def _generate_end_fused_stencil(self) -> None:
"""Generate f90 integration code surrounding a fused stencil."""
if self.interface.EndFusedStencil != UnusedDirective:
for i, stencil in enumerate(self.interface.StartFusedStencil):
logger.info(f"Generating END Fused statement for {stencil.name}")
self._generate(
EndFusedStencilStatement,
EndFusedStencilStatementGenerator,
self.interface.EndFusedStencil[i].startln,
stencil_data=stencil,
)

def _generate_delete(self) -> None:
"""Generate f90 integration code for delete section."""
if self.interface.StartDelete != UnusedDirective:
logger.info("Generating DELETE statement.")
for start, end in zip(
self.interface.StartDelete, self.interface.EndDelete, strict=True
):
self._generate(
StartDeleteStatement,
StartDeleteStatementGenerator,
start.startln,
)
self._generate(
EndDeleteStatement,
EndDeleteStatementGenerator,
end.startln,
)

def _generate_imports(self) -> None:
"""Generate f90 code for import statements."""
logger.info("Generating IMPORT statement.")
self._generate(
ImportsStatement,
ImportsStatementGenerator,
self.interface.Imports.startln,
stencils=self.interface.StartStencil,
stencils=self.interface.StartStencil + self.interface.StartFusedStencil,
)

def _generate_create(self) -> None:
Expand Down
38 changes: 35 additions & 3 deletions tools/src/icon4pytools/liskov/codegen/integration/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,51 @@ class EndProfileData(CodeGenInput):


@dataclass
class StartStencilData(CodeGenInput):
class BaseStartStencilData(CodeGenInput):
name: str
fields: list[FieldAssociationData]
bounds: BoundsData
acc_present: Optional[bool]
bounds: BoundsData


@dataclass
class StartStencilData(BaseStartStencilData):
mergecopy: Optional[bool]
copies: Optional[bool]


@dataclass
class EndStencilData(CodeGenInput):
class StartFusedStencilData(BaseStartStencilData):
...


@dataclass
class BaseEndStencilData(CodeGenInput):
name: str


@dataclass
class EndStencilData(BaseEndStencilData):
noendif: Optional[bool]
noprofile: Optional[bool]
noaccenddata: Optional[bool]


@dataclass
class EndFusedStencilData(BaseEndStencilData):
...


@dataclass
class StartDeleteData(CodeGenInput):
startln: int


@dataclass
class EndDeleteData(StartDeleteData):
...


@dataclass
class InsertData(CodeGenInput):
content: str
Expand All @@ -104,6 +132,10 @@ class InsertData(CodeGenInput):
class IntegrationCodeInterface:
StartStencil: Sequence[StartStencilData]
EndStencil: Sequence[EndStencilData]
StartFusedStencil: Sequence[StartFusedStencilData]
EndFusedStencil: Sequence[EndFusedStencilData]
StartDelete: Sequence[StartDeleteData] | UnusedDirective
EndDelete: Sequence[EndDeleteData] | UnusedDirective
Declare: Sequence[DeclareData]
Imports: ImportsData
StartCreate: Sequence[StartCreateData] | UnusedDirective
Expand Down
Loading

0 comments on commit bca6876

Please sign in to comment.