Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unified prognostic state #289

Merged
merged 5 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_verify_diffusion_init_against_first_regular_savepoint(


@pytest.mark.datatest
@pytest.mark.parametrize("step_date_init", ["2021-06-20T12:00:50.000"])
@pytest.mark.parametrize("step_date_init", ["2021-06-20T12:00:20.000"])
def test_verify_diffusion_init_against_other_regular_savepoint(
r04b09_diffusion_config,
grid_savepoint,
Expand Down
8 changes: 3 additions & 5 deletions model/atmosphere/diffusion/diffusion_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@

import numpy as np

from icon4py.model.atmosphere.diffusion.diffusion_states import (
DiffusionDiagnosticState,
PrognosticState,
)
from icon4py.model.atmosphere.diffusion.diffusion_states import DiffusionDiagnosticState
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.serialbox_utils import IconDiffusionExitSavepoint


Expand Down Expand Up @@ -47,7 +45,7 @@ def verify_diffusion_fields(
ref_exner = np.asarray(diffusion_savepoint.exner())
ref_theta_v = np.asarray(diffusion_savepoint.theta_v())
val_theta_v = np.asarray(prognostic_state.theta_v)
val_exner = np.asarray(prognostic_state.exner_pressure)
val_exner = np.asarray(prognostic_state.exner)
assert np.allclose(ref_theta_v, val_theta_v)
assert np.allclose(ref_exner, val_exner)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
DiffusionDiagnosticState,
DiffusionInterpolationState,
DiffusionMetricState,
PrognosticState,
)
from icon4py.model.atmosphere.diffusion.diffusion_utils import (
copy_field,
Expand Down Expand Up @@ -78,6 +77,7 @@
from icon4py.model.common.interpolation.stencils.mo_intp_rbf_rbf_vec_interpol_vertex import (
mo_intp_rbf_rbf_vec_interpol_vertex,
)
from icon4py.model.common.states.prognostic_state import PrognosticState


# flake8: noqa
Expand Down Expand Up @@ -538,7 +538,7 @@ def _sync_cell_fields(self, prognostic_state):
CellDim,
prognostic_state.w,
prognostic_state.theta_v,
prognostic_state.exner_pressure,
prognostic_state.exner,
)
handle_cell_comm.wait()
log.debug("communication of prognostic cell fields: theta, w, exner - done")
Expand Down Expand Up @@ -837,7 +837,7 @@ def _do_diffusion_step(
z_temp=self.z_temp,
area=self.cell_params.area,
theta_v=prognostic_state.theta_v,
exner=prognostic_state.exner_pressure,
exner=prognostic_state.exner,
rd_o_cvd=self.rd_o_cvd,
horizontal_start=cell_start_nudging,
horizontal_end=cell_end_local,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,3 @@ def geofac_n2s_nbh(self) -> Field[[CECDim], float]:
old_shape[0] * old_shape[1],
)
)


@dataclass
class PrognosticState:
"""Class that contains the prognostic state.

Corresponds to ICON t_nh_prog
"""

w: Field[[CellDim, KDim], float] # vertical_wind field, w(nproma, nlevp1, nblks_c) [m/s]
vn: Field[[EdgeDim, KDim], float] # vn(nproma, nlev, nblks_e) [m/s]
exner_pressure: Field[[CellDim, KDim], float] # exner(nrpoma, nlev, nblks_c)
theta_v: Field[[CellDim, KDim], float] # (nproma, nlev, nlbks_c) [K]
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@
from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants
from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection
from icon4py.model.atmosphere.dycore.state_utils.prognostic_state import PrognosticState
from icon4py.model.atmosphere.dycore.state_utils.utils import (
_allocate,
_allocate_indices,
Expand All @@ -159,6 +158,7 @@
from icon4py.model.common.grid.horizontal import EdgeParams, HorizontalMarkerIndex
from icon4py.model.common.grid.icon_grid import IconGrid
from icon4py.model.common.grid.vertical import VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState


class NonHydrostaticConfig:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@
from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.interpolation_state import InterpolationState
from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.prognostic_state import PrognosticState
from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate, _allocate_indices
from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, VertexDim
from icon4py.model.common.grid.horizontal import HorizontalMarkerIndex
from icon4py.model.common.grid.icon_grid import IconGrid
from icon4py.model.common.grid.vertical import VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState


cached_backend = run_gtfn_cached
Expand Down
10 changes: 1 addition & 9 deletions model/atmosphere/dycore/tests/test_solve_nonhydro.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants
from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection
from icon4py.model.atmosphere.dycore.state_utils.prognostic_state import PrognosticState
from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate
from icon4py.model.atmosphere.dycore.state_utils.z_fields import ZFields
from icon4py.model.common.dimension import CellDim, EdgeDim, KDim
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams
from icon4py.model.common.grid.vertical import VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.helpers import dallclose, random_field, zero_field
from icon4py.model.common.test_utils.simple_mesh import SimpleMesh

Expand Down Expand Up @@ -131,7 +131,6 @@ def test_nonhydro_predictor_step(
prognostic_state_nnow = PrognosticState(
w=sp.w_now(),
vn=sp.vn_now(),
exner_pressure=None,
theta_v=sp.theta_v_now(),
rho=sp.rho_now(),
exner=sp.exner_now(),
Expand All @@ -140,7 +139,6 @@ def test_nonhydro_predictor_step(
prognostic_state_nnew = PrognosticState(
w=sp.w_new(),
vn=sp.vn_new(),
exner_pressure=None,
theta_v=sp.theta_v_new(),
rho=sp.rho_new(),
exner=sp.exner_new(),
Expand Down Expand Up @@ -526,7 +524,6 @@ def test_nonhydro_corrector_step(
prognostic_state_nnow = PrognosticState(
w=sp.w_now(),
vn=sp.vn_now(),
exner_pressure=None,
theta_v=sp.theta_v_now(),
rho=sp.rho_now(),
exner=sp.exner_now(),
Expand All @@ -535,7 +532,6 @@ def test_nonhydro_corrector_step(
prognostic_state_nnew = PrognosticState(
w=sp.w_new(),
vn=sp.vn_new(),
exner_pressure=None,
theta_v=sp.theta_v_new(),
rho=sp.rho_new(),
exner=sp.exner_new(),
Expand Down Expand Up @@ -750,7 +746,6 @@ def test_run_solve_nonhydro_single_step(
prognostic_state_nnow = PrognosticState(
w=sp.w_now(),
vn=sp.vn_now(),
exner_pressure=None,
theta_v=sp.theta_v_now(),
rho=sp.rho_now(),
exner=sp.exner_now(),
Expand All @@ -759,7 +754,6 @@ def test_run_solve_nonhydro_single_step(
prognostic_state_nnew = PrognosticState(
w=sp.w_new(),
vn=sp.vn_new(),
exner_pressure=None,
theta_v=sp.theta_v_new(),
rho=sp.rho_new(),
exner=sp.exner_new(),
Expand Down Expand Up @@ -927,7 +921,6 @@ def test_run_solve_nonhydro_multi_step(
prognostic_state_nnow = PrognosticState(
w=sp.w_now(),
vn=sp.vn_now(),
exner_pressure=None,
theta_v=sp.theta_v_now(),
rho=sp.rho_now(),
exner=sp.exner_now(),
Expand All @@ -936,7 +929,6 @@ def test_run_solve_nonhydro_multi_step(
prognostic_state_nnew = PrognosticState(
w=sp.w_new(),
vn=sp.vn_new(),
exner_pressure=None,
theta_v=sp.theta_v_new(),
rho=sp.rho_new(),
exner=sp.exner_new(),
Expand Down
4 changes: 1 addition & 3 deletions model/atmosphere/dycore/tests/test_velocity_advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from gt4py.next.ffront.fbuiltins import int32

from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.prognostic_state import PrognosticState
from icon4py.model.atmosphere.dycore.velocity.velocity_advection import VelocityAdvection
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams
from icon4py.model.common.grid.vertical import VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.helpers import dallclose


Expand Down Expand Up @@ -148,7 +148,6 @@ def test_velocity_predictor_step(
prognostic_state = PrognosticState(
w=sp_v.w(),
vn=sp_v.vn(),
exner_pressure=None,
theta_v=None,
rho=None,
exner=None,
Expand Down Expand Up @@ -310,7 +309,6 @@ def test_velocity_corrector_step(
prognostic_state = PrognosticState(
w=sp_v.w(),
vn=sp_v.vn(),
exner_pressure=None,
theta_v=None,
rho=None,
exner=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@
class PrognosticState:
"""Class that contains the prognostic state.

corresponds to ICON t_nh_prog
Corresponds to ICON t_nh_prog
"""

w: Field[[CellDim, KDim], float] # vertical_wind field, w(nproma, nlevp1, nblks_c) [m/s]
vn: Field[[EdgeDim, KDim], float] # vn(nproma, nlev, nblks_e) [m/s]
exner_pressure: Field[[CellDim, KDim], float] # exner(nrpoma, nlev, nblks_c)
theta_v: Field[[CellDim, KDim], float] # (nproma, nlev, nlbks_c) [K]

rho: Field[[CellDim, KDim], float]
exner: Field[[CellDim, KDim], float]
rho: Field[[CellDim, KDim], float] # density, rho(nproma, nlev, nblks_c) [m/s]
w: Field[[CellDim, KDim], float] # vertical_wind field, w(nproma, nlevp1, nblks_c) [m/s]
vn: Field[
[EdgeDim, KDim], float
] # horizontal wind normal to edges, vn(nproma, nlev, nblks_e) [m/s]
exner: Field[[CellDim, KDim], float] # exner function, exner(nrpoma, nlev, nblks_c)
theta_v: Field[[CellDim, KDim], float] # virtual temperature, (nproma, nlev, nlbks_c) [K]

@property
def w_1(self) -> Field[[CellDim], float]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticState
from icon4py.model.atmosphere.dycore.state_utils.interpolation_state import InterpolationState
from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.prognostic_state import PrognosticState
from icon4py.model.common import dimension
from icon4py.model.common.decomposition.definitions import DecompositionInfo
from icon4py.model.common.dimension import (
Expand All @@ -52,6 +51,7 @@
)
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams, HorizontalGridSize
from icon4py.model.common.grid.icon_grid import GridConfig, IconGrid, VerticalGridSize
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.helpers import as_1D_sparse_field, flatten_first_two_dims


Expand Down Expand Up @@ -716,10 +716,9 @@ def construct_prognostics(self) -> PrognosticState:
return PrognosticState(
w=self.w(),
vn=self.vn(),
exner_pressure=self.exner(),
exner=self.exner(),
theta_v=self.theta_v(),
rho=None,
exner=None,
rho=self.rho(),
)

def construct_diagnostics_for_diffusion(self) -> DiffusionDiagnosticState:
Expand Down
10 changes: 4 additions & 6 deletions model/driver/src/icon4py/model/driver/dycore_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn

from icon4py.model.atmosphere.diffusion.diffusion import Diffusion, DiffusionParams
from icon4py.model.atmosphere.diffusion.diffusion_states import (
DiffusionDiagnosticState,
PrognosticState,
)
from icon4py.model.atmosphere.diffusion.diffusion_states import DiffusionDiagnosticState
from icon4py.model.atmosphere.diffusion.diffusion_utils import _identity_c_k, _identity_e_k
from icon4py.model.common.decomposition.definitions import (
ProcessProperties,
Expand All @@ -34,6 +31,7 @@
get_runtype,
)
from icon4py.model.common.dimension import CellDim, EdgeDim, KDim
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils import serialbox_utils as sb
from icon4py.model.driver.icon_configuration import IconRunConfig, read_config
from icon4py.model.driver.io_utils import (
Expand Down Expand Up @@ -123,8 +121,8 @@ def do_dynamics_substepping(
prognostic_state.vn,
new_p.w,
prognostic_state.w,
new_p.exner_pressure,
prognostic_state.exner_pressure,
new_p.exner,
prognostic_state.exner,
new_p.theta_v,
prognostic_state.theta_v,
offset_provider={},
Expand Down
2 changes: 1 addition & 1 deletion model/driver/src/icon4py/model/driver/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
DiffusionDiagnosticState,
DiffusionInterpolationState,
DiffusionMetricState,
PrognosticState,
)
from icon4py.model.common.decomposition.definitions import DecompositionInfo, ProcessProperties
from icon4py.model.common.decomposition.mpi_decomposition import ParallelLogger
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams
from icon4py.model.common.grid.icon_grid import IconGrid
from icon4py.model.common.grid.vertical import VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils import serialbox_utils as sb


Expand Down
6 changes: 4 additions & 2 deletions tools/src/icon4pytools/py2f/wrappers/diffusion_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
DiffusionMetricState,
DiffusionParams,
DiffusionType,
PrognosticState,
)
from icon4py.model.common.dimension import (
C2E2CDim,
Expand All @@ -55,6 +54,7 @@
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams, HorizontalGridSize
from icon4py.model.common.grid.icon_grid import GridConfig, IconGrid
from icon4py.model.common.grid.vertical import VerticalGridSize, VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState

from icon4pytools.py2f.cffi_utils import CffiMethod, to_fields

Expand Down Expand Up @@ -244,12 +244,14 @@ def diffusion_run(
hdef_ic: Field[[CellDim, KDim], float],
dwdx: Field[[CellDim, KDim], float],
dwdy: Field[[CellDim, KDim], float],
rho: Field[[CellDim, KDim], float],
):
diagnostic_state = DiffusionDiagnosticState(hdef_ic, div_ic, dwdx, dwdy)
prognostic_state = PrognosticState(
rho=rho,
w=w,
vn=vn,
exner_pressure=exner,
exner=exner,
theta_v=theta_v,
)
if linit:
Expand Down