Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/C2SM/icon4py into fortran-s…
Browse files Browse the repository at this point in the history
…erialisation
  • Loading branch information
samkellerhals committed May 22, 2023
2 parents bee0752 + be34d20 commit a6713b1
Show file tree
Hide file tree
Showing 18 changed files with 337 additions and 627 deletions.
233 changes: 62 additions & 171 deletions atm_dyn_iconam/src/icon4py/atm_dyn_iconam/mo_nh_diffusion_stencil_15.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,188 +11,79 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from gt4py.eve import SourceLocation
from gt4py.next.ffront import program_ast as past
from gt4py.next.iterator.builtins import (
deref,
if_,
list_get,
named_range,
shift,
unstructured_domain,
)
from gt4py.next.iterator.runtime import closure, fendef, fundef
from gt4py.next.type_system import type_specifications as ts
from gt4py.next.ffront.decorator import field_operator, program
from gt4py.next.ffront.experimental import as_offset
from gt4py.next.ffront.fbuiltins import Field, int32, where

from icon4py.common.dimension import C2E2C, C2E2CDim, CellDim, KDim, Koff
from icon4py.icon4pygen.metadata import FieldInfo
from icon4py.common.dimension import C2CEC, C2E2C, CECDim, CellDim, KDim, Koff


@fundef
def step(i, geofac_n2s_nbh, vcoef, theta_v, zd_vertoffset):
d_vcoef = list_get(i, deref(vcoef))
s_theta_v = shift(C2E2C, i, Koff, list_get(i, deref(zd_vertoffset)))(theta_v)
return list_get(i, deref(geofac_n2s_nbh)) * (
d_vcoef * deref(s_theta_v) + (1.0 - d_vcoef) * deref(shift(Koff, 1)(s_theta_v))
)
@field_operator
def _mo_nh_diffusion_stencil_15(
mask: Field[[CellDim, KDim], bool],
zd_vertoffset: Field[[CECDim, KDim], int32],
zd_diffcoef: Field[[CellDim, KDim], float],
geofac_n2s_c: Field[[CellDim], float],
geofac_n2s_nbh: Field[[CECDim], float],
vcoef: Field[[CECDim, KDim], float],
theta_v: Field[[CellDim, KDim], float],
z_temp: Field[[CellDim, KDim], float],
) -> Field[[CellDim, KDim], float]:

theta_v_0 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[0])))
theta_v_1 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[1])))
theta_v_2 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[2])))

@fundef
def _mo_nh_diffusion_stencil_15(
mask,
zd_vertoffset,
zd_diffcoef,
geofac_n2s_c,
geofac_n2s_nbh,
vcoef,
theta_v,
z_temp,
):
summed = (
step(0, geofac_n2s_nbh, vcoef, theta_v, zd_vertoffset)
+ step(1, geofac_n2s_nbh, vcoef, theta_v, zd_vertoffset)
+ step(2, geofac_n2s_nbh, vcoef, theta_v, zd_vertoffset)
theta_v_0_m1 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[0]) + int32(1)))
theta_v_1_m1 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[1]) + int32(1)))
theta_v_2_m1 = theta_v(as_offset(Koff, zd_vertoffset(C2CEC[2]) + int32(1)))

