Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fixes to FOAST -> ITIR lowering #1196

Merged
merged 4 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from dataclasses import dataclass, field
from typing import Any, Callable

import numpy as np

from gt4py.eve import NodeTranslator
from gt4py.next.common import DimensionKind
from gt4py.next.ffront import (
Expand Down Expand Up @@ -318,9 +316,12 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> itir.FunCall:
return self._lift_if_field(node)(
im.call_("not_")(to_value(node.operand)(self.visit(node.operand, **kwargs)))
)

dtype = type_info.extract_dtype(node.type)
assert type_info.is_arithmetic(dtype)
return self._lift_if_field(node)(
im.call_(node.op.value)(
im.literal_("0", "int"),
im.literal_("0", dtype.kind.name.lower()),
to_value(node.operand)(self.visit(node.operand, **kwargs)),
)
)
Expand Down Expand Up @@ -357,7 +358,9 @@ def visit_Compare(self, node: foast.Compare, **kwargs) -> itir.FunCall:
def _visit_shift(self, node: foast.Call, **kwargs) -> itir.FunCall:
match node.args[0]:
case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)):
return im.shift_(offset_name, offset_index)(self.visit(node.func, **kwargs))
return im.shift_(offset_name, itir.OffsetLiteral(value=offset_index))(
self.visit(node.func, **kwargs)
)
case foast.Name(id=offset_name):
return im.shift_(offset_name)(self.visit(node.func, **kwargs))
raise FieldOperatorLoweringError("Unexpected shift arguments!")
Expand Down Expand Up @@ -388,20 +391,23 @@ def _visit_reduce(self, node: foast.Call, **kwargs) -> itir.FunCall:
)

def _visit_max_over(self, node: foast.Call, **kwargs) -> itir.FunCall:
# TODO(tehrengruber): replace greater_ with max_ builtin as soon as itir supports it
init_expr = itir.Literal(value=str(np.finfo(np.float64).min), type="float64")
dtype = type_info.extract_dtype(node.type)
min_value, _ = type_info.arithmetic_bounds(dtype)
init_expr = itir.Literal(value=str(min_value), type=dtype.kind.name.lower())
return self._make_reduction_expr(
node,
lambda expr: im.call_("if_")(im.greater_("acc", expr), "acc", expr),
lambda expr: im.call_("maximum")("acc", expr),
init_expr,
**kwargs,
)

def _visit_min_over(self, node: foast.Call, **kwargs) -> itir.FunCall:
init_expr = itir.Literal(value=str(np.finfo(np.float64).max), type="float64")
dtype = type_info.extract_dtype(node.type)
_, max_value = type_info.arithmetic_bounds(dtype)
init_expr = itir.Literal(value=str(max_value), type=dtype.kind.name.lower())
return self._make_reduction_expr(
node,
lambda expr: im.call_("if_")(im.less_("acc", expr), "acc", expr),
lambda expr: im.call_("minimum")("acc", expr),
init_expr,
**kwargs,
)
Expand Down Expand Up @@ -548,7 +554,11 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> itir.FunCall:
if node.op is dialect_ast_enums.UnaryOperator.NOT:
return im.call_(node.op.value)(self.visit(node.operand, **kwargs))

return im.call_(node.op.value)(im.literal_("0", "int"), self.visit(node.operand, **kwargs))
dtype = type_info.extract_dtype(node.type)
assert type_info.is_arithmetic(dtype)
return im.call_(node.op.value)(
im.literal_("0", dtype.kind.name.lower()), self.visit(node.operand, **kwargs)
)

def _visit_shift(self, node: foast.Call, **kwargs) -> itir.SymRef: # type: ignore[override]
uid = f"{node.func.id}__{self._sequential_id()}"
Expand Down
12 changes: 12 additions & 0 deletions src/gt4py/next/type_system/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import functools
from typing import Callable, Iterator, Type, TypeGuard, cast

import numpy as np

from gt4py.eve.utils import XIterable, xiter
from gt4py.next.common import Dimension, GTTypeError
from gt4py.next.type_system import type_specifications as ts
Expand Down Expand Up @@ -227,6 +229,16 @@ def is_arithmetic(symbol_type: ts.TypeSpec) -> bool:
return is_floating_point(symbol_type) or is_integral(symbol_type)


