diff --git a/advection/src/icon4py/advection/hflx_limiter_mo_stencil_02.py b/advection/src/icon4py/advection/hflx_limiter_mo_stencil_02.py index f28297791d..c9c359621f 100644 --- a/advection/src/icon4py/advection/hflx_limiter_mo_stencil_02.py +++ b/advection/src/icon4py/advection/hflx_limiter_mo_stencil_02.py @@ -13,18 +13,18 @@ from functional.common import Field from functional.ffront.decorator import field_operator, program -from functional.ffront.fbuiltins import maximum, minimum, where +from functional.ffront.fbuiltins import int32, maximum, minimum, where from icon4py.common.dimension import CellDim, KDim @field_operator def _hflx_limiter_mo_stencil_02( - refin_ctrl: Field[[CellDim], int], + refin_ctrl: Field[[CellDim], int32], p_cc: Field[[CellDim, KDim], float], z_tracer_new_low: Field[[CellDim, KDim], float], - lo_bound: int, - hi_bound: int, + lo_bound: int32, + hi_bound: int32, ) -> tuple[Field[[CellDim, KDim], float], Field[[CellDim, KDim], float]]: condition = (refin_ctrl == lo_bound) | (refin_ctrl == hi_bound) z_tracer_new_tmp = where( @@ -32,18 +32,20 @@ def _hflx_limiter_mo_stencil_02( minimum(1.1 * p_cc, maximum(0.9 * p_cc, z_tracer_new_low)), z_tracer_new_low, ) - z_tracer_max = where(condition, maximum(p_cc, z_tracer_new_tmp), z_tracer_new_low) - z_tracer_min = where(condition, minimum(p_cc, z_tracer_new_tmp), z_tracer_new_low) - return z_tracer_max, z_tracer_min + return where( + condition, + (maximum(p_cc, z_tracer_new_tmp), minimum(p_cc, z_tracer_new_tmp)), + (z_tracer_new_low, z_tracer_new_low), + ) @program def hflx_limiter_mo_stencil_02( - refin_ctrl: Field[[CellDim], int], + refin_ctrl: Field[[CellDim], int32], p_cc: Field[[CellDim, KDim], float], z_tracer_new_low: Field[[CellDim, KDim], float], - lo_bound: int, - hi_bound: int, + lo_bound: int32, + hi_bound: int32, z_tracer_max: Field[[CellDim, KDim], float], z_tracer_min: Field[[CellDim, KDim], float], ): diff --git a/advection/tests/test_hflx_limiter_mo_stencil_02.py b/advection/tests/test_hflx_limiter_mo_stencil_02.py index 89af0f127a..13ed590a59 100644 --- a/advection/tests/test_hflx_limiter_mo_stencil_02.py +++ b/advection/tests/test_hflx_limiter_mo_stencil_02.py @@ -12,6 +12,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import numpy as np +from numpy import int32 from icon4py.advection.hflx_limiter_mo_stencil_02 import ( hflx_limiter_mo_stencil_02, @@ -30,8 +31,8 @@ def hflx_limiter_mo_stencil_02_numpy( ): refin_ctrl = np.expand_dims(refin_ctrl, axis=1) condition = np.logical_or( - np.equal(refin_ctrl, lo_bound * np.ones(refin_ctrl.shape)), - np.equal(refin_ctrl, hi_bound * np.ones(refin_ctrl.shape)), + np.equal(refin_ctrl, lo_bound * np.ones(refin_ctrl.shape, dtype=int32)), + np.equal(refin_ctrl, hi_bound * np.ones(refin_ctrl.shape, dtype=int32)), ) z_tracer_new_tmp = np.where( condition, @@ -49,10 +50,10 @@ def hflx_limiter_mo_stencil_02_numpy( def test_hflx_limiter_mo_stencil_02_some_matching_condition(): mesh = SimpleMesh() - hi_bound = 1 - lo_bound = 5 - refin_ctrl = constant_field(mesh, hi_bound, CellDim, dtype=int) - refin_ctrl[0:2] = 3 + hi_bound = np.int32(1) + lo_bound = np.int32(5) + refin_ctrl = constant_field(mesh, hi_bound, CellDim, dtype=int32) + refin_ctrl[0:2] = np.int32(3) p_cc = random_field(mesh, CellDim, KDim) z_tracer_new_low = random_field(mesh, CellDim, KDim) z_tracer_max = zero_field(mesh, CellDim, KDim) @@ -81,9 +82,9 @@ def test_hflx_limiter_mo_stencil_02_some_matching_condition(): def test_hflx_limiter_mo_stencil_02_none_matching_condition(): mesh = SimpleMesh() - hi_bound = 3 - lo_bound = 1 - refin_ctrl = constant_field(mesh, 2, CellDim, dtype=int) + hi_bound = np.int32(3) + lo_bound = np.int32(1) + refin_ctrl = constant_field(mesh, 2, CellDim, dtype=int32) p_cc = random_field(mesh, CellDim, KDim) z_tracer_new_low = random_field(mesh, CellDim, KDim) z_tracer_max = zero_field(mesh, CellDim, KDim)