sum_over_neighbors = (
geofac_n2s_nbh(C2CEC[0])
* (
vcoef(C2CEC[0]) * theta_v_0(C2E2C[0])
+ (1.0 - vcoef(C2CEC[0])) * theta_v_0_m1(C2E2C[0])
)
+ geofac_n2s_nbh(C2CEC[1])
* (
vcoef(C2CEC[1]) * theta_v_1(C2E2C[1])
+ (1.0 - vcoef(C2CEC[1])) * theta_v_1_m1(C2E2C[1])
)
+ geofac_n2s_nbh(C2CEC[2])
* (
vcoef(C2CEC[2]) * theta_v_2(C2E2C[2])
+ (1.0 - vcoef(C2CEC[2])) * theta_v_2_m1(C2E2C[2])
)
)
update = deref(z_temp) + deref(zd_diffcoef) * (
deref(theta_v) * deref(geofac_n2s_c) + summed

z_temp = where(
mask,
z_temp + zd_diffcoef * (theta_v * geofac_n2s_c + sum_over_neighbors),
z_temp,
)

return if_(deref(mask), update, deref(z_temp))
return z_temp


@fendef
@program
def mo_nh_diffusion_stencil_15(
mask,
zd_vertoffset,
zd_diffcoef,
geofac_n2s_c,
geofac_n2s_nbh,
vcoef,
theta_v,
z_temp,
horizontal_start,
horizontal_end,
vertical_start,
vertical_end,
mask: Field[[CellDim, KDim], bool],
zd_vertoffset: Field[[CECDim, KDim], int32],
zd_diffcoef: Field[[CellDim, KDim], float],
geofac_n2s_c: Field[[CellDim], float],
geofac_n2s_nbh: Field[[CECDim], float],
vcoef: Field[[CECDim, KDim], float],
theta_v: Field[[CellDim, KDim], float],
z_temp: Field[[CellDim, KDim], float],
):
closure(
unstructured_domain(
named_range(CellDim, horizontal_start, horizontal_end),
named_range(KDim, vertical_start, vertical_end),
),
_mo_nh_diffusion_stencil_15,
_mo_nh_diffusion_stencil_15(
mask,
zd_vertoffset,
zd_diffcoef,
geofac_n2s_c,
geofac_n2s_nbh,
vcoef,
theta_v,
z_temp,
[
mask,
zd_vertoffset,
zd_diffcoef,
geofac_n2s_c,
geofac_n2s_nbh,
vcoef,
theta_v,
z_temp,
],
out=z_temp,
)


_dummy_loc = SourceLocation(1, 1, "")
_metadata = {
"mask": FieldInfo(
field=past.FieldSymbol(
id="mask",
type=ts.FieldType(
dims=[CellDim, KDim], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL)
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"zd_vertoffset": FieldInfo(
field=past.FieldSymbol(
id="zd_vertoffset",
type=ts.FieldType(
dims=[CellDim, C2E2CDim, KDim],
dtype=ts.ScalarType(kind=ts.ScalarKind.INT32),
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"zd_diffcoef": FieldInfo(
field=past.FieldSymbol(
id="zd_diffcoef",
type=ts.FieldType(
dims=[CellDim, KDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"geofac_n2s_c": FieldInfo(
field=past.FieldSymbol(
id="geofac_n2s_c",
type=ts.FieldType(
dims=[CellDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"geofac_n2s_nbh": FieldInfo(
field=past.FieldSymbol(
id="geofac_n2s_nbh",
type=ts.FieldType(
dims=[CellDim, C2E2CDim],
dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64),
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"vcoef": FieldInfo(
field=past.FieldSymbol(
id="vcoef",
type=ts.FieldType(
dims=[CellDim, C2E2CDim, KDim],
dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64),
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"theta_v": FieldInfo(
field=past.FieldSymbol(
id="theta_v",
type=ts.FieldType(
dims=[CellDim, KDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
),
location=_dummy_loc,
),
inp=True,
out=False,
),
"z_temp": FieldInfo(
field=past.FieldSymbol(
id="z_temp",
type=ts.FieldType(
dims=[CellDim, KDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
),
location=_dummy_loc,
),
inp=True,
out=True,
),
}

# patch the fendef with metainfo for icon4pygen
mo_nh_diffusion_stencil_15.__dict__["offsets"] = [
Koff.value,
C2E2C.value,
] # could be done with a pass...
mo_nh_diffusion_stencil_15.__dict__["metadata"] = _metadata
Loading

0 comments on commit a6713b1

Please sign in to comment.