From 158ab1f0ddcabba291bac182c95f3a38b18198ce Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 1 Feb 2024 01:32:38 +0100 Subject: [PATCH 1/2] Print warning when connectivity for offset provider is missing. --- .../src/icon4py/model/common/grid/base.py | 20 ++++++++++++++++--- .../model/common/test_utils/helpers.py | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/model/common/src/icon4py/model/common/grid/base.py b/model/common/src/icon4py/model/common/grid/base.py index ab47b1c225..9b655e5982 100644 --- a/model/common/src/icon4py/model/common/grid/base.py +++ b/model/common/src/icon4py/model/common/grid/base.py @@ -11,8 +11,10 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import warnings from abc import ABC, abstractmethod from dataclasses import dataclass +from functools import cached_property from typing import Callable, Dict import numpy as np @@ -26,6 +28,10 @@ from icon4py.model.common.utils import builder +class MissingConnectivity(ValueError): + pass + + @dataclass( frozen=True, ) @@ -97,11 +103,15 @@ def _update_size(self): self.size[KDim] = self.config.num_levels def _get_offset_provider(self, dim, from_dim, to_dim): + if dim not in self.connectivities: + raise MissingConnectivity() return NeighborTableOffsetProvider( self.connectivities[dim], from_dim, to_dim, self.size[dim] ) def _get_offset_provider_for_sparse_fields(self, dim, from_dim, to_dim): + if dim not in self.connectivities: + raise MissingConnectivity() return neighbortable_offset_provider_for_1d_sparse_fields( self.connectivities[dim].shape, from_dim, to_dim ) @@ -113,11 +123,15 @@ def get_offset_provider(self, name): else: raise Exception(f"Offset provider for {name} not found.") - def get_all_offset_providers(self): + @cached_property + def offset_providers(self): offset_providers = {} for key, value in self.offset_provider_mapping.items(): - method, *args = value - offset_providers[key] = method(*args) if args else method() + try: + method, *args = value + offset_providers[key] = method(*args) if args else method() + except MissingConnectivity: + warnings.warn(f"{key} connectivity is missing from grid.", stacklevel=2) return offset_providers diff --git a/model/common/src/icon4py/model/common/test_utils/helpers.py b/model/common/src/icon4py/model/common/test_utils/helpers.py index 160141d37b..ec22cf58a8 100644 --- a/model/common/src/icon4py/model/common/test_utils/helpers.py +++ b/model/common/src/icon4py/model/common/test_utils/helpers.py @@ -169,7 +169,7 @@ def _test_validation(self, grid, backend, input_data): self.PROGRAM.with_backend(backend)( **input_data, - offset_provider=grid.get_all_offset_providers(), + offset_provider=grid.offset_providers(), ) for out in self.OUTPUTS: name, refslice, gtslice = ( From a951b826634016c7594c800c7ba4fdd609bad869 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 1 Feb 2024 01:36:17 +0100 Subject: [PATCH 2/2] Fix typo --- .../src/icon4py/model/common/grid/base.py | 24 +++++++++---------- .../model/common/test_utils/helpers.py | 4 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/model/common/src/icon4py/model/common/grid/base.py b/model/common/src/icon4py/model/common/grid/base.py index 9b655e5982..d4df89e2d8 100644 --- a/model/common/src/icon4py/model/common/grid/base.py +++ b/model/common/src/icon4py/model/common/grid/base.py @@ -86,6 +86,18 @@ def num_edges(self) -> int: def num_levels(self) -> int: pass + @cached_property + def offset_providers(self): + offset_providers = {} + for key, value in self.offset_provider_mapping.items(): + try: + method, *args = value + offset_providers[key] = method(*args) if args else method() + except MissingConnectivity: + warnings.warn(f"{key} connectivity is missing from grid.", stacklevel=2) + + return offset_providers + @builder def with_connectivities(self, connectivity: Dict[Dimension, np.ndarray]): self.connectivities.update({d: k.astype(int) for d, k in connectivity.items()}) @@ -123,17 +135,5 @@ def get_offset_provider(self, name): else: raise Exception(f"Offset provider for {name} not found.") - @cached_property - def offset_providers(self): - offset_providers = {} - for key, value in self.offset_provider_mapping.items(): - try: - method, *args = value - offset_providers[key] = method(*args) if args else method() - except MissingConnectivity: - warnings.warn(f"{key} connectivity is missing from grid.", stacklevel=2) - - return offset_providers - def update_size_connectivities(self, new_sizes): self.size.update(new_sizes) diff --git a/model/common/src/icon4py/model/common/test_utils/helpers.py b/model/common/src/icon4py/model/common/test_utils/helpers.py index ec22cf58a8..2826042b0a 100644 --- a/model/common/src/icon4py/model/common/test_utils/helpers.py +++ b/model/common/src/icon4py/model/common/test_utils/helpers.py @@ -169,7 +169,7 @@ def _test_validation(self, grid, backend, input_data): self.PROGRAM.with_backend(backend)( **input_data, - offset_provider=grid.offset_providers(), + offset_provider=grid.offset_providers, ) for out in self.OUTPUTS: name, refslice, gtslice = ( @@ -195,7 +195,7 @@ def _test_execution_benchmark(self, pytestconfig, grid, backend, input_data, ben benchmark( self.PROGRAM.with_backend(backend), **input_data, - offset_provider=grid.get_all_offset_providers(), + offset_provider=grid.offset_providers, ) else: