Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused diffusion stencils #250

Merged
merged 59 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
1e00c05
First version of fused diffusion stencils
muellch Nov 24, 2022
2abd8b9
Added prototype of fused_mo_nh_diffusion_stencil_01_02_03_rbf
muellch Nov 29, 2022
ee4e21e
Removed duplicate stencil definition from fused stencil 13 and 14
muellch Dec 7, 2022
381be77
Cleanup
muellch Jan 20, 2023
f5ad668
Cleanup
muellch Jan 23, 2023
ae10ca2
Change includes, remove 01_02_03_rbf from tox tests
muellch Jan 23, 2023
8a77b1a
Merge branch 'main' into fused_diffusion_stencils
halungge Feb 1, 2023
f30f1b0
use renamed diffusion stencils in fused stencils
halungge Feb 1, 2023
4ab35bc
readability refactoring (2): rename fused stencils
halungge Feb 1, 2023
5d0924b
rename fused_mo_nh_diffusion_stencil_07_08_09_10
halungge Feb 7, 2023
eac329f
rename fused_mo_nh_diffusion_stencil_11_12
halungge Feb 7, 2023
cd68da9
rename fused_mo_nh_diffusion_stencil_13_14
halungge Feb 7, 2023
4b225b7
Merge branch 'main' into fused_diffusion_stencils
halungge Feb 7, 2023
6f325b7
fix imports from gt4py
halungge Feb 7, 2023
e5eae6a
Merge branch 'main' into fused_diffusion_stencils
Mar 13, 2023
1e9c67a
make spelling internally consistent
Mar 13, 2023
1de4c08
Merge branch 'main' into fused_diffusion_stencils
muellch Apr 25, 2023
ce8877e
Renamed calculate_nabla2_of_theta correctly
muellch Apr 26, 2023
d38bdbe
Merge branch 'main' into fused_diffusion_stencils
muellch May 24, 2023
176dbf1
Add fused diffusion stencil containing stencil 15
muellch May 24, 2023
f5ef5e8
Add global component to the fused diffusion vn update stencil
muellch May 25, 2023
8976ec6
Renamed stencil 15, restructured fused vn stencil for scalar if
muellch May 26, 2023
e9e8cb5
Merge branch 'main' into fused_diffusion_stencils
muellch May 28, 2023
8e5d414
Substituted where with if expression
muellch May 30, 2023
adbebc2
Merge branch 'fused_diffusion_stencils' of github.com:C2SM/icon4py in…
muellch May 30, 2023
a294505
Undo renaming of test
muellch May 30, 2023
fc60b77
Magdalena feedback, lateral boundary stencil incorrectly called in gl…
muellch May 31, 2023
478688a
Drop w from argument list of fused stenil, write conditions always wi…
muellch May 31, 2023
1ed3b32
Merge branch 'main' into fused_diffusion_stencils
muellch May 31, 2023
d2f42f6
Fix domains for global case in fused vn stencil
muellch Jun 6, 2023
7df42fd
Deleted example stencil with output fields of different type
muellch Jun 6, 2023
a884480
Apply pre-commit
muellch Jun 6, 2023
084e35c
allow to uncoment with !!
huppd Jul 3, 2023
82ea0c3
Merge remote-tracking branch 'origin/main' into fused_diffusion_stenc…
huppd Jul 3, 2023
115a230
allow !$DSL to be commented out and increase verbosity of error msg
huppd Jul 4, 2023
29f6434
improve scan
huppd Jul 11, 2023
9083d69
WIP fused verification working
huppd Jul 19, 2023
c0321e2
styel up
huppd Jul 19, 2023
98e6a80
WIP substition mode
huppd Jul 19, 2023
37f4d69
cleanup style
huppd Jul 19, 2023
1bfabd5
fix import for fused
huppd Jul 19, 2023
eec5f6d
workaround for data regions
huppd Jul 19, 2023
c2685a8
fix style
huppd Jul 19, 2023
e1139a9
refactor 1
huppd Jul 25, 2023
2f4864e
more refactor
huppd Jul 25, 2023
12b6248
improve help test
huppd Jul 25, 2023
3d37a5a
test comenting out DSL
huppd Jul 26, 2023
bbf9505
more tests
huppd Jul 26, 2023
4904546
more tests
huppd Jul 26, 2023
cb80683
WIP
huppd Jul 26, 2023
3e9db8c
fix test
huppd Jul 27, 2023
0fb7dab
add one test
huppd Aug 2, 2023
29047a1
Merge remote-tracking branch 'origin/main' into fused_diffusion_stenc…
huppd Aug 2, 2023
2211331
Refactor StencilTransformer
samkellerhals Aug 8, 2023
aa9f5fb
simplify interface
samkellerhals Aug 9, 2023
53127b5
Simplify codegen template
samkellerhals Aug 9, 2023
0c645f2
Simplify deserialise step
samkellerhals Aug 9, 2023
16133fb
Further simplify deserialiser
samkellerhals Aug 9, 2023
3a9cec0
Run precommit
samkellerhals Aug 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions tools/src/icon4pytools/liskov/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from icon4pytools.common.logger import setup_logger
from icon4pytools.liskov.external.exceptions import MissingCommandError
from icon4pytools.liskov.pipeline.collection import (
load_gt4py_stencils,
parse_fortran_file,
process_stencils,
run_code_generation,
)

