Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functional: refactor common types utils into separate module #666

Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions src/functional/ffront/common_types_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# GT4Py Project - GridTools Framework
#
# Copyright (c) 2014-2021, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
from dataclasses import dataclass
from typing import Optional, Type, TypeGuard

from functional.ffront import common_types


def is_complete_symbol_type(fo_type: common_types.SymbolType) -> TypeGuard[common_types.SymbolType]:
"""Figure out if the foast type is completely deduced."""
match fo_type:
case None:
return False
case common_types.DeferredSymbolType():
return False
case common_types.SymbolType():
return True
return False


@dataclass
class TypeInfo:
"""
Wrapper around foast types for type deduction and compatibility checks.

Examples:
---------
>>> type_a = common_types.ScalarType(kind=common_types.ScalarKind.FLOAT64)
>>> typeinfo_a = TypeInfo(type_a)
>>> typeinfo_a.is_complete
True
>>> typeinfo_a.is_arithmetic_compatible
True
>>> typeinfo_a.is_logics_compatible
False
>>> typeinfo_b = TypeInfo(None)
>>> typeinfo_b.is_any_type
True
>>> typeinfo_b.is_arithmetic_compatible
False
>>> typeinfo_b.can_be_refined_to(typeinfo_a)
True

"""

type: common_types.SymbolType # noqa: A003

@property
def is_complete(self) -> bool:
return is_complete_symbol_type(self.type)

@property
def is_any_type(self) -> bool:
return (not self.is_complete) and ((self.type is None) or (self.constraint is None))

@property
def constraint(self) -> Optional[Type[common_types.SymbolType]]:
"""Find the constraint of a deferred type or the class of a complete type."""
if self.is_complete:
return self.type.__class__
elif self.type:
return self.type.constraint
return None

@property
def is_field_type(self) -> bool:
return issubclass(self.constraint, common_types.FieldType) if self.constraint else False

@property
def is_scalar(self) -> bool:
return issubclass(self.constraint, common_types.ScalarType) if self.constraint else False

@property
def is_arithmetic_compatible(self) -> bool:
match self.type:
case common_types.FieldType(
dtype=common_types.ScalarType(kind=dtype_kind)
) | common_types.ScalarType(kind=dtype_kind):
if dtype_kind is not common_types.ScalarKind.BOOL:
return True
return False

@property
def is_logics_compatible(self) -> bool:
match self.type:
case common_types.FieldType(
dtype=common_types.ScalarType(kind=dtype_kind)
) | common_types.ScalarType(kind=dtype_kind):
if dtype_kind is common_types.ScalarKind.BOOL:
return True
return False

def can_be_refined_to(self, other: "TypeInfo") -> bool:
if self.is_any_type:
return True
if self.is_complete:
return self.type == other.type
if self.constraint:
if other.is_complete:
return isinstance(other.type, self.constraint)
elif other.constraint:
return self.constraint is other.constraint
return False


def are_broadcast_compatible(left: TypeInfo, right: TypeInfo) -> bool:
"""
Check if ``left`` and ``right`` types are compatible after optional broadcasting.

A binary field operation between two arguments can proceed and the result is a field.
on top of the dimensions, also the dtypes must match.

Examples:
---------
>>> int_scalar_t = TypeInfo(common_types.ScalarType(kind=common_types.ScalarKind.INT64))
>>> are_broadcast_compatible(int_scalar_t, int_scalar_t)
True
>>> int_field_t = TypeInfo(common_types.FieldType(dtype=common_types.ScalarType(kind=common_types.ScalarKind.INT64),
... dims=...))
>>> are_broadcast_compatible(int_field_t, int_scalar_t)
True

"""
if left.is_field_type and right.is_field_type:
return left.type.dims == right.type.dims
elif left.is_field_type and right.is_scalar:
return left.type.dtype == right.type
elif left.is_scalar and left.is_field_type:
return left.type == right.type.dtype
elif left.is_scalar and right.is_scalar:
return left.type == right.type
return False


