diff --git a/advection/src/icon4py/advection/hflx_limiter_pd_stencil_02.py b/advection/src/icon4py/advection/hflx_limiter_pd_stencil_02.py index 857ee7584d..5d4d239f08 100644 --- a/advection/src/icon4py/advection/hflx_limiter_pd_stencil_02.py +++ b/advection/src/icon4py/advection/hflx_limiter_pd_stencil_02.py @@ -12,23 +12,23 @@ # SPDX-License-Identifier: GPL-3.0-or-later from functional.ffront.decorator import field_operator, program -from functional.ffront.fbuiltins import Field, where +from functional.ffront.fbuiltins import Field, int32, where from icon4py.common.dimension import E2C, CellDim, EdgeDim, KDim @field_operator def _hflx_limiter_pd_stencil_02( - refin_ctrl: Field[[EdgeDim], float], + refin_ctrl: Field[[EdgeDim], int32], r_m: Field[[CellDim, KDim], float], p_mflx_tracer_h_in: Field[[EdgeDim, KDim], float], - bound: float, + bound: int32, ) -> Field[[EdgeDim, KDim], float]: p_mflx_tracer_h_out = where( refin_ctrl == bound, p_mflx_tracer_h_in, where( - (p_mflx_tracer_h_in > 0.0) | (p_mflx_tracer_h_in == 0.0), + p_mflx_tracer_h_in >= 0.0, p_mflx_tracer_h_in * r_m(E2C[0]), p_mflx_tracer_h_in * r_m(E2C[1]), ), @@ -38,10 +38,10 @@ def _hflx_limiter_pd_stencil_02( @program def hflx_limiter_pd_stencil_02( - refin_ctrl: Field[[EdgeDim], float], + refin_ctrl: Field[[EdgeDim], int32], r_m: Field[[CellDim, KDim], float], p_mflx_tracer_h: Field[[EdgeDim, KDim], float], - bound: float, + bound: int32, ): _hflx_limiter_pd_stencil_02( refin_ctrl, diff --git a/advection/tests/test_hflx_limiter_pd_stencil_02.py b/advection/tests/test_hflx_limiter_pd_stencil_02.py index fdb93587cb..57eb18aa59 100644 --- a/advection/tests/test_hflx_limiter_pd_stencil_02.py +++ b/advection/tests/test_hflx_limiter_pd_stencil_02.py @@ -18,7 +18,7 @@ ) from icon4py.common.dimension import CellDim, EdgeDim, KDim from icon4py.testutils.simple_mesh import SimpleMesh -from icon4py.testutils.utils import random_field +from icon4py.testutils.utils import constant_field, random_field def hflx_limiter_pd_stencil_02_numpy( @@ -42,14 +42,12 @@ def hflx_limiter_pd_stencil_02_numpy( return p_mflx_tracer_h_out -def test_hflx_limiter_pd_stencil_02(): +def test_hflx_limiter_pd_stencil_02_nowhere_matching_refin_ctl(): mesh = SimpleMesh() - - refin_ctrl = random_field(mesh, EdgeDim) + bound = np.int32(7) + refin_ctrl = constant_field(mesh, 4, EdgeDim, dtype=np.int32) r_m = random_field(mesh, CellDim, KDim) p_mflx_tracer_h_in = random_field(mesh, EdgeDim, KDim) - p_mflx_tracer_h_out = random_field(mesh, EdgeDim, KDim) - bound = np.float64(7.0) ref = hflx_limiter_pd_stencil_02_numpy( mesh.e2c, @@ -63,10 +61,67 @@ def test_hflx_limiter_pd_stencil_02(): refin_ctrl, r_m, p_mflx_tracer_h_in, - p_mflx_tracer_h_out, bound, offset_provider={ "E2C": mesh.get_e2c_offset_provider(), }, ) - assert np.allclose(p_mflx_tracer_h_out, ref) + assert np.allclose(p_mflx_tracer_h_in, ref) + + +def test_hflx_limiter_pd_stencil_02_everywhere_matching_refin_ctl(): + mesh = SimpleMesh() + bound = np.int32(7) + refin_ctrl = constant_field(mesh, bound, EdgeDim, dtype=np.int32) + r_m = random_field(mesh, CellDim, KDim) + p_mflx_tracer_h_in = random_field(mesh, EdgeDim, KDim) + + hflx_limiter_pd_stencil_02( + refin_ctrl, + r_m, + p_mflx_tracer_h_in, + bound, + offset_provider={ + "E2C": mesh.get_e2c_offset_provider(), + }, + ) + assert np.allclose(p_mflx_tracer_h_in, p_mflx_tracer_h_in) + + +def test_hflx_limiter_pd_stencil_02_partly_matching_refin_ctl(): + mesh = SimpleMesh() + bound = np.int32(4) + refin_ctrl = constant_field(mesh, 5, EdgeDim, dtype=np.int32) + refin_ctrl[2:6] = bound + r_m = random_field(mesh, CellDim, KDim) + p_mflx_tracer_h_in = random_field(mesh, EdgeDim, KDim) + + hflx_limiter_pd_stencil_02( + refin_ctrl, + r_m, + p_mflx_tracer_h_in, + bound, + offset_provider={ + "E2C": mesh.get_e2c_offset_provider(), + }, + ) + assert np.allclose(p_mflx_tracer_h_in, p_mflx_tracer_h_in) + + +def test_hflx_limiter_pd_stencil_02_everywhere_matching_refin_ctl_does_not_change_inout_arg(): + mesh = SimpleMesh() + bound = np.int32(7) + refin_ctrl = constant_field(mesh, bound, EdgeDim, dtype=np.int32) + r_m = random_field(mesh, CellDim, KDim) + p_mflx_tracer_h_in = random_field(mesh, EdgeDim, KDim) + + hflx_limiter_pd_stencil_02( + refin_ctrl, + r_m, + p_mflx_tracer_h_in, + bound, + offset_provider={ + "E2C": mesh.get_e2c_offset_provider(), + }, + ) + assert np.allclose(p_mflx_tracer_h_in, p_mflx_tracer_h_in)