Expand Down Expand Up @@ -50,6 +50,12 @@ def main(ctx):
is_flag=True,
help="Add metadata header with information about program.",
)
@click.option(
"--fused/--unfused",
"-f/-u",
default=True,
help="Adds fused or unfused stencils.",
)
@click.argument(
"input_path",
type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path),
Expand All @@ -58,10 +64,10 @@ def main(ctx):
"output_path",
type=click.Path(dir_okay=False, resolve_path=True, path_type=pathlib.Path),
)
def integrate(input_path, output_path, profile, metadatagen):
def integrate(input_path, output_path, fused, profile, metadatagen):
mode = "integration"
iface = parse_fortran_file(input_path, output_path, mode)
iface_gt4py = load_gt4py_stencils(iface)
iface_gt4py = process_stencils(iface, fused)
run_code_generation(
input_path,
output_path,
Expand Down
137 changes: 107 additions & 30 deletions tools/src/icon4pytools/liskov/codegen/integration/deserialise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
BoundsData,
DeclareData,
EndCreateData,
EndDeleteData,
EndFusedStencilData,
EndIfData,
EndProfileData,
EndStencilData,
Expand All @@ -28,6 +30,8 @@
InsertData,
IntegrationCodeInterface,
StartCreateData,
StartDeleteData,
StartFusedStencilData,
StartProfileData,
StartStencilData,
UnusedDirective,
Expand Down Expand Up @@ -134,6 +138,16 @@ class EndProfileDataFactory(OptionalMultiUseDataFactory):
dtype: Type[EndProfileData] = EndProfileData


class EndDeleteDataFactory(OptionalMultiUseDataFactory):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.EndDelete
dtype: Type[EndDeleteData] = EndDeleteData


class StartDeleteDataFactory(OptionalMultiUseDataFactory):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartDelete
dtype: Type[StartDeleteData] = StartDeleteData


class StartCreateDataFactory(DataFactoryBase):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartCreate
dtype: Type[StartCreateData] = StartCreateData
Expand Down Expand Up @@ -227,47 +241,26 @@ def __call__(self, parsed: ts.ParsedDict) -> list[EndStencilData]:
return deserialised


class StartStencilDataFactory(DataFactoryBase):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartStencil
dtype: Type[StartStencilData] = StartStencilData

def __call__(self, parsed: ts.ParsedDict) -> list[StartStencilData]:
"""Create and return a list of StartStencilData objects from the parsed directives.

Args:
parsed (ParsedDict): Dictionary of parsed directives and their associated content.
class EndFusedStencilDataFactory(DataFactoryBase):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.EndFusedStencil
dtype: Type[EndFusedStencilData] = EndFusedStencilData

Returns:
List[StartStencilData]: List of StartStencilData objects created from the parsed directives.
"""
def __call__(self, parsed: ts.ParsedDict) -> list[EndFusedStencilData]:
deserialised = []
field_dimensions = flatten_list_of_dicts(
[DeclareDataFactory.get_field_dimensions(dim) for dim in parsed["content"]["Declare"]]
)
directives = extract_directive(parsed["directives"], self.directive_cls)
for i, directive in enumerate(directives):
named_args = parsed["content"]["StartStencil"][i]
acc_present = string_to_bool(pop_item_from_dict(named_args, "accpresent", "true"))
mergecopy = string_to_bool(pop_item_from_dict(named_args, "mergecopy", "false"))
copies = string_to_bool(pop_item_from_dict(named_args, "copies", "true"))
extracted = extract_directive(parsed["directives"], self.directive_cls)
for i, directive in enumerate(extracted):
named_args = parsed["content"]["EndFusedStencil"][i]
stencil_name = _extract_stencil_name(named_args, directive)
bounds = self._make_bounds(named_args)
fields = self._make_fields(named_args, field_dimensions)
fields_w_tolerance = self._update_tolerances(named_args, fields)

deserialised.append(
self.dtype(
name=stencil_name,
fields=fields_w_tolerance,
bounds=bounds,
startln=directive.startln,
acc_present=acc_present,
mergecopy=mergecopy,
copies=copies,
)
)
return deserialised


class StartStencilDataFactoryBase(DataFactoryBase):
@staticmethod
def _make_bounds(named_args: dict) -> BoundsData:
"""Extract stencil bounds from directive arguments."""
Expand Down Expand Up @@ -355,6 +348,86 @@ def _update_tolerances(
return fields


class StartStencilDataFactory(StartStencilDataFactoryBase):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartStencil
dtype: Type[StartStencilData] = StartStencilData

def __call__(self, parsed: ts.ParsedDict) -> list[StartStencilData]:
"""Create and return a list of StartStencilData objects from the parsed directives.

Args:
parsed (ParsedDict): Dictionary of parsed directives and their associated content.

Returns:
List[StartStencilData]: List of StartStencilData objects created from the parsed directives.
"""
deserialised = []
field_dimensions = flatten_list_of_dicts(
[DeclareDataFactory.get_field_dimensions(dim) for dim in parsed["content"]["Declare"]]
)
directives = extract_directive(parsed["directives"], self.directive_cls)
for i, directive in enumerate(directives):
named_args = parsed["content"]["StartStencil"][i]
acc_present = string_to_bool(pop_item_from_dict(named_args, "accpresent", "true"))
mergecopy = string_to_bool(pop_item_from_dict(named_args, "mergecopy", "false"))
copies = string_to_bool(pop_item_from_dict(named_args, "copies", "true"))
stencil_name = _extract_stencil_name(named_args, directive)
bounds = self._make_bounds(named_args)
fields = self._make_fields(named_args, field_dimensions)
fields_w_tolerance = self._update_tolerances(named_args, fields)

deserialised.append(
self.dtype(
name=stencil_name,
fields=fields_w_tolerance,
bounds=bounds,
startln=directive.startln,
acc_present=acc_present,
mergecopy=mergecopy,
copies=copies,
)
)
return deserialised


class StartFusedStencilDataFactory(StartStencilDataFactoryBase):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartFusedStencil
dtype: Type[StartFusedStencilData] = StartFusedStencilData

def __call__(self, parsed: ts.ParsedDict) -> list[StartFusedStencilData]:
samkellerhals marked this conversation as resolved.
Show resolved Hide resolved
"""Create and return a list of StartStencilData objects from the parsed directives.

Args:
parsed (ParsedDict): Dictionary of parsed directives and their associated content.

Returns:
List[StartStencilData]: List of StartStencilData objects created from the parsed directives.
"""
deserialised = []
field_dimensions = flatten_list_of_dicts(
[DeclareDataFactory.get_field_dimensions(dim) for dim in parsed["content"]["Declare"]]
)
directives = extract_directive(parsed["directives"], self.directive_cls)
for i, directive in enumerate(directives):
named_args = parsed["content"]["StartFusedStencil"][i]
acc_present = string_to_bool(pop_item_from_dict(named_args, "accpresent", "true"))
stencil_name = _extract_stencil_name(named_args, directive)
bounds = self._make_bounds(named_args)
fields = self._make_fields(named_args, field_dimensions)
fields_w_tolerance = self._update_tolerances(named_args, fields)

deserialised.append(
self.dtype(
name=stencil_name,
fields=fields_w_tolerance,
bounds=bounds,
startln=directive.startln,
acc_present=acc_present,
)
)
return deserialised


class InsertDataFactory(DataFactoryBase):
directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.Insert
dtype: Type[InsertData] = InsertData
Expand All @@ -378,6 +451,10 @@ class IntegrationCodeDeserialiser(Deserialiser):
"Declare": DeclareDataFactory(),
"StartStencil": StartStencilDataFactory(),
"EndStencil": EndStencilDataFactory(),
"StartFusedStencil": StartFusedStencilDataFactory(),
"EndFusedStencil": EndFusedStencilDataFactory(),
"StartDelete": StartDeleteDataFactory(),
"EndDelete": EndDeleteDataFactory(),
"EndIf": EndIfDataFactory(),
"StartProfile": StartProfileDataFactory(),
"EndProfile": EndProfileDataFactory(),
Expand Down
55 changes: 54 additions & 1 deletion tools/src/icon4pytools/liskov/codegen/integration/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
DeclareStatementGenerator,
EndCreateStatement,
EndCreateStatementGenerator,
EndDeleteStatement,
EndDeleteStatementGenerator,
EndFusedStencilStatement,
EndFusedStencilStatementGenerator,
EndIfStatement,
EndIfStatementGenerator,
EndProfileStatement,
Expand All @@ -38,6 +42,10 @@
MetadataStatementGenerator,
StartCreateStatement,
StartCreateStatementGenerator,
StartDeleteStatement,
StartDeleteStatementGenerator,
StartFusedStencilStatement,
StartFusedStencilStatementGenerator,
StartProfileStatement,
StartProfileStatementGenerator,
StartStencilStatement,
Expand Down Expand Up @@ -75,6 +83,9 @@ def __call__(self, data: Any = None) -> list[GeneratedCode]:
self._generate_declare()
self._generate_start_stencil()
self._generate_end_stencil()
self._generate_start_fused_stencil()
self._generate_end_fused_stencil()
self._generate_delete()
self._generate_endif()
self._generate_profile()
self._generate_insert()
Expand Down Expand Up @@ -167,14 +178,56 @@ def _generate_end_stencil(self) -> None:
noaccenddata=self.interface.EndStencil[i].noaccenddata,
)

def _generate_start_fused_stencil(self) -> None:
"""Generate f90 integration code surrounding a fused stencil."""
if self.interface.StartFusedStencil != UnusedDirective:
for stencil in self.interface.StartFusedStencil:
logger.info(f"Generating START FUSED statement for {stencil.name}")
self._generate(
StartFusedStencilStatement,
StartFusedStencilStatementGenerator,
stencil.startln,
stencil_data=stencil,
)

def _generate_end_fused_stencil(self) -> None:
"""Generate f90 integration code surrounding a fused stencil."""
if self.interface.EndFusedStencil != UnusedDirective:
for i, stencil in enumerate(self.interface.StartFusedStencil):
logger.info(f"Generating END Fused statement for {stencil.name}")
self._generate(
EndFusedStencilStatement,
EndFusedStencilStatementGenerator,
self.interface.EndFusedStencil[i].startln,
stencil_data=stencil,
)

def _generate_delete(self) -> None:
"""Generate f90 integration code for delete section."""
if self.interface.StartDelete != UnusedDirective:
logger.info("Generating DELETE statement.")
for start, end in zip(
self.interface.StartDelete, self.interface.EndDelete, strict=True
):
self._generate(
StartDeleteStatement,
StartDeleteStatementGenerator,
start.startln,
)
self._generate(
EndDeleteStatement,
EndDeleteStatementGenerator,
end.startln,
)

def _generate_imports(self) -> None:
"""Generate f90 code for import statements."""
logger.info("Generating IMPORT statement.")
self._generate(
ImportsStatement,
ImportsStatementGenerator,
self.interface.Imports.startln,
stencils=self.interface.StartStencil,
stencils=self.interface.StartStencil + self.interface.StartFusedStencil,
)

def _generate_create(self) -> None:
Expand Down
47 changes: 44 additions & 3 deletions tools/src/icon4pytools/liskov/codegen/integration/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,60 @@ class EndProfileData(CodeGenInput):


@dataclass
class StartStencilData(CodeGenInput):
class StartBasicStencilData(CodeGenInput):
samkellerhals marked this conversation as resolved.
Show resolved Hide resolved
name: str
fields: list[FieldAssociationData]
bounds: BoundsData
acc_present: Optional[bool]
bounds: BoundsData


@dataclass
class StartStencilData(StartBasicStencilData):
mergecopy: Optional[bool]
copies: Optional[bool]


@dataclass
class EndStencilData(CodeGenInput):
class StartFusedStencilData(StartBasicStencilData):
...


@dataclass
class EndStencilBaseData(CodeGenInput):
name: str


@dataclass
class EndStencilData(EndStencilBaseData):
noendif: Optional[bool]
noprofile: Optional[bool]
noaccenddata: Optional[bool]


@dataclass
class EndFusedStencilData(EndStencilBaseData):
...


@dataclass
class StartDeleteData(CodeGenInput):
def __init__(self, startStencil: StartStencilData):
self.startln = startStencil.startln
samkellerhals marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class EndDeleteData(CodeGenInput):
def __init__(self, endStencil: EndStencilData=None, startln:int=None):
samkellerhals marked this conversation as resolved.
Show resolved Hide resolved
if endStencil is None and startln is None:
self.startln = 0
elif endStencil is not None and startln is not None:
raise Exception()
elif endStencil is not None:
self.startln = endStencil.startln
elif startln is not None:
self.startln = startln


@dataclass
class InsertData(CodeGenInput):
content: str
Expand All @@ -104,6 +141,10 @@ class InsertData(CodeGenInput):
class IntegrationCodeInterface:
StartStencil: Sequence[StartStencilData]
EndStencil: Sequence[EndStencilData]
StartFusedStencil: Sequence[StartFusedStencilData]
EndFusedStencil: Sequence[EndFusedStencilData]
StartDelete: Sequence[StartDeleteData] | UnusedDirective
EndDelete: Sequence[EndDeleteData] | UnusedDirective
Declare: Sequence[DeclareData]
Imports: ImportsData
StartCreate: Sequence[StartCreateData] | UnusedDirective
Expand Down
Loading