Skip to content

Commit

Permalink
Merge branch 'skip_offset_provider' into fix_symbolic_grid_sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Feb 1, 2024
2 parents a2a6362 + a951b82 commit 9aa34d2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
30 changes: 22 additions & 8 deletions model/common/src/icon4py/model/common/grid/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import cached_property
from typing import Callable, Dict

import numpy as np
Expand All @@ -26,6 +28,10 @@
from icon4py.model.common.utils import builder


class MissingConnectivity(ValueError):
pass


@dataclass(
frozen=True,
)
Expand Down Expand Up @@ -80,6 +86,18 @@ def num_edges(self) -> int:
def num_levels(self) -> int:
pass

@cached_property
def offset_providers(self):
offset_providers = {}
for key, value in self.offset_provider_mapping.items():
try:
method, *args = value
offset_providers[key] = method(*args) if args else method()
except MissingConnectivity:
warnings.warn(f"{key} connectivity is missing from grid.", stacklevel=2)

return offset_providers

@builder
def with_connectivities(self, connectivity: Dict[Dimension, np.ndarray]):
self.connectivities.update({d: k.astype(int) for d, k in connectivity.items()})
Expand All @@ -97,11 +115,15 @@ def _update_size(self):
self.size[KDim] = self.config.num_levels

def _get_offset_provider(self, dim, from_dim, to_dim):
if dim not in self.connectivities:
raise MissingConnectivity()
return NeighborTableOffsetProvider(
self.connectivities[dim], from_dim, to_dim, self.size[dim]
)

def _get_offset_provider_for_sparse_fields(self, dim, from_dim, to_dim):
if dim not in self.connectivities:
raise MissingConnectivity()
return neighbortable_offset_provider_for_1d_sparse_fields(
self.connectivities[dim].shape, from_dim, to_dim
)
Expand All @@ -113,13 +135,5 @@ def get_offset_provider(self, name):
else:
raise Exception(f"Offset provider for {name} not found.")

def get_all_offset_providers(self):
offset_providers = {}
for key, value in self.offset_provider_mapping.items():
method, *args = value
offset_providers[key] = method(*args) if args else method()

return offset_providers

def update_size_connectivities(self, new_sizes):
self.size.update(new_sizes)
4 changes: 2 additions & 2 deletions model/common/src/icon4py/model/common/test_utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _test_validation(self, grid, backend, input_data):

self.PROGRAM.with_backend(backend)(
**input_data,
offset_provider=grid.get_all_offset_providers(),
offset_provider=grid.offset_providers,
)
for out in self.OUTPUTS:
name, refslice, gtslice = (
Expand All @@ -195,7 +195,7 @@ def _test_execution_benchmark(self, pytestconfig, grid, backend, input_data, ben
benchmark(
self.PROGRAM.with_backend(backend),
**input_data,
offset_provider=grid.get_all_offset_providers(),
offset_provider=grid.offset_providers,
)

else:
Expand Down

0 comments on commit 9aa34d2

Please sign in to comment.