Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new itir.Program everywhere #596

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base-requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# VCS
-e git+https://github.com/GridTools/gt4py.git@icon4py_20241113#egg=gt4py # use tagged release until #596 & gt4py#1738 is merged
-e git+https://github.com/havogt/gt4py.git@increase_inlining_loop#egg=gt4py
git+https://github.com/GridTools/serialbox#egg=serialbox&subdirectory=src/serialbox-python

# PyPI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def reference(
**kwargs,
) -> dict:
c2e = grid.connectivities[dims.C2EDim]
c2ce = grid.get_offset_provider("C2CE").table
c2ce = grid.get_offset_provider("C2CE").ndarray

geofac_div = np.expand_dims(geofac_div, axis=-1)
e_bln_c_s = np.expand_dims(e_bln_c_s, axis=-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import gt4py.next as gtx
from gt4py.next.ffront.decorator import GridType, field_operator, program
from gt4py.next.ffront.fbuiltins import where

from icon4py.model.atmosphere.dycore.stencils.compute_contravariant_correction_of_w import (
Expand All @@ -19,7 +18,7 @@
from icon4py.model.common.type_alias import vpfloat, wpfloat


@field_operator
@gtx.field_operator
def _fused_solve_nonhydro_stencil_39_40(
e_bln_c_s: gtx.Field[gtx.Dims[dims.CEDim], wpfloat],
z_w_concorr_me: fa.EdgeKField[vpfloat],
Expand All @@ -39,7 +38,7 @@ def _fused_solve_nonhydro_stencil_39_40(
return w_concorr_c


@program(grid_type=GridType.UNSTRUCTURED)
@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED)
def fused_solve_nonhydro_stencil_39_40(
e_bln_c_s: gtx.Field[gtx.Dims[dims.CEDim], wpfloat],
z_w_concorr_me: fa.EdgeKField[vpfloat],
Expand Down
10 changes: 5 additions & 5 deletions model/common/src/icon4py/model/common/grid/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,18 @@ def _get_offset_provider(self, dim, from_dim, to_dim):
), 'Neighbor table\'s "{}" data type must be gtx.int32. Instead it\'s "{}"'.format(
dim, self.connectivities[dim].dtype
)
return gtx.NeighborTableOffsetProvider(
self.connectivities[dim],
from_dim,
return gtx.as_connectivity(
[from_dim, dim],
to_dim,
self.size[dim],
has_skip_values=self._has_skip_values(dim),
self.connectivities[dim],
skip_value=-1 if self._has_skip_values(dim) else None,
)

def _get_offset_provider_for_sparse_fields(self, dim, from_dim, to_dim):
if dim not in self.connectivities:
raise MissingConnectivity()
return grid_utils.neighbortable_offset_provider_for_1d_sparse_fields(
dim,
self.connectivities[dim].shape,
from_dim,
to_dim,
Expand Down
12 changes: 6 additions & 6 deletions model/common/src/icon4py/model/common/grid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
# SPDX-License-Identifier: BSD-3-Clause
import gt4py.next as gtx
import numpy as np
from gt4py.next import Dimension, NeighborTableOffsetProvider
from gt4py.next import Dimension


def neighbortable_offset_provider_for_1d_sparse_fields(
dim: Dimension,
old_shape: tuple[int, int],
origin_axis: Dimension,
neighbor_axis: Dimension,
Expand All @@ -22,10 +23,9 @@ def neighbortable_offset_provider_for_1d_sparse_fields(
), 'Neighbor table\'s ("{}" to "{}") data type for 1d sparse fields must be gtx.int32. Instead it\'s "{}"'.format(
origin_axis, neighbor_axis, table.dtype
)
return NeighborTableOffsetProvider(
table,
origin_axis,
return gtx.as_connectivity(
[origin_axis, dim],
neighbor_axis,
table.shape[1],
has_skip_values=has_skip_values,
table,
skip_value=-1 if has_skip_values else None,
)
9 changes: 5 additions & 4 deletions model/common/src/icon4py/model/common/states/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TODO: @halungge: allow to read configuration data

"""

import collections
import enum
import functools
Expand Down Expand Up @@ -344,8 +345,8 @@ def _get_offset_providers(self, grid: icon_grid.IconGrid) -> dict[str, gtx.Field
horizontal_offsets = {
k: v
for k, v in grid.offset_providers.items()
if isinstance(v, gtx.NeighborTableOffsetProvider)
and v.origin_axis.kind == gtx.DimensionKind.HORIZONTAL
if isinstance(v, gtx.Connectivity)
and v.domain.dims[0].kind == gtx.DimensionKind.HORIZONTAL
}
offset_providers.update(horizontal_offsets)
if dim.kind == gtx.DimensionKind.VERTICAL:
Expand Down Expand Up @@ -445,8 +446,8 @@ def _get_offset_providers(self, grid: icon_grid.IconGrid) -> dict[str, gtx.Field
horizontal_offsets = {
k: v
for k, v in grid.offset_providers.items()
if isinstance(v, gtx.NeighborTableOffsetProvider)
and v.origin_axis.kind == gtx.DimensionKind.HORIZONTAL
if isinstance(v, gtx.Connectivity)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to remove this workaround with PR 1746...

and v.domain.dims[0].kind == gtx.DimensionKind.HORIZONTAL
}
offset_providers.update(horizontal_offsets)
if dim.kind == gtx.DimensionKind.VERTICAL:
Expand Down
33 changes: 20 additions & 13 deletions tools/src/icon4pytools/common/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@
from typing import Any, TypeGuard

import icon4py.model.common.dimension
from gt4py import eve
from gt4py.next.common import Connectivity, Dimension, DimensionKind
from gt4py import eve, next as gtx
from gt4py.next.common import (
_DEFAULT_SKIP_VALUE,
Connectivity,
Dimension,
DimensionKind,
NeighborConnectivityType,
)
from gt4py.next.ffront import program_ast as past
from gt4py.next.ffront.decorator import FieldOperator, Program, program
from gt4py.next.ffront.fbuiltins import FieldOffset
Expand All @@ -36,7 +42,7 @@

@dataclass(frozen=True)
class StencilInfo:
fendef: itir.FencilDefinition
program: itir.Program
fields: dict[str, FieldInfo]
connectivity_chains: list[eve.concepts.SymbolRef]
offset_provider: dict
Expand Down Expand Up @@ -115,7 +121,7 @@ def _get_field_infos(fvprog: Program) -> dict[str, FieldInfo]:
return fields


def _provide_offset(offset: str, is_global: bool = False) -> DummyConnectivity | Dimension:
def _provide_offset(offset: str, is_global: bool = False) -> NeighborConnectivityType | Dimension:
if offset == dims.Koff.value:
assert len(dims.Koff.target) == 1
assert dims.Koff.source == dims.Koff.target[0]
Expand All @@ -124,8 +130,8 @@ def _provide_offset(offset: str, is_global: bool = False) -> DummyConnectivity |
return _provide_neighbor_table(offset, is_global)


def _provide_neighbor_table(chain: str, is_global: bool) -> DummyConnectivity:
"""Build an offset provider based on connectivity chain string.
def _provide_neighbor_table(chain: str, is_global: bool) -> NeighborConnectivityType:
"""Build a NeighborConnectivityType based on connectivity chain string.

Connectivity strings must contain one of the following connectivity type identifiers:
C (cell), E (Edge), V (Vertex) and be separated by a '2' e.g. 'E2V'. If the origin is to
Expand Down Expand Up @@ -154,11 +160,12 @@ def _provide_neighbor_table(chain: str, is_global: bool) -> DummyConnectivity:
}
location_chain: list[Dimension] = [map_to_dim.get(c) for c in chain if c not in ("2", "O")] # type: ignore[misc] # type specified

return DummyConnectivity(
return NeighborConnectivityType(
offset.target,
codomain=offset.source,
skip_value=_DEFAULT_SKIP_VALUE if skip_values else None,
dtype=gtx.IndexType, # type: ignore[attr-defined] # need to do an explicit export for mypy
max_neighbors=_calc_num_neighbors(location_chain, include_center),
has_skip_values=skip_values,
origin_axis=offset.target[0],
neighbor_axis=offset.source,
)


Expand Down Expand Up @@ -187,7 +194,7 @@ def _scan_for_offsets(fvprog: Program) -> list[eve.concepts.SymbolRef]:
)

all_offset_labels = (
fvprog.itir.pre_walk_values()
fvprog.gtir.pre_walk_values()
.if_isinstance(itir.OffsetLiteral)
.getattr("value")
.if_isinstance(str)
Expand All @@ -206,7 +213,7 @@ def get_stencil_info(
"""Generate StencilInfo dataclass from a fencil definition."""
fvprog = _get_fvprog(fencil_def)
offsets = _scan_for_offsets(fvprog)
fendef = fvprog.itir
program = fvprog.gtir

fields = _get_field_infos(fvprog)

Expand All @@ -216,7 +223,7 @@ def get_stencil_info(
offset_provider[offset] = _provide_offset(offset, is_global)
if offset != dims.Koff.value:
connectivity_chains.append(offset)
return StencilInfo(fendef, fields, connectivity_chains, offset_provider)
return StencilInfo(program, fields, connectivity_chains, offset_provider)


def _calc_num_neighbors(dim_list: list[Dimension], includes_center: bool) -> int:
Expand Down
61 changes: 31 additions & 30 deletions tools/src/icon4pytools/icon4pygen/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from pathlib import Path
from typing import Any, Iterable, List

from gt4py.next.common import Connectivity, Dimension
from gt4py.next.common import DimensionKind, OffsetProvider
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms import LiftMode
from gt4py.next.program_processors.codegens.gtfn import gtfn_module
from gt4py.next.type_system import type_specifications as ts
from icon4py.model.common import dimension as dims
Expand All @@ -33,26 +32,27 @@


def transform_and_configure_fencil(
fencil: itir.FencilDefinition,
) -> itir.FencilDefinition:
program: itir.Program,
) -> itir.Program:
"""Transform the domain representation and configure the FencilDefinition parameters."""
grid_size_symbols = [itir.Sym(id=arg, type=_SIZE_TYPE) for arg in GRID_SIZE_ARGS]

for closure in fencil.closures:
if not len(closure.domain.args) == 2:
raise TypeError(f"Output domain of '{fencil.id}' must be 2-dimensional.")
assert isinstance(closure.domain.args[0], itir.FunCall) and isinstance(
closure.domain.args[1], itir.FunCall
)
horizontal_axis = closure.domain.args[0].args[0]
vertical_axis = closure.domain.args[1].args[0]
for stmt in program.body:
assert isinstance(stmt, itir.SetAt)
domain = stmt.domain
assert isinstance(domain, itir.FunCall)
if not len(domain.args) == 2:
raise TypeError(f"Output domain of '{program.id}' must be 2-dimensional.")
assert isinstance(domain.args[0], itir.FunCall) and isinstance(domain.args[1], itir.FunCall)
horizontal_axis = domain.args[0].args[0]
vertical_axis = domain.args[1].args[0]
assert isinstance(horizontal_axis, itir.AxisLiteral) and isinstance(
vertical_axis, itir.AxisLiteral
)
assert horizontal_axis.value in ["Vertex", "Edge", "Cell"]
assert vertical_axis.value == "K"

closure.domain = itir.FunCall(
stmt.domain = itir.FunCall(
fun=itir.SymRef(id="unstructured_domain"),
args=[
itir.FunCall(
Expand All @@ -66,7 +66,7 @@ def transform_and_configure_fencil(
itir.FunCall(
fun=itir.SymRef(id="named_range"),
args=[
itir.AxisLiteral(value=dims.Koff.source.value),
itir.AxisLiteral(value=dims.Koff.source.value, kind=DimensionKind.VERTICAL),
itir.SymRef(id=V_START),
itir.SymRef(id=V_END),
],
Expand All @@ -75,16 +75,18 @@ def transform_and_configure_fencil(
)

fencil_params = [
*(p for p in fencil.params if not is_size_param(p) and p not in grid_size_symbols),
*(p for p in get_missing_domain_params(fencil.params)),
*(p for p in program.params if not is_size_param(p) and p not in grid_size_symbols),
*(p for p in get_missing_domain_params(program.params)),
*grid_size_symbols,
]

return itir.FencilDefinition(
id=fencil.id,
function_definitions=fencil.function_definitions,
return itir.Program(
id=program.id,
function_definitions=program.function_definitions,
params=fencil_params,
closures=fencil.closures,
declarations=program.declarations,
body=program.body,
implicit_domain=program.implicit_domain,
)


Expand All @@ -100,32 +102,32 @@ def get_missing_domain_params(params: List[itir.Sym]) -> Iterable[itir.Sym]:
return (itir.Sym(id=p, type=_SIZE_TYPE) for p in missing_args)


def check_for_domain_bounds(fencil: itir.FencilDefinition) -> None:
def check_for_domain_bounds(program: itir.Program) -> None:
"""Check that fencil params contain domain boundaries, emit warning otherwise."""
param_ids = {param.id for param in fencil.params}
param_ids = {param.id for param in program.params}
all_domain_params_present = all(
param in param_ids for param in [H_START, H_END, V_START, V_END]
)
if not all_domain_params_present:
warnings.warn(
f"Domain boundaries are missing or have non-standard names for '{fencil.id}'. "
f"Domain boundaries are missing or have non-standard names for '{program.id}'. "
"Adapting domain to use the standard names. This feature will be removed in the future.",
DeprecationWarning,
stacklevel=2,
)


def generate_gtheader(
fencil: itir.FencilDefinition,
offset_provider: dict[str, Connectivity | Dimension],
program: itir.Program,
offset_provider: OffsetProvider,
imperative: bool,
temporaries: bool,
**kwargs: Any,
) -> str:
"""Generate a GridTools C++ header for a given stencil definition using specified configuration parameters."""
check_for_domain_bounds(fencil)
check_for_domain_bounds(program)

transformed_fencil = transform_and_configure_fencil(fencil)
transformed_fencil = transform_and_configure_fencil(program)

translation = gtfn_module.GTFNTranslationStep(
enable_itir_transforms=True,
Expand All @@ -134,7 +136,6 @@ def generate_gtheader(

if temporaries:
translation = translation.replace(
lift_mode=LiftMode.USE_TEMPORARIES,
symbolic_domain_sizes={
"Cell": "num_cells",
"Edge": "num_edges",
Expand All @@ -159,9 +160,9 @@ def __init__(self, stencil_info: StencilInfo) -> None:
def __call__(self, outpath: Path, imperative: bool, temporaries: bool) -> None:
"""Generate C++ code using the GTFN backend and write it to a file."""
gtheader = generate_gtheader(
fencil=self.stencil_info.fendef,
program=self.stencil_info.program,
offset_provider=self.stencil_info.offset_provider,
imperative=imperative,
temporaries=temporaries,
)
write_string(gtheader, outpath, f"{self.stencil_info.fendef.id}.hpp")
write_string(gtheader, outpath, f"{self.stencil_info.program.id}.hpp")
2 changes: 1 addition & 1 deletion tools/src/icon4pytools/icon4pygen/bindings/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class PyBindGen:
"""

def __init__(self, stencil_info: StencilInfo, levels_per_thread: int, block_size: int) -> None:
self.stencil_name = stencil_info.fendef.id
self.stencil_name = stencil_info.program.id
self.fields, self.offsets = self._stencil_info_to_binding_type(stencil_info)
self.levels_per_thread = levels_per_thread
self.block_size = block_size
Expand Down
Loading