From bca68768f28507cf18ac4a7513fe9ae74e5cf7a0 Mon Sep 17 00:00:00 2001 From: Daniel Hupp Date: Thu, 10 Aug 2023 11:27:29 +0200 Subject: [PATCH] Fused diffusion stencils (#250) Introduce stencil fusion to Liskov. - Adds START FUSED, END FUSED, START DELETE, END DELETE directives. - Adds a --fused and --unfused mode to icon_liskov in integration mode. Co-authored-by: samkellerhals --- tools/src/icon4pytools/liskov/cli.py | 12 +- .../liskov/codegen/integration/deserialise.py | 95 +++++-- .../liskov/codegen/integration/generate.py | 55 +++- .../liskov/codegen/integration/interface.py | 38 ++- .../liskov/codegen/integration/template.py | 203 ++++++++++----- .../src/icon4pytools/liskov/external/gt4py.py | 21 +- .../src/icon4pytools/liskov/parsing/parse.py | 20 ++ tools/src/icon4pytools/liskov/parsing/scan.py | 16 +- .../icon4pytools/liskov/parsing/transform.py | 115 +++++++++ .../liskov/pipeline/collection.py | 16 +- tools/tests/liskov/fortran_samples.py | 235 ++++++++++++++++++ tools/tests/liskov/test_cli.py | 8 +- .../liskov/test_directives_deserialiser.py | 10 + tools/tests/liskov/test_external.py | 4 + tools/tests/liskov/test_generation.py | 4 + tools/tests/liskov/test_parser.py | 16 +- tools/tests/liskov/test_validation.py | 8 +- 17 files changed, 770 insertions(+), 106 deletions(-) create mode 100644 tools/src/icon4pytools/liskov/parsing/transform.py diff --git a/tools/src/icon4pytools/liskov/cli.py b/tools/src/icon4pytools/liskov/cli.py index a16e0a428e..e6685c3ad5 100644 --- a/tools/src/icon4pytools/liskov/cli.py +++ b/tools/src/icon4pytools/liskov/cli.py @@ -18,8 +18,8 @@ from icon4pytools.common.logger import setup_logger from icon4pytools.liskov.external.exceptions import MissingCommandError from icon4pytools.liskov.pipeline.collection import ( - load_gt4py_stencils, parse_fortran_file, + process_stencils, run_code_generation, ) @@ -50,6 +50,12 @@ def main(ctx): is_flag=True, help="Add metadata header with information about program.", ) +@click.option( + "--fused/--unfused", + "-f/-u", + default=True, + help="Adds fused or unfused stencils.", +) @click.argument( "input_path", type=click.Path(exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path), @@ -58,10 +64,10 @@ def main(ctx): "output_path", type=click.Path(dir_okay=False, resolve_path=True, path_type=pathlib.Path), ) -def integrate(input_path, output_path, profile, metadatagen): +def integrate(input_path, output_path, fused, profile, metadatagen): mode = "integration" iface = parse_fortran_file(input_path, output_path, mode) - iface_gt4py = load_gt4py_stencils(iface) + iface_gt4py = process_stencils(iface, fused) run_code_generation( input_path, output_path, diff --git a/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py b/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py index 4b4605c7cb..f4118c9e95 100644 --- a/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py +++ b/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py @@ -20,6 +20,8 @@ BoundsData, DeclareData, EndCreateData, + EndDeleteData, + EndFusedStencilData, EndIfData, EndProfileData, EndStencilData, @@ -28,6 +30,8 @@ InsertData, IntegrationCodeInterface, StartCreateData, + StartDeleteData, + StartFusedStencilData, StartProfileData, StartStencilData, UnusedDirective, @@ -134,6 +138,16 @@ class EndProfileDataFactory(OptionalMultiUseDataFactory): dtype: Type[EndProfileData] = EndProfileData +class EndDeleteDataFactory(OptionalMultiUseDataFactory): + directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.EndDelete + dtype: Type[EndDeleteData] = EndDeleteData + + +class StartDeleteDataFactory(OptionalMultiUseDataFactory): + directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartDelete + dtype: Type[StartDeleteData] = StartDeleteData + + class StartCreateDataFactory(DataFactoryBase): directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartCreate dtype: Type[StartCreateData] = StartCreateData @@ -227,47 +241,80 @@ def __call__(self, parsed: ts.ParsedDict) -> list[EndStencilData]: return deserialised -class StartStencilDataFactory(DataFactoryBase): - directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartStencil - dtype: Type[StartStencilData] = StartStencilData +class EndFusedStencilDataFactory(DataFactoryBase): + directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.EndFusedStencil + dtype: Type[EndFusedStencilData] = EndFusedStencilData - def __call__(self, parsed: ts.ParsedDict) -> list[StartStencilData]: - """Create and return a list of StartStencilData objects from the parsed directives. + def __call__(self, parsed: ts.ParsedDict) -> list[EndFusedStencilData]: + deserialised = [] + extracted = extract_directive(parsed["directives"], self.directive_cls) + for i, directive in enumerate(extracted): + named_args = parsed["content"]["EndFusedStencil"][i] + stencil_name = _extract_stencil_name(named_args, directive) + deserialised.append( + self.dtype( + name=stencil_name, + startln=directive.startln, + ) + ) + return deserialised - Args: - parsed (ParsedDict): Dictionary of parsed directives and their associated content. - Returns: - List[StartStencilData]: List of StartStencilData objects created from the parsed directives. - """ - deserialised = [] +class StartStencilDataFactoryBase(DataFactoryBase): + directive_cls: Type[ts.ParsedDirective] = None + dtype: Type[StartFusedStencilData] = None + + def __call__(self, parsed: ts.ParsedDict) -> list[StartStencilData]: field_dimensions = flatten_list_of_dicts( [DeclareDataFactory.get_field_dimensions(dim) for dim in parsed["content"]["Declare"]] ) directives = extract_directive(parsed["directives"], self.directive_cls) + return self.create_stencil_data( + parsed, field_dimensions, directives, self.directive_cls, self.dtype + ) + + def create_stencil_data( + self, + parsed: ts.ParsedDict, + field_dimensions: list[dict[str, Any]], + directives: list[ts.ParsedDirective], + directive_cls: Type[ts.ParsedDirective], + dtype: Type[StartStencilData | StartFusedStencilData], + ) -> list[StartStencilData | StartFusedStencilData]: + """Create and return a list of StartStencilData or StartFusedStencilData objects from parsed directives.""" + deserialised = [] for i, directive in enumerate(directives): - named_args = parsed["content"]["StartStencil"][i] + named_args = parsed["content"][directive_cls.__name__][i] + additional_attrs = self._pop_additional_attributes(dtype, named_args) acc_present = string_to_bool(pop_item_from_dict(named_args, "accpresent", "true")) - mergecopy = string_to_bool(pop_item_from_dict(named_args, "mergecopy", "false")) - copies = string_to_bool(pop_item_from_dict(named_args, "copies", "true")) stencil_name = _extract_stencil_name(named_args, directive) bounds = self._make_bounds(named_args) fields = self._make_fields(named_args, field_dimensions) fields_w_tolerance = self._update_tolerances(named_args, fields) deserialised.append( - self.dtype( + dtype( name=stencil_name, fields=fields_w_tolerance, bounds=bounds, startln=directive.startln, acc_present=acc_present, - mergecopy=mergecopy, - copies=copies, + **additional_attrs, ) ) return deserialised + def _pop_additional_attributes( + self, dtype: Type[StartStencilData | StartFusedStencilData], named_args: dict[str, Any] + ) -> dict: + """Pop and return additional attributes specific to StartStencilData.""" + additional_attrs = {} + if dtype == StartStencilData: + mergecopy = string_to_bool(pop_item_from_dict(named_args, "mergecopy", "false")) + copies = string_to_bool(pop_item_from_dict(named_args, "copies", "true")) + additional_attrs = {"mergecopy": mergecopy, "copies": copies} + return additional_attrs + @staticmethod def _make_bounds(named_args: dict) -> BoundsData: """Extract stencil bounds from directive arguments.""" @@ -355,6 +402,16 @@ def _update_tolerances( return fields +class StartFusedStencilDataFactory(StartStencilDataFactoryBase): + directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartFusedStencil + dtype: Type[StartFusedStencilData] = StartFusedStencilData + + +class StartStencilDataFactory(StartStencilDataFactoryBase): + directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.StartStencil + dtype: Type[StartStencilData] = StartStencilData + + class InsertDataFactory(DataFactoryBase): directive_cls: Type[ts.ParsedDirective] = icon4pytools.liskov.parsing.parse.Insert dtype: Type[InsertData] = InsertData @@ -378,6 +435,10 @@ class IntegrationCodeDeserialiser(Deserialiser): "Declare": DeclareDataFactory(), "StartStencil": StartStencilDataFactory(), "EndStencil": EndStencilDataFactory(), + "StartFusedStencil": StartFusedStencilDataFactory(), + "EndFusedStencil": EndFusedStencilDataFactory(), + "StartDelete": StartDeleteDataFactory(), + "EndDelete": EndDeleteDataFactory(), "EndIf": EndIfDataFactory(), "StartProfile": StartProfileDataFactory(), "EndProfile": EndProfileDataFactory(), diff --git a/tools/src/icon4pytools/liskov/codegen/integration/generate.py b/tools/src/icon4pytools/liskov/codegen/integration/generate.py index 9705435be3..40324d138f 100644 --- a/tools/src/icon4pytools/liskov/codegen/integration/generate.py +++ b/tools/src/icon4pytools/liskov/codegen/integration/generate.py @@ -24,6 +24,10 @@ DeclareStatementGenerator, EndCreateStatement, EndCreateStatementGenerator, + EndDeleteStatement, + EndDeleteStatementGenerator, + EndFusedStencilStatement, + EndFusedStencilStatementGenerator, EndIfStatement, EndIfStatementGenerator, EndProfileStatement, @@ -38,6 +42,10 @@ MetadataStatementGenerator, StartCreateStatement, StartCreateStatementGenerator, + StartDeleteStatement, + StartDeleteStatementGenerator, + StartFusedStencilStatement, + StartFusedStencilStatementGenerator, StartProfileStatement, StartProfileStatementGenerator, StartStencilStatement, @@ -75,6 +83,9 @@ def __call__(self, data: Any = None) -> list[GeneratedCode]: self._generate_declare() self._generate_start_stencil() self._generate_end_stencil() + self._generate_start_fused_stencil() + self._generate_end_fused_stencil() + self._generate_delete() self._generate_endif() self._generate_profile() self._generate_insert() @@ -167,6 +178,48 @@ def _generate_end_stencil(self) -> None: noaccenddata=self.interface.EndStencil[i].noaccenddata, ) + def _generate_start_fused_stencil(self) -> None: + """Generate f90 integration code surrounding a fused stencil.""" + if self.interface.StartFusedStencil != UnusedDirective: + for stencil in self.interface.StartFusedStencil: + logger.info(f"Generating START FUSED statement for {stencil.name}") + self._generate( + StartFusedStencilStatement, + StartFusedStencilStatementGenerator, + stencil.startln, + stencil_data=stencil, + ) + + def _generate_end_fused_stencil(self) -> None: + """Generate f90 integration code surrounding a fused stencil.""" + if self.interface.EndFusedStencil != UnusedDirective: + for i, stencil in enumerate(self.interface.StartFusedStencil): + logger.info(f"Generating END Fused statement for {stencil.name}") + self._generate( + EndFusedStencilStatement, + EndFusedStencilStatementGenerator, + self.interface.EndFusedStencil[i].startln, + stencil_data=stencil, + ) + + def _generate_delete(self) -> None: + """Generate f90 integration code for delete section.""" + if self.interface.StartDelete != UnusedDirective: + logger.info("Generating DELETE statement.") + for start, end in zip( + self.interface.StartDelete, self.interface.EndDelete, strict=True + ): + self._generate( + StartDeleteStatement, + StartDeleteStatementGenerator, + start.startln, + ) + self._generate( + EndDeleteStatement, + EndDeleteStatementGenerator, + end.startln, + ) + def _generate_imports(self) -> None: """Generate f90 code for import statements.""" logger.info("Generating IMPORT statement.") @@ -174,7 +227,7 @@ def _generate_imports(self) -> None: ImportsStatement, ImportsStatementGenerator, self.interface.Imports.startln, - stencils=self.interface.StartStencil, + stencils=self.interface.StartStencil + self.interface.StartFusedStencil, ) def _generate_create(self) -> None: diff --git a/tools/src/icon4pytools/liskov/codegen/integration/interface.py b/tools/src/icon4pytools/liskov/codegen/integration/interface.py index b48e0d6a6b..93903cab78 100644 --- a/tools/src/icon4pytools/liskov/codegen/integration/interface.py +++ b/tools/src/icon4pytools/liskov/codegen/integration/interface.py @@ -78,23 +78,51 @@ class EndProfileData(CodeGenInput): @dataclass -class StartStencilData(CodeGenInput): +class BaseStartStencilData(CodeGenInput): name: str fields: list[FieldAssociationData] - bounds: BoundsData acc_present: Optional[bool] + bounds: BoundsData + + +@dataclass +class StartStencilData(BaseStartStencilData): mergecopy: Optional[bool] copies: Optional[bool] @dataclass -class EndStencilData(CodeGenInput): +class StartFusedStencilData(BaseStartStencilData): + ... + + +@dataclass +class BaseEndStencilData(CodeGenInput): name: str + + +@dataclass +class EndStencilData(BaseEndStencilData): noendif: Optional[bool] noprofile: Optional[bool] noaccenddata: Optional[bool] +@dataclass +class EndFusedStencilData(BaseEndStencilData): + ... + + +@dataclass +class StartDeleteData(CodeGenInput): + startln: int + + +@dataclass +class EndDeleteData(StartDeleteData): + ... + + @dataclass class InsertData(CodeGenInput): content: str @@ -104,6 +132,10 @@ class InsertData(CodeGenInput): class IntegrationCodeInterface: StartStencil: Sequence[StartStencilData] EndStencil: Sequence[EndStencilData] + StartFusedStencil: Sequence[StartFusedStencilData] + EndFusedStencil: Sequence[EndFusedStencilData] + StartDelete: Sequence[StartDeleteData] | UnusedDirective + EndDelete: Sequence[EndDeleteData] | UnusedDirective Declare: Sequence[DeclareData] Imports: ImportsData StartCreate: Sequence[StartCreateData] | UnusedDirective diff --git a/tools/src/icon4pytools/liskov/codegen/integration/template.py b/tools/src/icon4pytools/liskov/codegen/integration/template.py index dba9543864..2f154f82d2 100644 --- a/tools/src/icon4pytools/liskov/codegen/integration/template.py +++ b/tools/src/icon4pytools/liskov/codegen/integration/template.py @@ -20,7 +20,12 @@ from gt4py.eve.codegen import TemplatedGenerator from icon4pytools.liskov.codegen.integration.exceptions import UndeclaredFieldError -from icon4pytools.liskov.codegen.integration.interface import DeclareData, StartStencilData +from icon4pytools.liskov.codegen.integration.interface import ( + BaseStartStencilData, + DeclareData, + StartFusedStencilData, + StartStencilData, +) from icon4pytools.liskov.external.metadata import CodeMetadata @@ -97,19 +102,21 @@ class MetadataStatementGenerator(TemplatedGenerator): ) -class EndStencilStatement(eve.Node): - stencil_data: StartStencilData - profile: bool - noendif: Optional[bool] - noprofile: Optional[bool] - noaccenddata: Optional[bool] - +class EndBasicStencilStatement(eve.Node): name: str = eve.datamodels.field(init=False) input_fields: InputFields = eve.datamodels.field(init=False) output_fields: OutputFields = eve.datamodels.field(init=False) tolerance_fields: ToleranceFields = eve.datamodels.field(init=False) bounds_fields: BoundsFields = eve.datamodels.field(init=False) + +class EndStencilStatement(EndBasicStencilStatement): + stencil_data: StartStencilData + profile: bool + noendif: Optional[bool] + noprofile: Optional[bool] + noaccenddata: Optional[bool] + def __post_init__(self) -> None: # type: ignore all_fields = [Field(**asdict(f)) for f in self.stencil_data.fields] self.bounds_fields = BoundsFields(**asdict(self.stencil_data.bounds)) @@ -121,25 +128,7 @@ def __post_init__(self) -> None: # type: ignore ) -class EndStencilStatementGenerator(TemplatedGenerator): - EndStencilStatement = as_jinja( - """ - {%- if _this_node.profile %} - {% if _this_node.noprofile %}{% else %}call nvtxEndRange(){% endif %} - {%- endif %} - {% if _this_node.noendif %}{% else %}#endif{% endif %} - call wrap_run_{{ name }}( & - {{ input_fields }} - {{ output_fields }} - {{ tolerance_fields }} - {{ bounds_fields }} - - {%- if not _this_node.noaccenddata %} - !$ACC END DATA - {%- endif %} - """ - ) - +class BaseEndStencilStatementGenerator(TemplatedGenerator): InputFields = as_jinja( """ {%- for field in _this_node.fields %} @@ -200,10 +189,82 @@ def visit_OutputFields(self, out: OutputFields) -> OutputFields: # type: ignore ) +class EndStencilStatementGenerator(BaseEndStencilStatementGenerator): + EndStencilStatement = as_jinja( + """ + {%- if _this_node.profile %} + {% if _this_node.noprofile %}{% else %}call nvtxEndRange(){% endif %} + {%- endif %} + {% if _this_node.noendif %}{% else %}#endif{% endif %} + call wrap_run_{{ name }}( & + {{ input_fields }} + {{ output_fields }} + {{ tolerance_fields }} + {{ bounds_fields }} + + {%- if not _this_node.noaccenddata %} + !$ACC END DATA + {%- endif %} + """ + ) + + +class EndFusedStencilStatementGenerator(BaseEndStencilStatementGenerator): + EndFusedStencilStatement = as_jinja( + """ + call wrap_run_{{ name }}( & + {{ input_fields }} + {{ output_fields }} + {{ tolerance_fields }} + {{ bounds_fields }} + + !$ACC EXIT DATA DELETE( & + {%- for d in _this_node.copy_declarations %} + !$ACC {{ d.variable }}_before {%- if not loop.last -%}, & {% else %} ) & {%- endif -%} + {%- endfor %} + !$ACC IF ( i_am_accel_node ) + """ + ) + + class Declaration(Assign): ... +class CopyDeclaration(Declaration): + lh_index: str + rh_index: str + + +def _make_copy_declaration(f: Field) -> CopyDeclaration: + if f.dims is None: + raise UndeclaredFieldError(f"{f.variable} was not declared!") + + lh_idx = render_index(f.dims) + + # get length of association index + association_dims = get_array_dims(f.association).split(",") + n_association_dims = len(association_dims) + + offset = len(",".join(association_dims)) + 2 + truncated_association = f.association[:-offset] + + if n_association_dims > f.dims: + rh_idx = f"{lh_idx},{association_dims[-1]}" + else: + rh_idx = f"{lh_idx}" + + lh_idx = enclose_in_parentheses(lh_idx) + rh_idx = enclose_in_parentheses(rh_idx) + + return CopyDeclaration( + variable=f.variable, + association=truncated_association, + lh_index=lh_idx, + rh_index=rh_idx, + ) + + class DeclareStatement(eve.Node): declare_data: DeclareData declarations: list[Declaration] = eve.datamodels.field(init=False) @@ -226,11 +287,6 @@ class DeclareStatementGenerator(TemplatedGenerator): ) -class CopyDeclaration(Declaration): - lh_index: str - rh_index: str - - class StartStencilStatement(eve.Node): stencil_data: StartStencilData profile: bool @@ -238,36 +294,33 @@ class StartStencilStatement(eve.Node): def __post_init__(self) -> None: # type: ignore all_fields = [Field(**asdict(f)) for f in self.stencil_data.fields] - self.copy_declarations = [self.make_copy_declaration(f) for f in all_fields if f.out] + self.copy_declarations = [_make_copy_declaration(f) for f in all_fields if f.out] self.acc_present = "PRESENT" if self.stencil_data.acc_present else "NONE" - @staticmethod - def make_copy_declaration(f: Field) -> CopyDeclaration: - if f.dims is None: - raise UndeclaredFieldError(f"{f.variable} was not declared!") - lh_idx = render_index(f.dims) - - # get length of association index - association_dims = get_array_dims(f.association).split(",") - n_association_dims = len(association_dims) +class StartFusedStencilStatement(eve.Node): + stencil_data: StartFusedStencilData + copy_declarations: list[CopyDeclaration] = eve.datamodels.field(init=False) - offset = len(",".join(association_dims)) + 2 - truncated_association = f.association[:-offset] + def __post_init__(self) -> None: # type: ignore + all_fields = [Field(**asdict(f)) for f in self.stencil_data.fields] + self.copy_declarations = [_make_copy_declaration(f) for f in all_fields if f.out] + self.acc_present = "PRESENT" if self.stencil_data.acc_present else "NONE" - if n_association_dims > f.dims: - rh_idx = f"{lh_idx},{association_dims[-1]}" - else: - rh_idx = f"{lh_idx}" - lh_idx = enclose_in_parentheses(lh_idx) - rh_idx = enclose_in_parentheses(rh_idx) +class EndFusedStencilStatement(EndBasicStencilStatement): + stencil_data: StartFusedStencilData + copy_declarations: list[CopyDeclaration] = eve.datamodels.field(init=False) - return CopyDeclaration( - variable=f.variable, - association=truncated_association, - lh_index=lh_idx, - rh_index=rh_idx, + def __post_init__(self) -> None: # type: ignore + all_fields = [Field(**asdict(f)) for f in self.stencil_data.fields] + self.copy_declarations = [_make_copy_declaration(f) for f in all_fields if f.out] + self.bounds_fields = BoundsFields(**asdict(self.stencil_data.bounds)) + self.name = self.stencil_data.name + self.input_fields = InputFields(fields=[f for f in all_fields if f.inp]) + self.output_fields = OutputFields(fields=[f for f in all_fields if f.out]) + self.tolerance_fields = ToleranceFields( + fields=[f for f in all_fields if f.rel_tol or f.abs_tol] ) @@ -314,8 +367,30 @@ class StartStencilStatementGenerator(TemplatedGenerator): ) +class StartFusedStencilStatementGenerator(TemplatedGenerator): + StartFusedStencilStatement = as_jinja( + """ + + !$ACC ENTER DATA CREATE( & + {%- for d in _this_node.copy_declarations %} + !$ACC {{ d.variable }}_before {%- if not loop.last -%}, & {% else %} ) & {%- endif -%} + {%- endfor %} + !$ACC IF ( i_am_accel_node ) + + #ifdef __DSL_VERIFY + !$ACC KERNELS IF( i_am_accel_node ) DEFAULT(PRESENT) ASYNC(1) + {%- for d in _this_node.copy_declarations %} + {{ d.variable }}_before{{ d.lh_index }} = {{ d.association }}{{ d.rh_index }} + {%- endfor %} + !$ACC END KERNELS + #endif + + """ + ) + + class ImportsStatement(eve.Node): - stencils: list[StartStencilData] + stencils: list[BaseStartStencilData] stencil_names: list[str] = eve.datamodels.field(init=False) def __post_init__(self) -> None: # type: ignore @@ -354,6 +429,22 @@ class EndCreateStatementGenerator(TemplatedGenerator): EndCreateStatement = as_jinja("!$ACC END DATA") +class StartDeleteStatement(eve.Node): + ... + + +class StartDeleteStatementGenerator(TemplatedGenerator): + StartDeleteStatement = as_jinja("#ifdef __DSL_VERIFY") + + +class EndDeleteStatement(eve.Node): + ... + + +class EndDeleteStatementGenerator(TemplatedGenerator): + EndDeleteStatement = as_jinja("#endif") + + class EndIfStatement(eve.Node): ... diff --git a/tools/src/icon4pytools/liskov/external/gt4py.py b/tools/src/icon4pytools/liskov/external/gt4py.py index 1b98413f77..53b815173f 100644 --- a/tools/src/icon4pytools/liskov/external/gt4py.py +++ b/tools/src/icon4pytools/liskov/external/gt4py.py @@ -13,14 +13,17 @@ import importlib from inspect import getmembers -from typing import Any +from typing import Any, Sequence from gt4py.next.ffront.decorator import Program from icon4pytools.common import ICON4PY_MODEL_QUALIFIED_NAME from icon4pytools.common.logger import setup_logger from icon4pytools.icon4pygen.metadata import get_stencil_info -from icon4pytools.liskov.codegen.integration.interface import IntegrationCodeInterface +from icon4pytools.liskov.codegen.integration.interface import ( + BaseStartStencilData, + IntegrationCodeInterface, +) from icon4pytools.liskov.external.exceptions import IncompatibleFieldError, UnknownStencilError from icon4pytools.liskov.pipeline.definition import Step @@ -37,7 +40,13 @@ def __init__(self, parsed: IntegrationCodeInterface): def __call__(self, data: Any = None) -> IntegrationCodeInterface: logger.info("Updating parsed fields with data from icon4py stencils...") - for s in self.parsed.StartStencil: + self._set_in_out_field(self.parsed.StartStencil) + self._set_in_out_field(self.parsed.StartFusedStencil) + + return self.parsed + + def _set_in_out_field(self, startStencil: Sequence[BaseStartStencilData]) -> None: + for s in startStencil: gt4py_program = self._collect_icon4py_stencil(s.name) gt4py_stencil_info = get_stencil_info(gt4py_program) gt4py_fields = gt4py_stencil_info.fields @@ -45,12 +54,10 @@ def __call__(self, data: Any = None) -> IntegrationCodeInterface: try: field_info = gt4py_fields[f.variable] except KeyError: - raise IncompatibleFieldError( - f"Used field variable name that is incompatible with the expected field names defined in {s.name} in icon4py." - ) + error_msg = f"Used field variable name ({f.variable}) that is incompatible with the expected field names defined in {s.name} in icon4py." + raise IncompatibleFieldError(error_msg) f.out = field_info.out f.inp = field_info.inp - return self.parsed def _collect_icon4py_stencil(self, stencil_name: str) -> Program: """Collect and return the ICON4PY stencil program with the given name.""" diff --git a/tools/src/icon4pytools/liskov/parsing/parse.py b/tools/src/icon4pytools/liskov/parsing/parse.py index 02759621a5..7d60812b08 100644 --- a/tools/src/icon4pytools/liskov/parsing/parse.py +++ b/tools/src/icon4pytools/liskov/parsing/parse.py @@ -187,6 +187,14 @@ class EndStencil(WithArguments): pattern = "END STENCIL" +class StartFusedStencil(WithArguments): + pattern = "START FUSED STENCIL" + + +class EndFusedStencil(WithArguments): + pattern = "END FUSED STENCIL" + + class Declare(WithArguments): pattern = "DECLARE" @@ -215,6 +223,14 @@ class EndProfile(WithoutArguments): pattern = "END PROFILE" +class StartDelete(WithoutArguments): + pattern = "START DELETE" + + +class EndDelete(WithoutArguments): + pattern = "END DELETE" + + class Insert(FreeForm): pattern = "INSERT" @@ -222,6 +238,10 @@ class Insert(FreeForm): SUPPORTED_DIRECTIVES: Sequence[Type[ParsedDirective]] = [ StartStencil, EndStencil, + StartFusedStencil, + EndFusedStencil, + StartDelete, + EndDelete, Imports, Declare, StartCreate, diff --git a/tools/src/icon4pytools/liskov/parsing/scan.py b/tools/src/icon4pytools/liskov/parsing/scan.py index 0fb9879749..a4da42309f 100644 --- a/tools/src/icon4pytools/liskov/parsing/scan.py +++ b/tools/src/icon4pytools/liskov/parsing/scan.py @@ -36,7 +36,9 @@ def __init__(self, input_filepath: Path) -> None: A directive must start with !$DSL ( with the directive arguments delimited by a ;. The directive if on multiple lines must include a & at the end of the line. The directive - must always be closed by a closing bracket ). + must always be closed by a closing bracket ). A directive can be + commented out by using a ! before the directive, + for example, !!$DSL means the directive is disabled. Example: !$DSL IMPORTS() @@ -61,11 +63,11 @@ def __call__(self, data: Any = None) -> list[ts.RawDirective]: with self.input_filepath.open() as f: scanned_directives = [] lines = f.readlines() - for lnumber, string in enumerate(lines): - if ts.DIRECTIVE_IDENT in string: - stripped = string.strip() + for lnumber, line in enumerate(lines): + stripped = line.strip() + if stripped.startswith(ts.DIRECTIVE_IDENT): eol = stripped[-1] - scanned = Scanned(string, lnumber) + scanned = Scanned(line, lnumber) scanned_directives.append(scanned) match eol: @@ -77,12 +79,12 @@ def __call__(self, data: Any = None) -> list[ts.RawDirective]: if ts.DIRECTIVE_IDENT not in next_line: raise DirectiveSyntaxError( f"Error in directive on line number: {lnumber + 1}\n Invalid use of & in single line " - f"directive. " + f"directive in file {self.input_filepath} ." ) continue case _: raise DirectiveSyntaxError( - f"Error in directive on line number: {lnumber + 1}\n Used invalid end of line character." + f"Error in directive on line number: {lnumber + 1}\n Used invalid end of line characterat in file {self.input_filepath} ." ) logger.info(f"Scanning for directives at {self.input_filepath}") return directives diff --git a/tools/src/icon4pytools/liskov/parsing/transform.py b/tools/src/icon4pytools/liskov/parsing/transform.py new file mode 100644 index 0000000000..12fd9c0478 --- /dev/null +++ b/tools/src/icon4pytools/liskov/parsing/transform.py @@ -0,0 +1,115 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +from typing import Any + +import icon4pytools.liskov.parsing.types as ts +from icon4pytools.common.logger import setup_logger +from icon4pytools.liskov.codegen.integration.interface import ( + EndDeleteData, + EndFusedStencilData, + EndStencilData, + IntegrationCodeInterface, + StartDeleteData, + StartFusedStencilData, + StartStencilData, + UnusedDirective, +) +from icon4pytools.liskov.pipeline.definition import Step + + +logger = setup_logger(__name__) + + +class StencilTransformer(Step): + def __init__(self, parsed: IntegrationCodeInterface, fused: bool) -> None: + self.parsed = parsed + self.fused = fused + + def __call__(self, data: Any = None) -> ts.ParsedDict: + """Transform stencils in the parse tree based on the 'fused' flag, transforming or removing as necessary. + + This method processes stencils present in the 'parsed' object according to the 'fused' + flag. If 'fused' is True, it identifies and processes stencils that are eligible for + deletion. If 'fused' is False, it removes fused stencils. + + Args: + data (Any): Optional data to be passed. Default is None. + + Returns: + ts.ParsedDict: The parsed directives along with any modifications applied. + """ + if self.fused: + logger.info("Transforming stencils for deletion.") + self._process_stencils_for_deletion() + else: + logger.info("Removing fused stencils.") + self._remove_fused_stencils() + + return self.parsed + + def _process_stencils_for_deletion(self) -> None: + stencils_to_remove = [] + + for start_fused, end_fused in zip( + self.parsed.StartFusedStencil, self.parsed.EndFusedStencil, strict=True + ): + for start_single, end_single in zip( + self.parsed.StartStencil, self.parsed.EndStencil, strict=True + ): + if self._stencil_is_removable(start_fused, end_fused, start_single, end_single): + self._create_delete_directives(start_single, end_single) + stencils_to_remove += [start_single, end_single] + + self._remove_stencils(stencils_to_remove) + + def _stencil_is_removable( + self, + start_fused: StartFusedStencilData, + end_fused: EndFusedStencilData, + start_single: StartStencilData, + end_single: EndStencilData, + ) -> bool: + return ( + start_fused.startln < start_single.startln + and start_single.startln < end_fused.startln + and start_fused.startln < end_single.startln + and end_single.startln < end_fused.startln + ) + + def _create_delete_directives( + self, start_single: StartStencilData, end_single: EndStencilData + ) -> None: + for attr, param in zip(["StartDelete", "EndDelete"], [start_single, end_single]): + directive = getattr(self.parsed, attr) + if directive == UnusedDirective: + directive = [] + + if attr == "StartDelete": + cls = StartDeleteData + elif attr == "EndDelete": + cls = EndDeleteData + + directive.append(cls(startln=param.startln)) + setattr(self.parsed, attr, directive) + + def _remove_stencils(self, stencils_to_remove: list[StartStencilData | EndStencilData]) -> None: + attributes_to_modify = ["StartStencil", "EndStencil"] + + for attr_name in attributes_to_modify: + current_stencil_list = getattr(self.parsed, attr_name) + modified_stencil_list = [_ for _ in current_stencil_list if _ not in stencils_to_remove] + setattr(self.parsed, attr_name, modified_stencil_list) + + def _remove_fused_stencils(self) -> None: + self.parsed.StartFusedStencil = [] + self.parsed.EndFusedStencil = [] diff --git a/tools/src/icon4pytools/liskov/pipeline/collection.py b/tools/src/icon4pytools/liskov/pipeline/collection.py index 006fa8a0e9..8af70e54ac 100644 --- a/tools/src/icon4pytools/liskov/pipeline/collection.py +++ b/tools/src/icon4pytools/liskov/pipeline/collection.py @@ -21,6 +21,7 @@ from icon4pytools.liskov.external.gt4py import UpdateFieldsWithGt4PyStencils from icon4pytools.liskov.parsing.parse import DirectivesParser from icon4pytools.liskov.parsing.scan import DirectivesScanner +from icon4pytools.liskov.parsing.transform import StencilTransformer from icon4pytools.liskov.pipeline.definition import Step, linear_pipeline @@ -68,16 +69,21 @@ def parse_fortran_file( @linear_pipeline -def load_gt4py_stencils(parsed: IntegrationCodeInterface) -> list[Step]: - """Execute a pipeline to update fields of a IntegrationCodeInterface object with GT4Py stencils. +def process_stencils(parsed: IntegrationCodeInterface, fused: bool) -> list[Step]: + """Execute a linear pipeline to transform stencils and produce either fused or unfused execution. + + This function takes an input `parsed` object of type `IntegrationCodeInterface` and a `fused` boolean flag. + It then executes a linear pipeline, consisting of two steps: transformation of stencils for fusion or unfusion, + and updating fields with information from GT4Py stencils. Args: - parsed: The input IntegrationCodeInterface object. + parsed (IntegrationCodeInterface): The input object containing parsed integration code. + fused (bool): A boolean flag indicating whether to produce fused (True) or unfused (False) execution. Returns: - The updated object with fields containing information from GT4Py stencils. + The updated and transformed object with fields containing information from GT4Py stencils. """ - return [UpdateFieldsWithGt4PyStencils(parsed)] + return [StencilTransformer(parsed, fused), UpdateFieldsWithGt4PyStencils(parsed)] @linear_pipeline diff --git a/tools/tests/liskov/fortran_samples.py b/tools/tests/liskov/fortran_samples.py index 858b3a2089..7ed24e5954 100644 --- a/tools/tests/liskov/fortran_samples.py +++ b/tools/tests/liskov/fortran_samples.py @@ -69,6 +69,55 @@ !$DSL END CREATE() """ +SINGLE_STENCIL_WITH_COMMENTS = """\ + ! Use !$DSL statements, they are great. They can be easily commented out by: + + !!$DSL IMPORTS() + + ! $DSL START CREATE() + + !$DSL IMPORTS() + + !$DSL START CREATE() + + !$DSL DECLARE(vn=nproma,p_patch%nlev,p_patch%nblks_e; suffix=dsl) + + !$DSL DECLARE(vn=nproma,p_patch%nlev,p_patch%nblks_e; a=nproma,p_patch%nlev,p_patch%nblks_e; & + !$DSL b=nproma,p_patch%nlev,p_patch%nblks_e; type=REAL(vp)) + + !$DSL START STENCIL(name=apply_nabla2_to_vn_in_lateral_boundary; & + !$DSL z_nabla2_e=z_nabla2_e(:,:,1); area_edge=p_patch%edges%area_edge(:,1); & + !$DSL fac_bdydiff_v=fac_bdydiff_v; vn=p_nh_prog%vn(:,:,1); & + !$DSL vertical_lower=1; vertical_upper=nlev; & + !$DSL horizontal_lower=i_startidx; horizontal_upper=i_endidx; & + !$DSL accpresent=True) + !$OMP DO PRIVATE(je,jk,jb,i_startidx,i_endidx) ICON_OMP_DEFAULT_SCHEDULE + DO jb = i_startblk,i_endblk + + CALL get_indices_e(p_patch, jb, i_startblk, i_endblk, & + i_startidx, i_endidx, start_bdydiff_e, grf_bdywidth_e) + + !$ACC PARALLEL IF( i_am_accel_node .AND. acc_on ) DEFAULT(NONE) ASYNC(1) + vn_before(:,:,:) = p_nh_prog%vn(:,:,:) + !$ACC END PARALLEL + + !$ACC PARALLEL LOOP DEFAULT(NONE) GANG VECTOR COLLAPSE(2) ASYNC(1) IF( i_am_accel_node .AND. acc_on ) + DO jk = 1, nlev + !DIR$ IVDEP + DO je = i_startidx, i_endidx + p_nh_prog%vn(je,jk,jb) = & + p_nh_prog%vn(je,jk,jb) + & + z_nabla2_e(je,jk,jb) * & + p_patch%edges%area_edge(je,jb)*fac_bdydiff_v + ENDDO + ENDDO + !$DSL START PROFILE(name=apply_nabla2_to_vn_in_lateral_boundary) + !$ACC END PARALLEL LOOP + !$DSL END PROFILE() + !$DSL END STENCIL(name=apply_nabla2_to_vn_in_lateral_boundary; noprofile=True) + !$DSL END CREATE() + """ + MULTIPLE_STENCILS = """\ !$DSL IMPORTS() @@ -212,6 +261,192 @@ """ +SINGLE_FUSED = """\ + !$DSL IMPORTS() + + !$DSL INSERT(INTEGER :: start_interior_idx_c, end_interior_idx_c, start_nudging_idx_c, end_halo_1_idx_c) + + !$DSL DECLARE(kh_smag_e=nproma,p_patch%nlev,p_patch%nblks_e; & + !$DSL kh_smag_ec=nproma,p_patch%nlev,p_patch%nblks_e; & + !$DSL z_nabla2_e=nproma,p_patch%nlev,p_patch%nblks_e; & + !$DSL kh_c=nproma,p_patch%nlev; & + !$DSL div=nproma,p_patch%nlev; & + !$DSL div_ic=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL hdef_ic=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL z_nabla4_e2=nproma,p_patch%nlev; & + !$DSL vn=nproma,p_patch%nlev,p_patch%nblks_e; & + !$DSL z_nabla2_c=nproma,p_patch%nlev,p_patch%nblks_e; & + !$DSL dwdx=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL dwdy=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL w=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL enh_diffu_3d=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL z_temp=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL theta_v=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL exner=nproma,p_patch%nlev,p_patch%nblks_c) + + !$DSL START FUSED STENCIL(name=calculate_diagnostic_quantities_for_turbulence; & + !$DSL kh_smag_ec=kh_smag_ec(:,:,1); vn=p_nh_prog%vn(:,:,1); e_bln_c_s=p_int%e_bln_c_s(:,:,1); & + !$DSL geofac_div=p_int%geofac_div(:,:,1); diff_multfac_smag=diff_multfac_smag(:); & + !$DSL wgtfac_c=p_nh_metrics%wgtfac_c(:,:,1); div_ic=p_nh_diag%div_ic(:,:,1); & + !$DSL hdef_ic=p_nh_diag%hdef_ic(:,:,1); & + !$DSL div_ic_abs_tol=1e-18_wp; vertical_lower=2; & + !$DSL vertical_upper=nlev; horizontal_lower=i_startidx; horizontal_upper=i_endidx) + + !$DSL START STENCIL(name=temporary_fields_for_turbulence_diagnostics; kh_smag_ec=kh_smag_ec(:,:,1); vn=p_nh_prog%vn(:,:,1); & + !$DSL e_bln_c_s=p_int%e_bln_c_s(:,:,1); geofac_div=p_int%geofac_div(:,:,1); & + !$DSL diff_multfac_smag=diff_multfac_smag(:); kh_c=kh_c(:,:); div=div(:,:); & + !$DSL vertical_lower=1; vertical_upper=nlev; horizontal_lower=i_startidx; & + !$DSL horizontal_upper=i_endidx) + + !$DSL END STENCIL(name=temporary_fields_for_turbulence_diagnostics) + + !$DSL START STENCIL(name=calculate_diagnostics_for_turbulence; div=div; kh_c=kh_c; wgtfac_c=p_nh_metrics%wgtfac_c(:,:,1); & + !$DSL div_ic=p_nh_diag%div_ic(:,:,1); hdef_ic=p_nh_diag%hdef_ic(:,:,1); div_ic_abs_tol=1e-18_wp; & + !$DSL vertical_lower=2; vertical_upper=nlev; horizontal_lower=i_startidx; horizontal_upper=i_endidx) + + !$DSL END STENCIL(name=calculate_diagnostics_for_turbulence) + + !$DSL END FUSED STENCIL(name=calculate_diagnostic_quantities_for_turbulence) + """ + + +MULTIPLE_FUSED = """\ + !$DSL IMPORTS() + !$DSL START DELETE() + !$DSL END DELETE() + !$DSL INSERT(INTEGER :: start_interior_idx_c, end_interior_idx_c, start_nudging_idx_c, end_halo_1_idx_c) + + !$DSL DECLARE(kh_smag_e=nproma,p_patch%nlev,p_patch%nblks_e; & + !$DSL kh_smag_ec=nproma,p_patch%nlev,p_patch%nblks_e; & + !$DSL z_nabla2_e=nproma,p_patch%nlev,p_patch%nblks_e; & + !$DSL kh_c=nproma,p_patch%nlev; & + !$DSL div=nproma,p_patch%nlev; & + !$DSL div_ic=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL hdef_ic=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL z_nabla4_e2=nproma,p_patch%nlev; & + !$DSL vn=nproma,p_patch%nlev,p_patch%nblks_e; & + !$DSL z_nabla2_c=nproma,p_patch%nlev,p_patch%nblks_e; & + !$DSL dwdx=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL dwdy=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL w=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL enh_diffu_3d=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL z_temp=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL theta_v=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL exner=nproma,p_patch%nlev,p_patch%nblks_c) + + !$DSL INSERT(start_2nd_nudge_line_idx_e = i_startidx) + !$DSL INSERT(end_interior_idx_e = i_endidx) + ! Compute nabla4(v) + + !$DSL START FUSED STENCIL(name=apply_diffusion_to_vn; & + !$DSL u_vert=u_vert(:,:,1); & + !$DSL v_vert=v_vert(:,:,1); & + !$DSL primal_normal_vert_v1=p_patch%edges%primal_normal_vert_x(:,:,1); & + !$DSL primal_normal_vert_v2=p_patch%edges%primal_normal_vert_y(:,:,1); & + !$DSL z_nabla2_e=z_nabla2_e(:,:,1); & + !$DSL inv_vert_vert_length=p_patch%edges%inv_vert_vert_length(:,1); & + !$DSL inv_primal_edge_length=p_patch%edges%inv_primal_edge_length(:,1); & + !$DSL area_edge=p_patch%edges%area_edge(:,1); & + !$DSL kh_smag_e=kh_smag_e(:,:,1); & + !$DSL diff_multfac_vn=diff_multfac_vn(:); & + !$DSL nudgecoeff_e=p_int%nudgecoeff_e(:,1); & + !$DSL vn=p_nh_prog%vn(:,:,1); & + !$DSL horz_idx=horz_idx(:); & + !$DSL nudgezone_diff=nudgezone_diff; & + !$DSL fac_bdydiff_v=fac_bdydiff_v; & + !$DSL start_2nd_nudge_line_idx_e=start_2nd_nudge_line_idx_e-1; & + !$DSL limited_area=l_limited_area; & + !$DSL vn_rel_tol=1e-11_wp; & + !$DSL vertical_lower=1; & + !$DSL vertical_upper=nlev; & + !$DSL horizontal_lower=start_bdydiff_idx_e; & + !$DSL horizontal_upper=end_interior_idx_e) + + !$DSL START STENCIL(name=calculate_nabla4; u_vert=u_vert(:,:,1); v_vert=v_vert(:,:,1); & + !$DSL primal_normal_vert_v1=p_patch%edges%primal_normal_vert_x(:,:,1); & + !$DSL primal_normal_vert_v2=p_patch%edges%primal_normal_vert_y(:,:,1); & + !$DSL z_nabla2_e=z_nabla2_e(:,:,1); inv_vert_vert_length=p_patch%edges%inv_vert_vert_length(:,1); & + !$DSL inv_primal_edge_length=p_patch%edges%inv_primal_edge_length(:,1); z_nabla4_e2_abs_tol=1e-27_wp; & + !$DSL z_nabla4_e2=z_nabla4_e2(:, :); vertical_lower=1; vertical_upper=nlev; horizontal_lower=i_startidx; & + !$DSL horizontal_upper=i_endidx) + + !$DSL END STENCIL(name=calculate_nabla4) + + !$DSL START STENCIL(name=apply_nabla2_and_nabla4_to_vn; nudgezone_diff=nudgezone_diff; area_edge=p_patch%edges%area_edge(:,1); & + !$DSL kh_smag_e=kh_smag_e(:,:,1); z_nabla2_e=z_nabla2_e(:,:,1); z_nabla4_e2=z_nabla4_e2(:,:); & + !$DSL diff_multfac_vn=diff_multfac_vn(:); nudgecoeff_e=p_int%nudgecoeff_e(:,1); vn=p_nh_prog%vn(:,:,1); vn_rel_tol=1e-11_wp; & + !$DSL vertical_lower=1; vertical_upper=nlev; horizontal_lower=i_startidx; horizontal_upper=i_endidx) + + !$DSL END STENCIL(name=apply_nabla2_and_nabla4_to_vn) + + !$DSL START STENCIL(name=apply_nabla2_and_nabla4_global_to_vn; area_edge=p_patch%edges%area_edge(:,1); kh_smag_e=kh_smag_e(:,:,1); & + !$DSL z_nabla2_e=z_nabla2_e(:,:,1); z_nabla4_e2=z_nabla4_e2(:,:); diff_multfac_vn=diff_multfac_vn(:); vn=p_nh_prog%vn(:,:,1); & + !$DSL vn_rel_tol=1e-10_wp; vertical_lower=1; vertical_upper=nlev; horizontal_lower=i_startidx; horizontal_upper=i_endidx) + + !$DSL END STENCIL(name=apply_nabla2_and_nabla4_global_to_vn) + + !$DSL INSERT(start_bdydiff_idx_e = i_startidx) + + !$DSL START STENCIL(name=apply_nabla2_to_vn_in_lateral_boundary; z_nabla2_e=z_nabla2_e(:,:,1); area_edge=p_patch%edges%area_edge(:,1); & + !$DSL fac_bdydiff_v=fac_bdydiff_v; vn=p_nh_prog%vn(:,:,1); vn_abs_tol=1e-14_wp; vertical_lower=1; & + !$DSL vertical_upper=nlev; horizontal_lower=i_startidx; horizontal_upper=i_endidx) + + !$DSL END STENCIL(name=apply_nabla2_to_vn_in_lateral_boundary) + + !$DSL END FUSED STENCIL(name=apply_diffusion_to_vn) + + + !$DSL INSERT(start_nudging_idx_c = i_startidx) + !$DSL INSERT(end_halo_1_idx_c = i_endidx) + + !$DSL INSERT(!$ACC PARALLEL IF( i_am_accel_node ) DEFAULT(PRESENT) ASYNC(1)) + !$DSL INSERT(w_old(:,:,:) = p_nh_prog%w(:,:,:)) + !$DSL INSERT(!$ACC END PARALLEL) + + !$DSL START FUSED STENCIL(name=apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulance; & + !$DSL area=p_patch%cells%area(:,1); geofac_grg_x=p_int%geofac_grg(:,:,1,1); & + !$DSL geofac_grg_y=p_int%geofac_grg(:,:,1,2); geofac_n2s=p_int%geofac_n2s(:,:,1); & + !$DSL w_old=w_old(:,:,1); w=p_nh_prog%w(:,:,1); diff_multfac_w=diff_multfac_w; & + !$DSL diff_multfac_n2w=diff_multfac_n2w(:); vert_idx=vert_idx(:); & + !$DSL horz_idx=horz_idx(:); nrdmax=nrdmax(jg); interior_idx=start_interior_idx_c-1; & + !$DSL halo_idx=end_interior_idx_c; dwdx=p_nh_diag%dwdx(:,:,1); & + !$DSL dwdy=p_nh_diag%dwdy(:,:,1); & + !$DSL w_rel_tol=1e-09_wp; dwdx_rel_tol=1e-09_wp; dwdy_abs_tol=1e-09_wp; & + !$DSL vertical_lower=1; vertical_upper=nlev; horizontal_lower=start_nudging_idx_c; & + !$DSL horizontal_upper=end_halo_1_idx_c) + + !$DSL START STENCIL(name=calculate_nabla2_for_w; w=p_nh_prog%w(:,:,1); geofac_n2s=p_int%geofac_n2s(:,:,1); & + !$DSL z_nabla2_c=z_nabla2_c(:,:,1); z_nabla2_c_abs_tol=1e-21_wp; vertical_lower=1; vertical_upper=nlev; & + !$DSL horizontal_lower=i_startidx; horizontal_upper=i_endidx) + + !$DSL END STENCIL(name=calculate_nabla2_for_w) + + !$DSL START STENCIL(name=calculate_horizontal_gradients_for_turbulence; w=p_nh_prog%w(:,:,1); geofac_grg_x=p_int%geofac_grg(:,:,1,1); geofac_grg_y=p_int%geofac_grg(:,:,1,2); & + !$DSL dwdx=p_nh_diag%dwdx(:,:,1); dwdy=p_nh_diag%dwdy(:,:,1); dwdx_rel_tol=1e-09_wp; dwdy_rel_tol=1e-09_wp; vertical_lower=2; & + !$DSL vertical_upper=nlev; horizontal_lower=i_startidx; horizontal_upper=i_endidx) + + !$DSL END STENCIL(name=calculate_horizontal_gradients_for_turbulence) + + !$DSL INSERT(start_interior_idx_c = i_startidx) + !$DSL INSERT(end_interior_idx_c = i_endidx) + + !$DSL START STENCIL(name=apply_nabla2_to_w; diff_multfac_w=diff_multfac_w; area=p_patch%cells%area(:,1); & + !$DSL z_nabla2_c=z_nabla2_c(:,:,1); geofac_n2s=p_int%geofac_n2s(:,:,1); w=p_nh_prog%w(:,:,1); & + !$DSL w_abs_tol=1e-15_wp; vertical_lower=1; vertical_upper=nlev; horizontal_lower=i_startidx; & + !$DSL horizontal_upper=i_endidx) + + !$DSL END STENCIL(name=apply_nabla2_to_w) + + !$DSL START STENCIL(name=apply_nabla2_to_w_in_upper_damping_layer; w=p_nh_prog%w(:,:,1); diff_multfac_n2w=diff_multfac_n2w(:); & + !$DSL cell_area=p_patch%cells%area(:,1); z_nabla2_c=z_nabla2_c(:,:,1); vertical_lower=2; w_abs_tol=1e-16_wp; w_rel_tol=1e-10_wp; & + !$DSL vertical_upper=nrdmax(jg); horizontal_lower=i_startidx; horizontal_upper=i_endidx) + + !$DSL END STENCIL(name=apply_nabla2_to_w_in_upper_damping_layer) + + !$DSL END FUSED STENCIL(name=apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulance) + """ + + FREE_FORM_STENCIL = """\ !$DSL IMPORTS() diff --git a/tools/tests/liskov/test_cli.py b/tools/tests/liskov/test_cli.py index edd178f62e..292374ba7f 100644 --- a/tools/tests/liskov/test_cli.py +++ b/tools/tests/liskov/test_cli.py @@ -21,9 +21,12 @@ from .fortran_samples import ( CONSECUTIVE_STENCIL, FREE_FORM_STENCIL, + MULTIPLE_FUSED, MULTIPLE_STENCILS, NO_DIRECTIVES_STENCIL, + SINGLE_FUSED, SINGLE_STENCIL, + SINGLE_STENCIL_WITH_COMMENTS, ) @@ -37,12 +40,15 @@ def outfile(tmp_path): files = [ ("NO_DIRECTIVES", NO_DIRECTIVES_STENCIL), ("SINGLE", SINGLE_STENCIL), + ("COMMENTS", SINGLE_STENCIL_WITH_COMMENTS), ("CONSECUTIVE", CONSECUTIVE_STENCIL), ("FREE_FORM", FREE_FORM_STENCIL), ("MULTIPLE", MULTIPLE_STENCILS), + ("SINGLE_FUSED", SINGLE_FUSED), + ("MULTIPLE_FUSED", MULTIPLE_FUSED), ] -flags = {"serialise": ["--multinode"], "integrate": ["-p", "-m"]} +flags = {"serialise": ["--multinode"], "integrate": ["-p", "-m", "-f", "-u"]} for file_name, file_content in files: for cmd in flags.keys(): diff --git a/tools/tests/liskov/test_directives_deserialiser.py b/tools/tests/liskov/test_directives_deserialiser.py index 13a235f6f5..56fb66e57f 100644 --- a/tools/tests/liskov/test_directives_deserialiser.py +++ b/tools/tests/liskov/test_directives_deserialiser.py @@ -19,6 +19,7 @@ from icon4pytools.liskov.codegen.integration.deserialise import ( DeclareDataFactory, EndCreateDataFactory, + EndDeleteDataFactory, EndIfDataFactory, EndProfileDataFactory, EndStencilDataFactory, @@ -32,6 +33,7 @@ BoundsData, DeclareData, EndCreateData, + EndDeleteData, EndIfData, EndProfileData, EndStencilData, @@ -83,6 +85,14 @@ 5, EndProfileData, ), + ( + EndDeleteDataFactory, + ts.EndDelete, + "END DELETE", + 6, + 6, + EndDeleteData, + ), ], ) def test_data_factories_no_args(factory_class, directive_type, string, startln, endln, expected): diff --git a/tools/tests/liskov/test_external.py b/tools/tests/liskov/test_external.py index e109dcaefa..739c2ba217 100644 --- a/tools/tests/liskov/test_external.py +++ b/tools/tests/liskov/test_external.py @@ -82,6 +82,10 @@ def test_stencil_collector_invalid_member(): Imports=None, Declare=None, EndStencil=None, + StartFusedStencil=None, + EndFusedStencil=None, + StartDelete=None, + EndDelete=None, StartCreate=None, EndCreate=None, EndIf=None, diff --git a/tools/tests/liskov/test_generation.py b/tools/tests/liskov/test_generation.py index dc16b09c6a..bfe9397b34 100644 --- a/tools/tests/liskov/test_generation.py +++ b/tools/tests/liskov/test_generation.py @@ -93,6 +93,10 @@ def integration_code_interface(): return IntegrationCodeInterface( StartStencil=[start_stencil_data], EndStencil=[end_stencil_data], + StartFusedStencil=[], + EndFusedStencil=[], + StartDelete=[], + EndDelete=[], Declare=[declare_data], Imports=imports_data, StartCreate=[start_create_data], diff --git a/tools/tests/liskov/test_parser.py b/tools/tests/liskov/test_parser.py index 5bb2518e13..0b7094d56c 100644 --- a/tools/tests/liskov/test_parser.py +++ b/tools/tests/liskov/test_parser.py @@ -22,7 +22,13 @@ from icon4pytools.liskov.parsing.parse import DirectivesParser from .conftest import insert_new_lines, scan_for_directives -from .fortran_samples import MULTIPLE_STENCILS, NO_DIRECTIVES_STENCIL, SINGLE_STENCIL +from .fortran_samples import ( + MULTIPLE_STENCILS, + NO_DIRECTIVES_STENCIL, + SINGLE_FUSED, + SINGLE_STENCIL, + SINGLE_STENCIL_WITH_COMMENTS, +) def test_parse_no_input(): @@ -76,7 +82,12 @@ def test_parse_single_directive(directive, string, startln, endln, expected_cont @mark.parametrize( "stencil, num_directives, num_content", - [(SINGLE_STENCIL, 9, 8), (MULTIPLE_STENCILS, 11, 7)], + [ + (SINGLE_STENCIL, 9, 8), + (SINGLE_STENCIL_WITH_COMMENTS, 9, 8), + (MULTIPLE_STENCILS, 11, 7), + (SINGLE_FUSED, 9, 7), + ], ) def test_file_parsing(make_f90_tmpfile, stencil, num_directives, num_content): fpath = make_f90_tmpfile(content=stencil) @@ -108,6 +119,7 @@ def test_directive_parser_no_directives_found(make_f90_tmpfile): "stencil, directive", [ (SINGLE_STENCIL, "!$DSL FOO()"), + (SINGLE_STENCIL_WITH_COMMENTS, "!$DSL FOO()"), (MULTIPLE_STENCILS, "!$DSL BAR()"), ], ) diff --git a/tools/tests/liskov/test_validation.py b/tools/tests/liskov/test_validation.py index d6fbc4f927..ffa2be1bef 100644 --- a/tools/tests/liskov/test_validation.py +++ b/tools/tests/liskov/test_validation.py @@ -24,7 +24,7 @@ from icon4pytools.liskov.parsing.validation import DirectiveSyntaxValidator from .conftest import insert_new_lines, scan_for_directives -from .fortran_samples import MULTIPLE_STENCILS, SINGLE_STENCIL +from .fortran_samples import MULTIPLE_STENCILS, SINGLE_STENCIL_WITH_COMMENTS @mark.parametrize( @@ -73,7 +73,7 @@ def test_directive_syntax_validator(directive): ], ) def test_directive_semantics_validation_repeated_directives(make_f90_tmpfile, directive): - fpath = make_f90_tmpfile(content=SINGLE_STENCIL) + fpath = make_f90_tmpfile(content=SINGLE_STENCIL_WITH_COMMENTS) opath = fpath.with_suffix(".gen") insert_new_lines(fpath, [directive]) directives = scan_for_directives(fpath) @@ -93,7 +93,7 @@ def test_directive_semantics_validation_repeated_directives(make_f90_tmpfile, di ], ) def test_directive_semantics_validation_repeated_stencil(make_f90_tmpfile, directive): - fpath = make_f90_tmpfile(content=SINGLE_STENCIL) + fpath = make_f90_tmpfile(content=SINGLE_STENCIL_WITH_COMMENTS) opath = fpath.with_suffix(".gen") insert_new_lines(fpath, [directive]) directives = scan_for_directives(fpath) @@ -109,7 +109,7 @@ def test_directive_semantics_validation_repeated_stencil(make_f90_tmpfile, direc ], ) def test_directive_semantics_validation_required_directives(make_f90_tmpfile, directive): - new = SINGLE_STENCIL.replace(directive, "") + new = SINGLE_STENCIL_WITH_COMMENTS.replace(directive, "") fpath = make_f90_tmpfile(content=new) opath = fpath.with_suffix(".gen") directives = scan_for_directives(fpath)