diff --git a/model/common/src/icon4py/model/common/metrics/stencils/compute_wgtfacq_c.py b/model/common/src/icon4py/model/common/metrics/stencils/compute_wgtfacq_c.py index 25bc3f2bba..4b2288c039 100644 --- a/model/common/src/icon4py/model/common/metrics/stencils/compute_wgtfacq_c.py +++ b/model/common/src/icon4py/model/common/metrics/stencils/compute_wgtfacq_c.py @@ -20,13 +20,13 @@ def compute_wgtfacq_c( ) -> np.array: nlev = nlevp1 - 1 - wgtfacq_c = np.zeros(z_ifc.shape) + wgtfacq_c = np.zeros((z_ifc.shape[0], nlevp1)) z1 = 0.5 * (z_ifc[:, nlev] - z_ifc[:, nlevp1]) z2 = 0.5 * (z_ifc[:, nlev] + z_ifc[:, nlev - 1]) - z_ifc[:, nlevp1] z3 = 0.5 * (z_ifc[:, nlev - 1] + z_ifc[:, nlev - 2]) - z_ifc[:, nlevp1] - wgtfacq_c[:, 2] = z1 * z2 / (z2 - z3) / (z1 - z3) - wgtfacq_c[:, 1] = (z1 - wgtfacq_c[:, 2] * (z1 - z3)) / (z1 - z2) - wgtfacq_c[:, 0] = 1.0 - (wgtfacq_c[:, 1] + wgtfacq_c[:, 2]) + wgtfacq_c[:, nlev - 2] = z1 * z2 / (z2 - z3) / (z1 - z3) + wgtfacq_c[:, nlev - 1] = (z1 - wgtfacq_c[:, nlev - 2] * (z1 - z3)) / (z1 - z2) + wgtfacq_c[:, nlev] = 1.0 - (wgtfacq_c[:, nlev - 1] + wgtfacq_c[:, nlev - 2]) return wgtfacq_c diff --git a/model/common/tests/metric_tests/test_compute_wgtfacq_c.py b/model/common/tests/metric_tests/test_compute_wgtfacq_c.py index 1205c9baf9..d30c563796 100644 --- a/model/common/tests/metric_tests/test_compute_wgtfacq_c.py +++ b/model/common/tests/metric_tests/test_compute_wgtfacq_c.py @@ -25,7 +25,6 @@ import pytest -from icon4py.model.common.dimension import CellDim, KDim from icon4py.model.common.metrics.stencils.compute_wgtfacq_c import compute_wgtfacq_c from icon4py.model.common.test_utils.datatest_fixtures import ( # noqa: F401 # import fixtures from test_utils package data_provider, @@ -39,27 +38,18 @@ processor_props, ranked_data_path, ) -from icon4py.model.common.test_utils.helpers import dallclose, zero_field -from icon4py.model.common.type_alias import wpfloat +from icon4py.model.common.test_utils.helpers import dallclose @pytest.mark.datatest def test_compute_wgtfacq_c(icon_grid, metrics_savepoint): # noqa: F811 # fixture - wgtfacq_c = zero_field(icon_grid, CellDim, KDim, dtype=wpfloat) - wgtfacq_c_ref = zero_field(icon_grid, CellDim, KDim, dtype=wpfloat, extend={KDim: 1}) wgtfacq_c_dsl = metrics_savepoint.wgtfacq_c_dsl() z_ifc = metrics_savepoint.z_ifc() - # wgtfacq_c not serialized -> infer from wgtfacq_c_dsl - wgtfacq_c_ref = wgtfacq_c_ref.asnumpy() - wgtfacq_c_ref[:, 0] = wgtfacq_c_dsl.asnumpy()[:, -1] - wgtfacq_c_ref[:, 1] = wgtfacq_c_dsl.asnumpy()[:, -2] - wgtfacq_c_ref[:, 2] = wgtfacq_c_dsl.asnumpy()[:, -3] - vertical_end = icon_grid.num_levels wgtfacq_c = compute_wgtfacq_c( z_ifc.asnumpy(), vertical_end, ) - assert dallclose(wgtfacq_c, wgtfacq_c_ref) + assert dallclose(wgtfacq_c, wgtfacq_c_dsl.asnumpy())