def arithmetic_bounds(arithmetic_type: ts.ScalarType):
assert is_arithmetic(arithmetic_type)
return {
ts.ScalarKind.FLOAT32: (np.finfo(np.float32).min, np.finfo(np.float32).max),
ts.ScalarKind.FLOAT64: (np.finfo(np.float64).min, np.finfo(np.float64).max),
ts.ScalarKind.INT32: (np.iinfo(np.int32).min, np.iinfo(np.int32).max),
ts.ScalarKind.INT64: (np.iinfo(np.int64).min, np.iinfo(np.int64).max),
}[arithmetic_type.kind]


def is_type_or_tuple_of_type(type_: ts.TypeSpec, expected_type: type | tuple) -> bool:
"""
Return True if ``type_`` matches any of the expected.
Expand Down
12 changes: 6 additions & 6 deletions tests/next_tests/ffront_tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def combine_tuple(a: Field[[Edge], int64], b: Field[[Edge], int64]) -> Field[[Ve


@pytest.mark.xfail(raises=NotImplementedError)
def test_tuple_with_local_field_in_reduction_shifted(reduction_setup):
def test_tuple_with_local_field_in_reduction_shifted(reduction_setup, fieldview_backend):
rs = reduction_setup
Edge = rs.Edge
Vertex = rs.Vertex
Expand All @@ -400,7 +400,7 @@ def test_tuple_with_local_field_in_reduction_shifted(reduction_setup):
b = np_as_located_field(Vertex)(2 * np.ones((num_vertices,)))
out = np_as_located_field(Edge)(np.zeros((num_edges,)))

@field_operator
@field_operator(backend=fieldview_backend)
def reduce_tuple_element(
edge_field: Field[[Edge], float64], vertex_field: Field[[Vertex], float64]
) -> Field[[Edge], float64]:
Expand Down Expand Up @@ -864,7 +864,7 @@ def test_tuple_unpacking(fieldview_backend):
out3 = np_as_located_field(IDim)(np.ones((size)))
out4 = np_as_located_field(IDim)(np.ones((size)))

@field_operator
@field_operator(backend=fieldview_backend)
def unpack(
inp: Field[[IDim], float64],
) -> tuple[
Expand Down Expand Up @@ -906,7 +906,7 @@ def test_tuple_unpacking_star_multi(fieldview_backend):
Field[[IDim], float64],
]

@field_operator
@field_operator(backend=fieldview_backend)
def unpack(
inp: Field[[IDim], float64],
) -> OutType:
Expand All @@ -928,7 +928,7 @@ def test_tuple_unpacking_too_many_values(fieldview_backend):
match=(r"Could not deduce type: Too many values to unpack \(expected 3\)"),
):

@field_operator
@field_operator(backend=fieldview_backend)
def _star_unpack() -> tuple[int, float64, int]:
a, b, c = (1, 2.0, 3, 4, 5, 6, 7.0)
return a, b, c
Expand All @@ -939,7 +939,7 @@ def test_tuple_unpacking_too_many_values(fieldview_backend):
FieldOperatorTypeDeductionError, match=(r"Assignment value must be of type tuple!")
):

@field_operator
@field_operator(backend=fieldview_backend)
def _invalid_unpack() -> tuple[int, float64, int]:
a, b, c = 1
return a
Expand Down
6 changes: 3 additions & 3 deletions tests/next_tests/ffront_tests/test_gt4py_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_minover_execution(reduction_setup, fieldview_backend):
Vertex, V2EDim = rs.Vertex, rs.V2EDim
in_field = np_as_located_field(Vertex, V2EDim)(rs.v2e_table)

@field_operator
@field_operator(backend=fieldview_backend)
def minover_fieldoperator(input: Field[[Vertex, V2EDim], int64]) -> Field[[Vertex], int64]:
return min_over(input, axis=V2EDim)

Expand All @@ -101,7 +101,7 @@ def test_minover_execution_float(reduction_setup, fieldview_backend):
in_field = np_as_located_field(Vertex, V2EDim)(in_array)
out_field = np_as_located_field(Vertex)(np.zeros(rs.num_vertices))

@field_operator
@field_operator(backend=fieldview_backend)
def minover_fieldoperator(input: Field[[Vertex, V2EDim], float64]) -> Field[[Vertex], float64]:
return min_over(input, axis=V2EDim)

Expand Down Expand Up @@ -323,7 +323,7 @@ def test_conditional_shifted(fieldview_backend):
out_I_float = np_as_located_field(IDim)(np.random.randn(size).astype("float64"))
mask = np_as_located_field(IDim)(np.zeros((size,), dtype=bool))

@field_operator()
@field_operator
def conditional_shifted(
mask: Field[[IDim], bool], a: Field[[IDim], float64], b: Field[[IDim], float64]
) -> Field[[IDim], float64]:
Expand Down