Skip to content

Commit

Permalink
Unified prognostic state (#289)
Browse files Browse the repository at this point in the history
* Moving prognostics to common

* Move prognostic state module to common/states

* Modified diffusion wrapper and test_verify_diffusion_init_against_other_regular_savepoint in test_diffusion.py

* Modified the order of imported modules in prognostic_state.py
  • Loading branch information
OngChia authored Oct 10, 2023
1 parent b898987 commit 276bdde
Show file tree
Hide file tree
Showing 13 changed files with 31 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_verify_diffusion_init_against_first_regular_savepoint(


@pytest.mark.datatest
@pytest.mark.parametrize("step_date_init", ["2021-06-20T12:00:50.000"])
@pytest.mark.parametrize("step_date_init", ["2021-06-20T12:00:20.000"])
def test_verify_diffusion_init_against_other_regular_savepoint(
r04b09_diffusion_config,
grid_savepoint,
Expand Down
8 changes: 3 additions & 5 deletions model/atmosphere/diffusion/diffusion_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@

import numpy as np

from icon4py.model.atmosphere.diffusion.diffusion_states import (
DiffusionDiagnosticState,
PrognosticState,
)
from icon4py.model.atmosphere.diffusion.diffusion_states import DiffusionDiagnosticState
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.serialbox_utils import IconDiffusionExitSavepoint


Expand Down Expand Up @@ -47,7 +45,7 @@ def verify_diffusion_fields(
ref_exner = np.asarray(diffusion_savepoint.exner())
ref_theta_v = np.asarray(diffusion_savepoint.theta_v())
val_theta_v = np.asarray(prognostic_state.theta_v)
val_exner = np.asarray(prognostic_state.exner_pressure)
val_exner = np.asarray(prognostic_state.exner)
assert np.allclose(ref_theta_v, val_theta_v)
assert np.allclose(ref_exner, val_exner)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
DiffusionDiagnosticState,
DiffusionInterpolationState,
DiffusionMetricState,
PrognosticState,
)
from icon4py.model.atmosphere.diffusion.diffusion_utils import (
copy_field,
Expand Down Expand Up @@ -78,6 +77,7 @@
from icon4py.model.common.interpolation.stencils.mo_intp_rbf_rbf_vec_interpol_vertex import (
mo_intp_rbf_rbf_vec_interpol_vertex,
)
from icon4py.model.common.states.prognostic_state import PrognosticState


# flake8: noqa
Expand Down Expand Up @@ -538,7 +538,7 @@ def _sync_cell_fields(self, prognostic_state):
CellDim,
prognostic_state.w,
prognostic_state.theta_v,
prognostic_state.exner_pressure,
prognostic_state.exner,
)
handle_cell_comm.wait()
log.debug("communication of prognostic cell fields: theta, w, exner - done")
Expand Down Expand Up @@ -837,7 +837,7 @@ def _do_diffusion_step(
z_temp=self.z_temp,
area=self.cell_params.area,
theta_v=prognostic_state.theta_v,
exner=prognostic_state.exner_pressure,
exner=prognostic_state.exner,
rd_o_cvd=self.rd_o_cvd,
horizontal_start=cell_start_nudging,
horizontal_end=cell_end_local,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,3 @@ def geofac_n2s_nbh(self) -> Field[[CECDim], float]:
old_shape[0] * old_shape[1],
)
)


@dataclass
class PrognosticState:
"""Class that contains the prognostic state.
Corresponds to ICON t_nh_prog
"""

w: Field[[CellDim, KDim], float] # vertical_wind field, w(nproma, nlevp1, nblks_c) [m/s]
vn: Field[[EdgeDim, KDim], float] # vn(nproma, nlev, nblks_e) [m/s]
exner_pressure: Field[[CellDim, KDim], float] # exner(nrpoma, nlev, nblks_c)
theta_v: Field[[CellDim, KDim], float] # (nproma, nlev, nlbks_c) [K]
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@
from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants
from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection
from icon4py.model.atmosphere.dycore.state_utils.prognostic_state import PrognosticState
from icon4py.model.atmosphere.dycore.state_utils.utils import (
_allocate,
_allocate_indices,
Expand All @@ -159,6 +158,7 @@
from icon4py.model.common.grid.horizontal import EdgeParams, HorizontalMarkerIndex
from icon4py.model.common.grid.icon_grid import IconGrid
from icon4py.model.common.grid.vertical import VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState


class NonHydrostaticConfig:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@
from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.interpolation_state import InterpolationState
from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.prognostic_state import PrognosticState
from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate, _allocate_indices
from icon4py.model.common.dimension import CellDim, EdgeDim, KDim, VertexDim
from icon4py.model.common.grid.horizontal import HorizontalMarkerIndex
from icon4py.model.common.grid.icon_grid import IconGrid
from icon4py.model.common.grid.vertical import VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState


cached_backend = run_gtfn_cached
Expand Down
10 changes: 1 addition & 9 deletions model/atmosphere/dycore/tests/test_solve_nonhydro.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.nh_constants import NHConstants
from icon4py.model.atmosphere.dycore.state_utils.prep_adv_state import PrepAdvection
from icon4py.model.atmosphere.dycore.state_utils.prognostic_state import PrognosticState
from icon4py.model.atmosphere.dycore.state_utils.utils import _allocate
from icon4py.model.atmosphere.dycore.state_utils.z_fields import ZFields
from icon4py.model.common.dimension import CellDim, EdgeDim, KDim
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams
from icon4py.model.common.grid.vertical import VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.helpers import dallclose, random_field, zero_field
from icon4py.model.common.test_utils.simple_mesh import SimpleMesh

Expand Down Expand Up @@ -131,7 +131,6 @@ def test_nonhydro_predictor_step(
prognostic_state_nnow = PrognosticState(
w=sp.w_now(),
vn=sp.vn_now(),
exner_pressure=None,
theta_v=sp.theta_v_now(),
rho=sp.rho_now(),
exner=sp.exner_now(),
Expand All @@ -140,7 +139,6 @@ def test_nonhydro_predictor_step(
prognostic_state_nnew = PrognosticState(
w=sp.w_new(),
vn=sp.vn_new(),
exner_pressure=None,
theta_v=sp.theta_v_new(),
rho=sp.rho_new(),
exner=sp.exner_new(),
Expand Down Expand Up @@ -526,7 +524,6 @@ def test_nonhydro_corrector_step(
prognostic_state_nnow = PrognosticState(
w=sp.w_now(),
vn=sp.vn_now(),
exner_pressure=None,
theta_v=sp.theta_v_now(),
rho=sp.rho_now(),
exner=sp.exner_now(),
Expand All @@ -535,7 +532,6 @@ def test_nonhydro_corrector_step(
prognostic_state_nnew = PrognosticState(
w=sp.w_new(),
vn=sp.vn_new(),
exner_pressure=None,
theta_v=sp.theta_v_new(),
rho=sp.rho_new(),
exner=sp.exner_new(),
Expand Down Expand Up @@ -750,7 +746,6 @@ def test_run_solve_nonhydro_single_step(
prognostic_state_nnow = PrognosticState(
w=sp.w_now(),
vn=sp.vn_now(),
exner_pressure=None,
theta_v=sp.theta_v_now(),
rho=sp.rho_now(),
exner=sp.exner_now(),
Expand All @@ -759,7 +754,6 @@ def test_run_solve_nonhydro_single_step(
prognostic_state_nnew = PrognosticState(
w=sp.w_new(),
vn=sp.vn_new(),
exner_pressure=None,
theta_v=sp.theta_v_new(),
rho=sp.rho_new(),
exner=sp.exner_new(),
Expand Down Expand Up @@ -927,7 +921,6 @@ def test_run_solve_nonhydro_multi_step(
prognostic_state_nnow = PrognosticState(
w=sp.w_now(),
vn=sp.vn_now(),
exner_pressure=None,
theta_v=sp.theta_v_now(),
rho=sp.rho_now(),
exner=sp.exner_now(),
Expand All @@ -936,7 +929,6 @@ def test_run_solve_nonhydro_multi_step(
prognostic_state_nnew = PrognosticState(
w=sp.w_new(),
vn=sp.vn_new(),
exner_pressure=None,
theta_v=sp.theta_v_new(),
rho=sp.rho_new(),
exner=sp.exner_new(),
Expand Down
4 changes: 1 addition & 3 deletions model/atmosphere/dycore/tests/test_velocity_advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from gt4py.next.ffront.fbuiltins import int32

from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.prognostic_state import PrognosticState
from icon4py.model.atmosphere.dycore.velocity.velocity_advection import VelocityAdvection
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams
from icon4py.model.common.grid.vertical import VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.helpers import dallclose


Expand Down Expand Up @@ -148,7 +148,6 @@ def test_velocity_predictor_step(
prognostic_state = PrognosticState(
w=sp_v.w(),
vn=sp_v.vn(),
exner_pressure=None,
theta_v=None,
rho=None,
exner=None,
Expand Down Expand Up @@ -310,7 +309,6 @@ def test_velocity_corrector_step(
prognostic_state = PrognosticState(
w=sp_v.w(),
vn=sp_v.vn(),
exner_pressure=None,
theta_v=None,
rho=None,
exner=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@
class PrognosticState:
"""Class that contains the prognostic state.
corresponds to ICON t_nh_prog
Corresponds to ICON t_nh_prog
"""

w: Field[[CellDim, KDim], float] # vertical_wind field, w(nproma, nlevp1, nblks_c) [m/s]
vn: Field[[EdgeDim, KDim], float] # vn(nproma, nlev, nblks_e) [m/s]
exner_pressure: Field[[CellDim, KDim], float] # exner(nrpoma, nlev, nblks_c)
theta_v: Field[[CellDim, KDim], float] # (nproma, nlev, nlbks_c) [K]

rho: Field[[CellDim, KDim], float]
exner: Field[[CellDim, KDim], float]
rho: Field[[CellDim, KDim], float] # density, rho(nproma, nlev, nblks_c) [m/s]
w: Field[[CellDim, KDim], float] # vertical_wind field, w(nproma, nlevp1, nblks_c) [m/s]
vn: Field[
[EdgeDim, KDim], float
] # horizontal wind normal to edges, vn(nproma, nlev, nblks_e) [m/s]
exner: Field[[CellDim, KDim], float] # exner function, exner(nrpoma, nlev, nblks_c)
theta_v: Field[[CellDim, KDim], float] # virtual temperature, (nproma, nlev, nlbks_c) [K]

@property
def w_1(self) -> Field[[CellDim], float]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from icon4py.model.atmosphere.dycore.state_utils.diagnostic_state import DiagnosticState
from icon4py.model.atmosphere.dycore.state_utils.interpolation_state import InterpolationState
from icon4py.model.atmosphere.dycore.state_utils.metric_state import MetricStateNonHydro
from icon4py.model.atmosphere.dycore.state_utils.prognostic_state import PrognosticState
from icon4py.model.common import dimension
from icon4py.model.common.decomposition.definitions import DecompositionInfo
from icon4py.model.common.dimension import (
Expand All @@ -52,6 +51,7 @@
)
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams, HorizontalGridSize
from icon4py.model.common.grid.icon_grid import GridConfig, IconGrid, VerticalGridSize
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils.helpers import as_1D_sparse_field, flatten_first_two_dims


Expand Down Expand Up @@ -716,10 +716,9 @@ def construct_prognostics(self) -> PrognosticState:
return PrognosticState(
w=self.w(),
vn=self.vn(),
exner_pressure=self.exner(),
exner=self.exner(),
theta_v=self.theta_v(),
rho=None,
exner=None,
rho=self.rho(),
)

def construct_diagnostics_for_diffusion(self) -> DiffusionDiagnosticState:
Expand Down
10 changes: 4 additions & 6 deletions model/driver/src/icon4py/model/driver/dycore_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn

from icon4py.model.atmosphere.diffusion.diffusion import Diffusion, DiffusionParams
from icon4py.model.atmosphere.diffusion.diffusion_states import (
DiffusionDiagnosticState,
PrognosticState,
)
from icon4py.model.atmosphere.diffusion.diffusion_states import DiffusionDiagnosticState
from icon4py.model.atmosphere.diffusion.diffusion_utils import _identity_c_k, _identity_e_k
from icon4py.model.common.decomposition.definitions import (
ProcessProperties,
Expand All @@ -34,6 +31,7 @@
get_runtype,
)
from icon4py.model.common.dimension import CellDim, EdgeDim, KDim
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils import serialbox_utils as sb
from icon4py.model.driver.icon_configuration import IconRunConfig, read_config
from icon4py.model.driver.io_utils import (
Expand Down Expand Up @@ -123,8 +121,8 @@ def do_dynamics_substepping(
prognostic_state.vn,
new_p.w,
prognostic_state.w,
new_p.exner_pressure,
prognostic_state.exner_pressure,
new_p.exner,
prognostic_state.exner,
new_p.theta_v,
prognostic_state.theta_v,
offset_provider={},
Expand Down
2 changes: 1 addition & 1 deletion model/driver/src/icon4py/model/driver/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
DiffusionDiagnosticState,
DiffusionInterpolationState,
DiffusionMetricState,
PrognosticState,
)
from icon4py.model.common.decomposition.definitions import DecompositionInfo, ProcessProperties
from icon4py.model.common.decomposition.mpi_decomposition import ParallelLogger
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams
from icon4py.model.common.grid.icon_grid import IconGrid
from icon4py.model.common.grid.vertical import VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils import serialbox_utils as sb


Expand Down
6 changes: 4 additions & 2 deletions tools/src/icon4pytools/py2f/wrappers/diffusion_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
DiffusionMetricState,
DiffusionParams,
DiffusionType,
PrognosticState,
)
from icon4py.model.common.dimension import (
C2E2CDim,
Expand All @@ -55,6 +54,7 @@
from icon4py.model.common.grid.horizontal import CellParams, EdgeParams, HorizontalGridSize
from icon4py.model.common.grid.icon_grid import GridConfig, IconGrid
from icon4py.model.common.grid.vertical import VerticalGridSize, VerticalModelParams
from icon4py.model.common.states.prognostic_state import PrognosticState

from icon4pytools.py2f.cffi_utils import CffiMethod, to_fields

Expand Down Expand Up @@ -244,12 +244,14 @@ def diffusion_run(
hdef_ic: Field[[CellDim, KDim], float],
dwdx: Field[[CellDim, KDim], float],
dwdy: Field[[CellDim, KDim], float],
rho: Field[[CellDim, KDim], float],
):
diagnostic_state = DiffusionDiagnosticState(hdef_ic, div_ic, dwdx, dwdy)
prognostic_state = PrognosticState(
rho=rho,
w=w,
vn=vn,
exner_pressure=exner,
exner=exner,
theta_v=theta_v,
)
if linit:
Expand Down

0 comments on commit 276bdde

Please sign in to comment.