Skip to content

Commit

Permalink
remove zero stencils from dycore/state_utils/utils.py (#425)
Browse files Browse the repository at this point in the history
- remove zero stencils from dycore/state_utils/utils.py
- remove idiv_method from SolveNonHydroConfig
  • Loading branch information
halungge authored Apr 2, 2024
1 parent a9b799b commit e68db25
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 180 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@

import icon4py.model.atmosphere.dycore.nh_solve.solve_nonhydro_program as nhsolve_prog
import icon4py.model.common.constants as constants
from icon4py.model.atmosphere.dycore.set_cell_kdim_field_to_zero_wp import (
set_cell_kdim_field_to_zero_wp,
)

from icon4py.model.atmosphere.dycore.accumulate_prep_adv_fields import (
accumulate_prep_adv_fields,
)
Expand Down Expand Up @@ -136,8 +140,6 @@
_allocate_indices,
_calculate_divdamp_fields,
compute_z_raylfac,
set_zero_c_k,
set_zero_e_k,
)
from icon4py.model.atmosphere.dycore.update_dynamical_exner_time_increment import (
update_dynamical_exner_time_increment,
Expand Down Expand Up @@ -244,7 +246,6 @@ def __init__(
rayleigh_type: int = 2,
rayleigh_coeff: float = 0.05,
divdamp_order: int = 24, # the ICON default is 4,
idiv_method: int = 1,
is_iau_active: bool = False,
iau_wgt_dyn: float = 0.0,
divdamp_type: int = 3,
Expand Down Expand Up @@ -328,9 +329,6 @@ def __init__(
#: IAU weight for dynamics fields
self.iau_wgt_dyn: float = iau_wgt_dyn

#: from mo_dynamics_nml.f90
self.idiv_method: int = idiv_method

self._validate()

def _validate(self):
Expand All @@ -345,9 +343,6 @@ def _validate(self):
if self.itime_scheme != 4:
raise NotImplementedError("itime_scheme can only be 4")

if self.idiv_method != 1:
raise NotImplementedError("idiv_method can only be 1")

if self.divdamp_order != 24:
raise NotImplementedError("divdamp_order can only be 24")

Expand Down Expand Up @@ -709,6 +704,9 @@ def run_predictor_step(
start_edge_lb_plus4 = self.grid.get_start_index(
EdgeDim, HorizontalMarkerIndex.lateral_boundary(EdgeDim) + 4
)
start_edge_local_minus2 = self.grid.get_start_index(
EdgeDim, HorizontalMarkerIndex.local(EdgeDim) - 2
)
end_edge_local_minus2 = self.grid.get_end_index(
EdgeDim, HorizontalMarkerIndex.local(EdgeDim) - 2
)
Expand Down Expand Up @@ -898,50 +896,26 @@ def run_predictor_step(
},
)
if self.config.iadv_rhotheta <= 2:
tmp_0_0 = self.grid.get_start_index(EdgeDim, HorizontalMarkerIndex.local(EdgeDim) - 2)
offset = 2 if self.config.idiv_method == 1 else 3
tmp_0_1 = self.grid.get_end_index(
EdgeDim, HorizontalMarkerIndex.local(EdgeDim) - offset
)

set_zero_e_k.with_backend(backend)(
field=z_fields.z_rho_e,
horizontal_start=tmp_0_0,
horizontal_end=tmp_0_1,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

set_zero_e_k.with_backend(backend)(
field=z_fields.z_theta_v_e,
horizontal_start=tmp_0_0,
horizontal_end=tmp_0_1,
set_two_edge_kdim_fields_to_zero_wp.with_backend(backend)(
edge_kdim_field_to_zero_wp_1=z_fields.z_rho_e,
edge_kdim_field_to_zero_wp_2=z_fields.z_theta_v_e,
horizontal_start=start_edge_local_minus2,
horizontal_end=end_edge_local_minus2,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

# initialize also nest boundary points with zero
if self.grid.limited_area:
set_zero_e_k.with_backend(backend)(
field=z_fields.z_rho_e,
horizontal_start=start_edge_lb,
horizontal_end=end_edge_local_minus1,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

set_zero_e_k.with_backend(backend)(
field=z_fields.z_theta_v_e,
set_two_edge_kdim_fields_to_zero_wp.with_backend(backend)(
edge_kdim_field_to_zero_wp_1=z_fields.z_rho_e,
edge_kdim_field_to_zero_wp_2=z_fields.z_theta_v_e,
horizontal_start=start_edge_lb,
horizontal_end=end_edge_local_minus1,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

if self.config.iadv_rhotheta == 2:
# Compute upwind-biased values for rho and theta starting from centered differences
# Note: the length of the backward trajectory should be 0.5*dtime*(vn,vt) in order to arrive
Expand Down Expand Up @@ -1134,20 +1108,19 @@ def run_predictor_step(
},
)

if self.config.idiv_method == 1:
compute_mass_flux.with_backend(backend)(
z_rho_e=z_fields.z_rho_e,
z_vn_avg=self.z_vn_avg,
ddqz_z_full_e=self.metric_state_nonhydro.ddqz_z_full_e,
z_theta_v_e=z_fields.z_theta_v_e,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
z_theta_v_fl_e=self.z_theta_v_fl_e,
horizontal_start=start_edge_lb_plus4,
horizontal_end=end_edge_local_minus2,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)
compute_mass_flux.with_backend(backend)(
z_rho_e=z_fields.z_rho_e,
z_vn_avg=self.z_vn_avg,
ddqz_z_full_e=self.metric_state_nonhydro.ddqz_z_full_e,
z_theta_v_e=z_fields.z_theta_v_e,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
z_theta_v_fl_e=self.z_theta_v_fl_e,
horizontal_start=start_edge_lb_plus4,
horizontal_end=end_edge_local_minus2,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

nhsolve_prog.predictor_stencils_35_36.with_backend(backend)(
vn=prognostic_state[nnew].vn,
Expand Down Expand Up @@ -1203,22 +1176,21 @@ def run_predictor_step(
},
)

if self.config.idiv_method == 1:
compute_divergence_of_fluxes_of_rho_and_theta.with_backend(backend)(
geofac_div=self.interpolation_state.geofac_div,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
z_theta_v_fl_e=self.z_theta_v_fl_e,
z_flxdiv_mass=self.z_flxdiv_mass,
z_flxdiv_theta=self.z_flxdiv_theta,
horizontal_start=start_cell_nudging,
horizontal_end=end_cell_local,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={
"C2E": self.grid.get_offset_provider("C2E"),
"C2CE": self.grid.get_offset_provider("C2CE"),
},
)
compute_divergence_of_fluxes_of_rho_and_theta.with_backend(backend)(
geofac_div=self.interpolation_state.geofac_div,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
z_theta_v_fl_e=self.z_theta_v_fl_e,
z_flxdiv_mass=self.z_flxdiv_mass,
z_flxdiv_theta=self.z_flxdiv_theta,
horizontal_start=start_cell_nudging,
horizontal_end=end_cell_local,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={
"C2E": self.grid.get_offset_provider("C2E"),
"C2CE": self.grid.get_offset_provider("C2CE"),
},
)

nhsolve_prog.stencils_43_44_45_45b.with_backend(backend)(
z_w_expl=z_fields.z_w_expl,
Expand Down Expand Up @@ -1686,50 +1658,48 @@ def run_corrector_step(
},
)

if self.config.idiv_method == 1:
log.debug("corrector: start stencil 32")
compute_mass_flux.with_backend(backend)(
z_rho_e=z_fields.z_rho_e,
log.debug("corrector: start stencil 32")
compute_mass_flux.with_backend(backend)(
z_rho_e=z_fields.z_rho_e,
z_vn_avg=self.z_vn_avg,
ddqz_z_full_e=self.metric_state_nonhydro.ddqz_z_full_e,
z_theta_v_e=z_fields.z_theta_v_e,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
z_theta_v_fl_e=self.z_theta_v_fl_e,
horizontal_start=start_edge_lb_plus4,
horizontal_end=end_edge_local_minus2, # TODO: (halungge) this is actually the second halo line
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

if lprep_adv: # Preparations for tracer advection
log.debug("corrector: doing prep advection")
if lclean_mflx:
log.debug("corrector: start stencil 33")
set_two_edge_kdim_fields_to_zero_wp.with_backend(backend)(
edge_kdim_field_to_zero_wp_1=prep_adv.vn_traj,
edge_kdim_field_to_zero_wp_2=prep_adv.mass_flx_me,
horizontal_start=start_edge_lb,
horizontal_end=end_edge_end,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)
log.debug(f"corrector: start stencil 34")
accumulate_prep_adv_fields.with_backend(backend)(
z_vn_avg=self.z_vn_avg,
ddqz_z_full_e=self.metric_state_nonhydro.ddqz_z_full_e,
z_theta_v_e=z_fields.z_theta_v_e,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
z_theta_v_fl_e=self.z_theta_v_fl_e,
vn_traj=prep_adv.vn_traj,
mass_flx_me=prep_adv.mass_flx_me,
r_nsubsteps=r_nsubsteps,
horizontal_start=start_edge_lb_plus4,
horizontal_end=end_edge_local_minus2,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

if lprep_adv: # Preparations for tracer advection
log.debug("corrector: doing prep advection")
if lclean_mflx:
log.debug("corrector: start stencil 33")
set_two_edge_kdim_fields_to_zero_wp.with_backend(backend)(
edge_kdim_field_to_zero_wp_1=prep_adv.vn_traj,
edge_kdim_field_to_zero_wp_2=prep_adv.mass_flx_me,
horizontal_start=start_edge_lb,
horizontal_end=end_edge_end,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)
log.debug(f"corrector: start stencil 34")
accumulate_prep_adv_fields.with_backend(backend)(
z_vn_avg=self.z_vn_avg,
mass_fl_e=diagnostic_state_nh.mass_fl_e,
vn_traj=prep_adv.vn_traj,
mass_flx_me=prep_adv.mass_flx_me,
r_nsubsteps=r_nsubsteps,
horizontal_start=start_edge_lb_plus4,
horizontal_end=end_edge_local_minus2,
vertical_start=0,
vertical_end=self.grid.num_levels,
offset_provider={},
)

if self.config.idiv_method == 1:
# verified for e-9
log.debug(f"corrector: start stencile 41")
compute_divergence_of_fluxes_of_rho_and_theta.with_backend(backend)(
Expand Down Expand Up @@ -1961,7 +1931,7 @@ def run_corrector_step(
r_nsubsteps=r_nsubsteps,
horizontal_start=start_cell_nudging,
horizontal_end=end_cell_local,
vertical_start=0,
vertical_start=1,
vertical_end=self.grid.num_levels,
offset_provider={},
)
Expand All @@ -1982,8 +1952,8 @@ def run_corrector_step(
if lprep_adv:
if lclean_mflx:
log.debug(f"corrector set prep_adv.mass_flx_ic to zero")
set_zero_c_k.with_backend(backend)(
field=prep_adv.mass_flx_ic,
set_cell_kdim_field_to_zero_wp.with_backend(backend)(
field_to_zero_wp=prep_adv.mass_flx_ic,
horizontal_start=start_cell_lb,
horizontal_end=end_cell_nudging,
vertical_start=0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,18 @@
from icon4py.model.atmosphere.dycore.set_cell_kdim_field_to_zero_vp import (
_set_cell_kdim_field_to_zero_vp,
)
from icon4py.model.atmosphere.dycore.set_cell_kdim_field_to_zero_wp import (
_set_cell_kdim_field_to_zero_wp,
)
from icon4py.model.atmosphere.dycore.set_lower_boundary_condition_for_w_and_contravariant_correction import (
_set_lower_boundary_condition_for_w_and_contravariant_correction,
)
from icon4py.model.atmosphere.dycore.set_theta_v_prime_ic_at_lower_boundary import (
_set_theta_v_prime_ic_at_lower_boundary,
)
from icon4py.model.atmosphere.dycore.state_utils.utils import _set_zero_c_k, _set_zero_e_k
from icon4py.model.atmosphere.dycore.state_utils.utils import (
_broadcast_zero_to_three_edge_kdim_fields_wp,
)
from icon4py.model.atmosphere.dycore.update_densety_exener_wind import _update_densety_exener_wind
from icon4py.model.atmosphere.dycore.update_wind import _update_wind
from icon4py.model.common.dimension import CEDim, CellDim, ECDim, EdgeDim, KDim
Expand All @@ -92,19 +97,12 @@ def init_test_fields(
indices_cells_2: int32,
nlev: int32,
):
_set_zero_e_k(
out=z_rho_e,
domain={EdgeDim: (indices_edges_1, indices_edges_2), KDim: (0, nlev)},
)
_set_zero_e_k(
out=z_theta_v_e,
_broadcast_zero_to_three_edge_kdim_fields_wp(
out=(z_rho_e, z_theta_v_e, z_graddiv_vn),
domain={EdgeDim: (indices_edges_1, indices_edges_2), KDim: (0, nlev)},
)
_set_zero_e_k(
out=z_graddiv_vn,
domain={EdgeDim: (indices_edges_1, indices_edges_2), KDim: (0, nlev)},
)
_set_zero_c_k(

_set_cell_kdim_field_to_zero_wp(
out=z_dwdz_dd,
domain={CellDim: (indices_cells_1, indices_cells_2), KDim: (0, nlev)},
)
Expand All @@ -125,7 +123,7 @@ def _predictor_stencils_2_3(
_extrapolate_temporally_exner_pressure(exner_exfac, exner, exner_ref_mc, exner_pr),
(z_exner_ex_pr, exner_pr),
)
z_exner_ex_pr = where(k_field == nlev, _set_zero_c_k(), z_exner_ex_pr)
z_exner_ex_pr = where(k_field == nlev, _set_cell_kdim_field_to_zero_wp(), z_exner_ex_pr)

return z_exner_ex_pr, exner_pr

Expand Down
Loading

0 comments on commit e68db25

Please sign in to comment.