diff --git a/ci/default.yml b/ci/default.yml index 4a85316030..7c17844fac 100644 --- a/ci/default.yml +++ b/ci/default.yml @@ -17,7 +17,7 @@ test_model_stencils: # exclude slow test configurations - if: $BACKEND == "roundtrip" && $GRID == "icon_grid" when: never - - when: always + - when: on_success test_tools: extends: .test_template diff --git a/tools/README.md b/tools/README.md index d6288ad2f7..72132757df 100644 --- a/tools/README.md +++ b/tools/README.md @@ -148,6 +148,8 @@ In addition, other optional keyword arguments are the following: - `copies`: Takes a boolean string input, and controls whether before field copies should be made or not. If set to False only the `#ifdef __DSL_VERIFY` directive is generated. Defaults to true.

+- `optional_module`: Takes a boolean string input, and controls whether stencils is part of an optional module. Defaults to "None".

+ #### `!$DSL END STENCIL()` This directive denotes the end of a stencil. The required argument is `name`, which must match the name of the preceding `START STENCIL` directive. diff --git a/tools/src/icon4pytools/liskov/cli.py b/tools/src/icon4pytools/liskov/cli.py index a90afbc09e..a118199e4d 100644 --- a/tools/src/icon4pytools/liskov/cli.py +++ b/tools/src/icon4pytools/liskov/cli.py @@ -12,6 +12,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import pathlib +from typing import Optional import click @@ -27,6 +28,10 @@ logger = setup_logger(__name__) +def split_comma(ctx, param, value) -> Optional[tuple[str]]: + return tuple(v.strip() for v in value.split(",")) if value else None + + @click.group(invoke_without_command=True) @click.pass_context def main(ctx: click.Context) -> None: @@ -56,6 +61,11 @@ def main(ctx: click.Context) -> None: default=False, help="Adds fused or unfused stencils.", ) +@click.option( + "--optional-modules-to-enable", + callback=split_comma, + help="Specify a list of comma-separated optional DSL modules to enable.", +) @click.option( "--verification/--substitution", "-v/-s", @@ -77,10 +87,13 @@ def integrate( verification: bool, profile: bool, metadatagen: bool, + optional_modules_to_enable: Optional[tuple[str]], ) -> None: mode = "integration" iface = parse_fortran_file(input_path, output_path, mode) - iface_gt4py = process_stencils(iface, fused) + iface_gt4py = process_stencils( + iface, fused, optional_modules_to_enable=optional_modules_to_enable + ) run_code_generation( input_path, output_path, diff --git a/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py b/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py index 54e085d685..d669f220c6 100644 --- a/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py +++ b/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py @@ -53,6 +53,10 @@ TOLERANCE_ARGS = ["abs_tol", "rel_tol"] DEFAULT_DECLARE_IDENT_TYPE = "REAL(wp)" DEFAULT_DECLARE_SUFFIX = "before" +DEFAULT_STARTSTENCIL_ACC_PRESENT = "true" +DEFAULT_STARTSTENCIL_MERGECOPY = "false" +DEFAULT_STARTSTENCIL_COPIES = "true" +DEFAULT_STARTSTENCIL_OPTIONAL_MODULE = "None" logger = setup_logger(__name__) @@ -286,12 +290,13 @@ def create_stencil_data( for i, directive in enumerate(directives): named_args = parsed["content"][directive_cls.__name__][i] additional_attrs = self._pop_additional_attributes(dtype, named_args) - acc_present = string_to_bool(pop_item_from_dict(named_args, "accpresent", "true")) + acc_present = string_to_bool( + pop_item_from_dict(named_args, "accpresent", DEFAULT_STARTSTENCIL_ACC_PRESENT) + ) stencil_name = _extract_stencil_name(named_args, directive) bounds = self._make_bounds(named_args) fields = self._make_fields(named_args, field_dimensions) fields_w_tolerance = self._update_tolerances(named_args, fields) - deserialised.append( dtype( name=stencil_name, @@ -305,14 +310,27 @@ def create_stencil_data( return deserialised def _pop_additional_attributes( - self, dtype: Type[StartStencilData | StartFusedStencilData], named_args: dict[str, Any] + self, + dtype: Type[StartStencilData | StartFusedStencilData], + named_args: dict[str, Any], ) -> dict: """Pop and return additional attributes specific to StartStencilData.""" additional_attrs = {} if dtype == StartStencilData: - mergecopy = string_to_bool(pop_item_from_dict(named_args, "mergecopy", "false")) - copies = string_to_bool(pop_item_from_dict(named_args, "copies", "true")) - additional_attrs = {"mergecopy": mergecopy, "copies": copies} + mergecopy = string_to_bool( + pop_item_from_dict(named_args, "mergecopy", DEFAULT_STARTSTENCIL_MERGECOPY) + ) + copies = string_to_bool( + pop_item_from_dict(named_args, "copies", DEFAULT_STARTSTENCIL_COPIES) + ) + optional_module = pop_item_from_dict( + named_args, "optional_module", DEFAULT_STARTSTENCIL_OPTIONAL_MODULE + ) + additional_attrs = { + "mergecopy": mergecopy, + "copies": copies, + "optional_module": optional_module, + } return additional_attrs @staticmethod diff --git a/tools/src/icon4pytools/liskov/codegen/integration/generate.py b/tools/src/icon4pytools/liskov/codegen/integration/generate.py index 077e6f437d..fb8bd5ccd4 100644 --- a/tools/src/icon4pytools/liskov/codegen/integration/generate.py +++ b/tools/src/icon4pytools/liskov/codegen/integration/generate.py @@ -146,6 +146,7 @@ def _generate_start_stencil(self) -> None: acc_present=stencil.acc_present, mergecopy=stencil.mergecopy, copies=stencil.copies, + optional_module=stencil.optional_module, ) i += 2 diff --git a/tools/src/icon4pytools/liskov/codegen/integration/interface.py b/tools/src/icon4pytools/liskov/codegen/integration/interface.py index 93903cab78..07cb43872e 100644 --- a/tools/src/icon4pytools/liskov/codegen/integration/interface.py +++ b/tools/src/icon4pytools/liskov/codegen/integration/interface.py @@ -89,6 +89,7 @@ class BaseStartStencilData(CodeGenInput): class StartStencilData(BaseStartStencilData): mergecopy: Optional[bool] copies: Optional[bool] + optional_module: Optional[str] @dataclass diff --git a/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py b/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py index 6860b455f5..f550b4fd09 100644 --- a/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py +++ b/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py @@ -40,6 +40,7 @@ "accpresent", "mergecopy", "copies", + "optional_module", "horizontal_lower", "horizontal_upper", "vertical_lower", diff --git a/tools/src/icon4pytools/liskov/parsing/transform.py b/tools/src/icon4pytools/liskov/parsing/transform.py index e6a0bd8ef8..ed1c8ad93b 100644 --- a/tools/src/icon4pytools/liskov/parsing/transform.py +++ b/tools/src/icon4pytools/liskov/parsing/transform.py @@ -10,9 +10,10 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any +from typing import Any, Optional from icon4pytools.common.logger import setup_logger +from icon4pytools.liskov.codegen.integration.deserialise import DEFAULT_STARTSTENCIL_OPTIONAL_MODULE from icon4pytools.liskov.codegen.integration.interface import ( EndDeleteData, EndFusedStencilData, @@ -30,7 +31,18 @@ logger = setup_logger(__name__) -class StencilTransformer(Step): +def _remove_stencils( + parsed: IntegrationCodeInterface, stencils_to_remove: list[CodeGenInput] +) -> None: + attributes_to_modify = ["StartStencil", "EndStencil"] + + for attr_name in attributes_to_modify: + current_stencil_list = getattr(parsed, attr_name) + modified_stencil_list = [_ for _ in current_stencil_list if _ not in stencils_to_remove] + setattr(parsed, attr_name, modified_stencil_list) + + +class FusedStencilTransformer(Step): def __init__(self, parsed: IntegrationCodeInterface, fused: bool) -> None: self.parsed = parsed self.fused = fused @@ -71,7 +83,7 @@ def _process_stencils_for_deletion(self) -> None: self._create_delete_directives(start_single, end_single) stencils_to_remove += [start_single, end_single] - self._remove_stencils(stencils_to_remove) + _remove_stencils(self.parsed, stencils_to_remove) def _stencil_is_removable( self, @@ -105,14 +117,6 @@ def _create_delete_directives( directive.append(cls(startln=param.startln)) setattr(self.parsed, attr, directive) - def _remove_stencils(self, stencils_to_remove: list[CodeGenInput]) -> None: - attributes_to_modify = ["StartStencil", "EndStencil"] - - for attr_name in attributes_to_modify: - current_stencil_list = getattr(self.parsed, attr_name) - modified_stencil_list = [_ for _ in current_stencil_list if _ not in stencils_to_remove] - setattr(self.parsed, attr_name, modified_stencil_list) - def _remove_fused_stencils(self) -> None: self.parsed.StartFusedStencil = [] self.parsed.EndFusedStencil = [] @@ -120,3 +124,52 @@ def _remove_fused_stencils(self) -> None: def _remove_delete(self) -> None: self.parsed.StartDelete = [] self.parsed.EndDelete = [] + + +class OptionalModulesTransformer(Step): + def __init__( + self, parsed: IntegrationCodeInterface, optional_modules_to_enable: Optional[tuple[str]] + ) -> None: + self.parsed = parsed + self.optional_modules_to_enable = optional_modules_to_enable + + def __call__(self, data: Any = None) -> IntegrationCodeInterface: + """Transform stencils in the parse tree based on 'optional_modules_to_enable', either enabling specific modules or removing them. + + Args: + data (Any): Optional data to be passed. Defaults to None. + + Returns: + IntegrationCodeInterface: The modified interface object. + """ + if self.optional_modules_to_enable is not None: + action = "enabling" + else: + action = "removing" + logger.info(f"Transforming stencils by {action} optional modules.") + self._transform_stencils() + + return self.parsed + + def _transform_stencils(self) -> None: + """Identify stencils to transform based on 'optional_modules_to_enable' and applies necessary changes.""" + stencils_to_remove = [] + for start_stencil, end_stencil in zip( + self.parsed.StartStencil, self.parsed.EndStencil, strict=False + ): + if self._should_remove_stencil(start_stencil): + stencils_to_remove.extend([start_stencil, end_stencil]) + + _remove_stencils(self.parsed, stencils_to_remove) + + def _should_remove_stencil(self, stencil: StartStencilData) -> bool: + """Determine if a stencil should be removed based on 'optional_modules_to_enable'. + + Returns: + bool: True if the stencil should be removed, False otherwise. + """ + if stencil.optional_module == DEFAULT_STARTSTENCIL_OPTIONAL_MODULE: + return False + if self.optional_modules_to_enable is None: + return True + return stencil.optional_module not in self.optional_modules_to_enable diff --git a/tools/src/icon4pytools/liskov/pipeline/collection.py b/tools/src/icon4pytools/liskov/pipeline/collection.py index 4d458ad1cc..8b924926b3 100644 --- a/tools/src/icon4pytools/liskov/pipeline/collection.py +++ b/tools/src/icon4pytools/liskov/pipeline/collection.py @@ -11,7 +11,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later from pathlib import Path -from typing import Any +from typing import Any, Optional from icon4pytools.liskov.codegen.integration.deserialise import IntegrationCodeDeserialiser from icon4pytools.liskov.codegen.integration.generate import IntegrationCodeGenerator @@ -22,7 +22,10 @@ from icon4pytools.liskov.external.gt4py import UpdateFieldsWithGt4PyStencils from icon4pytools.liskov.parsing.parse import DirectivesParser from icon4pytools.liskov.parsing.scan import DirectivesScanner -from icon4pytools.liskov.parsing.transform import StencilTransformer +from icon4pytools.liskov.parsing.transform import ( + FusedStencilTransformer, + OptionalModulesTransformer, +) from icon4pytools.liskov.pipeline.definition import Step, linear_pipeline @@ -70,21 +73,28 @@ def parse_fortran_file( @linear_pipeline -def process_stencils(parsed: IntegrationCodeInterface, fused: bool) -> list[Step]: +def process_stencils( + parsed: IntegrationCodeInterface, fused: bool, optional_modules_to_enable: Optional[tuple[str]] +) -> list[Step]: """Execute a linear pipeline to transform stencils and produce either fused or unfused execution. This function takes an input `parsed` object of type `IntegrationCodeInterface` and a `fused` boolean flag. - It then executes a linear pipeline, consisting of two steps: transformation of stencils for fusion or unfusion, - and updating fields with information from GT4Py stencils. + It then executes a linear pipeline, consisting of three steps: transformation of stencils for fusion or unfusion, + enabling optional modules, and updating fields with information from GT4Py stencils. Args: parsed (IntegrationCodeInterface): The input object containing parsed integration code. fused (bool): A boolean flag indicating whether to produce fused (True) or unfused (False) execution. + optional_modules_to_enable (Optional[tuple[str]]): A tuple of optional modules to enable. Returns: - The updated and transformed object with fields containing information from GT4Py stencils. + The updated and transformed IntegrationCodeInterface object. """ - return [StencilTransformer(parsed, fused), UpdateFieldsWithGt4PyStencils(parsed)] + return [ + FusedStencilTransformer(parsed, fused), + OptionalModulesTransformer(parsed, optional_modules_to_enable), + UpdateFieldsWithGt4PyStencils(parsed), + ] @linear_pipeline diff --git a/tools/tests/liskov/test_external.py b/tools/tests/liskov/test_external.py index 55f2e5f3b2..454d01e428 100644 --- a/tools/tests/liskov/test_external.py +++ b/tools/tests/liskov/test_external.py @@ -77,6 +77,7 @@ def test_stencil_collector_invalid_member(): acc_present=False, mergecopy=False, copies=True, + optional_module="None", ) ], Imports=None, diff --git a/tools/tests/liskov/test_generation.py b/tools/tests/liskov/test_generation.py index 46b7f43dae..b715f6b643 100644 --- a/tools/tests/liskov/test_generation.py +++ b/tools/tests/liskov/test_generation.py @@ -79,6 +79,7 @@ def integration_code_interface(): acc_present=False, mergecopy=False, copies=True, + optional_module="None", ) end_stencil_data = EndStencilData( name="stencil1", startln=3, noendif=False, noprofile=False, noaccenddata=False diff --git a/tools/tests/liskov/test_serialisation_deserialiser.py b/tools/tests/liskov/test_serialisation_deserialiser.py index 0431086beb..518029af3a 100644 --- a/tools/tests/liskov/test_serialisation_deserialiser.py +++ b/tools/tests/liskov/test_serialisation_deserialiser.py @@ -105,6 +105,7 @@ def parsed_dict(): "horizontal_lower": "i_startidx", "horizontal_upper": "i_endidx", "accpresent": "True", + "optional_module": "advection", } ], "StartProfile": [{"name": "apply_nabla2_to_vn_in_lateral_boundary"}], @@ -127,5 +128,7 @@ def test_savepoint_data_factory(parsed_dict): savepoints = SavepointDataFactory()(parsed_dict) assert len(savepoints) == 2 assert any([isinstance(sp, SavepointData) for sp in savepoints]) + # check that unnecessary keys have been removed + assert not any(f.variable == "optional_module" for sp in savepoints for f in sp.fields) assert any([isinstance(f, FieldSerialisationData) for f in savepoints[0].fields]) assert any([isinstance(m, Metadata) for m in savepoints[0].metadata]) diff --git a/tools/tests/liskov/test_transform.py b/tools/tests/liskov/test_transform.py index 4c2a60454c..1ac5a2bcbb 100644 --- a/tools/tests/liskov/test_transform.py +++ b/tools/tests/liskov/test_transform.py @@ -33,7 +33,10 @@ StartProfileData, StartStencilData, ) -from icon4pytools.liskov.parsing.transform import StencilTransformer +from icon4pytools.liskov.parsing.transform import ( + FusedStencilTransformer, + OptionalModulesTransformer, +) @pytest.fixture @@ -94,6 +97,7 @@ def integration_code_interface(): acc_present=False, mergecopy=False, copies=True, + optional_module="None", ) end_stencil_data1 = EndStencilData( name="stencil1", startln=3, noendif=False, noprofile=False, noaccenddata=False @@ -126,6 +130,7 @@ def integration_code_interface(): acc_present=False, mergecopy=False, copies=True, + optional_module="advection", ) end_stencil_data2 = EndStencilData( name="stencil2", startln=6, noendif=False, noprofile=False, noaccenddata=False @@ -165,20 +170,32 @@ def integration_code_interface(): @pytest.fixture -def stencil_transform_fused(integration_code_interface): - return StencilTransformer(integration_code_interface, fused=True) +def fused_stencil_transform_fused(integration_code_interface): + return FusedStencilTransformer(integration_code_interface, fused=True) @pytest.fixture -def stencil_transform_unfused(integration_code_interface): - return StencilTransformer(integration_code_interface, fused=False) +def fused_stencil_transform_unfused(integration_code_interface): + return FusedStencilTransformer(integration_code_interface, fused=False) + + +@pytest.fixture +def optional_modules_transform_enabled(integration_code_interface): + return OptionalModulesTransformer( + integration_code_interface, optional_modules_to_enable=["advection"] + ) + + +@pytest.fixture +def optional_modules_transform_disabled(integration_code_interface): + return OptionalModulesTransformer(integration_code_interface, optional_modules_to_enable=None) def test_transform_fused( - stencil_transform_fused, + fused_stencil_transform_fused, ): # Check that the transformed interface is as expected - transformed = stencil_transform_fused() + transformed = fused_stencil_transform_fused() assert len(transformed.StartFusedStencil) == 1 assert len(transformed.EndFusedStencil) == 1 assert len(transformed.StartStencil) == 1 @@ -188,10 +205,10 @@ def test_transform_fused( def test_transform_unfused( - stencil_transform_unfused, + fused_stencil_transform_unfused, ): # Check that the transformed interface is as expected - transformed = stencil_transform_unfused() + transformed = fused_stencil_transform_unfused() assert not transformed.StartFusedStencil assert not transformed.EndFusedStencil @@ -199,3 +216,30 @@ def test_transform_unfused( assert len(transformed.EndStencil) == 2 assert not transformed.StartDelete assert not transformed.EndDelete + + +def test_transform_optional_enabled( + optional_modules_transform_enabled, +): + # Check that the transformed interface is as expected + transformed = optional_modules_transform_enabled() + assert len(transformed.StartFusedStencil) == 1 + assert len(transformed.EndFusedStencil) == 1 + assert len(transformed.StartStencil) == 2 + assert len(transformed.EndStencil) == 2 + assert len(transformed.StartDelete) == 1 + assert len(transformed.EndDelete) == 1 + + +def test_transform_optional_disabled( + optional_modules_transform_disabled, +): + # Check that the transformed interface is as expected + transformed = optional_modules_transform_disabled() + + assert len(transformed.StartFusedStencil) == 1 + assert len(transformed.EndFusedStencil) == 1 + assert len(transformed.StartStencil) == 1 + assert len(transformed.EndStencil) == 1 + assert len(transformed.StartDelete) == 1 + assert len(transformed.EndDelete) == 1