Skip to content

Commit

Permalink
fix and rename helper function to recognize python backends. (#430)
Browse files Browse the repository at this point in the history
fix the is_otf:

    - recognize backend by backend.__name__
    - invert the negation: it now tests positively for python backends

in addition unskip those tests with embedded that now work, only skip for roundtrip because it is too slow and run the datatests for metric fields with gtfn_cpu .
  • Loading branch information
halungge authored Apr 3, 2024
1 parent e68db25 commit 7df1bff
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 18 deletions.
3 changes: 2 additions & 1 deletion ci/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ test_model_datatests:
extends: .test_template
stage: test
script:
- tox -r -e run_model_tests -c model/ --verbose -- $COMPONENT
- tox -r -e run_model_tests -c model/ --verbose -- --backend=$BACKEND $COMPONENT
parallel:
matrix:
- COMPONENT: [atmosphere/diffusion/tests/diffusion_tests, atmosphere/dycore/tests/dycore_tests, common/tests, driver/tests]
BACKEND: [gtfn_cpu]
18 changes: 10 additions & 8 deletions model/common/src/icon4py/model/common/test_utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from dataclasses import dataclass, field
from typing import ClassVar, Optional

import gt4py.next.program_processors.modular_executor
import numpy as np
import numpy.typing as npt
import pytest
Expand All @@ -37,16 +36,19 @@ def backend(request):
return request.param


def is_otf(backend) -> bool:
def is_python(backend) -> bool:
# want to exclude python backends:
# - cannot run on embedded: because of slicing
# - roundtrip is very slow on large grid
if hasattr(backend, "executor"):
if isinstance(
backend.executor, gt4py.next.program_processors.modular_executor.ModularExecutor
):
return True
return False
return is_embedded(backend) or is_roundtrip(backend)


def is_embedded(backend) -> bool:
return backend is None


def is_roundtrip(backend) -> bool:
return backend.__name__ == "roundtrip" if backend else False


def _shape(
Expand Down
9 changes: 5 additions & 4 deletions model/common/tests/metric_tests/test_metric_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from icon4py.model.common.test_utils.helpers import (
StencilTest,
dallclose,
is_otf,
is_python,
is_roundtrip,
random_field,
zero_field,
)
Expand Down Expand Up @@ -65,7 +66,7 @@ def input_data(self, grid) -> dict:


def test_compute_ddq_z_half(icon_grid, metrics_savepoint, backend):
if not is_otf(backend):
if is_python(backend):
pytest.skip("skipping: unsupported backend")
ddq_z_half_ref = metrics_savepoint.ddqz_z_half()
z_ifc = metrics_savepoint.z_ifc()
Expand Down Expand Up @@ -100,8 +101,8 @@ def test_compute_ddq_z_half(icon_grid, metrics_savepoint, backend):


def test_compute_ddqz_z_full(icon_grid, metrics_savepoint, backend):
if not is_otf(backend):
pytest.skip("skipping: unsupported backend")
if is_roundtrip(backend):
pytest.skip("skipping: slow backend")
z_ifc = metrics_savepoint.z_ifc()
inv_ddqz_full_ref = metrics_savepoint.inv_ddqz_z_full()
ddqz_z_full = zero_field(icon_grid, CellDim, KDim)
Expand Down
14 changes: 9 additions & 5 deletions model/common/tests/metric_tests/test_reference_atmosphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
compute_reference_atmosphere_cell_fields,
compute_reference_atmosphere_edge_fields,
)
from icon4py.model.common.test_utils.helpers import dallclose, is_otf, zero_field
from icon4py.model.common.test_utils.helpers import dallclose, is_roundtrip, zero_field
from icon4py.model.common.type_alias import wpfloat


Expand All @@ -38,6 +38,8 @@
def test_compute_reference_atmosphere_fields_on_full_level_masspoints(
icon_grid, metrics_savepoint, backend
):
if is_roundtrip(backend):
pytest.skip("skipping: slow backend")
exner_ref_mc_ref = metrics_savepoint.exner_ref_mc()
rho_ref_mc_ref = metrics_savepoint.rho_ref_mc()
theta_ref_mc_ref = metrics_savepoint.theta_ref_mc()
Expand Down Expand Up @@ -89,6 +91,8 @@ def test_compute_reference_atmosphere_fields_on_full_level_masspoints(
def test_compute_reference_atmsophere_on_half_level_mass_points(
icon_grid, metrics_savepoint, backend
):
if is_roundtrip(backend):
pytest.skip("skipping: slow backend")
theta_ref_ic_ref = metrics_savepoint.theta_ref_ic()
z_ifc = metrics_savepoint.z_ifc()

Expand Down Expand Up @@ -124,8 +128,8 @@ def test_compute_reference_atmsophere_on_half_level_mass_points(

@pytest.mark.datatest
def test_compute_d_exner_dz_ref_ic(icon_grid, metrics_savepoint, backend):
if not is_otf(backend):
pytest.skip("skipping: unsupported backend")
if is_roundtrip(backend):
pytest.skip("skipping: slow backend")
theta_ref_ic = metrics_savepoint.theta_ref_ic()
d_exner_dz_ref_ic_ref = metrics_savepoint.d_exner_dz_ref_ic()
d_exner_dz_ref_ic = zero_field(icon_grid, CellDim, KDim, extend={KDim: 1})
Expand All @@ -144,8 +148,8 @@ def test_compute_d_exner_dz_ref_ic(icon_grid, metrics_savepoint, backend):
def test_compute_reference_atmosphere_on_full_level_edge_fields(
icon_grid, interpolation_savepoint, metrics_savepoint, backend
):
if not is_otf(backend):
pytest.skip("skipping: unsupported backend")
if is_roundtrip(backend):
pytest.skip("skipping: slow backend")
rho_ref_me_ref = metrics_savepoint.rho_ref_me()
theta_ref_me_ref = metrics_savepoint.theta_ref_me()
rho_ref_me = metrics_savepoint.rho_ref_me()
Expand Down

0 comments on commit 7df1bff

Please sign in to comment.