diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py index 1c92a21217..137bb20db9 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py @@ -13,7 +13,6 @@ import logging from typing import Final, Optional -import numpy as np from gt4py.next import as_field from gt4py.next.common import Field from gt4py.next.ffront.fbuiltins import int32 @@ -138,11 +137,13 @@ from icon4py.model.atmosphere.dycore.mo_solve_nonhydro_stencil_68 import ( mo_solve_nonhydro_stencil_68, ) -from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro -from icon4py.model.atmosphere.dycore.state_utils.interpolation_state import InterpolationState -from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants -from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection +from icon4py.model.atmosphere.dycore.state_utils.states import ( + DiagnosticStateNonHydro, + InterpolationState, + MetricStateNonHydro, + PrepAdvection, +) from icon4py.model.atmosphere.dycore.state_utils.utils import ( _allocate, _allocate_indices, @@ -497,6 +498,10 @@ def time_step( vertical_start=0, vertical_end=self.grid.num_levels, offset_provider={}, + horizontal_start=0, + horizontal_end=end_cell_end, + vertical_start=0, + vertical_end=self.grid.num_levels, ) mo_solve_nonhydro_stencil_67.with_backend(backend)( @@ -663,7 +668,7 @@ def run_predictor_step( if self.config.igradp_method == 3: nhsolve_prog.predictor_stencils_4_5_6.with_backend(backend)( - wgtfacq_c_dsl=self.metric_state_nonhydro.wgtfacq_c_dsl, + wgtfacq_c_dsl=self.metric_state_nonhydro.wgtfacq_c, z_exner_ex_pr=self.z_exner_ex_pr, z_exner_ic=self.z_exner_ic, wgtfac_c=self.metric_state_nonhydro.wgtfac_c, @@ -708,7 +713,7 @@ def run_predictor_step( # Perturbation theta at top and surface levels nhsolve_prog.predictor_stencils_11_lower_upper.with_backend(backend)( - wgtfacq_c_dsl=self.metric_state_nonhydro.wgtfacq_c_dsl, + wgtfacq_c_dsl=self.metric_state_nonhydro.wgtfacq_c, z_rth_pr=self.z_rth_pr_2, theta_ref_ic=self.metric_state_nonhydro.theta_ref_ic, z_theta_v_pr_ic=self.z_theta_v_pr_ic, @@ -1073,7 +1078,7 @@ def run_predictor_step( vn_ie=diagnostic_state_nh.vn_ie, z_vt_ie=z_fields.z_vt_ie, z_kin_hor_e=z_fields.z_kin_hor_e, - wgtfacq_e_dsl=self.metric_state_nonhydro.wgtfacq_e_dsl, + wgtfacq_e_dsl=self.metric_state_nonhydro.wgtfacq_e, k_field=self.k_field, nlev=self.grid.num_levels, horizontal_start=start_edge_lb_plus4, @@ -1087,7 +1092,7 @@ def run_predictor_step( e_bln_c_s=self.interpolation_state.e_bln_c_s, z_w_concorr_me=self.z_w_concorr_me, wgtfac_c=self.metric_state_nonhydro.wgtfac_c, - wgtfacq_c_dsl=self.metric_state_nonhydro.wgtfacq_c_dsl, + wgtfacq_c_dsl=self.metric_state_nonhydro.wgtfacq_c, w_concorr_c=diagnostic_state_nh.w_concorr_c, k_field=self.k_field, nflatlev_startindex_plus1=int32(self.vertical_params.nflatlev + 1), diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/diagnostic_state.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/diagnostic_state.py deleted file mode 100644 index f12aa39b55..0000000000 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/diagnostic_state.py +++ /dev/null @@ -1,88 +0,0 @@ -# ICON4Py - ICON inspired code in Python and GT4Py -# -# Copyright (c) 2022, ETH Zurich and MeteoSwiss -# All rights reserved. -# -# This file is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass - -from gt4py.next.common import Field - -from icon4py.model.common.dimension import CellDim, EdgeDim, KDim - - -@dataclass -class DiagnosticState: - # fields for 3D elements in turbdiff - hdef_ic: Field[ - [CellDim, KDim], float - ] # ! divergence at half levels(nproma,nlevp1,nblks_c) [1/s] - div_ic: Field[ - [CellDim, KDim], float - ] # ! horizontal wind field deformation (nproma,nlevp1,nblks_c) [1/s^2] - dwdx: Field[ - [CellDim, KDim], float - ] # zonal gradient of vertical wind speed (nproma,nlevp1,nblks_c) [1/s] - - dwdy: Field[ - [CellDim, KDim], float - ] # meridional gradient of vertical wind speed (nproma,nlevp1,nblks_c) - - vt: Field[[EdgeDim, KDim], float] - vn_ie: Field[ - [EdgeDim, KDim], float - ] # normal wind at half levels (nproma,nlevp1,nblks_e) [m/s] # TODO: change this back to KHalfDim, but how do we treat it wrt to field_operators and domain? - w_concorr_c: Field[ - [CellDim, KDim], float - ] # contravariant vert correction (nproma,nlevp1,nblks_c)[m/s] # TODO: change this back to KHalfDim, but how do we treat it wrt to field_operators and domain? - ddt_w_adv_pc: Field[[CellDim, KDim], float] - ddt_vn_apc_pc: Field[[EdgeDim, KDim], float] - ntnd: float - - -@dataclass -class DiagnosticStateNonHydro: - vt: Field[[EdgeDim, KDim], float] - vn_ie: Field[ - [EdgeDim, KDim], float - ] # normal wind at half levels (nproma,nlevp1,nblks_e) [m/s] # TODO: change this back to KHalfDim, but how do we treat it wrt to field_operators and domain? - w_concorr_c: Field[ - [CellDim, KDim], float - ] # contravariant vert correction (nproma,nlevp1,nblks_c)[m/s] # TODO: change this back to KHalfDim, but how do we treat it wrt to field_operators and domain? - theta_v_ic: Field[[CellDim, KDim], float] - exner_pr: Field[[CellDim, KDim], float] - rho_ic: Field[[CellDim, KDim], float] - ddt_exner_phy: Field[[CellDim, KDim], float] - grf_tend_rho: Field[[CellDim, KDim], float] - grf_tend_thv: Field[[CellDim, KDim], float] - grf_tend_w: Field[[CellDim, KDim], float] - mass_fl_e: Field[[EdgeDim, KDim], float] - ddt_vn_phy: Field[[EdgeDim, KDim], float] - grf_tend_vn: Field[[EdgeDim, KDim], float] - ddt_vn_apc_ntl1: Field[[EdgeDim, KDim], float] - ddt_vn_apc_ntl2: Field[[EdgeDim, KDim], float] - ddt_w_adv_ntl1: Field[[CellDim, KDim], float] - ddt_w_adv_ntl2: Field[[CellDim, KDim], float] - - # Analysis increments - rho_incr: Field[[EdgeDim, KDim], float] # moist density increment [kg/m^3] - vn_incr: Field[[EdgeDim, KDim], float] # normal velocity increment [m/s] - exner_incr: Field[[EdgeDim, KDim], float] # exner increment [- ] - - @property - def ddt_vn_apc_pc( - self, - ) -> tuple[Field[[EdgeDim, KDim], float], Field[[EdgeDim, KDim], float]]: - return (self.ddt_vn_apc_ntl1, self.ddt_vn_apc_ntl2) - - @property - def ddt_w_adv_pc( - self, - ) -> tuple[Field[[CellDim, KDim], float], Field[[CellDim, KDim], float]]: - return (self.ddt_w_adv_ntl1, self.ddt_w_adv_ntl2) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/interpolation_state.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/interpolation_state.py deleted file mode 100644 index 440c4a4b9c..0000000000 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/interpolation_state.py +++ /dev/null @@ -1,63 +0,0 @@ -# ICON4Py - ICON inspired code in Python and GT4Py -# -# Copyright (c) 2022, ETH Zurich and MeteoSwiss -# All rights reserved. -# -# This file is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from dataclasses import dataclass - -from gt4py.next.common import Field - -from icon4py.model.common.dimension import ( - C2E2CODim, - CEDim, - CellDim, - E2C2EDim, - E2C2EODim, - E2CDim, - ECDim, - EdgeDim, - V2CDim, - V2EDim, - VertexDim, -) - - -@dataclass -class InterpolationState: - """Represents the ICON interpolation state used int SolveNonHydro.""" - - e_bln_c_s: Field[[CEDim], float] # coefficent for bilinear interpolation from edge to cell () - rbf_coeff_1: Field[ - [VertexDim, V2EDim], float - ] # rbf_vec_coeff_v_1(nproma, rbf_vec_dim_v, nblks_v) - rbf_coeff_2: Field[ - [VertexDim, V2EDim], float - ] # rbf_vec_coeff_v_2(nproma, rbf_vec_dim_v, nblks_v) - - geofac_div: Field[[CEDim], float] # factor for divergence (nproma,cell_type,nblks_c) - - geofac_n2s: Field[ - [CellDim, C2E2CODim], float - ] # factor for nabla2-scalar (nproma,cell_type+1,nblks_c) - geofac_grg_x: Field[[CellDim, C2E2CODim], float] - geofac_grg_y: Field[ - [CellDim, C2E2CODim], float - ] # factors for green gauss gradient (nproma,4,nblks_c,2) - nudgecoeff_e: Field[[EdgeDim], float] # Nudgeing coeffients for edges - - c_lin_e: Field[[EdgeDim, E2CDim], float] - geofac_grdiv: Field[[EdgeDim, E2C2EODim], float] - rbf_vec_coeff_e: Field[[EdgeDim, E2C2EDim], float] - c_intp: Field[[VertexDim, V2CDim], float] - geofac_rot: Field[[VertexDim, V2EDim], float] - pos_on_tplane_e_1: Field[[ECDim], float] - pos_on_tplane_e_2: Field[[ECDim], float] - e_flx_avg: Field[[EdgeDim, E2C2EODim], float] diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/metric_state.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/metric_state.py deleted file mode 100644 index 3d32d340b9..0000000000 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/metric_state.py +++ /dev/null @@ -1,87 +0,0 @@ -# ICON4Py - ICON inspired code in Python and GT4Py -# -# Copyright (c) 2022, ETH Zurich and MeteoSwiss -# All rights reserved. -# -# This file is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from dataclasses import dataclass - -from gt4py.next.common import Field -from numpy import int32 - -from icon4py.model.common.dimension import CECDim, CellDim, ECDim, EdgeDim, KDim - - -@dataclass -class MetricState: - mask_hdiff: Field[[CellDim, KDim], bool] - theta_ref_mc: Field[[CellDim, KDim], float] - wgtfac_c: Field[ - [CellDim, KDim], float - ] # weighting factor for interpolation from full to half levels (nproma,nlevp1,nblks_c) - zd_vertidx: Field[[CECDim, KDim], int32] - zd_diffcoef: Field[[CellDim, KDim], float] - zd_intcoef: Field[[CECDim, KDim], float] - - coeff_gradekin: Field[[ECDim], float] - ddqz_z_full_e: Field[[EdgeDim, KDim], float] - wgtfac_e: Field[[EdgeDim, KDim], float] - wgtfacq_e_dsl: Field[[EdgeDim, KDim], float] - ddxn_z_full: Field[[EdgeDim, KDim], float] - ddxt_z_full: Field[[EdgeDim, KDim], float] - ddqz_z_half: Field[[CellDim, KDim], float] # half KDim ? - coeff1_dwdz: Field[[CellDim, KDim], float] - coeff2_dwdz: Field[[CellDim, KDim], float] - zd_vertoffset: Field[[CECDim, KDim], int32] - - -@dataclass -class MetricStateNonHydro: - bdy_halo_c: Field[[CellDim], bool] - # Finally, a mask field that excludes boundary halo points - mask_prog_halo_c: Field[[CellDim, KDim], bool] - rayleigh_w: Field[[KDim], float] - - wgtfac_c: Field[[CellDim, KDim], float] - wgtfacq_c_dsl: Field[[CellDim, KDim], float] - wgtfac_e: Field[[EdgeDim, KDim], float] - wgtfacq_e_dsl: Field[[EdgeDim, KDim], float] - - exner_exfac: Field[[CellDim, KDim], float] - exner_ref_mc: Field[[CellDim, KDim], float] - rho_ref_mc: Field[[CellDim, KDim], float] - theta_ref_mc: Field[[CellDim, KDim], float] - rho_ref_me: Field[[EdgeDim, KDim], float] - theta_ref_me: Field[[EdgeDim, KDim], float] - theta_ref_ic: Field[[CellDim, KDim], float] - - d_exner_dz_ref_ic: Field[[CellDim, KDim], float] - ddqz_z_half: Field[[CellDim, KDim], float] # half KDim ? - d2dexdz2_fac1_mc: Field[[CellDim, KDim], float] - d2dexdz2_fac2_mc: Field[[CellDim, KDim], float] - ddxn_z_full: Field[[EdgeDim, KDim], float] - ddqz_z_full_e: Field[[EdgeDim, KDim], float] - ddxt_z_full: Field[[EdgeDim, KDim], float] - inv_ddqz_z_full: Field[[CellDim, KDim], float] - - vertoffset_gradp: Field[[ECDim, KDim], float] - zdiff_gradp: Field[[ECDim, KDim], float] - ipeidx_dsl: Field[[EdgeDim, KDim], bool] - pg_exdist: Field[[EdgeDim, KDim], float] - - vwind_expl_wgt: Field[[CellDim], float] - vwind_impl_wgt: Field[[CellDim], float] - - hmask_dd3d: Field[[EdgeDim], float] - scalfac_dd3d: Field[[KDim], float] - - coeff1_dwdz: Field[[CellDim, KDim], float] - coeff2_dwdz: Field[[CellDim, KDim], float] - coeff_gradekin: Field[[ECDim], float] diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/prep_adv_state.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/prep_adv_state.py deleted file mode 100644 index 8a94ccafb4..0000000000 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/prep_adv_state.py +++ /dev/null @@ -1,25 +0,0 @@ -# ICON4Py - ICON inspired code in Python and GT4Py -# -# Copyright (c) 2022, ETH Zurich and MeteoSwiss -# All rights reserved. -# -# This file is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from dataclasses import dataclass - -from gt4py.next.common import Field - -from icon4py.model.common.dimension import CellDim, EdgeDim, KDim - - -@dataclass -class PrepAdvection: - vn_traj: Field[[EdgeDim, KDim], float] - mass_flx_me: Field[[EdgeDim, KDim], float] - mass_flx_ic: Field[[CellDim, KDim], float] diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/states.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/states.py new file mode 100644 index 0000000000..fb413282bf --- /dev/null +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/state_utils/states.py @@ -0,0 +1,157 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +from dataclasses import dataclass + +from gt4py.next.common import Field + +from icon4py.model.common.dimension import ( + C2E2CODim, + CEDim, + CellDim, + E2C2EDim, + E2C2EODim, + E2CDim, + ECDim, + EdgeDim, + KDim, + V2CDim, + V2EDim, + VertexDim, +) + + +@dataclass +class DiagnosticStateNonHydro: + vt: Field[[EdgeDim, KDim], float] + vn_ie: Field[ + [EdgeDim, KDim], float + ] # normal wind at half levels (nproma,nlevp1,nblks_e) [m/s] # TODO: change this back to KHalfDim, but how do we treat it wrt to field_operators and domain? + w_concorr_c: Field[ + [CellDim, KDim], float + ] # contravariant vert correction (nproma,nlevp1,nblks_c)[m/s] # TODO: change this back to KHalfDim, but how do we treat it wrt to field_operators and domain? + theta_v_ic: Field[[CellDim, KDim], float] + exner_pr: Field[[CellDim, KDim], float] + rho_ic: Field[[CellDim, KDim], float] + ddt_exner_phy: Field[[CellDim, KDim], float] + grf_tend_rho: Field[[CellDim, KDim], float] + grf_tend_thv: Field[[CellDim, KDim], float] + grf_tend_w: Field[[CellDim, KDim], float] + mass_fl_e: Field[[EdgeDim, KDim], float] + ddt_vn_phy: Field[[EdgeDim, KDim], float] + grf_tend_vn: Field[[EdgeDim, KDim], float] + ddt_vn_apc_ntl1: Field[[EdgeDim, KDim], float] + ddt_vn_apc_ntl2: Field[[EdgeDim, KDim], float] + ddt_w_adv_ntl1: Field[[CellDim, KDim], float] + ddt_w_adv_ntl2: Field[[CellDim, KDim], float] + + # Analysis increments + rho_incr: Field[[EdgeDim, KDim], float] # moist density increment [kg/m^3] + vn_incr: Field[[EdgeDim, KDim], float] # normal velocity increment [m/s] + exner_incr: Field[[EdgeDim, KDim], float] # exner increment [- ] + + @property + def ddt_vn_apc_pc( + self, + ) -> tuple[Field[[EdgeDim, KDim], float], Field[[EdgeDim, KDim], float]]: + return (self.ddt_vn_apc_ntl1, self.ddt_vn_apc_ntl2) + + @property + def ddt_w_adv_pc( + self, + ) -> tuple[Field[[CellDim, KDim], float], Field[[CellDim, KDim], float]]: + return (self.ddt_w_adv_ntl1, self.ddt_w_adv_ntl2) + + +@dataclass +class InterpolationState: + """Represents the ICON interpolation state used int SolveNonHydro.""" + + e_bln_c_s: Field[[CEDim], float] # coefficent for bilinear interpolation from edge to cell () + rbf_coeff_1: Field[ + [VertexDim, V2EDim], float + ] # rbf_vec_coeff_v_1(nproma, rbf_vec_dim_v, nblks_v) + rbf_coeff_2: Field[ + [VertexDim, V2EDim], float + ] # rbf_vec_coeff_v_2(nproma, rbf_vec_dim_v, nblks_v) + + geofac_div: Field[[CEDim], float] # factor for divergence (nproma,cell_type,nblks_c) + + geofac_n2s: Field[ + [CellDim, C2E2CODim], float + ] # factor for nabla2-scalar (nproma,cell_type+1,nblks_c) + geofac_grg_x: Field[[CellDim, C2E2CODim], float] + geofac_grg_y: Field[ + [CellDim, C2E2CODim], float + ] # factors for green gauss gradient (nproma,4,nblks_c,2) + nudgecoeff_e: Field[[EdgeDim], float] # Nudgeing coeffients for edges + + c_lin_e: Field[[EdgeDim, E2CDim], float] + geofac_grdiv: Field[[EdgeDim, E2C2EODim], float] + rbf_vec_coeff_e: Field[[EdgeDim, E2C2EDim], float] + c_intp: Field[[VertexDim, V2CDim], float] + geofac_rot: Field[[VertexDim, V2EDim], float] + pos_on_tplane_e_1: Field[[ECDim], float] + pos_on_tplane_e_2: Field[[ECDim], float] + e_flx_avg: Field[[EdgeDim, E2C2EODim], float] + + +@dataclass +class MetricStateNonHydro: + bdy_halo_c: Field[[CellDim], bool] + # Finally, a mask field that excludes boundary halo points + mask_prog_halo_c: Field[[CellDim, KDim], bool] + rayleigh_w: Field[[KDim], float] + + wgtfac_c: Field[[CellDim, KDim], float] + wgtfacq_c: Field[[CellDim, KDim], float] + wgtfac_e: Field[[EdgeDim, KDim], float] + wgtfacq_e: Field[[EdgeDim, KDim], float] + + exner_exfac: Field[[CellDim, KDim], float] + exner_ref_mc: Field[[CellDim, KDim], float] + rho_ref_mc: Field[[CellDim, KDim], float] + theta_ref_mc: Field[[CellDim, KDim], float] + rho_ref_me: Field[[EdgeDim, KDim], float] + theta_ref_me: Field[[EdgeDim, KDim], float] + theta_ref_ic: Field[[CellDim, KDim], float] + + d_exner_dz_ref_ic: Field[[CellDim, KDim], float] + ddqz_z_half: Field[[CellDim, KDim], float] # half KDim ? + d2dexdz2_fac1_mc: Field[[CellDim, KDim], float] + d2dexdz2_fac2_mc: Field[[CellDim, KDim], float] + ddxn_z_full: Field[[EdgeDim, KDim], float] + ddqz_z_full_e: Field[[EdgeDim, KDim], float] + ddxt_z_full: Field[[EdgeDim, KDim], float] + inv_ddqz_z_full: Field[[CellDim, KDim], float] + + vertoffset_gradp: Field[[ECDim, KDim], float] + zdiff_gradp: Field[[ECDim, KDim], float] + ipeidx_dsl: Field[[EdgeDim, KDim], bool] + pg_exdist: Field[[EdgeDim, KDim], float] + + vwind_expl_wgt: Field[[CellDim], float] + vwind_impl_wgt: Field[[CellDim], float] + + hmask_dd3d: Field[[EdgeDim], float] + scalfac_dd3d: Field[[KDim], float] + + coeff1_dwdz: Field[[CellDim, KDim], float] + coeff2_dwdz: Field[[CellDim, KDim], float] + coeff_gradekin: Field[[ECDim], float] + + +@dataclass +class PrepAdvection: + vn_traj: Field[[EdgeDim, KDim], float] + mass_flx_me: Field[[EdgeDim, KDim], float] + mass_flx_ic: Field[[CellDim, KDim], float] diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity/velocity_advection.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity/velocity_advection.py index e6340459d5..7888e09356 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity/velocity_advection.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity/velocity_advection.py @@ -55,9 +55,11 @@ from icon4py.model.atmosphere.dycore.mo_velocity_advection_stencil_20 import ( mo_velocity_advection_stencil_20, ) -from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro -from icon4py.model.atmosphere.dycore.state_utils.interpolation_state import InterpolationState -from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro +from icon4py.model.atmosphere.dycore.state_utils.states import ( + DiagnosticStateNonHydro, + InterpolationState, + MetricStateNonHydro, +) from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate, _allocate_indices from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, VertexDim from icon4py.model.common.grid.horizontal import EdgeParams, HorizontalMarkerIndex @@ -249,7 +251,7 @@ def run_predictor_step( offset_provider={}, ) velocity_prog.mo_velocity_advection_stencil_06.with_backend(backend)( - wgtfacq_e=self.metric_state.wgtfacq_e_dsl, + wgtfacq_e=self.metric_state.wgtfacq_e, vn=prognostic_state.vn, vn_ie=diagnostic_state.vn_ie, horizontal_start=start_edge_lb_plus4, diff --git a/model/atmosphere/dycore/tests/dycore_tests/mpi_tests/test_parallel_solve_nonhydro.py b/model/atmosphere/dycore/tests/dycore_tests/mpi_tests/test_parallel_solve_nonhydro.py index 2b2d6786db..8b27f60238 100644 --- a/model/atmosphere/dycore/tests/dycore_tests/mpi_tests/test_parallel_solve_nonhydro.py +++ b/model/atmosphere/dycore/tests/dycore_tests/mpi_tests/test_parallel_solve_nonhydro.py @@ -20,9 +20,11 @@ NonHydrostaticParams, SolveNonhydro, ) -from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants -from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection +from icon4py.model.atmosphere.dycore.state_utils.states import ( + DiagnosticStateNonHydro, + PrepAdvection, +) from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate from icon4py.model.atmosphere.dycore.state_utils.z_fields import ZFields from icon4py.model.common.decomposition import definitions diff --git a/model/atmosphere/dycore/tests/dycore_tests/test_solve_nonhydro.py b/model/atmosphere/dycore/tests/dycore_tests/test_solve_nonhydro.py index 7d56058655..6a7fad0b3c 100644 --- a/model/atmosphere/dycore/tests/dycore_tests/test_solve_nonhydro.py +++ b/model/atmosphere/dycore/tests/dycore_tests/test_solve_nonhydro.py @@ -20,9 +20,11 @@ NonHydrostaticParams, SolveNonhydro, ) -from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants -from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection +from icon4py.model.atmosphere.dycore.state_utils.states import ( + DiagnosticStateNonHydro, + PrepAdvection, +) from icon4py.model.atmosphere.dycore.state_utils.utils import ( _allocate, _calculate_bdy_divdamp, @@ -37,6 +39,8 @@ from icon4py.model.common.states.prognostic_state import PrognosticState from icon4py.model.common.test_utils.helpers import dallclose +from .utils import construct_interpolation_state_for_nonhydro, construct_nh_metric_state + backend = run_gtfn @@ -143,8 +147,8 @@ def test_nonhydro_predictor_step( z_fields = allocate_z_fields(icon_grid) - interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro() - metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels) + interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint) + metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels) cell_geometry: CellParams = grid_savepoint.construct_cell_geometry() edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry() @@ -574,8 +578,8 @@ def test_nonhydro_corrector_step( nh_constants = create_nh_constants(sp) divdamp_fac_o2 = sp.divdamp_fac_o2() - interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro() - metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels) + interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint) + metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels) cell_geometry: CellParams = grid_savepoint.construct_cell_geometry() edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry() @@ -723,8 +727,8 @@ def test_run_solve_nonhydro_single_step( nh_constants = create_nh_constants(sp) - interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro() - metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels) + interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint) + metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels) cell_geometry: CellParams = grid_savepoint.construct_cell_geometry() edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry() @@ -844,8 +848,8 @@ def test_run_solve_nonhydro_multi_step( nh_constants = create_nh_constants(sp) - interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro() - metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels) + interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint) + metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels) cell_geometry: CellParams = grid_savepoint.construct_cell_geometry() edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry() diff --git a/model/atmosphere/dycore/tests/dycore_tests/test_velocity_advection.py b/model/atmosphere/dycore/tests/dycore_tests/test_velocity_advection.py index f1b56b5c5d..0d6e99dc90 100644 --- a/model/atmosphere/dycore/tests/dycore_tests/test_velocity_advection.py +++ b/model/atmosphere/dycore/tests/dycore_tests/test_velocity_advection.py @@ -14,13 +14,15 @@ import pytest from gt4py.next.ffront.fbuiltins import int32 -from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro +from icon4py.model.atmosphere.dycore.state_utils.states import DiagnosticStateNonHydro from icon4py.model.atmosphere.dycore.velocity.velocity_advection import VelocityAdvection from icon4py.model.common.grid.horizontal import CellParams, EdgeParams from icon4py.model.common.grid.vertical import VerticalModelParams from icon4py.model.common.states.prognostic_state import PrognosticState from icon4py.model.common.test_utils.helpers import dallclose +from .utils import construct_interpolation_state_for_nonhydro, construct_nh_metric_state + @pytest.mark.datatest def test_scalfactors(savepoint_velocity_init, icon_grid): @@ -48,8 +50,8 @@ def test_velocity_init( step_date_init, damping_height, ): - interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro() - metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels) + interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint) + metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels) vertical_params = VerticalModelParams( vct_a=grid_savepoint.vct_a(), @@ -87,8 +89,8 @@ def test_verify_velocity_init_against_regular_savepoint( savepoint = savepoint_velocity_init dtime = savepoint.get_metadata("dtime").get("dtime") - interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro() - metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels) + interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint) + metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels) vertical_params = VerticalModelParams( vct_a=grid_savepoint.vct_a(), rayleigh_damping_height=damping_height, @@ -163,9 +165,9 @@ def test_velocity_predictor_step( rho=None, exner=None, ) - interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro() + interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint) - metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels) + metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels) cell_geometry: CellParams = grid_savepoint.construct_cell_geometry() edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry() @@ -213,7 +215,7 @@ def test_velocity_predictor_step( # stencil 07 if not vn_only: assert dallclose( - icon_result_z_v_grad_w.asnumpy()[3777:31558, :], + icon_result_z_v_grad_w[3777:31558, :], velocity_advection.z_v_grad_w.asnumpy()[3777:31558, :], atol=1.0e-16, ) @@ -313,9 +315,9 @@ def test_velocity_corrector_step( exner=None, ) - interpolation_state = interpolation_savepoint.construct_interpolation_state_for_nonhydro() + interpolation_state = construct_interpolation_state_for_nonhydro(interpolation_savepoint) - metric_state_nonhydro = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels) + metric_state_nonhydro = construct_nh_metric_state(metrics_savepoint, icon_grid.num_levels) cell_geometry: CellParams = grid_savepoint.construct_cell_geometry() edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry() diff --git a/model/atmosphere/dycore/tests/dycore_tests/utils.py b/model/atmosphere/dycore/tests/dycore_tests/utils.py new file mode 100644 index 0000000000..a6eee5b0f2 --- /dev/null +++ b/model/atmosphere/dycore/tests/dycore_tests/utils.py @@ -0,0 +1,82 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from icon4py.model.atmosphere.dycore.state_utils.states import ( + InterpolationState, + MetricStateNonHydro, +) +from icon4py.model.common.dimension import CEDim +from icon4py.model.common.test_utils.helpers import as_1D_sparse_field +from icon4py.model.common.test_utils.serialbox_utils import InterpolationSavepoint, MetricSavepoint + + +def construct_interpolation_state_for_nonhydro( + savepoint: InterpolationSavepoint, +) -> InterpolationState: + grg = savepoint.geofac_grg() + return InterpolationState( + c_lin_e=savepoint.c_lin_e(), + c_intp=savepoint.c_intp(), + e_flx_avg=savepoint.e_flx_avg(), + geofac_grdiv=savepoint.geofac_grdiv(), + geofac_rot=savepoint.geofac_rot(), + pos_on_tplane_e_1=savepoint.pos_on_tplane_e_x(), + pos_on_tplane_e_2=savepoint.pos_on_tplane_e_y(), + rbf_vec_coeff_e=savepoint.rbf_vec_coeff_e(), + e_bln_c_s=as_1D_sparse_field(savepoint.e_bln_c_s(), CEDim), + rbf_coeff_1=savepoint.rbf_vec_coeff_v1(), + rbf_coeff_2=savepoint.rbf_vec_coeff_v2(), + geofac_div=as_1D_sparse_field(savepoint.geofac_div(), CEDim), + geofac_n2s=savepoint.geofac_n2s(), + geofac_grg_x=grg[0], + geofac_grg_y=grg[1], + nudgecoeff_e=savepoint.nudgecoeff_e(), + ) + + +def construct_nh_metric_state(savepoint: MetricSavepoint, num_k_lev) -> MetricStateNonHydro: + return MetricStateNonHydro( + bdy_halo_c=savepoint.bdy_halo_c(), + mask_prog_halo_c=savepoint.mask_prog_halo_c(), + rayleigh_w=savepoint.rayleigh_w(), + exner_exfac=savepoint.exner_exfac(), + exner_ref_mc=savepoint.exner_ref_mc(), + wgtfac_c=savepoint.wgtfac_c(), + wgtfacq_c=savepoint.wgtfacq_c_dsl(), + inv_ddqz_z_full=savepoint.inv_ddqz_z_full(), + rho_ref_mc=savepoint.rho_ref_mc(), + theta_ref_mc=savepoint.theta_ref_mc(), + vwind_expl_wgt=savepoint.vwind_expl_wgt(), + d_exner_dz_ref_ic=savepoint.d_exner_dz_ref_ic(), + ddqz_z_half=savepoint.ddqz_z_half(), + theta_ref_ic=savepoint.theta_ref_ic(), + d2dexdz2_fac1_mc=savepoint.d2dexdz2_fac1_mc(), + d2dexdz2_fac2_mc=savepoint.d2dexdz2_fac2_mc(), + rho_ref_me=savepoint.rho_ref_me(), + theta_ref_me=savepoint.theta_ref_me(), + ddxn_z_full=savepoint.ddxn_z_full(), + zdiff_gradp=savepoint.zdiff_gradp(), + vertoffset_gradp=savepoint.vertoffset_gradp(), + ipeidx_dsl=savepoint.ipeidx_dsl(), + pg_exdist=savepoint.pg_exdist(), + ddqz_z_full_e=savepoint.ddqz_z_full_e(), + ddxt_z_full=savepoint.ddxt_z_full(), + wgtfac_e=savepoint.wgtfac_e(), + wgtfacq_e=savepoint.wgtfacq_e_dsl(num_k_lev), + vwind_impl_wgt=savepoint.vwind_impl_wgt(), + hmask_dd3d=savepoint.hmask_dd3d(), + scalfac_dd3d=savepoint.scalfac_dd3d(), + coeff1_dwdz=savepoint.coeff1_dwdz(), + coeff2_dwdz=savepoint.coeff2_dwdz(), + coeff_gradekin=savepoint.coeff_gradekin(), + ) diff --git a/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py b/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py index 113aed4589..78e02152cf 100644 --- a/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py +++ b/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py @@ -24,9 +24,6 @@ DiffusionInterpolationState, DiffusionMetricState, ) -from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticState -from icon4py.model.atmosphere.dycore.state_utils.interpolation_state import InterpolationState -from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro from icon4py.model.common import dimension from icon4py.model.common.decomposition.definitions import DecompositionInfo from icon4py.model.common.dimension import ( @@ -466,27 +463,6 @@ def construct_interpolation_state_for_diffusion( nudgecoeff_e=self.nudgecoeff_e(), ) - def construct_interpolation_state_for_nonhydro(self) -> InterpolationState: - grg = self.geofac_grg() - return InterpolationState( - c_lin_e=self.c_lin_e(), - c_intp=self.c_intp(), - e_flx_avg=self.e_flx_avg(), - geofac_grdiv=self.geofac_grdiv(), - geofac_rot=self.geofac_rot(), - pos_on_tplane_e_1=self.pos_on_tplane_e_x(), - pos_on_tplane_e_2=self.pos_on_tplane_e_y(), - rbf_vec_coeff_e=self.rbf_vec_coeff_e(), - e_bln_c_s=as_1D_sparse_field(self.e_bln_c_s(), CEDim), - rbf_coeff_1=self.rbf_vec_coeff_v1(), - rbf_coeff_2=self.rbf_vec_coeff_v2(), - geofac_div=as_1D_sparse_field(self.geofac_div(), CEDim), - geofac_n2s=self.geofac_n2s(), - geofac_grg_x=grg[0], - geofac_grg_y=grg[1], - nudgecoeff_e=self.nudgecoeff_e(), - ) - class MetricSavepoint(IconSavepoint): def bdy_halo_c(self): @@ -549,9 +525,6 @@ def vwind_impl_wgt(self): def wgtfacq_c_dsl(self): return self._get_field("wgtfacq_c_dsl", CellDim, KDim) - def wgtfacq_c(self): - return self._get_field("wgtfacq_c", CellDim, KDim) - def zdiff_gradp(self): field = self._get_field("zdiff_gradp_dsl", EdgeDim, E2CDim, KDim) return flatten_first_two_dims(ECDim, KDim, field=field) @@ -633,43 +606,6 @@ def zd_vertidx(self): def zd_indlist(self): return np.squeeze(self.serializer.read("zd_indlist", self.savepoint)) - def construct_nh_metric_state(self, num_k_lev) -> MetricStateNonHydro: - return MetricStateNonHydro( - bdy_halo_c=self.bdy_halo_c(), - mask_prog_halo_c=self.mask_prog_halo_c(), - rayleigh_w=self.rayleigh_w(), - exner_exfac=self.exner_exfac(), - exner_ref_mc=self.exner_ref_mc(), - wgtfac_c=self.wgtfac_c(), - wgtfacq_c_dsl=self.wgtfacq_c_dsl(), - inv_ddqz_z_full=self.inv_ddqz_z_full(), - rho_ref_mc=self.rho_ref_mc(), - theta_ref_mc=self.theta_ref_mc(), - vwind_expl_wgt=self.vwind_expl_wgt(), - d_exner_dz_ref_ic=self.d_exner_dz_ref_ic(), - ddqz_z_half=self.ddqz_z_half(), - theta_ref_ic=self.theta_ref_ic(), - d2dexdz2_fac1_mc=self.d2dexdz2_fac1_mc(), - d2dexdz2_fac2_mc=self.d2dexdz2_fac2_mc(), - rho_ref_me=self.rho_ref_me(), - theta_ref_me=self.theta_ref_me(), - ddxn_z_full=self.ddxn_z_full(), - zdiff_gradp=self.zdiff_gradp(), - vertoffset_gradp=self.vertoffset_gradp(), - ipeidx_dsl=self.ipeidx_dsl(), - pg_exdist=self.pg_exdist(), - ddqz_z_full_e=self.ddqz_z_full_e(), - ddxt_z_full=self.ddxt_z_full(), - wgtfac_e=self.wgtfac_e(), - wgtfacq_e_dsl=self.wgtfacq_e_dsl(num_k_lev), - vwind_impl_wgt=self.vwind_impl_wgt(), - hmask_dd3d=self.hmask_dd3d(), - scalfac_dd3d=self.scalfac_dd3d(), - coeff1_dwdz=self.coeff1_dwdz(), - coeff2_dwdz=self.coeff2_dwdz(), - coeff_gradekin=self.coeff_gradekin(), - ) - def construct_metric_state_for_diffusion(self) -> DiffusionMetricState: return DiffusionMetricState( mask_hdiff=self.mask_hdiff(), @@ -753,20 +689,6 @@ def construct_diagnostics_for_diffusion(self) -> DiffusionDiagnosticState: dwdy=self.dwdy(), ) - def construct_diagnostics(self) -> DiagnosticState: - return DiagnosticState( - hdef_ic=self.hdef_ic(), - div_ic=self.div_ic(), - dwdx=self.dwdx(), - dwdy=self.dwdy(), - vt=None, - vn_ie=None, - w_concorr_c=None, - ddt_w_adv_pc_before=None, - ddt_vn_apc_pc_before=None, - ntnd=None, - ) - class IconNonHydroInitSavepoint(IconSavepoint): def bdy_divdamp(self): diff --git a/model/driver/src/icon4py/model/driver/dycore_driver.py b/model/driver/src/icon4py/model/driver/dycore_driver.py index 9d30a2dc31..33ef96cdbc 100644 --- a/model/driver/src/icon4py/model/driver/dycore_driver.py +++ b/model/driver/src/icon4py/model/driver/dycore_driver.py @@ -24,9 +24,11 @@ NonHydrostaticParams, SolveNonhydro, ) -from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants -from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection +from icon4py.model.atmosphere.dycore.state_utils.states import ( + DiagnosticStateNonHydro, + PrepAdvection, +) from icon4py.model.atmosphere.dycore.state_utils.z_fields import ZFields from icon4py.model.common.decomposition.definitions import ( ProcessProperties, @@ -151,7 +153,6 @@ def time_integration( inital_divdamp_fac_o2: float, do_prep_adv: bool, ): - log.info( f"starting time loop for dtime={self.run_config.dtime} n_timesteps={self._n_time_steps}" ) @@ -217,7 +218,6 @@ def _integrate_one_time_step( inital_divdamp_fac_o2: float, do_prep_adv: bool, ): - self._do_dyn_substepping( solve_nonhydro_diagnostic_state, prognostic_state_list, @@ -247,7 +247,6 @@ def _do_dyn_substepping( inital_divdamp_fac_o2: float, do_prep_adv: bool, ): - # TODO (Chia Rui): compute airmass for prognostic_state here do_recompute = True diff --git a/model/driver/src/icon4py/model/driver/io_utils.py b/model/driver/src/icon4py/model/driver/io_utils.py index 04cc8f0a2c..b21c206434 100644 --- a/model/driver/src/icon4py/model/driver/io_utils.py +++ b/model/driver/src/icon4py/model/driver/io_utils.py @@ -23,11 +23,13 @@ DiffusionInterpolationState, DiffusionMetricState, ) -from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro -from icon4py.model.atmosphere.dycore.state_utils.interpolation_state import InterpolationState -from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants -from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection +from icon4py.model.atmosphere.dycore.state_utils.states import ( + DiagnosticStateNonHydro, + InterpolationState, + MetricStateNonHydro, + PrepAdvection, +) from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate from icon4py.model.atmosphere.dycore.state_utils.z_fields import ZFields from icon4py.model.common.decomposition.definitions import DecompositionInfo, ProcessProperties @@ -75,7 +77,6 @@ def read_icon_grid( # TODO (Chia Rui): initialization of prognostic variables and topography of Jablonowski Williamson test def model_initialization(): - # create two prognostic states, nnow and nnew? # at least two prognostic states are global because they are needed in the dycore, AND possibly nesting and restart processes in the future # one is enough for the JW test diff --git a/model/driver/tests/test_timeloop.py b/model/driver/tests/test_timeloop.py index 5a94728dd3..fa72699114 100644 --- a/model/driver/tests/test_timeloop.py +++ b/model/driver/tests/test_timeloop.py @@ -22,16 +22,20 @@ NonHydrostaticParams, SolveNonhydro, ) -from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants -from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection +from icon4py.model.atmosphere.dycore.state_utils.states import ( + DiagnosticStateNonHydro, + InterpolationState, + MetricStateNonHydro, + PrepAdvection, +) from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate from icon4py.model.atmosphere.dycore.state_utils.z_fields import ZFields -from icon4py.model.common.dimension import CellDim, EdgeDim, KDim +from icon4py.model.common.dimension import CEDim, CellDim, EdgeDim, KDim from icon4py.model.common.grid.horizontal import CellParams, EdgeParams from icon4py.model.common.grid.vertical import VerticalModelParams from icon4py.model.common.states.prognostic_state import PrognosticState -from icon4py.model.common.test_utils.helpers import dallclose +from icon4py.model.common.test_utils.helpers import as_1D_sparse_field, dallclose from icon4py.model.driver.dycore_driver import TimeLoop @@ -157,10 +161,60 @@ def test_run_timeloop_single_step( wgt_nnew_vel=sp.wgt_nnew_vel(), ) - nonhydro_interpolation_state = ( - interpolation_savepoint.construct_interpolation_state_for_nonhydro() + grg = interpolation_savepoint.geofac_grg() + nonhydro_interpolation_state = InterpolationState( + c_lin_e=interpolation_savepoint.c_lin_e(), + c_intp=interpolation_savepoint.c_intp(), + e_flx_avg=interpolation_savepoint.e_flx_avg(), + geofac_grdiv=interpolation_savepoint.geofac_grdiv(), + geofac_rot=interpolation_savepoint.geofac_rot(), + pos_on_tplane_e_1=interpolation_savepoint.pos_on_tplane_e_x(), + pos_on_tplane_e_2=interpolation_savepoint.pos_on_tplane_e_y(), + rbf_vec_coeff_e=interpolation_savepoint.rbf_vec_coeff_e(), + e_bln_c_s=as_1D_sparse_field(interpolation_savepoint.e_bln_c_s(), CEDim), + rbf_coeff_1=interpolation_savepoint.rbf_vec_coeff_v1(), + rbf_coeff_2=interpolation_savepoint.rbf_vec_coeff_v2(), + geofac_div=as_1D_sparse_field(interpolation_savepoint.geofac_div(), CEDim), + geofac_n2s=interpolation_savepoint.geofac_n2s(), + geofac_grg_x=grg[0], + geofac_grg_y=grg[1], + nudgecoeff_e=interpolation_savepoint.nudgecoeff_e(), + ) + nonhydro_metric_state = MetricStateNonHydro( + bdy_halo_c=metrics_savepoint.bdy_halo_c(), + mask_prog_halo_c=metrics_savepoint.mask_prog_halo_c(), + rayleigh_w=metrics_savepoint.rayleigh_w(), + exner_exfac=metrics_savepoint.exner_exfac(), + exner_ref_mc=metrics_savepoint.exner_ref_mc(), + wgtfac_c=metrics_savepoint.wgtfac_c(), + wgtfacq_c=metrics_savepoint.wgtfacq_c_dsl(), + inv_ddqz_z_full=metrics_savepoint.inv_ddqz_z_full(), + rho_ref_mc=metrics_savepoint.rho_ref_mc(), + theta_ref_mc=metrics_savepoint.theta_ref_mc(), + vwind_expl_wgt=metrics_savepoint.vwind_expl_wgt(), + d_exner_dz_ref_ic=metrics_savepoint.d_exner_dz_ref_ic(), + ddqz_z_half=metrics_savepoint.ddqz_z_half(), + theta_ref_ic=metrics_savepoint.theta_ref_ic(), + d2dexdz2_fac1_mc=metrics_savepoint.d2dexdz2_fac1_mc(), + d2dexdz2_fac2_mc=metrics_savepoint.d2dexdz2_fac2_mc(), + rho_ref_me=metrics_savepoint.rho_ref_me(), + theta_ref_me=metrics_savepoint.theta_ref_me(), + ddxn_z_full=metrics_savepoint.ddxn_z_full(), + zdiff_gradp=metrics_savepoint.zdiff_gradp(), + vertoffset_gradp=metrics_savepoint.vertoffset_gradp(), + ipeidx_dsl=metrics_savepoint.ipeidx_dsl(), + pg_exdist=metrics_savepoint.pg_exdist(), + ddqz_z_full_e=metrics_savepoint.ddqz_z_full_e(), + ddxt_z_full=metrics_savepoint.ddxt_z_full(), + wgtfac_e=metrics_savepoint.wgtfac_e(), + wgtfacq_e=metrics_savepoint.wgtfacq_e_dsl(icon_grid.num_levels), + vwind_impl_wgt=metrics_savepoint.vwind_impl_wgt(), + hmask_dd3d=metrics_savepoint.hmask_dd3d(), + scalfac_dd3d=metrics_savepoint.scalfac_dd3d(), + coeff1_dwdz=metrics_savepoint.coeff1_dwdz(), + coeff2_dwdz=metrics_savepoint.coeff2_dwdz(), + coeff_gradekin=metrics_savepoint.coeff_gradekin(), ) - nonhydro_metric_state = metrics_savepoint.construct_nh_metric_state(icon_grid.num_levels) cell_geometry: CellParams = grid_savepoint.construct_cell_geometry() edge_geometry: EdgeParams = grid_savepoint.construct_edge_geometry() @@ -270,7 +324,6 @@ def test_run_timeloop_single_step( ) if debug_mode: - script_dir = os.path.dirname(os.path.abspath(__file__)) base_dir = script_dir + "/"