Skip to content

Commit

Permalink
hflx_limiter_pd_stencil_02.py: use int fields
Browse files Browse the repository at this point in the history
  • Loading branch information
halungge committed Oct 4, 2022
1 parent 5b9aa1b commit 2efb391
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 14 deletions.
12 changes: 6 additions & 6 deletions advection/src/icon4py/advection/hflx_limiter_pd_stencil_02.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from functional.ffront.decorator import field_operator, program
from functional.ffront.fbuiltins import Field, where
from functional.ffront.fbuiltins import Field, int32, where

from icon4py.common.dimension import E2C, CellDim, EdgeDim, KDim


@field_operator
def _hflx_limiter_pd_stencil_02(
refin_ctrl: Field[[EdgeDim], float],
refin_ctrl: Field[[EdgeDim], int32],
r_m: Field[[CellDim, KDim], float],
p_mflx_tracer_h_in: Field[[EdgeDim, KDim], float],
bound: float,
bound: int32,
) -> Field[[EdgeDim, KDim], float]:
p_mflx_tracer_h_out = where(
refin_ctrl == bound,
p_mflx_tracer_h_in,
where(
(p_mflx_tracer_h_in > 0.0) | (p_mflx_tracer_h_in == 0.0),
p_mflx_tracer_h_in >= 0.0,
p_mflx_tracer_h_in * r_m(E2C[0]),
p_mflx_tracer_h_in * r_m(E2C[1]),
),
Expand All @@ -38,10 +38,10 @@ def _hflx_limiter_pd_stencil_02(

@program
def hflx_limiter_pd_stencil_02(
refin_ctrl: Field[[EdgeDim], float],
refin_ctrl: Field[[EdgeDim], int32],
r_m: Field[[CellDim, KDim], float],
p_mflx_tracer_h: Field[[EdgeDim, KDim], float],
bound: float,
bound: int32,
):
_hflx_limiter_pd_stencil_02(
refin_ctrl,
Expand Down
71 changes: 63 additions & 8 deletions advection/tests/test_hflx_limiter_pd_stencil_02.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from icon4py.common.dimension import CellDim, EdgeDim, KDim
from icon4py.testutils.simple_mesh import SimpleMesh
from icon4py.testutils.utils import random_field
from icon4py.testutils.utils import constant_field, random_field


def hflx_limiter_pd_stencil_02_numpy(
Expand All @@ -42,14 +42,12 @@ def hflx_limiter_pd_stencil_02_numpy(
return p_mflx_tracer_h_out


def test_hflx_limiter_pd_stencil_02():
def test_hflx_limiter_pd_stencil_02_nowhere_matching_refin_ctl():
mesh = SimpleMesh()

refin_ctrl = random_field(mesh, EdgeDim)
bound = np.int32(7)
refin_ctrl = constant_field(mesh, 4, EdgeDim, dtype=np.int32)
r_m = random_field(mesh, CellDim, KDim)
p_mflx_tracer_h_in = random_field(mesh, EdgeDim, KDim)
p_mflx_tracer_h_out = random_field(mesh, EdgeDim, KDim)
bound = np.float64(7.0)

ref = hflx_limiter_pd_stencil_02_numpy(
mesh.e2c,
Expand All @@ -63,10 +61,67 @@ def test_hflx_limiter_pd_stencil_02():
refin_ctrl,
r_m,
p_mflx_tracer_h_in,
p_mflx_tracer_h_out,
bound,
offset_provider={
"E2C": mesh.get_e2c_offset_provider(),
},
)
assert np.allclose(p_mflx_tracer_h_out, ref)
assert np.allclose(p_mflx_tracer_h_in, ref)


def test_hflx_limiter_pd_stencil_02_everywhere_matching_refin_ctl():
mesh = SimpleMesh()
bound = np.int32(7)
refin_ctrl = constant_field(mesh, bound, EdgeDim, dtype=np.int32)
r_m = random_field(mesh, CellDim, KDim)
p_mflx_tracer_h_in = random_field(mesh, EdgeDim, KDim)

hflx_limiter_pd_stencil_02(
refin_ctrl,
r_m,
p_mflx_tracer_h_in,
bound,
offset_provider={
"E2C": mesh.get_e2c_offset_provider(),
},
)
assert np.allclose(p_mflx_tracer_h_in, p_mflx_tracer_h_in)


def test_hflx_limiter_pd_stencil_02_partly_matching_refin_ctl():
mesh = SimpleMesh()
bound = np.int32(4)
refin_ctrl = constant_field(mesh, 5, EdgeDim, dtype=np.int32)
refin_ctrl[2:6] = bound
r_m = random_field(mesh, CellDim, KDim)
p_mflx_tracer_h_in = random_field(mesh, EdgeDim, KDim)

hflx_limiter_pd_stencil_02(
refin_ctrl,
r_m,
p_mflx_tracer_h_in,
bound,
offset_provider={
"E2C": mesh.get_e2c_offset_provider(),
},
)
assert np.allclose(p_mflx_tracer_h_in, p_mflx_tracer_h_in)


def test_hflx_limiter_pd_stencil_02_everywhere_matching_refin_ctl_does_not_change_inout_arg():
mesh = SimpleMesh()
bound = np.int32(7)
refin_ctrl = constant_field(mesh, bound, EdgeDim, dtype=np.int32)
r_m = random_field(mesh, CellDim, KDim)
p_mflx_tracer_h_in = random_field(mesh, EdgeDim, KDim)

hflx_limiter_pd_stencil_02(
refin_ctrl,
r_m,
p_mflx_tracer_h_in,
bound,
offset_provider={
"E2C": mesh.get_e2c_offset_provider(),
},
)
assert np.allclose(p_mflx_tracer_h_in, p_mflx_tracer_h_in)

0 comments on commit 2efb391

Please sign in to comment.