Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix and rename helper function to recognize python backends. #430

Merged
merged 3 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ci/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ test_model_datatests:
extends: .test_template
stage: test
script:
- tox -r -e run_model_tests -c model/ --verbose -- $COMPONENT
- tox -r -e run_model_tests -c model/ --verbose -- --backend=$BACKEND $COMPONENT
parallel:
matrix:
- COMPONENT: [atmosphere/diffusion/tests/diffusion_tests, atmosphere/dycore/tests/dycore_tests, common/tests, driver/tests]
BACKEND: [gtfn_cpu]
18 changes: 10 additions & 8 deletions model/common/src/icon4py/model/common/test_utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from dataclasses import dataclass, field
from typing import ClassVar, Optional

import gt4py.next.program_processors.modular_executor
import numpy as np
import numpy.typing as npt
import pytest
Expand All @@ -37,16 +36,19 @@ def backend(request):
return request.param


def is_otf(backend) -> bool:
def is_python(backend) -> bool:
# want to exclude python backends:
# - cannot run on embedded: because of slicing
# - roundtrip is very slow on large grid
if hasattr(backend, "executor"):
if isinstance(
backend.executor, gt4py.next.program_processors.modular_executor.ModularExecutor
):
return True
return False
return is_embedded(backend) or is_roundtrip(backend)


def is_embedded(backend) -> bool:
return backend is None


def is_roundtrip(backend) -> bool:
return backend.__name__ == "roundtrip" if backend else False


def _shape(
Expand Down
9 changes: 5 additions & 4 deletions model/common/tests/metric_tests/test_metric_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from icon4py.model.common.test_utils.helpers import (
StencilTest,
dallclose,
is_otf,
is_python,
is_roundtrip,
random_field,
zero_field,
)
Expand Down Expand Up @@ -65,7 +66,7 @@ def input_data(self, grid) -> dict:


def test_compute_ddq_z_half(icon_grid, metrics_savepoint, backend):
if not is_otf(backend):
if is_python(backend):
pytest.skip("skipping: unsupported backend")
ddq_z_half_ref = metrics_savepoint.ddqz_z_half()
z_ifc = metrics_savepoint.z_ifc()
Expand Down Expand Up @@ -100,8 +101,8 @@ def test_compute_ddq_z_half(icon_grid, metrics_savepoint, backend):


def test_compute_ddqz_z_full(icon_grid, metrics_savepoint, backend):
if not is_otf(backend):
pytest.skip("skipping: unsupported backend")
if is_roundtrip(backend):
pytest.skip("skipping: slow backend")
z_ifc = metrics_savepoint.z_ifc()
inv_ddqz_full_ref = metrics_savepoint.inv_ddqz_z_full()
ddqz_z_full = zero_field(icon_grid, CellDim, KDim)
Expand Down
14 changes: 9 additions & 5 deletions model/common/tests/metric_tests/test_reference_atmosphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
compute_reference_atmosphere_cell_fields,
compute_reference_atmosphere_edge_fields,
)
from icon4py.model.common.test_utils.helpers import dallclose, is_otf, zero_field
from icon4py.model.common.test_utils.helpers import dallclose, is_roundtrip, zero_field
from icon4py.model.common.type_alias import wpfloat


Expand All @@ -38,6 +38,8 @@
def test_compute_reference_atmosphere_fields_on_full_level_masspoints(
icon_grid, metrics_savepoint, backend
):
if is_roundtrip(backend):
pytest.skip("skipping: slow backend")
exner_ref_mc_ref = metrics_savepoint.exner_ref_mc()
rho_ref_mc_ref = metrics_savepoint.rho_ref_mc()
theta_ref_mc_ref = metrics_savepoint.theta_ref_mc()
Expand Down Expand Up @@ -89,6 +91,8 @@ def test_compute_reference_atmosphere_fields_on_full_level_masspoints(
def test_compute_reference_atmsophere_on_half_level_mass_points(
icon_grid, metrics_savepoint, backend
):
if is_roundtrip(backend):
pytest.skip("skipping: slow backend")
theta_ref_ic_ref = metrics_savepoint.theta_ref_ic()
z_ifc = metrics_savepoint.z_ifc()

Expand Down Expand Up @@ -124,8 +128,8 @@ def test_compute_reference_atmsophere_on_half_level_mass_points(

@pytest.mark.datatest
def test_compute_d_exner_dz_ref_ic(icon_grid, metrics_savepoint, backend):
if not is_otf(backend):
pytest.skip("skipping: unsupported backend")
if is_roundtrip(backend):
pytest.skip("skipping: slow backend")
theta_ref_ic = metrics_savepoint.theta_ref_ic()
d_exner_dz_ref_ic_ref = metrics_savepoint.d_exner_dz_ref_ic()
d_exner_dz_ref_ic = zero_field(icon_grid, CellDim, KDim, extend={KDim: 1})
Expand All @@ -144,8 +148,8 @@ def test_compute_d_exner_dz_ref_ic(icon_grid, metrics_savepoint, backend):
def test_compute_reference_atmosphere_on_full_level_edge_fields(
icon_grid, interpolation_savepoint, metrics_savepoint, backend
):
if not is_otf(backend):
pytest.skip("skipping: unsupported backend")
if is_roundtrip(backend):
pytest.skip("skipping: slow backend")
rho_ref_me_ref = metrics_savepoint.rho_ref_me()
theta_ref_me_ref = metrics_savepoint.theta_ref_me()
rho_ref_me = metrics_savepoint.rho_ref_me()
Expand Down
Loading