Skip to content

Commit

Permalink
Enable optional stencils and modules in Liskov (#358)
Browse files Browse the repository at this point in the history

This PR allows to integrate optional stencils using Liskov. The GT4Py stencils are ignored if their module e.g. advection is not specified explicitly at compilation. The motivation for this feature is that the performance of some modules (e.g. advection and graupel) is not yet satisfying but keeping them in a separate branch is a lot of overhead.

More specifically this PR:

    adds a new argument parameter optional_module to the START STENCIL liskov directive to specify which optional module the stencil belongs to, e.g. optional_module=advection.
    adds a new option --enable_optional_stencil to the liskov CLI to enable the optional stencils and specify which module. e.g. --enable_optional_stencil=advection. Corresponding PR in icon-exclaim adds the changes in the icon_liskov preprocessor command in configure and CMakelist.txt.

---------

Co-authored-by: Nina Burgdorfer <[email protected]>
Co-authored-by: samkellerhals <[email protected]>
Co-authored-by: Daniel Hupp <[email protected]>
Co-authored-by: Nicoletta Farabullini <[email protected]>
  • Loading branch information
5 people authored and iomaganaris committed Jun 18, 2024
1 parent 86caf76 commit 8f677d9
Show file tree
Hide file tree
Showing 13 changed files with 183 additions and 35 deletions.
2 changes: 1 addition & 1 deletion ci/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ test_model_stencils:
# exclude slow test configurations
- if: $BACKEND == "roundtrip" && $GRID == "icon_grid"
when: never
- when: always
- when: on_success

test_tools:
extends: .test_template
Expand Down
2 changes: 2 additions & 0 deletions tools/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ In addition, other optional keyword arguments are the following:

- `copies`: Takes a boolean string input, and controls whether before field copies should be made or not. If set to False only the `#ifdef __DSL_VERIFY` directive is generated. Defaults to true.<br><br>

- `optional_module`: Takes a boolean string input, and controls whether stencils is part of an optional module. Defaults to "None".<br><br>

#### `!$DSL END STENCIL()`

This directive denotes the end of a stencil. The required argument is `name`, which must match the name of the preceding `START STENCIL` directive.
Expand Down
15 changes: 14 additions & 1 deletion tools/src/icon4pytools/liskov/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import pathlib
from typing import Optional

import click

Expand All @@ -27,6 +28,10 @@
logger = setup_logger(__name__)


def split_comma(ctx, param, value) -> Optional[tuple[str]]:
return tuple(v.strip() for v in value.split(",")) if value else None


@click.group(invoke_without_command=True)
@click.pass_context
def main(ctx: click.Context) -> None:
Expand Down Expand Up @@ -56,6 +61,11 @@ def main(ctx: click.Context) -> None:
default=False,
help="Adds fused or unfused stencils.",
)
@click.option(
"--optional-modules-to-enable",
callback=split_comma,
help="Specify a list of comma-separated optional DSL modules to enable.",
)
@click.option(
"--verification/--substitution",
"-v/-s",
Expand All @@ -77,10 +87,13 @@ def integrate(
verification: bool,
profile: bool,
metadatagen: bool,
optional_modules_to_enable: Optional[tuple[str]],
) -> None:
mode = "integration"
iface = parse_fortran_file(input_path, output_path, mode)
iface_gt4py = process_stencils(iface, fused)
iface_gt4py = process_stencils(
iface, fused, optional_modules_to_enable=optional_modules_to_enable
)
run_code_generation(
input_path,
output_path,
Expand Down
30 changes: 24 additions & 6 deletions tools/src/icon4pytools/liskov/codegen/integration/deserialise.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
TOLERANCE_ARGS = ["abs_tol", "rel_tol"]
DEFAULT_DECLARE_IDENT_TYPE = "REAL(wp)"
DEFAULT_DECLARE_SUFFIX = "before"
DEFAULT_STARTSTENCIL_ACC_PRESENT = "true"
DEFAULT_STARTSTENCIL_MERGECOPY = "false"
DEFAULT_STARTSTENCIL_COPIES = "true"
DEFAULT_STARTSTENCIL_OPTIONAL_MODULE = "None"

logger = setup_logger(__name__)

Expand Down Expand Up @@ -286,12 +290,13 @@ def create_stencil_data(
for i, directive in enumerate(directives):
named_args = parsed["content"][directive_cls.__name__][i]
additional_attrs = self._pop_additional_attributes(dtype, named_args)
acc_present = string_to_bool(pop_item_from_dict(named_args, "accpresent", "true"))
acc_present = string_to_bool(
pop_item_from_dict(named_args, "accpresent", DEFAULT_STARTSTENCIL_ACC_PRESENT)
)
stencil_name = _extract_stencil_name(named_args, directive)
bounds = self._make_bounds(named_args)
fields = self._make_fields(named_args, field_dimensions)
fields_w_tolerance = self._update_tolerances(named_args, fields)

deserialised.append(
dtype(
name=stencil_name,
Expand All @@ -305,14 +310,27 @@ def create_stencil_data(
return deserialised

def _pop_additional_attributes(
self, dtype: Type[StartStencilData | StartFusedStencilData], named_args: dict[str, Any]
self,
dtype: Type[StartStencilData | StartFusedStencilData],
named_args: dict[str, Any],
) -> dict:
"""Pop and return additional attributes specific to StartStencilData."""
additional_attrs = {}
if dtype == StartStencilData:
mergecopy = string_to_bool(pop_item_from_dict(named_args, "mergecopy", "false"))
copies = string_to_bool(pop_item_from_dict(named_args, "copies", "true"))
additional_attrs = {"mergecopy": mergecopy, "copies": copies}
mergecopy = string_to_bool(
pop_item_from_dict(named_args, "mergecopy", DEFAULT_STARTSTENCIL_MERGECOPY)
)
copies = string_to_bool(
pop_item_from_dict(named_args, "copies", DEFAULT_STARTSTENCIL_COPIES)
)
optional_module = pop_item_from_dict(
named_args, "optional_module", DEFAULT_STARTSTENCIL_OPTIONAL_MODULE
)
additional_attrs = {
"mergecopy": mergecopy,
"copies": copies,
"optional_module": optional_module,
}
return additional_attrs

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def _generate_start_stencil(self) -> None:
acc_present=stencil.acc_present,
mergecopy=stencil.mergecopy,
copies=stencil.copies,
optional_module=stencil.optional_module,
)
i += 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class BaseStartStencilData(CodeGenInput):
class StartStencilData(BaseStartStencilData):
mergecopy: Optional[bool]
copies: Optional[bool]
optional_module: Optional[str]


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"accpresent",
"mergecopy",
"copies",
"optional_module",
"horizontal_lower",
"horizontal_upper",
"vertical_lower",
Expand Down
75 changes: 64 additions & 11 deletions tools/src/icon4pytools/liskov/parsing/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
from typing import Any
from typing import Any, Optional

from icon4pytools.common.logger import setup_logger
from icon4pytools.liskov.codegen.integration.deserialise import DEFAULT_STARTSTENCIL_OPTIONAL_MODULE
from icon4pytools.liskov.codegen.integration.interface import (
EndDeleteData,
EndFusedStencilData,
Expand All @@ -30,7 +31,18 @@
logger = setup_logger(__name__)


class StencilTransformer(Step):
def _remove_stencils(
parsed: IntegrationCodeInterface, stencils_to_remove: list[CodeGenInput]
) -> None:
attributes_to_modify = ["StartStencil", "EndStencil"]

for attr_name in attributes_to_modify:
current_stencil_list = getattr(parsed, attr_name)
modified_stencil_list = [_ for _ in current_stencil_list if _ not in stencils_to_remove]
setattr(parsed, attr_name, modified_stencil_list)


class FusedStencilTransformer(Step):
def __init__(self, parsed: IntegrationCodeInterface, fused: bool) -> None:
self.parsed = parsed
self.fused = fused
Expand Down Expand Up @@ -71,7 +83,7 @@ def _process_stencils_for_deletion(self) -> None:
self._create_delete_directives(start_single, end_single)
stencils_to_remove += [start_single, end_single]

self._remove_stencils(stencils_to_remove)
_remove_stencils(self.parsed, stencils_to_remove)

def _stencil_is_removable(
self,
Expand Down Expand Up @@ -105,18 +117,59 @@ def _create_delete_directives(
directive.append(cls(startln=param.startln))
setattr(self.parsed, attr, directive)

def _remove_stencils(self, stencils_to_remove: list[CodeGenInput]) -> None:
attributes_to_modify = ["StartStencil", "EndStencil"]

for attr_name in attributes_to_modify:
current_stencil_list = getattr(self.parsed, attr_name)
modified_stencil_list = [_ for _ in current_stencil_list if _ not in stencils_to_remove]
setattr(self.parsed, attr_name, modified_stencil_list)

def _remove_fused_stencils(self) -> None:
self.parsed.StartFusedStencil = []
self.parsed.EndFusedStencil = []

def _remove_delete(self) -> None:
self.parsed.StartDelete = []
self.parsed.EndDelete = []


class OptionalModulesTransformer(Step):
def __init__(
self, parsed: IntegrationCodeInterface, optional_modules_to_enable: Optional[tuple[str]]
) -> None:
self.parsed = parsed
self.optional_modules_to_enable = optional_modules_to_enable

def __call__(self, data: Any = None) -> IntegrationCodeInterface:
"""Transform stencils in the parse tree based on 'optional_modules_to_enable', either enabling specific modules or removing them.
Args:
data (Any): Optional data to be passed. Defaults to None.
Returns:
IntegrationCodeInterface: The modified interface object.
"""
if self.optional_modules_to_enable is not None:
action = "enabling"
else:
action = "removing"
logger.info(f"Transforming stencils by {action} optional modules.")
self._transform_stencils()

return self.parsed

def _transform_stencils(self) -> None:
"""Identify stencils to transform based on 'optional_modules_to_enable' and applies necessary changes."""
stencils_to_remove = []
for start_stencil, end_stencil in zip(
self.parsed.StartStencil, self.parsed.EndStencil, strict=False
):
if self._should_remove_stencil(start_stencil):
stencils_to_remove.extend([start_stencil, end_stencil])

_remove_stencils(self.parsed, stencils_to_remove)

def _should_remove_stencil(self, stencil: StartStencilData) -> bool:
"""Determine if a stencil should be removed based on 'optional_modules_to_enable'.
Returns:
bool: True if the stencil should be removed, False otherwise.
"""
if stencil.optional_module == DEFAULT_STARTSTENCIL_OPTIONAL_MODULE:
return False
if self.optional_modules_to_enable is None:
return True
return stencil.optional_module not in self.optional_modules_to_enable
24 changes: 17 additions & 7 deletions tools/src/icon4pytools/liskov/pipeline/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
from pathlib import Path
from typing import Any
from typing import Any, Optional

from icon4pytools.liskov.codegen.integration.deserialise import IntegrationCodeDeserialiser
from icon4pytools.liskov.codegen.integration.generate import IntegrationCodeGenerator
Expand All @@ -22,7 +22,10 @@
from icon4pytools.liskov.external.gt4py import UpdateFieldsWithGt4PyStencils
from icon4pytools.liskov.parsing.parse import DirectivesParser
from icon4pytools.liskov.parsing.scan import DirectivesScanner
from icon4pytools.liskov.parsing.transform import StencilTransformer
from icon4pytools.liskov.parsing.transform import (
FusedStencilTransformer,
OptionalModulesTransformer,
)
from icon4pytools.liskov.pipeline.definition import Step, linear_pipeline


Expand Down Expand Up @@ -70,21 +73,28 @@ def parse_fortran_file(


@linear_pipeline
def process_stencils(parsed: IntegrationCodeInterface, fused: bool) -> list[Step]:
def process_stencils(
parsed: IntegrationCodeInterface, fused: bool, optional_modules_to_enable: Optional[tuple[str]]
) -> list[Step]:
"""Execute a linear pipeline to transform stencils and produce either fused or unfused execution.
This function takes an input `parsed` object of type `IntegrationCodeInterface` and a `fused` boolean flag.
It then executes a linear pipeline, consisting of two steps: transformation of stencils for fusion or unfusion,
and updating fields with information from GT4Py stencils.
It then executes a linear pipeline, consisting of three steps: transformation of stencils for fusion or unfusion,
enabling optional modules, and updating fields with information from GT4Py stencils.
Args:
parsed (IntegrationCodeInterface): The input object containing parsed integration code.
fused (bool): A boolean flag indicating whether to produce fused (True) or unfused (False) execution.
optional_modules_to_enable (Optional[tuple[str]]): A tuple of optional modules to enable.
Returns:
The updated and transformed object with fields containing information from GT4Py stencils.
The updated and transformed IntegrationCodeInterface object.
"""
return [StencilTransformer(parsed, fused), UpdateFieldsWithGt4PyStencils(parsed)]
return [
FusedStencilTransformer(parsed, fused),
OptionalModulesTransformer(parsed, optional_modules_to_enable),
UpdateFieldsWithGt4PyStencils(parsed),
]


@linear_pipeline
Expand Down
1 change: 1 addition & 0 deletions tools/tests/liskov/test_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_stencil_collector_invalid_member():
acc_present=False,
mergecopy=False,
copies=True,
optional_module="None",
)
],
Imports=None,
Expand Down
1 change: 1 addition & 0 deletions tools/tests/liskov/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def integration_code_interface():
acc_present=False,
mergecopy=False,
copies=True,
optional_module="None",
)
end_stencil_data = EndStencilData(
name="stencil1", startln=3, noendif=False, noprofile=False, noaccenddata=False
Expand Down
3 changes: 3 additions & 0 deletions tools/tests/liskov/test_serialisation_deserialiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def parsed_dict():
"horizontal_lower": "i_startidx",
"horizontal_upper": "i_endidx",
"accpresent": "True",
"optional_module": "advection",
}
],
"StartProfile": [{"name": "apply_nabla2_to_vn_in_lateral_boundary"}],
Expand All @@ -127,5 +128,7 @@ def test_savepoint_data_factory(parsed_dict):
savepoints = SavepointDataFactory()(parsed_dict)
assert len(savepoints) == 2
assert any([isinstance(sp, SavepointData) for sp in savepoints])
# check that unnecessary keys have been removed
assert not any(f.variable == "optional_module" for sp in savepoints for f in sp.fields)
assert any([isinstance(f, FieldSerialisationData) for f in savepoints[0].fields])
assert any([isinstance(m, Metadata) for m in savepoints[0].metadata])
Loading

0 comments on commit 8f677d9

Please sign in to comment.