diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py
index 43e6d5ba76..ba7e470574 100644
--- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py
+++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py
@@ -13,7 +13,6 @@
import logging
from typing import Final, Optional
-import numpy as np
from gt4py.next import as_field
from gt4py.next.common import Field
from gt4py.next.ffront.fbuiltins import int32
@@ -138,11 +137,13 @@
from icon4py.model.atmosphere.dycore.mo_solve_nonhydro_stencil_68 import (
mo_solve_nonhydro_stencil_68,
)
-from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
-from icon4py.model.atmosphere.dycore.state_utils.interpolation_state import InterpolationState
-from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants
-from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection
+from icon4py.model.atmosphere.dycore.state_utils.states import (
+ DiagnosticStateNonHydro,
+ InterpolationState,
+ MetricStateNonHydro,
+ PrepAdvection,
+)
from icon4py.model.atmosphere.dycore.state_utils.utils import (
_allocate,
_allocate_indices,
@@ -493,6 +494,10 @@ def time_step(
rd_o_cvd=self.params.rd_o_cvd,
rd_o_p0ref=self.params.rd_o_p0ref,
offset_provider={},
+ horizontal_start=0,
+ horizontal_end=end_cell_end,
+ vertical_start=0,
+ vertical_end=self.grid.num_levels,
)
mo_solve_nonhydro_stencil_67.with_backend(backend)(
@@ -659,7 +664,7 @@ def run_predictor_step(
if self.config.igradp_method == 3:
nhsolve_prog.predictor_stencils_4_5_6.with_backend(backend)(
- wgtfacq_c_dsl=self.metric_state_nonhydro.wgtfacq_c_dsl,
+ wgtfacq_c_dsl=self.metric_state_nonhydro.wgtfacq_c,
z_exner_ex_pr=self.z_exner_ex_pr,
z_exner_ic=self.z_exner_ic,
wgtfac_c=self.metric_state_nonhydro.wgtfac_c,
@@ -704,7 +709,7 @@ def run_predictor_step(
# Perturbation theta at top and surface levels
nhsolve_prog.predictor_stencils_11_lower_upper.with_backend(backend)(
- wgtfacq_c_dsl=self.metric_state_nonhydro.wgtfacq_c_dsl,
+ wgtfacq_c_dsl=self.metric_state_nonhydro.wgtfacq_c,
z_rth_pr=self.z_rth_pr_2,
theta_ref_ic=self.metric_state_nonhydro.theta_ref_ic,
z_theta_v_pr_ic=self.z_theta_v_pr_ic,
@@ -1069,7 +1074,7 @@ def run_predictor_step(
vn_ie=diagnostic_state_nh.vn_ie,
z_vt_ie=z_fields.z_vt_ie,
z_kin_hor_e=z_fields.z_kin_hor_e,
- wgtfacq_e_dsl=self.metric_state_nonhydro.wgtfacq_e_dsl,
+ wgtfacq_e_dsl=self.metric_state_nonhydro.wgtfacq_e,
k_field=self.k_field,
nlev=self.grid.num_levels,
horizontal_start=start_edge_lb_plus4,
@@ -1083,7 +1088,7 @@ def run_predictor_step(
e_bln_c_s=self.interpolation_state.e_bln_c_s,
z_w_concorr_me=self.z_w_concorr_me,
wgtfac_c=self.metric_state_nonhydro.wgtfac_c,
- wgtfacq_c_dsl=self.metric_state_nonhydro.wgtfacq_c_dsl,
+ wgtfacq_c_dsl=self.metric_state_nonhydro.wgtfacq_c,
w_concorr_c=diagnostic_state_nh.w_concorr_c,
k_field=self.k_field,
nflatlev_startindex_plus1=int32(self.vertical_params.nflatlev + 1),
diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/diagnostic_state.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/diagnostic_state.py
deleted file mode 100644
index f12aa39b55..0000000000
--- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/diagnostic_state.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# ICON4Py - ICON inspired code in Python and GT4Py
-#
-# Copyright (c) 2022, ETH Zurich and MeteoSwiss
-# All rights reserved.
-#
-# This file is free software: you can redistribute it and/or modify it under
-# the terms of the GNU General Public License as published by the
-# Free Software Foundation, either version 3 of the License, or any later
-# version. See the LICENSE.txt file at the top-level directory of this
-# distribution for a copy of the license or check .
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-from dataclasses import dataclass
-
-from gt4py.next.common import Field
-
-from icon4py.model.common.dimension import CellDim, EdgeDim, KDim
-
-
-@dataclass
-class DiagnosticState:
- # fields for 3D elements in turbdiff
- hdef_ic: Field[
- [CellDim, KDim], float
- ] # ! divergence at half levels(nproma,nlevp1,nblks_c) [1/s]
- div_ic: Field[
- [CellDim, KDim], float
- ] # ! horizontal wind field deformation (nproma,nlevp1,nblks_c) [1/s^2]
- dwdx: Field[
- [CellDim, KDim], float
- ] # zonal gradient of vertical wind speed (nproma,nlevp1,nblks_c) [1/s]
-
- dwdy: Field[
- [CellDim, KDim], float
- ] # meridional gradient of vertical wind speed (nproma,nlevp1,nblks_c)
-
- vt: Field[[EdgeDim, KDim], float]
- vn_ie: Field[
- [EdgeDim, KDim], float
- ] # normal wind at half levels (nproma,nlevp1,nblks_e) [m/s] # TODO: change this back to KHalfDim, but how do we treat it wrt to field_operators and domain?
- w_concorr_c: Field[
- [CellDim, KDim], float
- ] # contravariant vert correction (nproma,nlevp1,nblks_c)[m/s] # TODO: change this back to KHalfDim, but how do we treat it wrt to field_operators and domain?
- ddt_w_adv_pc: Field[[CellDim, KDim], float]
- ddt_vn_apc_pc: Field[[EdgeDim, KDim], float]
- ntnd: float
-
-
-@dataclass
-class DiagnosticStateNonHydro:
- vt: Field[[EdgeDim, KDim], float]
- vn_ie: Field[
- [EdgeDim, KDim], float
- ] # normal wind at half levels (nproma,nlevp1,nblks_e) [m/s] # TODO: change this back to KHalfDim, but how do we treat it wrt to field_operators and domain?
- w_concorr_c: Field[
- [CellDim, KDim], float
- ] # contravariant vert correction (nproma,nlevp1,nblks_c)[m/s] # TODO: change this back to KHalfDim, but how do we treat it wrt to field_operators and domain?
- theta_v_ic: Field[[CellDim, KDim], float]
- exner_pr: Field[[CellDim, KDim], float]
- rho_ic: Field[[CellDim, KDim], float]
- ddt_exner_phy: Field[[CellDim, KDim], float]
- grf_tend_rho: Field[[CellDim, KDim], float]
- grf_tend_thv: Field[[CellDim, KDim], float]
- grf_tend_w: Field[[CellDim, KDim], float]
- mass_fl_e: Field[[EdgeDim, KDim], float]
- ddt_vn_phy: Field[[EdgeDim, KDim], float]
- grf_tend_vn: Field[[EdgeDim, KDim], float]
- ddt_vn_apc_ntl1: Field[[EdgeDim, KDim], float]
- ddt_vn_apc_ntl2: Field[[EdgeDim, KDim], float]
- ddt_w_adv_ntl1: Field[[CellDim, KDim], float]
- ddt_w_adv_ntl2: Field[[CellDim, KDim], float]
-
- # Analysis increments
- rho_incr: Field[[EdgeDim, KDim], float] # moist density increment [kg/m^3]
- vn_incr: Field[[EdgeDim, KDim], float] # normal velocity increment [m/s]
- exner_incr: Field[[EdgeDim, KDim], float] # exner increment [- ]
-
- @property
- def ddt_vn_apc_pc(
- self,
- ) -> tuple[Field[[EdgeDim, KDim], float], Field[[EdgeDim, KDim], float]]:
- return (self.ddt_vn_apc_ntl1, self.ddt_vn_apc_ntl2)
-
- @property
- def ddt_w_adv_pc(
- self,
- ) -> tuple[Field[[CellDim, KDim], float], Field[[CellDim, KDim], float]]:
- return (self.ddt_w_adv_ntl1, self.ddt_w_adv_ntl2)
diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/interpolation_state.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/interpolation_state.py
deleted file mode 100644
index 440c4a4b9c..0000000000
--- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/interpolation_state.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# ICON4Py - ICON inspired code in Python and GT4Py
-#
-# Copyright (c) 2022, ETH Zurich and MeteoSwiss
-# All rights reserved.
-#
-# This file is free software: you can redistribute it and/or modify it under
-# the terms of the GNU General Public License as published by the
-# Free Software Foundation, either version 3 of the License, or any later
-# version. See the LICENSE.txt file at the top-level directory of this
-# distribution for a copy of the license or check .
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-from dataclasses import dataclass
-
-from gt4py.next.common import Field
-
-from icon4py.model.common.dimension import (
- C2E2CODim,
- CEDim,
- CellDim,
- E2C2EDim,
- E2C2EODim,
- E2CDim,
- ECDim,
- EdgeDim,
- V2CDim,
- V2EDim,
- VertexDim,
-)
-
-
-@dataclass
-class InterpolationState:
- """Represents the ICON interpolation state used int SolveNonHydro."""
-
- e_bln_c_s: Field[[CEDim], float] # coefficent for bilinear interpolation from edge to cell ()
- rbf_coeff_1: Field[
- [VertexDim, V2EDim], float
- ] # rbf_vec_coeff_v_1(nproma, rbf_vec_dim_v, nblks_v)
- rbf_coeff_2: Field[
- [VertexDim, V2EDim], float
- ] # rbf_vec_coeff_v_2(nproma, rbf_vec_dim_v, nblks_v)
-
- geofac_div: Field[[CEDim], float] # factor for divergence (nproma,cell_type,nblks_c)
-
- geofac_n2s: Field[
- [CellDim, C2E2CODim], float
- ] # factor for nabla2-scalar (nproma,cell_type+1,nblks_c)
- geofac_grg_x: Field[[CellDim, C2E2CODim], float]
- geofac_grg_y: Field[
- [CellDim, C2E2CODim], float
- ] # factors for green gauss gradient (nproma,4,nblks_c,2)
- nudgecoeff_e: Field[[EdgeDim], float] # Nudgeing coeffients for edges
-
- c_lin_e: Field[[EdgeDim, E2CDim], float]
- geofac_grdiv: Field[[EdgeDim, E2C2EODim], float]
- rbf_vec_coeff_e: Field[[EdgeDim, E2C2EDim], float]
- c_intp: Field[[VertexDim, V2CDim], float]
- geofac_rot: Field[[VertexDim, V2EDim], float]
- pos_on_tplane_e_1: Field[[ECDim], float]
- pos_on_tplane_e_2: Field[[ECDim], float]
- e_flx_avg: Field[[EdgeDim, E2C2EODim], float]
diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/metric_state.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/metric_state.py
deleted file mode 100644
index 3d32d340b9..0000000000
--- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/metric_state.py
+++ /dev/null
@@ -1,87 +0,0 @@
-# ICON4Py - ICON inspired code in Python and GT4Py
-#
-# Copyright (c) 2022, ETH Zurich and MeteoSwiss
-# All rights reserved.
-#
-# This file is free software: you can redistribute it and/or modify it under
-# the terms of the GNU General Public License as published by the
-# Free Software Foundation, either version 3 of the License, or any later
-# version. See the LICENSE.txt file at the top-level directory of this
-# distribution for a copy of the license or check .
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-from dataclasses import dataclass
-
-from gt4py.next.common import Field
-from numpy import int32
-
-from icon4py.model.common.dimension import CECDim, CellDim, ECDim, EdgeDim, KDim
-
-
-@dataclass
-class MetricState:
- mask_hdiff: Field[[CellDim, KDim], bool]
- theta_ref_mc: Field[[CellDim, KDim], float]
- wgtfac_c: Field[
- [CellDim, KDim], float
- ] # weighting factor for interpolation from full to half levels (nproma,nlevp1,nblks_c)
- zd_vertidx: Field[[CECDim, KDim], int32]
- zd_diffcoef: Field[[CellDim, KDim], float]
- zd_intcoef: Field[[CECDim, KDim], float]
-
- coeff_gradekin: Field[[ECDim], float]
- ddqz_z_full_e: Field[[EdgeDim, KDim], float]
- wgtfac_e: Field[[EdgeDim, KDim], float]
- wgtfacq_e_dsl: Field[[EdgeDim, KDim], float]
- ddxn_z_full: Field[[EdgeDim, KDim], float]
- ddxt_z_full: Field[[EdgeDim, KDim], float]
- ddqz_z_half: Field[[CellDim, KDim], float] # half KDim ?
- coeff1_dwdz: Field[[CellDim, KDim], float]
- coeff2_dwdz: Field[[CellDim, KDim], float]
- zd_vertoffset: Field[[CECDim, KDim], int32]
-
-
-@dataclass
-class MetricStateNonHydro:
- bdy_halo_c: Field[[CellDim], bool]
- # Finally, a mask field that excludes boundary halo points
- mask_prog_halo_c: Field[[CellDim, KDim], bool]
- rayleigh_w: Field[[KDim], float]
-
- wgtfac_c: Field[[CellDim, KDim], float]
- wgtfacq_c_dsl: Field[[CellDim, KDim], float]
- wgtfac_e: Field[[EdgeDim, KDim], float]
- wgtfacq_e_dsl: Field[[EdgeDim, KDim], float]
-
- exner_exfac: Field[[CellDim, KDim], float]
- exner_ref_mc: Field[[CellDim, KDim], float]
- rho_ref_mc: Field[[CellDim, KDim], float]
- theta_ref_mc: Field[[CellDim, KDim], float]
- rho_ref_me: Field[[EdgeDim, KDim], float]
- theta_ref_me: Field[[EdgeDim, KDim], float]
- theta_ref_ic: Field[[CellDim, KDim], float]
-
- d_exner_dz_ref_ic: Field[[CellDim, KDim], float]
- ddqz_z_half: Field[[CellDim, KDim], float] # half KDim ?
- d2dexdz2_fac1_mc: Field[[CellDim, KDim], float]
- d2dexdz2_fac2_mc: Field[[CellDim, KDim], float]
- ddxn_z_full: Field[[EdgeDim, KDim], float]
- ddqz_z_full_e: Field[[EdgeDim, KDim], float]
- ddxt_z_full: Field[[EdgeDim, KDim], float]
- inv_ddqz_z_full: Field[[CellDim, KDim], float]
-
- vertoffset_gradp: Field[[ECDim, KDim], float]
- zdiff_gradp: Field[[ECDim, KDim], float]
- ipeidx_dsl: Field[[EdgeDim, KDim], bool]
- pg_exdist: Field[[EdgeDim, KDim], float]
-
- vwind_expl_wgt: Field[[CellDim], float]
- vwind_impl_wgt: Field[[CellDim], float]
-
- hmask_dd3d: Field[[EdgeDim], float]
- scalfac_dd3d: Field[[KDim], float]
-
- coeff1_dwdz: Field[[CellDim, KDim], float]
- coeff2_dwdz: Field[[CellDim, KDim], float]
- coeff_gradekin: Field[[ECDim], float]
diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/prep_adv_state.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/prep_adv_state.py
deleted file mode 100644
index 8a94ccafb4..0000000000
--- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/prep_adv_state.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# ICON4Py - ICON inspired code in Python and GT4Py
-#
-# Copyright (c) 2022, ETH Zurich and MeteoSwiss
-# All rights reserved.
-#
-# This file is free software: you can redistribute it and/or modify it under
-# the terms of the GNU General Public License as published by the
-# Free Software Foundation, either version 3 of the License, or any later
-# version. See the LICENSE.txt file at the top-level directory of this
-# distribution for a copy of the license or check .
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-from dataclasses import dataclass
-
-from gt4py.next.common import Field
-
-from icon4py.model.common.dimension import CellDim, EdgeDim, KDim
-
-
-@dataclass
-class PrepAdvection:
- vn_traj: Field[[EdgeDim, KDim], float]
- mass_flx_me: Field[[EdgeDim, KDim], float]
- mass_flx_ic: Field[[CellDim, KDim], float]
diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/states.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/states.py
new file mode 100644
index 0000000000..fb413282bf
--- /dev/null
+++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/states.py
@@ -0,0 +1,157 @@
+# ICON4Py - ICON inspired code in Python and GT4Py
+#
+# Copyright (c) 2022, ETH Zurich and MeteoSwiss
+# All rights reserved.
+#
+# This file is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+from dataclasses import dataclass
+
+from gt4py.next.common import Field
+
+from icon4py.model.common.dimension import (
+ C2E2CODim,
+ CEDim,
+ CellDim,
+ E2C2EDim,
+ E2C2EODim,
+ E2CDim,
+ ECDim,
+ EdgeDim,
+ KDim,
+ V2CDim,
+ V2EDim,
+ VertexDim,
+)
+
+
+@dataclass
+class DiagnosticStateNonHydro:
+ vt: Field[[EdgeDim, KDim], float]
+ vn_ie: Field[
+ [EdgeDim, KDim], float
+ ] # normal wind at half levels (nproma,nlevp1,nblks_e) [m/s] # TODO: change this back to KHalfDim, but how do we treat it wrt to field_operators and domain?
+ w_concorr_c: Field[
+ [CellDim, KDim], float
+ ] # contravariant vert correction (nproma,nlevp1,nblks_c)[m/s] # TODO: change this back to KHalfDim, but how do we treat it wrt to field_operators and domain?
+ theta_v_ic: Field[[CellDim, KDim], float]
+ exner_pr: Field[[CellDim, KDim], float]
+ rho_ic: Field[[CellDim, KDim], float]
+ ddt_exner_phy: Field[[CellDim, KDim], float]
+ grf_tend_rho: Field[[CellDim, KDim], float]
+ grf_tend_thv: Field[[CellDim, KDim], float]
+ grf_tend_w: Field[[CellDim, KDim], float]
+ mass_fl_e: Field[[EdgeDim, KDim], float]
+ ddt_vn_phy: Field[[EdgeDim, KDim], float]
+ grf_tend_vn: Field[[EdgeDim, KDim], float]
+ ddt_vn_apc_ntl1: Field[[EdgeDim, KDim], float]
+ ddt_vn_apc_ntl2: Field[[EdgeDim, KDim], float]
+ ddt_w_adv_ntl1: Field[[CellDim, KDim], float]
+ ddt_w_adv_ntl2: Field[[CellDim, KDim], float]
+
+ # Analysis increments
+ rho_incr: Field[[EdgeDim, KDim], float] # moist density increment [kg/m^3]
+ vn_incr: Field[[EdgeDim, KDim], float] # normal velocity increment [m/s]
+ exner_incr: Field[[EdgeDim, KDim], float] # exner increment [- ]
+
+ @property
+ def ddt_vn_apc_pc(
+ self,
+ ) -> tuple[Field[[EdgeDim, KDim], float], Field[[EdgeDim, KDim], float]]:
+ return (self.ddt_vn_apc_ntl1, self.ddt_vn_apc_ntl2)
+
+ @property
+ def ddt_w_adv_pc(
+ self,
+ ) -> tuple[Field[[CellDim, KDim], float], Field[[CellDim, KDim], float]]:
+ return (self.ddt_w_adv_ntl1, self.ddt_w_adv_ntl2)
+
+
+@dataclass
+class InterpolationState:
+ """Represents the ICON interpolation state used int SolveNonHydro."""
+
+ e_bln_c_s: Field[[CEDim], float] # coefficent for bilinear interpolation from edge to cell ()
+ rbf_coeff_1: Field[
+ [VertexDim, V2EDim], float
+ ] # rbf_vec_coeff_v_1(nproma, rbf_vec_dim_v, nblks_v)
+ rbf_coeff_2: Field[
+ [VertexDim, V2EDim], float
+ ] # rbf_vec_coeff_v_2(nproma, rbf_vec_dim_v, nblks_v)
+
+ geofac_div: Field[[CEDim], float] # factor for divergence (nproma,cell_type,nblks_c)
+
+ geofac_n2s: Field[
+ [CellDim, C2E2CODim], float
+ ] # factor for nabla2-scalar (nproma,cell_type+1,nblks_c)
+ geofac_grg_x: Field[[CellDim, C2E2CODim], float]
+ geofac_grg_y: Field[
+ [CellDim, C2E2CODim], float
+ ] # factors for green gauss gradient (nproma,4,nblks_c,2)
+ nudgecoeff_e: Field[[EdgeDim], float] # Nudgeing coeffients for edges
+
+ c_lin_e: Field[[EdgeDim, E2CDim], float]
+ geofac_grdiv: Field[[EdgeDim, E2C2EODim], float]
+ rbf_vec_coeff_e: Field[[EdgeDim, E2C2EDim], float]
+ c_intp: Field[[VertexDim, V2CDim], float]
+ geofac_rot: Field[[VertexDim, V2EDim], float]
+ pos_on_tplane_e_1: Field[[ECDim], float]
+ pos_on_tplane_e_2: Field[[ECDim], float]
+ e_flx_avg: Field[[EdgeDim, E2C2EODim], float]
+
+
+@dataclass
+class MetricStateNonHydro:
+ bdy_halo_c: Field[[CellDim], bool]
+ # Finally, a mask field that excludes boundary halo points
+ mask_prog_halo_c: Field[[CellDim, KDim], bool]
+ rayleigh_w: Field[[KDim], float]
+
+ wgtfac_c: Field[[CellDim, KDim], float]
+ wgtfacq_c: Field[[CellDim, KDim], float]
+ wgtfac_e: Field[[EdgeDim, KDim], float]
+ wgtfacq_e: Field[[EdgeDim, KDim], float]
+
+ exner_exfac: Field[[CellDim, KDim], float]
+ exner_ref_mc: Field[[CellDim, KDim], float]
+ rho_ref_mc: Field[[CellDim, KDim], float]
+ theta_ref_mc: Field[[CellDim, KDim], float]
+ rho_ref_me: Field[[EdgeDim, KDim], float]
+ theta_ref_me: Field[[EdgeDim, KDim], float]
+ theta_ref_ic: Field[[CellDim, KDim], float]
+
+ d_exner_dz_ref_ic: Field[[CellDim, KDim], float]
+ ddqz_z_half: Field[[CellDim, KDim], float] # half KDim ?
+ d2dexdz2_fac1_mc: Field[[CellDim, KDim], float]
+ d2dexdz2_fac2_mc: Field[[CellDim, KDim], float]
+ ddxn_z_full: Field[[EdgeDim, KDim], float]
+ ddqz_z_full_e: Field[[EdgeDim, KDim], float]
+ ddxt_z_full: Field[[EdgeDim, KDim], float]
+ inv_ddqz_z_full: Field[[CellDim, KDim], float]
+
+ vertoffset_gradp: Field[[ECDim, KDim], float]
+ zdiff_gradp: Field[[ECDim, KDim], float]
+ ipeidx_dsl: Field[[EdgeDim, KDim], bool]
+ pg_exdist: Field[[EdgeDim, KDim], float]
+
+ vwind_expl_wgt: Field[[CellDim], float]
+ vwind_impl_wgt: Field[[CellDim], float]
+
+ hmask_dd3d: Field[[EdgeDim], float]
+ scalfac_dd3d: Field[[KDim], float]
+
+ coeff1_dwdz: Field[[CellDim, KDim], float]
+ coeff2_dwdz: Field[[CellDim, KDim], float]
+ coeff_gradekin: Field[[ECDim], float]
+
+
+@dataclass
+class PrepAdvection:
+ vn_traj: Field[[EdgeDim, KDim], float]
+ mass_flx_me: Field[[EdgeDim, KDim], float]
+ mass_flx_ic: Field[[CellDim, KDim], float]
diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity/velocity_advection.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity/velocity_advection.py
index e6340459d5..7888e09356 100644
--- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity/velocity_advection.py
+++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity/velocity_advection.py
@@ -55,9 +55,11 @@
from icon4py.model.atmosphere.dycore.mo_velocity_advection_stencil_20 import (
mo_velocity_advection_stencil_20,
)
-from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
-from icon4py.model.atmosphere.dycore.state_utils.interpolation_state import InterpolationState
-from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro
+from icon4py.model.atmosphere.dycore.state_utils.states import (
+ DiagnosticStateNonHydro,
+ InterpolationState,
+ MetricStateNonHydro,
+)
from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate, _allocate_indices
from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, VertexDim
from icon4py.model.common.grid.horizontal import EdgeParams, HorizontalMarkerIndex
@@ -249,7 +251,7 @@ def run_predictor_step(
offset_provider={},
)
velocity_prog.mo_velocity_advection_stencil_06.with_backend(backend)(
- wgtfacq_e=self.metric_state.wgtfacq_e_dsl,
+ wgtfacq_e=self.metric_state.wgtfacq_e,
vn=prognostic_state.vn,
vn_ie=diagnostic_state.vn_ie,
horizontal_start=start_edge_lb_plus4,
diff --git a/model/atmosphere/dycore/tests/dycore_tests/mpi_tests/test_parallel_solve_nonhydro.py b/model/atmosphere/dycore/tests/dycore_tests/mpi_tests/test_parallel_solve_nonhydro.py
index be1704e21a..dd8a2f54db 100644
--- a/model/atmosphere/dycore/tests/dycore_tests/mpi_tests/test_parallel_solve_nonhydro.py
+++ b/model/atmosphere/dycore/tests/dycore_tests/mpi_tests/test_parallel_solve_nonhydro.py
@@ -20,9 +20,11 @@
NonHydrostaticParams,
SolveNonhydro,
)
-from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants
-from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection
+from icon4py.model.atmosphere.dycore.state_utils.states import (
+ DiagnosticStateNonHydro,
+ PrepAdvection,
+)
from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate
from icon4py.model.atmosphere.dycore.state_utils.z_fields import ZFields
from icon4py.model.common.decomposition import definitions
diff --git a/model/atmosphere/dycore/tests/dycore_tests/test_solve_nonhydro.py b/model/atmosphere/dycore/tests/dycore_tests/test_solve_nonhydro.py
index 7d56058655..6a7fad0b3c 100644
--- a/model/atmosphere/dycore/tests/dycore_tests/test_solve_nonhydro.py
+++ b/model/atmosphere/dycore/tests/dycore_tests/test_solve_nonhydro.py
@@ -20,9 +20,11 @@
NonHydrostaticParams,
SolveNonhydro,
)
-from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants
-from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection
+from icon4py.model.atmosphere.dycore.state_utils.states import (
+ DiagnosticStateNonHydro,
+ PrepAdvection,
+)
from icon4py.model.atmosphere.dycore.state_utils.utils import (
_allocate,
_calculate_bdy_divdamp,
@@ -37,6 +39,8 @@
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.helpers import dallclose
+from .utils import construct_interpolation_state_for_nonhydro, construct_nh_metric_state
+
backend = run_gtfn
@@ -143,8 +147,8 @@ def test_nonhydro_predictor_step(
z_fields = allocate_z_fields(icon_grid)
- interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro()
- metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels)
+ interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
+ metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
cell_geometry: CellParams = grid_savepoint.construct_cell_geometry()
edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry()
@@ -574,8 +578,8 @@ def test_nonhydro_corrector_step(
nh_constants = create_nh_constants(sp)
divdamp_fac_o2 = sp.divdamp_fac_o2()
- interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro()
- metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels)
+ interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
+ metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
cell_geometry: CellParams = grid_savepoint.construct_cell_geometry()
edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry()
@@ -723,8 +727,8 @@ def test_run_solve_nonhydro_single_step(
nh_constants = create_nh_constants(sp)
- interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro()
- metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels)
+ interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
+ metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
cell_geometry: CellParams = grid_savepoint.construct_cell_geometry()
edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry()
@@ -844,8 +848,8 @@ def test_run_solve_nonhydro_multi_step(
nh_constants = create_nh_constants(sp)
- interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro()
- metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels)
+ interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
+ metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
cell_geometry: CellParams = grid_savepoint.construct_cell_geometry()
edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry()
diff --git a/model/atmosphere/dycore/tests/dycore_tests/test_velocity_advection.py b/model/atmosphere/dycore/tests/dycore_tests/test_velocity_advection.py
index f1b56b5c5d..0d6e99dc90 100644
--- a/model/atmosphere/dycore/tests/dycore_tests/test_velocity_advection.py
+++ b/model/atmosphere/dycore/tests/dycore_tests/test_velocity_advection.py
@@ -14,13 +14,15 @@
import pytest
from gt4py.next.ffront.fbuiltins import int32
-from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
+from icon4py.model.atmosphere.dycore.state_utils.states import DiagnosticStateNonHydro
from icon4py.model.atmosphere.dycore.velocity.velocity_advection import VelocityAdvection
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams
from icon4py.model.common.grid.vertical import VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.helpers import dallclose
+from .utils import construct_interpolation_state_for_nonhydro, construct_nh_metric_state
+
@pytest.mark.datatest
def test_scalfactors(savepoint_velocity_init, icon_grid):
@@ -48,8 +50,8 @@ def test_velocity_init(
step_date_init,
damping_height,
):
- interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro()
- metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels)
+ interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
+ metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
vertical_params = VerticalModelParams(
vct_a=grid_savepoint.vct_a(),
@@ -87,8 +89,8 @@ def test_verify_velocity_init_against_regular_savepoint(
savepoint = savepoint_velocity_init
dtime = savepoint.get_metadata("dtime").get("dtime")
- interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro()
- metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels)
+ interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
+ metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
vertical_params = VerticalModelParams(
vct_a=grid_savepoint.vct_a(),
rayleigh_damping_height=damping_height,
@@ -163,9 +165,9 @@ def test_velocity_predictor_step(
rho=None,
exner=None,
)
- interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro()
+ interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
- metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels)
+ metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
cell_geometry: CellParams = grid_savepoint.construct_cell_geometry()
edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry()
@@ -213,7 +215,7 @@ def test_velocity_predictor_step(
# stencil 07
if not vn_only:
assert dallclose(
- icon_result_z_v_grad_w.asnumpy()[3777:31558, :],
+ icon_result_z_v_grad_w[3777:31558, :],
velocity_advection.z_v_grad_w.asnumpy()[3777:31558, :],
atol=1.0e-16,
)
@@ -313,9 +315,9 @@ def test_velocity_corrector_step(
exner=None,
)
- interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro()
+ interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint)
- metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels)
+ metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels)
cell_geometry: CellParams = grid_savepoint.construct_cell_geometry()
edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry()
diff --git a/model/atmosphere/dycore/tests/dycore_tests/utils.py b/model/atmosphere/dycore/tests/dycore_tests/utils.py
new file mode 100644
index 0000000000..a6eee5b0f2
--- /dev/null
+++ b/model/atmosphere/dycore/tests/dycore_tests/utils.py
@@ -0,0 +1,82 @@
+# ICON4Py - ICON inspired code in Python and GT4Py
+#
+# Copyright (c) 2022, ETH Zurich and MeteoSwiss
+# All rights reserved.
+#
+# This file is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from icon4py.model.atmosphere.dycore.state_utils.states import (
+ InterpolationState,
+ MetricStateNonHydro,
+)
+from icon4py.model.common.dimension import CEDim
+from icon4py.model.common.test_utils.helpers import as_1D_sparse_field
+from icon4py.model.common.test_utils.serialbox_utils import InterpolationSavepoint, MetricSavepoint
+
+
+def construct_interpolation_state_for_nonhydro(
+ savepoint: InterpolationSavepoint,
+) -> InterpolationState:
+ grg = savepoint.geofac_grg()
+ return InterpolationState(
+ c_lin_e=savepoint.c_lin_e(),
+ c_intp=savepoint.c_intp(),
+ e_flx_avg=savepoint.e_flx_avg(),
+ geofac_grdiv=savepoint.geofac_grdiv(),
+ geofac_rot=savepoint.geofac_rot(),
+ pos_on_tplane_e_1=savepoint.pos_on_tplane_e_x(),
+ pos_on_tplane_e_2=savepoint.pos_on_tplane_e_y(),
+ rbf_vec_coeff_e=savepoint.rbf_vec_coeff_e(),
+ e_bln_c_s=as_1D_sparse_field(savepoint.e_bln_c_s(), CEDim),
+ rbf_coeff_1=savepoint.rbf_vec_coeff_v1(),
+ rbf_coeff_2=savepoint.rbf_vec_coeff_v2(),
+ geofac_div=as_1D_sparse_field(savepoint.geofac_div(), CEDim),
+ geofac_n2s=savepoint.geofac_n2s(),
+ geofac_grg_x=grg[0],
+ geofac_grg_y=grg[1],
+ nudgecoeff_e=savepoint.nudgecoeff_e(),
+ )
+
+
+def construct_nh_metric_state(savepoint: MetricSavepoint, num_k_lev) -> MetricStateNonHydro:
+ return MetricStateNonHydro(
+ bdy_halo_c=savepoint.bdy_halo_c(),
+ mask_prog_halo_c=savepoint.mask_prog_halo_c(),
+ rayleigh_w=savepoint.rayleigh_w(),
+ exner_exfac=savepoint.exner_exfac(),
+ exner_ref_mc=savepoint.exner_ref_mc(),
+ wgtfac_c=savepoint.wgtfac_c(),
+ wgtfacq_c=savepoint.wgtfacq_c_dsl(),
+ inv_ddqz_z_full=savepoint.inv_ddqz_z_full(),
+ rho_ref_mc=savepoint.rho_ref_mc(),
+ theta_ref_mc=savepoint.theta_ref_mc(),
+ vwind_expl_wgt=savepoint.vwind_expl_wgt(),
+ d_exner_dz_ref_ic=savepoint.d_exner_dz_ref_ic(),
+ ddqz_z_half=savepoint.ddqz_z_half(),
+ theta_ref_ic=savepoint.theta_ref_ic(),
+ d2dexdz2_fac1_mc=savepoint.d2dexdz2_fac1_mc(),
+ d2dexdz2_fac2_mc=savepoint.d2dexdz2_fac2_mc(),
+ rho_ref_me=savepoint.rho_ref_me(),
+ theta_ref_me=savepoint.theta_ref_me(),
+ ddxn_z_full=savepoint.ddxn_z_full(),
+ zdiff_gradp=savepoint.zdiff_gradp(),
+ vertoffset_gradp=savepoint.vertoffset_gradp(),
+ ipeidx_dsl=savepoint.ipeidx_dsl(),
+ pg_exdist=savepoint.pg_exdist(),
+ ddqz_z_full_e=savepoint.ddqz_z_full_e(),
+ ddxt_z_full=savepoint.ddxt_z_full(),
+ wgtfac_e=savepoint.wgtfac_e(),
+ wgtfacq_e=savepoint.wgtfacq_e_dsl(num_k_lev),
+ vwind_impl_wgt=savepoint.vwind_impl_wgt(),
+ hmask_dd3d=savepoint.hmask_dd3d(),
+ scalfac_dd3d=savepoint.scalfac_dd3d(),
+ coeff1_dwdz=savepoint.coeff1_dwdz(),
+ coeff2_dwdz=savepoint.coeff2_dwdz(),
+ coeff_gradekin=savepoint.coeff_gradekin(),
+ )
diff --git a/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py b/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py
index 113aed4589..78e02152cf 100644
--- a/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py
+++ b/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py
@@ -24,9 +24,6 @@
DiffusionInterpolationState,
DiffusionMetricState,
)
-from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticState
-from icon4py.model.atmosphere.dycore.state_utils.interpolation_state import InterpolationState
-from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro
from icon4py.model.common import dimension
from icon4py.model.common.decomposition.definitions import DecompositionInfo
from icon4py.model.common.dimension import (
@@ -466,27 +463,6 @@ def construct_interpolation_state_for_diffusion(
nudgecoeff_e=self.nudgecoeff_e(),
)
- def construct_interpolation_state_for_nonhydro(self) -> InterpolationState:
- grg = self.geofac_grg()
- return InterpolationState(
- c_lin_e=self.c_lin_e(),
- c_intp=self.c_intp(),
- e_flx_avg=self.e_flx_avg(),
- geofac_grdiv=self.geofac_grdiv(),
- geofac_rot=self.geofac_rot(),
- pos_on_tplane_e_1=self.pos_on_tplane_e_x(),
- pos_on_tplane_e_2=self.pos_on_tplane_e_y(),
- rbf_vec_coeff_e=self.rbf_vec_coeff_e(),
- e_bln_c_s=as_1D_sparse_field(self.e_bln_c_s(), CEDim),
- rbf_coeff_1=self.rbf_vec_coeff_v1(),
- rbf_coeff_2=self.rbf_vec_coeff_v2(),
- geofac_div=as_1D_sparse_field(self.geofac_div(), CEDim),
- geofac_n2s=self.geofac_n2s(),
- geofac_grg_x=grg[0],
- geofac_grg_y=grg[1],
- nudgecoeff_e=self.nudgecoeff_e(),
- )
-
class MetricSavepoint(IconSavepoint):
def bdy_halo_c(self):
@@ -549,9 +525,6 @@ def vwind_impl_wgt(self):
def wgtfacq_c_dsl(self):
return self._get_field("wgtfacq_c_dsl", CellDim, KDim)
- def wgtfacq_c(self):
- return self._get_field("wgtfacq_c", CellDim, KDim)
-
def zdiff_gradp(self):
field = self._get_field("zdiff_gradp_dsl", EdgeDim, E2CDim, KDim)
return flatten_first_two_dims(ECDim, KDim, field=field)
@@ -633,43 +606,6 @@ def zd_vertidx(self):
def zd_indlist(self):
return np.squeeze(self.serializer.read("zd_indlist", self.savepoint))
- def construct_nh_metric_state(self, num_k_lev) -> MetricStateNonHydro:
- return MetricStateNonHydro(
- bdy_halo_c=self.bdy_halo_c(),
- mask_prog_halo_c=self.mask_prog_halo_c(),
- rayleigh_w=self.rayleigh_w(),
- exner_exfac=self.exner_exfac(),
- exner_ref_mc=self.exner_ref_mc(),
- wgtfac_c=self.wgtfac_c(),
- wgtfacq_c_dsl=self.wgtfacq_c_dsl(),
- inv_ddqz_z_full=self.inv_ddqz_z_full(),
- rho_ref_mc=self.rho_ref_mc(),
- theta_ref_mc=self.theta_ref_mc(),
- vwind_expl_wgt=self.vwind_expl_wgt(),
- d_exner_dz_ref_ic=self.d_exner_dz_ref_ic(),
- ddqz_z_half=self.ddqz_z_half(),
- theta_ref_ic=self.theta_ref_ic(),
- d2dexdz2_fac1_mc=self.d2dexdz2_fac1_mc(),
- d2dexdz2_fac2_mc=self.d2dexdz2_fac2_mc(),
- rho_ref_me=self.rho_ref_me(),
- theta_ref_me=self.theta_ref_me(),
- ddxn_z_full=self.ddxn_z_full(),
- zdiff_gradp=self.zdiff_gradp(),
- vertoffset_gradp=self.vertoffset_gradp(),
- ipeidx_dsl=self.ipeidx_dsl(),
- pg_exdist=self.pg_exdist(),
- ddqz_z_full_e=self.ddqz_z_full_e(),
- ddxt_z_full=self.ddxt_z_full(),
- wgtfac_e=self.wgtfac_e(),
- wgtfacq_e_dsl=self.wgtfacq_e_dsl(num_k_lev),
- vwind_impl_wgt=self.vwind_impl_wgt(),
- hmask_dd3d=self.hmask_dd3d(),
- scalfac_dd3d=self.scalfac_dd3d(),
- coeff1_dwdz=self.coeff1_dwdz(),
- coeff2_dwdz=self.coeff2_dwdz(),
- coeff_gradekin=self.coeff_gradekin(),
- )
-
def construct_metric_state_for_diffusion(self) -> DiffusionMetricState:
return DiffusionMetricState(
mask_hdiff=self.mask_hdiff(),
@@ -753,20 +689,6 @@ def construct_diagnostics_for_diffusion(self) -> DiffusionDiagnosticState:
dwdy=self.dwdy(),
)
- def construct_diagnostics(self) -> DiagnosticState:
- return DiagnosticState(
- hdef_ic=self.hdef_ic(),
- div_ic=self.div_ic(),
- dwdx=self.dwdx(),
- dwdy=self.dwdy(),
- vt=None,
- vn_ie=None,
- w_concorr_c=None,
- ddt_w_adv_pc_before=None,
- ddt_vn_apc_pc_before=None,
- ntnd=None,
- )
-
class IconNonHydroInitSavepoint(IconSavepoint):
def bdy_divdamp(self):
diff --git a/model/driver/src/icon4py/model/driver/dycore_driver.py b/model/driver/src/icon4py/model/driver/dycore_driver.py
index 9d30a2dc31..33ef96cdbc 100644
--- a/model/driver/src/icon4py/model/driver/dycore_driver.py
+++ b/model/driver/src/icon4py/model/driver/dycore_driver.py
@@ -24,9 +24,11 @@
NonHydrostaticParams,
SolveNonhydro,
)
-from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants
-from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection
+from icon4py.model.atmosphere.dycore.state_utils.states import (
+ DiagnosticStateNonHydro,
+ PrepAdvection,
+)
from icon4py.model.atmosphere.dycore.state_utils.z_fields import ZFields
from icon4py.model.common.decomposition.definitions import (
ProcessProperties,
@@ -151,7 +153,6 @@ def time_integration(
inital_divdamp_fac_o2: float,
do_prep_adv: bool,
):
-
log.info(
f"starting time loop for dtime={self.run_config.dtime} n_timesteps={self._n_time_steps}"
)
@@ -217,7 +218,6 @@ def _integrate_one_time_step(
inital_divdamp_fac_o2: float,
do_prep_adv: bool,
):
-
self._do_dyn_substepping(
solve_nonhydro_diagnostic_state,
prognostic_state_list,
@@ -247,7 +247,6 @@ def _do_dyn_substepping(
inital_divdamp_fac_o2: float,
do_prep_adv: bool,
):
-
# TODO (Chia Rui): compute airmass for prognostic_state here
do_recompute = True
diff --git a/model/driver/src/icon4py/model/driver/io_utils.py b/model/driver/src/icon4py/model/driver/io_utils.py
index 04cc8f0a2c..b21c206434 100644
--- a/model/driver/src/icon4py/model/driver/io_utils.py
+++ b/model/driver/src/icon4py/model/driver/io_utils.py
@@ -23,11 +23,13 @@
DiffusionInterpolationState,
DiffusionMetricState,
)
-from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
-from icon4py.model.atmosphere.dycore.state_utils.interpolation_state import InterpolationState
-from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants
-from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection
+from icon4py.model.atmosphere.dycore.state_utils.states import (
+ DiagnosticStateNonHydro,
+ InterpolationState,
+ MetricStateNonHydro,
+ PrepAdvection,
+)
from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate
from icon4py.model.atmosphere.dycore.state_utils.z_fields import ZFields
from icon4py.model.common.decomposition.definitions import DecompositionInfo, ProcessProperties
@@ -75,7 +77,6 @@ def read_icon_grid(
# TODO (Chia Rui): initialization of prognostic variables and topography of Jablonowski Williamson test
def model_initialization():
-
# create two prognostic states, nnow and nnew?
# at least two prognostic states are global because they are needed in the dycore, AND possibly nesting and restart processes in the future
# one is enough for the JW test
diff --git a/model/driver/tests/test_timeloop.py b/model/driver/tests/test_timeloop.py
index 5a94728dd3..fa72699114 100644
--- a/model/driver/tests/test_timeloop.py
+++ b/model/driver/tests/test_timeloop.py
@@ -22,16 +22,20 @@
NonHydrostaticParams,
SolveNonhydro,
)
-from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants
-from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection
+from icon4py.model.atmosphere.dycore.state_utils.states import (
+ DiagnosticStateNonHydro,
+ InterpolationState,
+ MetricStateNonHydro,
+ PrepAdvection,
+)
from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate
from icon4py.model.atmosphere.dycore.state_utils.z_fields import ZFields
-from icon4py.model.common.dimension import CellDim, EdgeDim, KDim
+from icon4py.model.common.dimension import CEDim, CellDim, EdgeDim, KDim
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams
from icon4py.model.common.grid.vertical import VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState
-from icon4py.model.common.test_utils.helpers import dallclose
+from icon4py.model.common.test_utils.helpers import as_1D_sparse_field, dallclose
from icon4py.model.driver.dycore_driver import TimeLoop
@@ -157,10 +161,60 @@ def test_run_timeloop_single_step(
wgt_nnew_vel=sp.wgt_nnew_vel(),
)
- nonhydro_interpolation_state = (
- interpolation_savepoint.construct_interpolation_state_for_nonhydro()
+ grg = interpolation_savepoint.geofac_grg()
+ nonhydro_interpolation_state = InterpolationState(
+ c_lin_e=interpolation_savepoint.c_lin_e(),
+ c_intp=interpolation_savepoint.c_intp(),
+ e_flx_avg=interpolation_savepoint.e_flx_avg(),
+ geofac_grdiv=interpolation_savepoint.geofac_grdiv(),
+ geofac_rot=interpolation_savepoint.geofac_rot(),
+ pos_on_tplane_e_1=interpolation_savepoint.pos_on_tplane_e_x(),
+ pos_on_tplane_e_2=interpolation_savepoint.pos_on_tplane_e_y(),
+ rbf_vec_coeff_e=interpolation_savepoint.rbf_vec_coeff_e(),
+ e_bln_c_s=as_1D_sparse_field(interpolation_savepoint.e_bln_c_s(), CEDim),
+ rbf_coeff_1=interpolation_savepoint.rbf_vec_coeff_v1(),
+ rbf_coeff_2=interpolation_savepoint.rbf_vec_coeff_v2(),
+ geofac_div=as_1D_sparse_field(interpolation_savepoint.geofac_div(), CEDim),
+ geofac_n2s=interpolation_savepoint.geofac_n2s(),
+ geofac_grg_x=grg[0],
+ geofac_grg_y=grg[1],
+ nudgecoeff_e=interpolation_savepoint.nudgecoeff_e(),
+ )
+ nonhydro_metric_state = MetricStateNonHydro(
+ bdy_halo_c=metrics_savepoint.bdy_halo_c(),
+ mask_prog_halo_c=metrics_savepoint.mask_prog_halo_c(),
+ rayleigh_w=metrics_savepoint.rayleigh_w(),
+ exner_exfac=metrics_savepoint.exner_exfac(),
+ exner_ref_mc=metrics_savepoint.exner_ref_mc(),
+ wgtfac_c=metrics_savepoint.wgtfac_c(),
+ wgtfacq_c=metrics_savepoint.wgtfacq_c_dsl(),
+ inv_ddqz_z_full=metrics_savepoint.inv_ddqz_z_full(),
+ rho_ref_mc=metrics_savepoint.rho_ref_mc(),
+ theta_ref_mc=metrics_savepoint.theta_ref_mc(),
+ vwind_expl_wgt=metrics_savepoint.vwind_expl_wgt(),
+ d_exner_dz_ref_ic=metrics_savepoint.d_exner_dz_ref_ic(),
+ ddqz_z_half=metrics_savepoint.ddqz_z_half(),
+ theta_ref_ic=metrics_savepoint.theta_ref_ic(),
+ d2dexdz2_fac1_mc=metrics_savepoint.d2dexdz2_fac1_mc(),
+ d2dexdz2_fac2_mc=metrics_savepoint.d2dexdz2_fac2_mc(),
+ rho_ref_me=metrics_savepoint.rho_ref_me(),
+ theta_ref_me=metrics_savepoint.theta_ref_me(),
+ ddxn_z_full=metrics_savepoint.ddxn_z_full(),
+ zdiff_gradp=metrics_savepoint.zdiff_gradp(),
+ vertoffset_gradp=metrics_savepoint.vertoffset_gradp(),
+ ipeidx_dsl=metrics_savepoint.ipeidx_dsl(),
+ pg_exdist=metrics_savepoint.pg_exdist(),
+ ddqz_z_full_e=metrics_savepoint.ddqz_z_full_e(),
+ ddxt_z_full=metrics_savepoint.ddxt_z_full(),
+ wgtfac_e=metrics_savepoint.wgtfac_e(),
+ wgtfacq_e=metrics_savepoint.wgtfacq_e_dsl(icon_grid.num_levels),
+ vwind_impl_wgt=metrics_savepoint.vwind_impl_wgt(),
+ hmask_dd3d=metrics_savepoint.hmask_dd3d(),
+ scalfac_dd3d=metrics_savepoint.scalfac_dd3d(),
+ coeff1_dwdz=metrics_savepoint.coeff1_dwdz(),
+ coeff2_dwdz=metrics_savepoint.coeff2_dwdz(),
+ coeff_gradekin=metrics_savepoint.coeff_gradekin(),
)
- nonhydro_metric_state = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels)
cell_geometry: CellParams = grid_savepoint.construct_cell_geometry()
edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry()
@@ -270,7 +324,6 @@ def test_run_timeloop_single_step(
)
if debug_mode:
-
script_dir = os.path.dirname(os.path.abspath(__file__))
base_dir = script_dir + "/"