Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

settings elements migration #575

Merged
merged 146 commits into from
Dec 6, 2024
Merged
Changes from 1 commit
Commits
Show all changes
146 commits
Select commit Hold shift + click to select a range
2779282
first set of change
nfarabullini Oct 24, 2024
0b839ba
further edits
nfarabullini Oct 24, 2024
8bc21cb
edits
nfarabullini Oct 24, 2024
e5db63b
Merge branch 'main' into settings_elements_migration
nfarabullini Oct 25, 2024
b948818
edits for dace_orchestration
nfarabullini Oct 25, 2024
2be59d5
placed elements in new settings file under wrapper
nfarabullini Oct 28, 2024
55bc6ad
fixed pre-commit errors
nfarabullini Oct 28, 2024
1353a77
fixed imports
nfarabullini Oct 28, 2024
929f35b
merge with upstream
nfarabullini Oct 29, 2024
e4d0c55
changes to include tools in requirements files
nfarabullini Oct 29, 2024
e125696
small edits
nfarabullini Oct 29, 2024
61a4ac0
Merge branch 'main' into settings_elements_migration
nfarabullini Oct 30, 2024
4cac1cb
requirements clean up
nfarabullini Oct 30, 2024
ab832ac
replaced numpy imports with xp
nfarabullini Oct 30, 2024
d993546
edits for dace orchestration
nfarabullini Oct 30, 2024
6fd93e8
removed unused import
nfarabullini Oct 30, 2024
51204fc
further edits
nfarabullini Oct 30, 2024
5614831
further edits
nfarabullini Oct 30, 2024
0d8adae
import edit
nfarabullini Oct 31, 2024
f5e5de5
column_stack edit in serialbox_utils
nfarabullini Oct 31, 2024
c69f194
column_stack edit in serialbox_utils
nfarabullini Oct 31, 2024
ab05fc3
small edit to column_stack
nfarabullini Oct 31, 2024
9c9cf48
Merge branch 'main' into settings_elements_migration
nfarabullini Oct 31, 2024
1f0bd9c
small edit to column_stack
nfarabullini Oct 31, 2024
1740129
Merge branch 'settings_elements_migration' of https://github.com/C2SM…
nfarabullini Oct 31, 2024
f0c3c27
otehr edits
nfarabullini Oct 31, 2024
37497df
added xp.asarray
nfarabullini Oct 31, 2024
328e3bb
added xp.asarray to offset grids
nfarabullini Oct 31, 2024
6e8f133
added xp.asarray to fields
nfarabullini Oct 31, 2024
1979491
added xp.asarray to fields
nfarabullini Nov 1, 2024
188994e
added xp.asarray to fields
nfarabullini Nov 1, 2024
5390706
small edit
nfarabullini Nov 1, 2024
d7de72f
edited xp in the constructor
nfarabullini Nov 1, 2024
7b9f5ab
edits
nfarabullini Nov 1, 2024
2a7f4dc
small edit to allocate_data
nfarabullini Nov 1, 2024
dcfa4fe
update with upstream
nfarabullini Nov 26, 2024
a06d303
Update test_compute_antidiffusive_cell_fluxes_and_min_max.py
nfarabullini Nov 26, 2024
9827453
edits
nfarabullini Nov 26, 2024
8acb135
Merge branch 'settings_elements_migration' of https://github.com/C2SM…
nfarabullini Nov 26, 2024
bf95797
edits
nfarabullini Nov 26, 2024
3b6b3be
from xp to np edits
nfarabullini Nov 26, 2024
2a43848
some edits
nfarabullini Nov 26, 2024
14ccc0e
some edits
nfarabullini Nov 26, 2024
e08797d
some edits
nfarabullini Nov 26, 2024
eefd4ee
small edit
nfarabullini Nov 26, 2024
bdd0298
edit
nfarabullini Nov 26, 2024
6e1d005
small edit
nfarabullini Nov 26, 2024
c9e7608
added backend to field functions
nfarabullini Nov 27, 2024
2ba5dc6
some edits
nfarabullini Nov 27, 2024
6319388
some edits
nfarabullini Nov 27, 2024
2a197f7
small edit
nfarabullini Nov 27, 2024
5bab62a
small edit
nfarabullini Nov 27, 2024
513537c
small edit
nfarabullini Nov 27, 2024
15b9fbf
small edit
nfarabullini Nov 27, 2024
3e454f3
small edit to lower bound
nfarabullini Nov 27, 2024
bad9082
reverted edits
nfarabullini Nov 27, 2024
0093cd0
merge with upstream
nfarabullini Nov 27, 2024
4dbc433
small edit
nfarabullini Nov 27, 2024
371469d
edits for CPU only
nfarabullini Nov 27, 2024
3a91abd
edit import from xp tp numpy
nfarabullini Nov 27, 2024
c912ae2
re-enabled gpu and rountrip backends
nfarabullini Nov 27, 2024
fb15ea6
small error fix
nfarabullini Nov 27, 2024
07fe5da
further edits
nfarabullini Nov 27, 2024
8eeede1
cleanup
nfarabullini Nov 28, 2024
ba55221
samll edit
nfarabullini Nov 28, 2024
b4d812d
samll edit
nfarabullini Nov 28, 2024
b7a2088
samll edits
nfarabullini Nov 28, 2024
ea0fbbd
reverted common edits
nfarabullini Nov 28, 2024
052d584
Merge branch 'main' into settings_elements_migration
nfarabullini Nov 28, 2024
042bef6
further edits
nfarabullini Nov 28, 2024
9bae952
Merge branch 'settings_elements_migration' of https://github.com/C2SM…
nfarabullini Nov 28, 2024
b282959
fixed np import syntax
nfarabullini Nov 28, 2024
6bcc418
small edit
nfarabullini Nov 28, 2024
894d543
ran pre-commit
nfarabullini Nov 28, 2024
9084212
small edit to dace_orchestration
nfarabullini Nov 28, 2024
5767572
Merge branch 'main' into settings_elements_migration
nfarabullini Nov 28, 2024
0bef661
small edit for dace backend specification
nfarabullini Nov 29, 2024
5f6aff2
Merge branch 'main' into settings_elements_migration
nfarabullini Nov 29, 2024
cbc077e
Delete model/common/src/icon4py/model/common/settings.py
nfarabullini Nov 29, 2024
9e5aa0b
Update model/atmosphere/diffusion/tests/diffusion_tests/mpi_tests/tes…
nfarabullini Nov 29, 2024
664b12f
Update model/atmosphere/diffusion/tests/diffusion_tests/test_diffusio…
nfarabullini Nov 29, 2024
53bfd33
Merge branch 'settings_elements_migration' of https://github.com/C2SM…
nfarabullini Nov 29, 2024
84fbe3f
import fixes after settings file deletion
nfarabullini Nov 29, 2024
b9a7070
Update model/atmosphere/diffusion/tests/diffusion_tests/test_diffusio…
nfarabullini Dec 2, 2024
710056e
Update tools/src/icon4pytools/py2fgen/wrappers/simple.py
nfarabullini Dec 2, 2024
c704298
Update model/common/src/icon4py/model/common/grid/refinement.py
nfarabullini Dec 2, 2024
3bfdde0
Update model/common/src/icon4py/model/common/grid/vertical.py
nfarabullini Dec 2, 2024
34c3e71
Update model/common/src/icon4py/model/common/grid/vertical.py
nfarabullini Dec 2, 2024
96d2c62
Update model/common/src/icon4py/model/common/grid/vertical.py
nfarabullini Dec 2, 2024
a88409f
Update model/common/src/icon4py/model/common/grid/base.py
nfarabullini Dec 2, 2024
7230e53
Update model/atmosphere/diffusion/tests/diffusion_tests/test_diffusio…
nfarabullini Dec 2, 2024
7eae218
Update model/common/src/icon4py/model/common/states/factory.py
nfarabullini Dec 2, 2024
bc1f1aa
Update model/common/src/icon4py/model/common/states/utils.py
nfarabullini Dec 2, 2024
fe02475
Update model/common/src/icon4py/model/common/states/utils.py
nfarabullini Dec 2, 2024
fcab32a
Update tools/tests/py2fgen/utils.py
nfarabullini Dec 2, 2024
7d55ed0
Update tools/tests/py2fgen/utils.py
nfarabullini Dec 2, 2024
3feff98
Update tools/tests/py2fgen/utils.py
nfarabullini Dec 2, 2024
3f22348
Update tools/tests/py2fgen/utils.py
nfarabullini Dec 2, 2024
a3a847f
Update tools/tests/py2fgen/utils.py
nfarabullini Dec 2, 2024
4a7b59d
Update tools/src/icon4pytools/py2fgen/wrappers/debug_utils.py
nfarabullini Dec 2, 2024
c8a9e70
Update tools/src/icon4pytools/py2fgen/wrappers/common.py
nfarabullini Dec 2, 2024
87c9479
Update tools/src/icon4pytools/py2fgen/wrappers/common.py
nfarabullini Dec 2, 2024
1c5b8e5
Update tools/src/icon4pytools/py2fgen/wrappers/common.py
nfarabullini Dec 2, 2024
9c4c338
Update model/common/src/icon4py/model/common/grid/geometry.py
nfarabullini Dec 2, 2024
c62660b
Update model/common/src/icon4py/model/common/grid/refinement.py
nfarabullini Dec 2, 2024
34739eb
Update model/common/src/icon4py/model/common/grid/refinement.py
nfarabullini Dec 2, 2024
89bc0e4
Update model/common/src/icon4py/model/common/grid/refinement.py
nfarabullini Dec 2, 2024
617017e
Update model/common/src/icon4py/model/common/test_utils/helpers.py
nfarabullini Dec 2, 2024
d0471f9
Update model/common/src/icon4py/model/common/utils/gt4py_field_alloca…
nfarabullini Dec 2, 2024
a56b5fd
edits following review
nfarabullini Dec 2, 2024
125bc7d
dace edits
nfarabullini Dec 2, 2024
a2447c9
merge with upstream
nfarabullini Dec 2, 2024
ecffbf8
zero_field helper function edit
nfarabullini Dec 2, 2024
31c3bc9
further edits
nfarabullini Dec 2, 2024
a623f94
dace edits
nfarabullini Dec 2, 2024
522b5cd
dace edits
nfarabullini Dec 3, 2024
00eabab
dace edits
nfarabullini Dec 3, 2024
5d38db8
further edits
nfarabullini Dec 3, 2024
d4f1974
Merge branch 'main' into settings_elements_migration
nfarabullini Dec 3, 2024
fc0bb7f
Merge branch 'main' into settings_elements_migration
nfarabullini Dec 3, 2024
f906b89
ran pre-commit
nfarabullini Dec 3, 2024
6a288d9
Update model/common/src/icon4py/model/common/decomposition/definition…
nfarabullini Dec 4, 2024
6e1b4e5
Update model/common/src/icon4py/model/common/interpolation/interpolat…
nfarabullini Dec 4, 2024
c8b8e56
Update tools/src/icon4pytools/py2fgen/wrappers/common.py
nfarabullini Dec 4, 2024
2aeb84c
Update tools/src/icon4pytools/py2fgen/wrappers/dycore_wrapper.py
nfarabullini Dec 4, 2024
5473668
Update tools/src/icon4pytools/py2fgen/wrappers/dycore_wrapper.py
nfarabullini Dec 4, 2024
e3d3993
Update tools/src/icon4pytools/py2fgen/wrappers/dycore_wrapper.py
nfarabullini Dec 4, 2024
6c33d9e
Update tools/src/icon4pytools/py2fgen/wrappers/common.py
nfarabullini Dec 4, 2024
8c34b44
Update model/common/src/icon4py/model/common/decomposition/definition…
nfarabullini Dec 4, 2024
9275477
Update model/common/src/icon4py/model/common/decomposition/mpi_decomp…
nfarabullini Dec 4, 2024
a57932e
Update model/common/src/icon4py/model/common/test_utils/helpers.py
nfarabullini Dec 4, 2024
834a995
Update tools/src/icon4pytools/py2fgen/template.py
nfarabullini Dec 4, 2024
cd22f34
edits following review
nfarabullini Dec 4, 2024
2d506e2
Merge branch 'settings_elements_migration' of https://github.com/C2SM…
nfarabullini Dec 4, 2024
daa6fd4
edits following review
nfarabullini Dec 4, 2024
d7a051e
small fix
nfarabullini Dec 5, 2024
467bdec
small fix
nfarabullini Dec 5, 2024
5f0bd7e
small fix
nfarabullini Dec 5, 2024
66fafb5
tests fixes
nfarabullini Dec 6, 2024
cb2f01b
Merge branch 'main' into settings_elements_migration
nfarabullini Dec 6, 2024
1be66a6
merge with upstream
nfarabullini Dec 6, 2024
6bc9766
Merge branch 'settings_elements_migration' of https://github.com/C2SM…
nfarabullini Dec 6, 2024
4eac480
Merge branch 'main' into settings_elements_migration
nfarabullini Dec 6, 2024
d73e4f4
ran pre-commit
nfarabullini Dec 6, 2024
2c18429
edits following review
nfarabullini Dec 6, 2024
641f169
extract ndarray from fields when calling construct_icon_grid
halungge Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
further edits
nfarabullini committed Oct 24, 2024
commit 0b839badf30110ce19c841eba49696dd0c5a6e63
Original file line number Diff line number Diff line change
@@ -7,9 +7,10 @@
# SPDX-License-Identifier: BSD-3-Clause

