Skip to content

Commit

Permalink
Changes to solve_nonhydro stencils 57, 58 (#356)
Browse files Browse the repository at this point in the history
Co-authored-by: Christoph Müller <[email protected]>
  • Loading branch information
abishekg7 and muellch authored Feb 23, 2024
1 parent 77465d3 commit 19962f0
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
from icon4py.model.atmosphere.dycore.update_dynamical_exner_time_increment import (
update_dynamical_exner_time_increment,
)
from icon4py.model.atmosphere.dycore.update_mass_flux import update_mass_flux
from icon4py.model.atmosphere.dycore.update_mass_volume_flux import update_mass_volume_flux
from icon4py.model.atmosphere.dycore.update_mass_flux_weighted import update_mass_flux_weighted
from icon4py.model.atmosphere.dycore.update_theta_v import update_theta_v
from icon4py.model.atmosphere.dycore.velocity.velocity_advection import VelocityAdvection
Expand Down Expand Up @@ -1927,7 +1927,7 @@ def run_corrector_step(
offset_provider={},
)
log.debug(f"corrector start stencil 58")
update_mass_flux.with_backend(backend)(
update_mass_volume_flux.with_backend(backend)(
z_contr_w_fl_l=z_fields.z_contr_w_fl_l,
rho_ic=diagnostic_state_nh.rho_ic,
vwind_impl_wgt=self.metric_state_nonhydro.vwind_impl_wgt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,45 @@


@field_operator
def _update_mass_flux(
def _update_mass_volume_flux(
z_contr_w_fl_l: Field[[CellDim, KDim], wpfloat],
rho_ic: Field[[CellDim, KDim], wpfloat],
vwind_impl_wgt: Field[[CellDim], wpfloat],
w: Field[[CellDim, KDim], wpfloat],
mass_flx_ic: Field[[CellDim, KDim], wpfloat],
vol_flx_ic: Field[[CellDim, KDim], wpfloat],
r_nsubsteps: wpfloat,
) -> Field[[CellDim, KDim], wpfloat]:
) -> tuple[Field[[CellDim, KDim], wpfloat], Field[[CellDim, KDim], wpfloat]]:
"""Formerly known as _mo_solve_nonhydro_stencil_58."""
mass_flx_ic_wp = mass_flx_ic + (r_nsubsteps * (z_contr_w_fl_l + rho_ic * vwind_impl_wgt * w))
return mass_flx_ic_wp
z_a = r_nsubsteps * (z_contr_w_fl_l + rho_ic * vwind_impl_wgt * w)
mass_flx_ic_wp = mass_flx_ic + z_a
vol_flx_ic_wp = vol_flx_ic + z_a / rho_ic
return mass_flx_ic_wp, vol_flx_ic_wp


@program(grid_type=GridType.UNSTRUCTURED)
def update_mass_flux(
def update_mass_volume_flux(
z_contr_w_fl_l: Field[[CellDim, KDim], wpfloat],
rho_ic: Field[[CellDim, KDim], wpfloat],
vwind_impl_wgt: Field[[CellDim], wpfloat],
w: Field[[CellDim, KDim], wpfloat],
mass_flx_ic: Field[[CellDim, KDim], wpfloat],
vol_flx_ic: Field[[CellDim, KDim], wpfloat],
r_nsubsteps: wpfloat,
horizontal_start: int32,
horizontal_end: int32,
vertical_start: int32,
vertical_end: int32,
):
_update_mass_flux(
_update_mass_volume_flux(
z_contr_w_fl_l,
rho_ic,
vwind_impl_wgt,
w,
mass_flx_ic,
vol_flx_ic,
r_nsubsteps,
out=mass_flx_ic,
out=(mass_flx_ic, vol_flx_ic),
domain={
CellDim: (horizontal_start, horizontal_end),
KDim: (vertical_start, vertical_end),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
import pytest
from gt4py.next.ffront.fbuiltins import int32

from icon4py.model.atmosphere.dycore.update_mass_flux import update_mass_flux
from icon4py.model.atmosphere.dycore.update_mass_volume_flux import update_mass_volume_flux
from icon4py.model.common.dimension import CellDim, KDim
from icon4py.model.common.test_utils.helpers import StencilTest, random_field
from icon4py.model.common.type_alias import wpfloat


class TestMoSolveNonhydroStencil58(StencilTest):
PROGRAM = update_mass_flux
OUTPUTS = ("mass_flx_ic",)
PROGRAM = update_mass_volume_flux
OUTPUTS = (
"mass_flx_ic",
"vol_flx_ic",
)

@staticmethod
def reference(
Expand All @@ -33,12 +36,15 @@ def reference(
vwind_impl_wgt: np.array,
w: np.array,
mass_flx_ic: np.array,
vol_flx_ic: np.array,
r_nsubsteps,
**kwargs,
) -> dict:
vwind_impl_wgt = np.expand_dims(vwind_impl_wgt, axis=-1)
mass_flx_ic = mass_flx_ic + (r_nsubsteps * (z_contr_w_fl_l + rho_ic * vwind_impl_wgt * w))
return dict(mass_flx_ic=mass_flx_ic)
z_a = r_nsubsteps * (z_contr_w_fl_l + rho_ic * vwind_impl_wgt * w)
mass_flx_ic = mass_flx_ic + z_a
vol_flx_ic = vol_flx_ic + z_a / rho_ic
return dict(mass_flx_ic=mass_flx_ic, vol_flx_ic=vol_flx_ic)

@pytest.fixture
def input_data(self, grid):
Expand All @@ -47,6 +53,7 @@ def input_data(self, grid):
vwind_impl_wgt = random_field(grid, CellDim, dtype=wpfloat)
w = random_field(grid, CellDim, KDim, dtype=wpfloat)
mass_flx_ic = random_field(grid, CellDim, KDim, dtype=wpfloat)
vol_flx_ic = random_field(grid, CellDim, KDim, dtype=wpfloat)
r_nsubsteps = 7.0

return dict(
Expand All @@ -55,6 +62,7 @@ def input_data(self, grid):
vwind_impl_wgt=vwind_impl_wgt,
w=w,
mass_flx_ic=mass_flx_ic,
vol_flx_ic=vol_flx_ic,
r_nsubsteps=r_nsubsteps,
horizontal_start=int32(0),
horizontal_end=int32(grid.num_cells),
Expand Down

0 comments on commit 19962f0

Please sign in to comment.