def broadcast_typeinfos(left: TypeInfo, right: TypeInfo) -> TypeInfo:
"""
Decide the result type of a binary operation between arguments of ``left`` and ``right`` type.

Return None if the two types are not compatible even after broadcasting.

Examples:
---------
>>> int_scalar_t = TypeInfo(common_types.ScalarType(kind=common_types.ScalarKind.INT64))
>>> int_field_t = TypeInfo(common_types.FieldType(dtype=common_types.ScalarType(kind=common_types.ScalarKind.INT64),
... dims=...))
>>> assert broadcast_typeinfos(int_field_t, int_scalar_t).type == int_field_t.type

"""
if not are_broadcast_compatible(left, right):
return None
if left.is_scalar and right.is_field_type:
return right
return left
170 changes: 11 additions & 159 deletions src/functional/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,160 +11,13 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
from dataclasses import dataclass
from typing import Optional, Type, TypeGuard
from typing import Optional

import functional.ffront.field_operator_ast as foast
from eve import NodeTranslator, SymbolTableTrait
from functional.common import GTSyntaxError
from functional.ffront import common_types


def is_complete_symbol_type(fo_type: common_types.SymbolType) -> TypeGuard[common_types.SymbolType]:
"""Figure out if the foast type is completely deduced."""
match fo_type:
case None:
return False
case common_types.DeferredSymbolType():
return False
case common_types.SymbolType():
return True
return False


@dataclass
class TypeInfo:
"""
Wrapper around foast types for type deduction and compatibility checks.

Examples:
---------
>>> type_a = common_types.ScalarType(kind=common_types.ScalarKind.FLOAT64)
>>> typeinfo_a = TypeInfo(type_a)
>>> typeinfo_a.is_complete
True
>>> typeinfo_a.is_arithmetic_compatible
True
>>> typeinfo_a.is_logics_compatible
False
>>> typeinfo_b = TypeInfo(None)
>>> typeinfo_b.is_any_type
True
>>> typeinfo_b.is_arithmetic_compatible
False
>>> typeinfo_b.can_be_refined_to(typeinfo_a)
True

"""

type: common_types.SymbolType # noqa: A003

@property
def is_complete(self) -> bool:
return is_complete_symbol_type(self.type)

@property
def is_any_type(self) -> bool:
return (not self.is_complete) and ((self.type is None) or (self.constraint is None))

@property
def constraint(self) -> Optional[Type[common_types.SymbolType]]:
"""Find the constraint of a deferred type or the class of a complete type."""
if self.is_complete:
return self.type.__class__
elif self.type:
return self.type.constraint
return None

@property
def is_field_type(self) -> bool:
return issubclass(self.constraint, common_types.FieldType) if self.constraint else False

@property
def is_scalar(self) -> bool:
return issubclass(self.constraint, common_types.ScalarType) if self.constraint else False

@property
def is_arithmetic_compatible(self) -> bool:
match self.type:
case common_types.FieldType(
dtype=common_types.ScalarType(kind=dtype_kind)
) | common_types.ScalarType(kind=dtype_kind):
if dtype_kind is not common_types.ScalarKind.BOOL:
return True
return False

@property
def is_logics_compatible(self) -> bool:
match self.type:
case common_types.FieldType(
dtype=common_types.ScalarType(kind=dtype_kind)
) | common_types.ScalarType(kind=dtype_kind):
if dtype_kind is common_types.ScalarKind.BOOL:
return True
return False

def can_be_refined_to(self, other: "TypeInfo") -> bool:
if self.is_any_type:
return True
if self.is_complete:
return self.type == other.type
if self.constraint:
if other.is_complete:
return isinstance(other.type, self.constraint)
elif other.constraint:
return self.constraint is other.constraint
return False


def are_broadcast_compatible(left: TypeInfo, right: TypeInfo) -> bool:
"""
Check if ``left`` and ``right`` types are compatible after optional broadcasting.

A binary field operation between two arguments can proceed and the result is a field.
on top of the dimensions, also the dtypes must match.

Examples:
---------
>>> int_scalar_t = TypeInfo(common_types.ScalarType(kind=common_types.ScalarKind.INT64))
>>> are_broadcast_compatible(int_scalar_t, int_scalar_t)
True
>>> int_field_t = TypeInfo(common_types.FieldType(dtype=common_types.ScalarType(kind=common_types.ScalarKind.INT64),
... dims=...))
>>> are_broadcast_compatible(int_field_t, int_scalar_t)
True

"""
if left.is_field_type and right.is_field_type:
return left.type.dims == right.type.dims
elif left.is_field_type and right.is_scalar:
return left.type.dtype == right.type
elif left.is_scalar and left.is_field_type:
return left.type == right.type.dtype
elif left.is_scalar and right.is_scalar:
return left.type == right.type
return False


