Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implicit coercion of numeric types #702

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
FunctionType,
InputFlags,
NoneType,
NumericType,
OpaqueType,
StructType,
TupleType,
Expand Down Expand Up @@ -207,7 +208,7 @@ def check(
# If we already have a type for the expression, we just have to match it against
# the target
if actual := get_type_opt(expr):
subst, inst = check_type_against(actual, ty, expr, kind)
expr, subst, inst = check_type_against(actual, ty, expr, self.ctx, kind)
if inst:
expr = with_loc(expr, TypeApply(value=expr, tys=inst))
return with_type(ty.substitute(subst), expr), subst
Expand Down Expand Up @@ -329,7 +330,7 @@ def visit_PyExpr(self, node: PyExpr, ty: Type) -> tuple[ast.expr, Subst]:
def generic_visit(self, node: ast.expr, ty: Type) -> tuple[ast.expr, Subst]:
# Try to synthesize and then check if we can unify it with the given type
node, synth = self._synthesize(node, allow_free_vars=False)
subst, inst = check_type_against(synth, ty, node, self._kind)
node, subst, inst = check_type_against(synth, ty, node, self.ctx, self._kind)

# Apply instantiation of quantified type variables
if inst:
Expand Down Expand Up @@ -759,8 +760,8 @@ def generic_visit(self, node: ast.expr) -> NoReturn:


def check_type_against(
act: Type, exp: Type, node: AstNode, kind: str = "expression"
) -> tuple[Subst, Inst]:
act: Type, exp: Type, node: ast.expr, ctx: Context, kind: str = "expression"
) -> tuple[ast.expr, Subst, Inst]:
"""Checks a type against another type.

Returns a substitution for the free variables the expected type and an instantiation
Expand Down Expand Up @@ -797,14 +798,37 @@ def check_type_against(
# Finally, check that the instantiation respects the linearity requirements
check_inst(act, inst, node)

return subst, inst
return node, subst, inst

# Otherwise, we know that `act` has no unsolved type vars, so unification is trivial
assert not act.unsolved_vars
subst = unify(exp, act, {})
if subst is None:
# Maybe we can implicitly coerce `act` to `exp`
if coerced := try_coerce_to(act, exp, node, ctx):
return coerced, {}, []
raise GuppyTypeError(TypeMismatchError(node, exp, act, kind))
return subst, []
return node, subst, []


def try_coerce_to(
act: Type, exp: Type, node: ast.expr, ctx: Context
) -> ast.expr | None:
"""Tries to implicitly coerce an expression to a different type.

Returns the coerced expression or `None` if the type cannot be implicitly coerced.
"""
# Currently, we only support implicit coercions of numeric types
if not isinstance(act, NumericType) or not isinstance(exp, NumericType):
return None
# Ordering on `NumericType.Kind` defines the coercion relation
if act.kind < exp.kind:
f = ctx.globals.get_instance_func(act, f"__{exp.kind.name.lower()}__")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so int -> float is ok but float -> int is not, correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly 👍

assert f is not None
node, subst = f.check_call([node], exp, node, ctx)
assert len(subst) == 0, "Coercion methods are not generic"
return node
return None


def check_num_args(
Expand Down
2 changes: 1 addition & 1 deletion guppylang/definition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
from guppylang.checker.expr_checker import check_type_against

expr, res_ty = self.synthesize(args)
subst, _ = check_type_against(res_ty, ty, self.node)
expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
return expr, subst

@abstractmethod
Expand Down
7 changes: 5 additions & 2 deletions guppylang/std/_internal/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,10 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
# TODO: We could use the type information to infer some stuff
# in the comprehension
arr_compr, res_ty = self.synthesize_array_comprehension(compr)
subst, _ = check_type_against(res_ty, ty, self.node)
arr_compr = with_loc(self.node, arr_compr)
arr_compr, subst, _ = check_type_against(
res_ty, ty, arr_compr, self.ctx
)
return arr_compr, subst
# Or a list of array elements
case args:
Expand Down Expand Up @@ -359,7 +362,7 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:

def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
expr, res_ty = self.synthesize(args)
subst, _ = check_type_against(res_ty, ty, self.node)
expr, subst, _ = check_type_against(res_ty, ty, expr, self.ctx)
return expr, subst

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion guppylang/tys/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _visit_NoneType(self, ty: NoneType, inside_row: bool) -> str:

@_visit.register
def _visit_NumericType(self, ty: NumericType, inside_row: bool) -> str:
return ty.kind.value
return ty.kind.name.lower()

@_visit.register
def _visit_TypeParam(self, param: TypeParam, inside_row: bool) -> str:
Expand Down
12 changes: 8 additions & 4 deletions guppylang/tys/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Sequence
from dataclasses import dataclass, field
from enum import Enum, Flag, auto
from functools import cached_property
from functools import cached_property, total_ordering
from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast

import hugr.std.float
Expand Down Expand Up @@ -259,12 +259,16 @@ class NumericType(TypeBase):

kind: "Kind"

@total_ordering
class Kind(Enum):
"""The different kinds of numeric types."""

Nat = "nat"
Int = "int"
Float = "float"
Nat = auto()
Int = auto()
Float = auto()

def __lt__(self, other: "NumericType.Kind") -> bool:
return self.value < other.value

INT_WIDTH: ClassVar[int] = 6

Expand Down
11 changes: 11 additions & 0 deletions tests/integration/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ def main(a1: angle, a2: angle) -> bool:
validate(module.compile())


def test_implicit_coercion(validate):
@compile_guppy
def coerce(x: nat) -> float:
y: int = x
z: float = y
a: float = 1
return z + a

validate(coerce)


def test_angle_float_coercion(validate):
module = GuppyModule("test")
module.load(angle)
Expand Down
Loading