import pytest
from icon4pytools.py2fgen.wrappers.common import dace_orchestration

from icon4py.model.atmosphere.diffusion import diffusion as diffusion_, diffusion_states
from icon4py.model.common import dimension as dims, settings
from icon4py.model.common import dimension as dims
from icon4py.model.common.decomposition import definitions
from icon4py.model.common.grid import vertical as v_grid
from icon4py.model.common.test_utils import datatest_utils, helpers, parallel_helpers
@@ -171,7 +172,7 @@ def test_parallel_diffusion_multiple_steps(
backend,
diffusion_instance, # noqa: F811
):
if settings.dace_orchestration is None:
if dace_orchestration is None:
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
raise pytest.skip("This test is only executed for `--dace-orchestration=True`.")
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved

######################################################################
@@ -235,7 +236,7 @@ def test_parallel_diffusion_multiple_steps(
######################################################################
# DaCe NON-Orchestrated Backend
######################################################################
settings.dace_orchestration = None
dace_orchestration = None

diffusion = diffusion_.Diffusion(backend, exchange)

@@ -283,7 +284,7 @@ def test_parallel_diffusion_multiple_steps(
######################################################################
# DaCe Orchestrated Backend
######################################################################
settings.dace_orchestration = True
dace_orchestration = True

diffusion = diffusion_instance # the fixture makes sure that the orchestrator cache is cleared properly between pytest runs -if applicable-

Original file line number Diff line number Diff line change
@@ -8,11 +8,10 @@

import numpy as np
import pytest
from icon4pytools.py2fgen.wrappers.common import backend
from icon4pytools.py2fgen.wrappers.common import backend, dace_orchestration

import icon4py.model.common.dimension as dims
from icon4py.model.atmosphere.diffusion import diffusion, diffusion_states, diffusion_utils
from icon4py.model.common import settings
from icon4py.model.common.grid import vertical as v_grid
from icon4py.model.common.grid.geometry import CellParams, EdgeParams
from icon4py.model.common.test_utils import (
@@ -430,7 +429,7 @@ def test_run_diffusion_multiple_steps(
backend,
diffusion_instance, # noqa: F811
):
if settings.dace_orchestration is None:
if dace_orchestration is None:
raise pytest.skip("This test is only executed for `--dace-orchestration=True`.")

######################################################################
@@ -479,7 +478,7 @@ def test_run_diffusion_multiple_steps(
######################################################################
# DaCe NON-Orchestrated Backend
######################################################################
settings.dace_orchestration = None
dace_orchestration = None

diagnostic_state_dace_non_orch = diffusion_states.DiffusionDiagnosticState(
hdef_ic=savepoint_diffusion_init.hdef_ic(),
@@ -511,7 +510,7 @@ def test_run_diffusion_multiple_steps(
######################################################################
# DaCe Orchestrated Backend
######################################################################
settings.dace_orchestration = True
dace_orchestration = True

diagnostic_state_dace_orch = diffusion_states.DiffusionDiagnosticState(
hdef_ic=savepoint_diffusion_init.hdef_ic(),
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@

import numpy as np
import pytest

from icon4pytools.py2fgen.wrappers.common import backend

from icon4py.model.atmosphere.diffusion import diffusion, diffusion_utils
@@ -99,7 +100,7 @@ def test_diff_multfac_vn_smag_limit_for_loop_run_with_k4_substeps(backend):
def test_init_zero_vertex_k(backend):
grid = simple_grid.SimpleGrid()
f = helpers.random_field(grid, dims.VertexDim, dims.KDim)
diffusion_utils.init_zero_v_k.with_abckend(backend)(f, offset_provider={})
diffusion_utils.init_zero_v_k.with_backend(backend)(f, offset_provider={})
assert np.allclose(0.0, f.asnumpy())


104 changes: 0 additions & 104 deletions model/common/src/icon4py/model/common/config.py
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -7,109 +7,5 @@
# SPDX-License-Identifier: BSD-3-Clause

# Assuming this code is in a module called icon4py_config.py
import dataclasses
import os
from enum import Enum
from functools import cached_property

import numpy as np
from gt4py.next import itir_python as run_roundtrip
from gt4py.next.program_processors.runners.gtfn import (
run_gtfn_cached,
run_gtfn_gpu_cached,
)


try:
import dace
from gt4py.next.program_processors.runners.dace import (
run_dace_cpu,
run_dace_cpu_noopt,
run_dace_gpu,
run_dace_gpu_noopt,
)
except ImportError:
from types import ModuleType
from typing import Optional

dace: Optional[ModuleType] = None


class Device(Enum):
CPU = "CPU"
GPU = "GPU"


class GT4PyBackend(Enum):
CPU = "run_gtfn_cached"
GPU = "run_gtfn_gpu_cached"
ROUNDTRIP = "run_roundtrip"
DACE_CPU = "run_dace_cpu"
DACE_GPU = "run_dace_gpu"
DACE_CPU_NOOPT = "run_dace_cpu_noopt"
DACE_GPU_NOOPT = "run_dace_gpu_noopt"


@dataclasses.dataclass
class Icon4PyConfig:
@cached_property
def icon4py_backend(self):
backend = os.environ.get("ICON4PY_BACKEND", "CPU")
if hasattr(GT4PyBackend, backend):
return backend
else:
raise ValueError(
f"Invalid ICON4Py backend: {backend}. \n"
f"Available backends: {', '.join([f'{k}' for k in GT4PyBackend.__members__.keys()])}"
)

@cached_property
def icon4py_dace_orchestration(self):
# Any value other than None will be considered as True
return os.environ.get("ICON4PY_DACE_ORCHESTRATION", None)

@cached_property
def array_ns(self):
if self.device == Device.GPU:
import cupy as cp # type: ignore[import-untyped]

return cp
else:
return np

@cached_property
def gt4py_runner(self):
backend_map = {
GT4PyBackend.CPU.name: run_gtfn_cached,
GT4PyBackend.GPU.name: run_gtfn_gpu_cached,
GT4PyBackend.ROUNDTRIP.name: run_roundtrip,
}
if dace:
backend_map |= {
GT4PyBackend.DACE_CPU.name: run_dace_cpu,
GT4PyBackend.DACE_GPU.name: run_dace_gpu,
GT4PyBackend.DACE_CPU_NOOPT.name: run_dace_cpu_noopt,
GT4PyBackend.DACE_GPU_NOOPT.name: run_dace_gpu_noopt,
}
return backend_map[self.icon4py_backend]

@cached_property
def device(self):
device_map = {
GT4PyBackend.CPU.name: Device.CPU,
GT4PyBackend.GPU.name: Device.GPU,
GT4PyBackend.ROUNDTRIP.name: Device.CPU,
}
if dace:
device_map |= {
GT4PyBackend.DACE_CPU.name: Device.CPU,
GT4PyBackend.DACE_GPU.name: Device.GPU,
GT4PyBackend.DACE_CPU_NOOPT.name: Device.CPU,
GT4PyBackend.DACE_GPU_NOOPT.name: Device.GPU,
}
device = device_map[self.icon4py_backend]
return device

@cached_property
def limited_area(self):
return os.environ.get("ICON4PY_LAM", False)
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@
import gt4py.next as gtx
import numpy as np
from gt4py._core import definitions as core_defs
from icon4pytools.py2fgen.wrappers.common import dace_orchestration

from icon4py.model.common import dimension as dims, settings
from icon4py.model.common.decomposition import definitions as decomposition
@@ -88,7 +89,7 @@ def orchestrate(
"""

def _decorator(fuse_func: Callable[P, R]) -> Callable[P, R]:
if settings.dace_orchestration is not None:
if dace_orchestration is not None:
orchestrator_cache = {} # Caching

if "dace" not in settings.backend.name.lower():
4 changes: 2 additions & 2 deletions tools/src/icon4pytools/py2fgen/cli.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,6 @@
import pathlib

import click
from icon4py.model.common.config import GT4PyBackend

from icon4pytools.common.utils import write_string
from icon4pytools.py2fgen.generate import (
@@ -19,6 +18,7 @@
)
from icon4pytools.py2fgen.parsing import parse
from icon4pytools.py2fgen.plugin import generate_and_compile_cffi_plugin
from icon4pytools.py2fgen.wrappers import common


def parse_comma_separated_list(ctx, param, value) -> list[str]:
@@ -43,7 +43,7 @@ def parse_comma_separated_list(ctx, param, value) -> list[str]:
@click.option(
"--backend",
"-b",
type=click.Choice([e.name for e in GT4PyBackend], case_sensitive=False),
type=click.Choice([e.name for e in common.GT4PyBackend], case_sensitive=False),
default="CPU",
help="Set the backend to use, thereby unpacking Fortran pointers into NumPy or CuPy arrays respectively.",
)
5 changes: 2 additions & 3 deletions tools/src/icon4pytools/py2fgen/template.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,6 @@
from gt4py.eve.codegen import JinjaTemplate as as_jinja, TemplatedGenerator
from gt4py.next import Dimension
from gt4py.next.type_system.type_specifications import ScalarKind
from icon4py.model.common.config import GT4PyBackend

from icon4pytools.icon4pygen.bindings.codegen.type_conversion import (
BUILTIN_TO_CPP_TYPE,
@@ -22,7 +21,7 @@
)
from icon4pytools.py2fgen.plugin import int_array_to_bool_array, unpack, unpack_gpu
from icon4pytools.py2fgen.utils import flatten_and_get_unique_elts
from icon4pytools.py2fgen.wrappers import wrapper_dimension
from icon4pytools.py2fgen.wrappers import common, wrapper_dimension


# these arrays are not initialised in global experiments (e.g. ape_r02b04) and are not used
@@ -107,7 +106,7 @@ class PythonWrapper(CffiPlugin):
is_gt4py_program_present: bool = datamodels.field(init=False)

def __post_init__(self, *args: Any, **kwargs: Any) -> None:
self.gt4py_backend = GT4PyBackend[self.backend].value
self.gt4py_backend = common.GT4PyBackend[self.backend].value
self.is_gt4py_program_present = any(func.is_gt4py_program for func in self.functions)
self.uninitialised_arrays = get_uninitialised_arrays(self.limited_area)

108 changes: 106 additions & 2 deletions tools/src/icon4pytools/py2fgen/wrappers/common.py
Original file line number Diff line number Diff line change
@@ -7,16 +7,120 @@
# SPDX-License-Identifier: BSD-3-Clause
# type: ignore

import dataclasses
import logging

import os
from enum import Enum
from functools import cached_property

import numpy as np
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
from gt4py.next import itir_python as run_roundtrip
from gt4py.next.program_processors.runners.gtfn import (
run_gtfn_cached,
run_gtfn_gpu_cached,
)
from icon4py.model.common import dimension as dims
from icon4py.model.common.config import Icon4PyConfig
from icon4py.model.common.grid import base, horizontal, icon
from icon4py.model.common.settings import xp


log = logging.getLogger(__name__)

try:
import dace
from gt4py.next.program_processors.runners.dace import (
run_dace_cpu,
run_dace_cpu_noopt,
run_dace_gpu,
run_dace_gpu_noopt,
)
except ImportError:
from types import ModuleType
from typing import Optional

dace: Optional[ModuleType] = None


class Device(Enum):
CPU = "CPU"
GPU = "GPU"


class GT4PyBackend(Enum):
CPU = "run_gtfn_cached"
GPU = "run_gtfn_gpu_cached"
ROUNDTRIP = "run_roundtrip"
DACE_CPU = "run_dace_cpu"
DACE_GPU = "run_dace_gpu"
DACE_CPU_NOOPT = "run_dace_cpu_noopt"
DACE_GPU_NOOPT = "run_dace_gpu_noopt"


@dataclasses.dataclass
class Icon4PyConfig:
@cached_property
def icon4py_backend(self):
backend = os.environ.get("ICON4PY_BACKEND", "CPU")
if hasattr(GT4PyBackend, backend):
return backend
else:
raise ValueError(
f"Invalid ICON4Py backend: {backend}. \n"
f"Available backends: {', '.join([f'{k}' for k in GT4PyBackend.__members__.keys()])}"
)

@cached_property
def icon4py_dace_orchestration(self):
# Any value other than None will be considered as True
return os.environ.get("ICON4PY_DACE_ORCHESTRATION", None)

@cached_property
def array_ns(self):
if self.device == Device.GPU:
import cupy as cp # type: ignore[import-untyped]

return cp
else:
return np

@cached_property
def gt4py_runner(self):
backend_map = {
GT4PyBackend.CPU.name: run_gtfn_cached,
GT4PyBackend.GPU.name: run_gtfn_gpu_cached,
GT4PyBackend.ROUNDTRIP.name: run_roundtrip,
}
if dace:
backend_map |= {
GT4PyBackend.DACE_CPU.name: run_dace_cpu,
GT4PyBackend.DACE_GPU.name: run_dace_gpu,
GT4PyBackend.DACE_CPU_NOOPT.name: run_dace_cpu_noopt,
GT4PyBackend.DACE_GPU_NOOPT.name: run_dace_gpu_noopt,
}
return backend_map[self.icon4py_backend]

@cached_property
def device(self):
device_map = {
GT4PyBackend.CPU.name: Device.CPU,
GT4PyBackend.GPU.name: Device.GPU,
GT4PyBackend.ROUNDTRIP.name: Device.CPU,
}
if dace:
device_map |= {
GT4PyBackend.DACE_CPU.name: Device.CPU,
GT4PyBackend.DACE_GPU.name: Device.GPU,
GT4PyBackend.DACE_CPU_NOOPT.name: Device.CPU,
GT4PyBackend.DACE_GPU_NOOPT.name: Device.GPU,
}
device = device_map[self.icon4py_backend]
return device

@cached_property
def limited_area(self):
return os.environ.get("ICON4PY_LAM", False)


config = Icon4PyConfig()
backend = config.gt4py_runner
dace_orchestration = config.icon4py_dace_orchestration