Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Index stencils field view implementation #199

Merged
merged 14 commits into from
May 15, 2023
232 changes: 62 additions & 170 deletions atm_dyn_iconam/src/icon4py/atm_dyn_iconam/mo_nh_diffusion_stencil_15.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,187 +11,79 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.eve import SourceLocation
from gt4py.next.ffront import program_ast as past
from gt4py.next.iterator.builtins import (
deref,
if_,
list_get,
named_range,
shift,
unstructured_domain,
)
from gt4py.next.iterator.runtime import closure, fendef, fundef
from gt4py.next.type_system import type_specifications as ts
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.experimental import as_offset
from gt4py.next.ffront.fbuiltins import Field, int32, where

from icon4py.common.dimension import C2E2C, C2E2CDim, CellDim, KDim, Koff
from icon4py.icon4pygen.metadata import FieldInfo
from icon4py.common.dimension import C2CEC, C2E2C, CECDim, CellDim, KDim, Koff


def step(i, geofac_n2s_nbh, vcoef, theta_v, zd_vertoffset):
d_vcoef = list_get(i, deref(vcoef))
s_theta_v = shift(C2E2C, i, Koff, list_get(i, deref(zd_vertoffset)))(theta_v)
return list_get(i, deref(geofac_n2s_nbh)) * (
d_vcoef * deref(s_theta_v) + (1.0 - d_vcoef) * deref(shift(Koff, 1)(s_theta_v))
)
@field_operator
def _mo_nh_diffusion_stencil_15(
mask: Field[[CellDim, KDim], bool],
zd_vertoffset: Field[[CECDim, KDim], int32],
zd_diffcoef: Field[[CellDim, KDim], float],
geofac_n2s_c: Field[[CellDim], float],
geofac_n2s_nbh: Field[[CECDim], float],
vcoef: Field[[CECDim, KDim], float],
theta_v: Field[[CellDim, KDim], float],
z_temp: Field[[CellDim, KDim], float],
) -> Field[[CellDim, KDim], float]:

theta_v_0 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[0])))
theta_v_1 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[1])))
theta_v_2 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[2])))

@fundef
def _mo_nh_diffusion_stencil_15(
mask,
zd_vertoffset,
zd_diffcoef,
geofac_n2s_c,
geofac_n2s_nbh,
vcoef,
theta_v,
z_temp,
):
summed = (
step(0, geofac_n2s_nbh, vcoef, theta_v, zd_vertoffset)
+ step(1, geofac_n2s_nbh, vcoef, theta_v, zd_vertoffset)
+ step(2, geofac_n2s_nbh, vcoef, theta_v, zd_vertoffset)
theta_v_0_m1 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[0]) + int32(1)))
theta_v_1_m1 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[1]) + int32(1)))
theta_v_2_m1 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[2]) + int32(1)))

sum_over_neighbors = (
geofac_n2s_nbh(C2CEC[0])
* (
vcoef(C2CEC[0]) * theta_v_0(C2E2C[0])
+ (1.0 - vcoef(C2CEC[0])) * theta_v_0_m1(C2E2C[0])
)
+ geofac_n2s_nbh(C2CEC[1])
* (
vcoef(C2CEC[1]) * theta_v_1(C2E2C[1])
+ (1.0 - vcoef(C2CEC[1])) * theta_v_1_m1(C2E2C[1])
)
+ geofac_n2s_nbh(C2CEC[2])
* (
vcoef(C2CEC[2]) * theta_v_2(C2E2C[2])
+ (1.0 - vcoef(C2CEC[2])) * theta_v_2_m1(C2E2C[2])
)
)
update = deref(z_temp) + deref(zd_diffcoef) * (
deref(theta_v) * deref(geofac_n2s_c) + summed

z_temp = where(
mask,
z_temp + zd_diffcoef * (theta_v * geofac_n2s_c + sum_over_neighbors),
z_temp,
)

return if_(deref(mask), update, deref(z_temp))
return z_temp


@fendef
@program
def mo_nh_diffusion_stencil_15(
mask,
zd_vertoffset,
zd_diffcoef,
geofac_n2s_c,
geofac_n2s_nbh,
vcoef,
theta_v,
z_temp,
horizontal_start,
horizontal_end,
vertical_start,
vertical_end,
mask: Field[[CellDim, KDim], bool],
zd_vertoffset: Field[[CECDim, KDim], int32],
zd_diffcoef: Field[[CellDim, KDim], float],
geofac_n2s_c: Field[[CellDim], float],
geofac_n2s_nbh: Field[[CECDim], float],
vcoef: Field[[CECDim, KDim], float],
theta_v: Field[[CellDim, KDim], float],
z_temp: Field[[CellDim, KDim], float],
):
closure(
unstructured_domain(
named_range(CellDim, horizontal_start, horizontal_end),
named_range(KDim, vertical_start, vertical_end),
),
_mo_nh_diffusion_stencil_15,
_mo_nh_diffusion_stencil_15(
mask,
zd_vertoffset,
zd_diffcoef,
geofac_n2s_c,
geofac_n2s_nbh,
vcoef,
theta_v,
z_temp,
[
mask,
zd_vertoffset,
zd_diffcoef,
geofac_n2s_c,
geofac_n2s_nbh,
vcoef,
theta_v,
z_temp,
],
out=z_temp,
)


_dummy_loc = SourceLocation(1, 1, "")
_metadata = {
"mask": FieldInfo(
field=past.FieldSymbol(
id="mask",
type=ts.FieldType(
dims=[CellDim, KDim], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL)
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"zd_vertoffset": FieldInfo(
field=past.FieldSymbol(
id="zd_vertoffset",
type=ts.FieldType(
dims=[CellDim, C2E2CDim, KDim],
dtype=ts.ScalarType(kind=ts.ScalarKind.INT32),
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"zd_diffcoef": FieldInfo(
field=past.FieldSymbol(
id="zd_diffcoef",
type=ts.FieldType(
dims=[CellDim, KDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"geofac_n2s_c": FieldInfo(
field=past.FieldSymbol(
id="geofac_n2s_c",
type=ts.FieldType(
dims=[CellDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"geofac_n2s_nbh": FieldInfo(
field=past.FieldSymbol(
id="geofac_n2s_nbh",
type=ts.FieldType(
dims=[CellDim, C2E2CDim],
dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64),
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"vcoef": FieldInfo(
field=past.FieldSymbol(
id="vcoef",
type=ts.FieldType(
dims=[CellDim, C2E2CDim, KDim],
dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64),
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"theta_v": FieldInfo(
field=past.FieldSymbol(
id="theta_v",
type=ts.FieldType(
dims=[CellDim, KDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"z_temp": FieldInfo(
field=past.FieldSymbol(
id="z_temp",
type=ts.FieldType(
dims=[CellDim, KDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
),
location=_dummy_loc,
),
inp=True,
out=True,
),
}

# patch the fendef with metainfo for icon4pygen
mo_nh_diffusion_stencil_15.__dict__["offsets"] = [
Koff.value,
C2E2C.value,
] # could be done with a pass...
mo_nh_diffusion_stencil_15.__dict__["metadata"] = _metadata
Loading