diff --git a/ci/default.yml b/ci/default.yml
index 4a85316030..7c17844fac 100644
--- a/ci/default.yml
+++ b/ci/default.yml
@@ -17,7 +17,7 @@ test_model_stencils:
# exclude slow test configurations
- if: $BACKEND == "roundtrip" && $GRID == "icon_grid"
when: never
- - when: always
+ - when: on_success
test_tools:
extends: .test_template
diff --git a/tools/README.md b/tools/README.md
index d6288ad2f7..72132757df 100644
--- a/tools/README.md
+++ b/tools/README.md
@@ -148,6 +148,8 @@ In addition, other optional keyword arguments are the following:
- `copies`: Takes a boolean string input, and controls whether before field copies should be made or not. If set to False only the `#ifdef __DSL_VERIFY` directive is generated. Defaults to true.
+- `optional_module`: Takes a boolean string input, and controls whether stencils is part of an optional module. Defaults to "None".
+
#### `!$DSL END STENCIL()`
This directive denotes the end of a stencil. The required argument is `name`, which must match the name of the preceding `START STENCIL` directive.
diff --git a/tools/src/icon4pytools/liskov/cli.py b/tools/src/icon4pytools/liskov/cli.py
index a90afbc09e..a118199e4d 100644
--- a/tools/src/icon4pytools/liskov/cli.py
+++ b/tools/src/icon4pytools/liskov/cli.py
@@ -12,6 +12,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import pathlib
+from typing import Optional
import click
@@ -27,6 +28,10 @@
logger = setup_logger(__name__)
+def split_comma(ctx, param, value) -> Optional[tuple[str]]:
+ return tuple(v.strip() for v in value.split(",")) if value else None
+
+
@click.group(invoke_without_command=True)
@click.pass_context
def main(ctx: click.Context) -> None:
@@ -56,6 +61,11 @@ def main(ctx: click.Context) -> None:
default=False,
help="Adds fused or unfused stencils.",
)
+@click.option(
+ "--optional-modules-to-enable",
+ callback=split_comma,
+ help="Specify a list of comma-separated optional DSL modules to enable.",
+)
@click.option(
"--verification/--substitution",
"-v/-s",
@@ -77,10 +87,13 @@ def integrate(
verification: bool,
profile: bool,
metadatagen: bool,
+ optional_modules_to_enable: Optional[tuple[str]],
) -> None:
mode = "integration"
iface = parse_fortran_file(input_path, output_path, mode)
- iface_gt4py = process_stencils(iface, fused)
+ iface_gt4py = process_stencils(
+ iface, fused, optional_modules_to_enable=optional_modules_to_enable
+ )
run_code_generation(
input_path,
output_path,
diff --git a/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py b/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py
index 54e085d685..d669f220c6 100644
--- a/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py
+++ b/tools/src/icon4pytools/liskov/codegen/integration/deserialise.py
@@ -53,6 +53,10 @@
TOLERANCE_ARGS = ["abs_tol", "rel_tol"]
DEFAULT_DECLARE_IDENT_TYPE = "REAL(wp)"
DEFAULT_DECLARE_SUFFIX = "before"
+DEFAULT_STARTSTENCIL_ACC_PRESENT = "true"
+DEFAULT_STARTSTENCIL_MERGECOPY = "false"
+DEFAULT_STARTSTENCIL_COPIES = "true"
+DEFAULT_STARTSTENCIL_OPTIONAL_MODULE = "None"
logger = setup_logger(__name__)
@@ -286,12 +290,13 @@ def create_stencil_data(
for i, directive in enumerate(directives):
named_args = parsed["content"][directive_cls.__name__][i]
additional_attrs = self._pop_additional_attributes(dtype, named_args)
- acc_present = string_to_bool(pop_item_from_dict(named_args, "accpresent", "true"))
+ acc_present = string_to_bool(
+ pop_item_from_dict(named_args, "accpresent", DEFAULT_STARTSTENCIL_ACC_PRESENT)
+ )
stencil_name = _extract_stencil_name(named_args, directive)
bounds = self._make_bounds(named_args)
fields = self._make_fields(named_args, field_dimensions)
fields_w_tolerance = self._update_tolerances(named_args, fields)
-
deserialised.append(
dtype(
name=stencil_name,
@@ -305,14 +310,27 @@ def create_stencil_data(
return deserialised
def _pop_additional_attributes(
- self, dtype: Type[StartStencilData | StartFusedStencilData], named_args: dict[str, Any]
+ self,
+ dtype: Type[StartStencilData | StartFusedStencilData],
+ named_args: dict[str, Any],
) -> dict:
"""Pop and return additional attributes specific to StartStencilData."""
additional_attrs = {}
if dtype == StartStencilData:
- mergecopy = string_to_bool(pop_item_from_dict(named_args, "mergecopy", "false"))
- copies = string_to_bool(pop_item_from_dict(named_args, "copies", "true"))
- additional_attrs = {"mergecopy": mergecopy, "copies": copies}
+ mergecopy = string_to_bool(
+ pop_item_from_dict(named_args, "mergecopy", DEFAULT_STARTSTENCIL_MERGECOPY)
+ )
+ copies = string_to_bool(
+ pop_item_from_dict(named_args, "copies", DEFAULT_STARTSTENCIL_COPIES)
+ )
+ optional_module = pop_item_from_dict(
+ named_args, "optional_module", DEFAULT_STARTSTENCIL_OPTIONAL_MODULE
+ )
+ additional_attrs = {
+ "mergecopy": mergecopy,
+ "copies": copies,
+ "optional_module": optional_module,
+ }
return additional_attrs
@staticmethod
diff --git a/tools/src/icon4pytools/liskov/codegen/integration/generate.py b/tools/src/icon4pytools/liskov/codegen/integration/generate.py
index 077e6f437d..fb8bd5ccd4 100644
--- a/tools/src/icon4pytools/liskov/codegen/integration/generate.py
+++ b/tools/src/icon4pytools/liskov/codegen/integration/generate.py
@@ -146,6 +146,7 @@ def _generate_start_stencil(self) -> None:
acc_present=stencil.acc_present,
mergecopy=stencil.mergecopy,
copies=stencil.copies,
+ optional_module=stencil.optional_module,
)
i += 2
diff --git a/tools/src/icon4pytools/liskov/codegen/integration/interface.py b/tools/src/icon4pytools/liskov/codegen/integration/interface.py
index 93903cab78..07cb43872e 100644
--- a/tools/src/icon4pytools/liskov/codegen/integration/interface.py
+++ b/tools/src/icon4pytools/liskov/codegen/integration/interface.py
@@ -89,6 +89,7 @@ class BaseStartStencilData(CodeGenInput):
class StartStencilData(BaseStartStencilData):
mergecopy: Optional[bool]
copies: Optional[bool]
+ optional_module: Optional[str]
@dataclass
diff --git a/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py b/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py
index 6860b455f5..f550b4fd09 100644
--- a/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py
+++ b/tools/src/icon4pytools/liskov/codegen/serialisation/deserialise.py
@@ -40,6 +40,7 @@
"accpresent",
"mergecopy",
"copies",
+ "optional_module",
"horizontal_lower",
"horizontal_upper",
"vertical_lower",
diff --git a/tools/src/icon4pytools/liskov/parsing/transform.py b/tools/src/icon4pytools/liskov/parsing/transform.py
index e6a0bd8ef8..ed1c8ad93b 100644
--- a/tools/src/icon4pytools/liskov/parsing/transform.py
+++ b/tools/src/icon4pytools/liskov/parsing/transform.py
@@ -10,9 +10,10 @@
# distribution for a copy of the license or check .
#
# SPDX-License-Identifier: GPL-3.0-or-later
-from typing import Any
+from typing import Any, Optional
from icon4pytools.common.logger import setup_logger
+from icon4pytools.liskov.codegen.integration.deserialise import DEFAULT_STARTSTENCIL_OPTIONAL_MODULE
from icon4pytools.liskov.codegen.integration.interface import (
EndDeleteData,
EndFusedStencilData,
@@ -30,7 +31,18 @@
logger = setup_logger(__name__)
-class StencilTransformer(Step):
+def _remove_stencils(
+ parsed: IntegrationCodeInterface, stencils_to_remove: list[CodeGenInput]
+) -> None:
+ attributes_to_modify = ["StartStencil", "EndStencil"]
+
+ for attr_name in attributes_to_modify:
+ current_stencil_list = getattr(parsed, attr_name)
+ modified_stencil_list = [_ for _ in current_stencil_list if _ not in stencils_to_remove]
+ setattr(parsed, attr_name, modified_stencil_list)
+
+
+class FusedStencilTransformer(Step):
def __init__(self, parsed: IntegrationCodeInterface, fused: bool) -> None:
self.parsed = parsed
self.fused = fused
@@ -71,7 +83,7 @@ def _process_stencils_for_deletion(self) -> None:
self._create_delete_directives(start_single, end_single)
stencils_to_remove += [start_single, end_single]
- self._remove_stencils(stencils_to_remove)
+ _remove_stencils(self.parsed, stencils_to_remove)
def _stencil_is_removable(
self,
@@ -105,14 +117,6 @@ def _create_delete_directives(
directive.append(cls(startln=param.startln))
setattr(self.parsed, attr, directive)
- def _remove_stencils(self, stencils_to_remove: list[CodeGenInput]) -> None:
- attributes_to_modify = ["StartStencil", "EndStencil"]
-
- for attr_name in attributes_to_modify:
- current_stencil_list = getattr(self.parsed, attr_name)
- modified_stencil_list = [_ for _ in current_stencil_list if _ not in stencils_to_remove]
- setattr(self.parsed, attr_name, modified_stencil_list)
-
def _remove_fused_stencils(self) -> None:
self.parsed.StartFusedStencil = []
self.parsed.EndFusedStencil = []
@@ -120,3 +124,52 @@ def _remove_fused_stencils(self) -> None:
def _remove_delete(self) -> None:
self.parsed.StartDelete = []
self.parsed.EndDelete = []
+
+
+class OptionalModulesTransformer(Step):
+ def __init__(
+ self, parsed: IntegrationCodeInterface, optional_modules_to_enable: Optional[tuple[str]]
+ ) -> None:
+ self.parsed = parsed
+ self.optional_modules_to_enable = optional_modules_to_enable
+
+ def __call__(self, data: Any = None) -> IntegrationCodeInterface:
+ """Transform stencils in the parse tree based on 'optional_modules_to_enable', either enabling specific modules or removing them.
+
+ Args:
+ data (Any): Optional data to be passed. Defaults to None.
+
+ Returns:
+ IntegrationCodeInterface: The modified interface object.
+ """
+ if self.optional_modules_to_enable is not None:
+ action = "enabling"
+ else:
+ action = "removing"
+ logger.info(f"Transforming stencils by {action} optional modules.")
+ self._transform_stencils()
+
+ return self.parsed
+
+ def _transform_stencils(self) -> None:
+ """Identify stencils to transform based on 'optional_modules_to_enable' and applies necessary changes."""
+ stencils_to_remove = []
+ for start_stencil, end_stencil in zip(
+ self.parsed.StartStencil, self.parsed.EndStencil, strict=False
+ ):
+ if self._should_remove_stencil(start_stencil):
+ stencils_to_remove.extend([start_stencil, end_stencil])
+
+ _remove_stencils(self.parsed, stencils_to_remove)
+
+ def _should_remove_stencil(self, stencil: StartStencilData) -> bool:
+ """Determine if a stencil should be removed based on 'optional_modules_to_enable'.
+
+ Returns:
+ bool: True if the stencil should be removed, False otherwise.
+ """
+ if stencil.optional_module == DEFAULT_STARTSTENCIL_OPTIONAL_MODULE:
+ return False
+ if self.optional_modules_to_enable is None:
+ return True
+ return stencil.optional_module not in self.optional_modules_to_enable
diff --git a/tools/src/icon4pytools/liskov/pipeline/collection.py b/tools/src/icon4pytools/liskov/pipeline/collection.py
index 4d458ad1cc..8b924926b3 100644
--- a/tools/src/icon4pytools/liskov/pipeline/collection.py
+++ b/tools/src/icon4pytools/liskov/pipeline/collection.py
@@ -11,7 +11,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
from pathlib import Path
-from typing import Any
+from typing import Any, Optional
from icon4pytools.liskov.codegen.integration.deserialise import IntegrationCodeDeserialiser
from icon4pytools.liskov.codegen.integration.generate import IntegrationCodeGenerator
@@ -22,7 +22,10 @@
from icon4pytools.liskov.external.gt4py import UpdateFieldsWithGt4PyStencils
from icon4pytools.liskov.parsing.parse import DirectivesParser
from icon4pytools.liskov.parsing.scan import DirectivesScanner
-from icon4pytools.liskov.parsing.transform import StencilTransformer
+from icon4pytools.liskov.parsing.transform import (
+ FusedStencilTransformer,
+ OptionalModulesTransformer,
+)
from icon4pytools.liskov.pipeline.definition import Step, linear_pipeline
@@ -70,21 +73,28 @@ def parse_fortran_file(
@linear_pipeline
-def process_stencils(parsed: IntegrationCodeInterface, fused: bool) -> list[Step]:
+def process_stencils(
+ parsed: IntegrationCodeInterface, fused: bool, optional_modules_to_enable: Optional[tuple[str]]
+) -> list[Step]:
"""Execute a linear pipeline to transform stencils and produce either fused or unfused execution.
This function takes an input `parsed` object of type `IntegrationCodeInterface` and a `fused` boolean flag.
- It then executes a linear pipeline, consisting of two steps: transformation of stencils for fusion or unfusion,
- and updating fields with information from GT4Py stencils.
+ It then executes a linear pipeline, consisting of three steps: transformation of stencils for fusion or unfusion,
+ enabling optional modules, and updating fields with information from GT4Py stencils.
Args:
parsed (IntegrationCodeInterface): The input object containing parsed integration code.
fused (bool): A boolean flag indicating whether to produce fused (True) or unfused (False) execution.
+ optional_modules_to_enable (Optional[tuple[str]]): A tuple of optional modules to enable.
Returns:
- The updated and transformed object with fields containing information from GT4Py stencils.
+ The updated and transformed IntegrationCodeInterface object.
"""
- return [StencilTransformer(parsed, fused), UpdateFieldsWithGt4PyStencils(parsed)]
+ return [
+ FusedStencilTransformer(parsed, fused),
+ OptionalModulesTransformer(parsed, optional_modules_to_enable),
+ UpdateFieldsWithGt4PyStencils(parsed),
+ ]
@linear_pipeline
diff --git a/tools/tests/liskov/test_external.py b/tools/tests/liskov/test_external.py
index 55f2e5f3b2..454d01e428 100644
--- a/tools/tests/liskov/test_external.py
+++ b/tools/tests/liskov/test_external.py
@@ -77,6 +77,7 @@ def test_stencil_collector_invalid_member():
acc_present=False,
mergecopy=False,
copies=True,
+ optional_module="None",
)
],
Imports=None,
diff --git a/tools/tests/liskov/test_generation.py b/tools/tests/liskov/test_generation.py
index 46b7f43dae..b715f6b643 100644
--- a/tools/tests/liskov/test_generation.py
+++ b/tools/tests/liskov/test_generation.py
@@ -79,6 +79,7 @@ def integration_code_interface():
acc_present=False,
mergecopy=False,
copies=True,
+ optional_module="None",
)
end_stencil_data = EndStencilData(
name="stencil1", startln=3, noendif=False, noprofile=False, noaccenddata=False
diff --git a/tools/tests/liskov/test_serialisation_deserialiser.py b/tools/tests/liskov/test_serialisation_deserialiser.py
index 0431086beb..518029af3a 100644
--- a/tools/tests/liskov/test_serialisation_deserialiser.py
+++ b/tools/tests/liskov/test_serialisation_deserialiser.py
@@ -105,6 +105,7 @@ def parsed_dict():
"horizontal_lower": "i_startidx",
"horizontal_upper": "i_endidx",
"accpresent": "True",
+ "optional_module": "advection",
}
],
"StartProfile": [{"name": "apply_nabla2_to_vn_in_lateral_boundary"}],
@@ -127,5 +128,7 @@ def test_savepoint_data_factory(parsed_dict):
savepoints = SavepointDataFactory()(parsed_dict)
assert len(savepoints) == 2
assert any([isinstance(sp, SavepointData) for sp in savepoints])
+ # check that unnecessary keys have been removed
+ assert not any(f.variable == "optional_module" for sp in savepoints for f in sp.fields)
assert any([isinstance(f, FieldSerialisationData) for f in savepoints[0].fields])
assert any([isinstance(m, Metadata) for m in savepoints[0].metadata])
diff --git a/tools/tests/liskov/test_transform.py b/tools/tests/liskov/test_transform.py
index 4c2a60454c..1ac5a2bcbb 100644
--- a/tools/tests/liskov/test_transform.py
+++ b/tools/tests/liskov/test_transform.py
@@ -33,7 +33,10 @@
StartProfileData,
StartStencilData,
)
-from icon4pytools.liskov.parsing.transform import StencilTransformer
+from icon4pytools.liskov.parsing.transform import (
+ FusedStencilTransformer,
+ OptionalModulesTransformer,
+)
@pytest.fixture
@@ -94,6 +97,7 @@ def integration_code_interface():
acc_present=False,
mergecopy=False,
copies=True,
+ optional_module="None",
)
end_stencil_data1 = EndStencilData(
name="stencil1", startln=3, noendif=False, noprofile=False, noaccenddata=False
@@ -126,6 +130,7 @@ def integration_code_interface():
acc_present=False,
mergecopy=False,
copies=True,
+ optional_module="advection",
)
end_stencil_data2 = EndStencilData(
name="stencil2", startln=6, noendif=False, noprofile=False, noaccenddata=False
@@ -165,20 +170,32 @@ def integration_code_interface():
@pytest.fixture
-def stencil_transform_fused(integration_code_interface):
- return StencilTransformer(integration_code_interface, fused=True)
+def fused_stencil_transform_fused(integration_code_interface):
+ return FusedStencilTransformer(integration_code_interface, fused=True)
@pytest.fixture
-def stencil_transform_unfused(integration_code_interface):
- return StencilTransformer(integration_code_interface, fused=False)
+def fused_stencil_transform_unfused(integration_code_interface):
+ return FusedStencilTransformer(integration_code_interface, fused=False)
+
+
+@pytest.fixture
+def optional_modules_transform_enabled(integration_code_interface):
+ return OptionalModulesTransformer(
+ integration_code_interface, optional_modules_to_enable=["advection"]
+ )
+
+
+@pytest.fixture
+def optional_modules_transform_disabled(integration_code_interface):
+ return OptionalModulesTransformer(integration_code_interface, optional_modules_to_enable=None)
def test_transform_fused(
- stencil_transform_fused,
+ fused_stencil_transform_fused,
):
# Check that the transformed interface is as expected
- transformed = stencil_transform_fused()
+ transformed = fused_stencil_transform_fused()
assert len(transformed.StartFusedStencil) == 1
assert len(transformed.EndFusedStencil) == 1
assert len(transformed.StartStencil) == 1
@@ -188,10 +205,10 @@ def test_transform_fused(
def test_transform_unfused(
- stencil_transform_unfused,
+ fused_stencil_transform_unfused,
):
# Check that the transformed interface is as expected
- transformed = stencil_transform_unfused()
+ transformed = fused_stencil_transform_unfused()
assert not transformed.StartFusedStencil
assert not transformed.EndFusedStencil
@@ -199,3 +216,30 @@ def test_transform_unfused(
assert len(transformed.EndStencil) == 2
assert not transformed.StartDelete
assert not transformed.EndDelete
+
+
+def test_transform_optional_enabled(
+ optional_modules_transform_enabled,
+):
+ # Check that the transformed interface is as expected
+ transformed = optional_modules_transform_enabled()
+ assert len(transformed.StartFusedStencil) == 1
+ assert len(transformed.EndFusedStencil) == 1
+ assert len(transformed.StartStencil) == 2
+ assert len(transformed.EndStencil) == 2
+ assert len(transformed.StartDelete) == 1
+ assert len(transformed.EndDelete) == 1
+
+
+def test_transform_optional_disabled(
+ optional_modules_transform_disabled,
+):
+ # Check that the transformed interface is as expected
+ transformed = optional_modules_transform_disabled()
+
+ assert len(transformed.StartFusedStencil) == 1
+ assert len(transformed.EndFusedStencil) == 1
+ assert len(transformed.StartStencil) == 1
+ assert len(transformed.EndStencil) == 1
+ assert len(transformed.StartDelete) == 1
+ assert len(transformed.EndDelete) == 1