def broadcast_typeinfos(left: TypeInfo, right: TypeInfo) -> TypeInfo:
"""
Decide the result type of a binary operation between arguments of ``left`` and ``right`` type.

Return None if the two types are not compatible even after broadcasting.

Examples:
---------
>>> int_scalar_t = TypeInfo(common_types.ScalarType(kind=common_types.ScalarKind.INT64))
>>> int_field_t = TypeInfo(common_types.FieldType(dtype=common_types.ScalarType(kind=common_types.ScalarKind.INT64),
... dims=...))
>>> assert broadcast_typeinfos(int_field_t, int_scalar_t).type == int_field_t.type

"""
if not are_broadcast_compatible(left, right):
return None
if left.is_scalar and right.is_field_type:
return right
return left
from functional.ffront import common_types_utils as ct_utils


class FieldOperatorTypeDeduction(NodeTranslator):
Expand Down Expand Up @@ -212,14 +65,13 @@ def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name:
raise FieldOperatorTypeDeductionError.from_foast_node(
node, msg=f"Undeclared symbol {node.id}"
)
return node

symbol = symtable[node.id]
return foast.Name(id=node.id, type=symbol.type, location=node.location)

def visit_Assign(self, node: foast.Assign, **kwargs) -> foast.Assign:
new_value = node.value
if not is_complete_symbol_type(node.value.type):
if not ct_utils.is_complete_symbol_type(node.value.type):
new_value = self.visit(node.value, **kwargs)
new_target = self.visit(node.target, refine_type=new_value.type, **kwargs)
return foast.Assign(target=new_target, value=new_value, location=node.location)
Expand All @@ -232,7 +84,7 @@ def visit_Symbol(
) -> foast.Symbol:
symtable = kwargs["symtable"]
if refine_type:
if not TypeInfo(node.type).can_be_refined_to(TypeInfo(refine_type)):
if not ct_utils.TypeInfo(node.type).can_be_refined_to(ct_utils.TypeInfo(refine_type)):
raise FieldOperatorTypeDeductionError.from_foast_node(
node,
msg=(
Expand Down Expand Up @@ -311,13 +163,13 @@ def _deduce_arithmetic_binop_type(
right_type: common_types.SymbolType,
**kwargs,
) -> common_types.SymbolType:
left, right = TypeInfo(left_type), TypeInfo(right_type)
left, right = ct_utils.TypeInfo(left_type), ct_utils.TypeInfo(right_type)
if (
left.is_arithmetic_compatible
and right.is_arithmetic_compatible
and are_broadcast_compatible(left, right)
and ct_utils.are_broadcast_compatible(left, right)
):
return broadcast_typeinfos(left, right).type
return ct_utils.broadcast_typeinfos(left, right).type
else:
raise FieldOperatorTypeDeductionError.from_foast_node(
parent,
Expand All @@ -333,13 +185,13 @@ def _deduce_logical_binop_type(
right_type: common_types.SymbolType,
**kwargs,
) -> common_types.SymbolType:
left, right = TypeInfo(left_type), TypeInfo(right_type)
left, right = ct_utils.TypeInfo(left_type), ct_utils.TypeInfo(right_type)
if (
left.is_logics_compatible
and right.is_logics_compatible
and are_broadcast_compatible(left, right)
and ct_utils.are_broadcast_compatible(left, right)
):
return broadcast_typeinfos(left, right).type
return ct_utils.broadcast_typeinfos(left, right).type
else:
raise FieldOperatorTypeDeductionError.from_foast_node(
parent,
Expand All @@ -360,7 +212,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp:
def _is_unaryop_type_compatible(
self, op: foast.UnaryOperator, operand_type: common_types.FieldType
) -> bool:
operand_ti = TypeInfo(operand_type)
operand_ti = ct_utils.TypeInfo(operand_type)
if op in [foast.UnaryOperator.UADD, foast.UnaryOperator.USUB]:
return operand_ti.is_arithmetic_compatible
elif op is foast.UnaryOperator.NOT:
Expand Down
Loading