Skip to content

Commit

Permalink
changes order of args and return a tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
halungge committed Sep 16, 2022
1 parent 916a1f7 commit 99638df
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 22 deletions.
22 changes: 11 additions & 11 deletions advection/src/icon4py/advection/hflx_limiter_pd_stencil_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@

@field_operator
def _hflx_limiter_pd_stencil_01(
p_dtime: float,
dbl_eps: float,
geofac_div: Field[[CEDim], float],
p_cc: Field[[CellDim, KDim], float],
p_rhodz_now: Field[[CellDim, KDim], float],
p_mflx_tracer_h: Field[[EdgeDim, KDim], float],
p_dtime: float,
dbl_eps: float,
) -> Field[[CellDim, KDim], float]:
) -> tuple[Field[[CellDim, KDim], float]]:
zero = broadcast(0.0, (CellDim, KDim))

pm_0 = maximum(zero, p_mflx_tracer_h(C2E[0]) * geofac_div(C2CE[0]) * p_dtime)
pm_1 = maximum(zero, p_mflx_tracer_h(C2E[1]) * geofac_div(C2CE[1]) * p_dtime)
pm_2 = maximum(zero, p_mflx_tracer_h(C2E[2]) * geofac_div(C2CE[2]) * p_dtime)
Expand All @@ -36,25 +35,26 @@ def _hflx_limiter_pd_stencil_01(
broadcast(1.0, (CellDim, KDim)), (p_cc * p_rhodz_now) / (p_m + dbl_eps)
)

return r_m
return r_m, p_m


@program
def hflx_limiter_pd_stencil_01(
p_dtime: float,
dbl_eps: float,
geofac_div: Field[[CEDim], float],
p_cc: Field[[CellDim, KDim], float],
p_rhodz_now: Field[[CellDim, KDim], float],
p_mflx_tracer_h: Field[[EdgeDim, KDim], float],
r_m: Field[[CellDim, KDim], float],
p_dtime: float,
dbl_eps: float,
):
p_m: Field[[CellDim, KDim], float]
):
_hflx_limiter_pd_stencil_01(
p_dtime,
dbl_eps,
geofac_div,
p_cc,
p_rhodz_now,
p_mflx_tracer_h,
p_dtime,
dbl_eps,
out=r_m,
out=(r_m, p_m),
)
25 changes: 14 additions & 11 deletions advection/tests/test_hflx_limiter_pd_stencil_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@

def hflx_limiter_pd_stencil_01_numpy(
c2e: np.array,
p_dtime: float,
dbl_eps: float,
geofac_div: np.array,
p_cc: np.array,
p_rhodz_now: np.array,
p_mflx_tracer_h: np.array,
p_dtime,
dbl_eps,
):
)-> tuple[np.ndarray]:
geofac_div = np.expand_dims(geofac_div, axis=-1)
p_m_0 = np.maximum(0.0, p_mflx_tracer_h[c2e[:, 0]] * geofac_div[:, 0] * p_dtime)
p_m_1 = np.maximum(0.0, p_mflx_tracer_h[c2e[:, 1]] * geofac_div[:, 1] * p_dtime)
Expand All @@ -39,7 +39,7 @@ def hflx_limiter_pd_stencil_01_numpy(
p_m = p_m_0 + p_m_1 + p_m_2
r_m = np.minimum(1.0, p_cc * p_rhodz_now / (p_m + dbl_eps))

return r_m
return r_m, p_m


def test_hflx_limiter_pd_stencil_01():
Expand All @@ -49,30 +49,33 @@ def test_hflx_limiter_pd_stencil_01():
p_rhodz_now = random_field(mesh, CellDim, KDim)
p_mflx_tracer_h = random_field(mesh, EdgeDim, KDim)
r_m = zero_field(mesh, CellDim, KDim)
p_m = zero_field(mesh, CellDim, KDim)
p_dtime = np.float64(5)
dbl_eps = np.float64(1e-9)

ref = hflx_limiter_pd_stencil_01_numpy(
ref_r_m, ref_p_m = hflx_limiter_pd_stencil_01_numpy(
mesh.c2e,
p_dtime,
dbl_eps,
np.asarray(geofac_div),
np.asarray(p_cc),
np.asarray(p_rhodz_now),
np.asarray(p_mflx_tracer_h),
p_dtime,
dbl_eps,
np.asarray(p_mflx_tracer_h)
)

hflx_limiter_pd_stencil_01(
p_dtime,
dbl_eps,
as_1D_sparse_field(geofac_div, CEDim),
p_cc,
p_rhodz_now,
p_mflx_tracer_h,
r_m,
p_dtime,
dbl_eps,
p_m,
offset_provider={
"C2CE": StridedNeighborOffsetProvider(CellDim, CEDim, mesh.n_c2e),
"C2E": mesh.get_c2e_offset_provider(),
},
)
assert np.allclose(r_m, ref)
assert np.allclose(r_m, ref_r_m)
assert np.allclose(p_m, ref_p_m)

0 comments on commit 99638df

Please sign in to comment.