diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 4b6a2ea2..57776d6d 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -24,7 +24,7 @@ import sys import traceback from contextlib import suppress -from typing import Any, NoReturn, cast +from typing import Any, NoReturn, cast, TypeVar from guppylang.ast_util import ( AstNode, @@ -230,6 +230,21 @@ def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]: else: raise GuppyTypeError(f"Expected function type, got `{func_ty}`", node.func) + def visit_PyExpr(self, node: PyExpr, ty: GuppyType) -> tuple[ast.expr, Subst]: + python_val = eval_py_expr(node, self.ctx) + if act := python_value_to_guppy_type(python_val, node, self.ctx.globals): + subst = unify(ty, act, {}) + if subst is None: + self._fail(ty, act, node) + act = act.substitute(subst) + subst = {x: s for x, s in subst.items() if x in ty.unsolved_vars} + return with_type(act, with_loc(node, ast.Constant(value=python_val))), subst + + raise GuppyError( + f"Python expression of type `{type(python_val)}` is not supported by Guppy", + node, + ) + def generic_visit(self, node: ast.expr, ty: GuppyType) -> tuple[ast.expr, Subst]: # Try to synthesize and then check if we can unify it with the given type node, synth = self._synthesize(node, allow_free_vars=False) @@ -497,34 +512,7 @@ def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, GuppyType]: ) def visit_PyExpr(self, node: PyExpr) -> tuple[ast.expr, GuppyType]: - # The method we used for obtaining the Python variables in scope only works in - # CPython (see `get_py_scope()`). - if sys.implementation.name != "cpython": - raise GuppyError( - "Compile-time `py(...)` expressions are only supported in CPython", node - ) - - try: - python_val = eval( # noqa: S307, PGH001 - ast.unparse(node.value), - None, - DummyEvalDict(self.ctx, node.value), - ) - except DummyEvalDict.GuppyVarUsedError as e: - raise GuppyError( - f"Guppy variable `{e.var}` cannot be accessed in a compile-time " - "`py(...)` expression", - e.node or node, - ) from None - except Exception as e: # noqa: BLE001 - # Remove the top frame pointing to the `eval` call from the stack trace - tb = e.__traceback__.tb_next if e.__traceback__ else None - raise GuppyError( - "Error occurred while evaluating Python expression:\n\n" - + "".join(traceback.format_exception(type(e), e, tb)), - node, - ) from e - + python_val = eval_py_expr(node, self.ctx) if ty := python_value_to_guppy_type(python_val, node, self.ctx.globals): return with_loc(node, ast.Constant(value=python_val)), ty @@ -898,6 +886,38 @@ def check_linear_use_from_outer_scope(expr: ast.expr, locals: Locals) -> None: return node, elt_ty +def eval_py_expr(node: PyExpr, ctx: Context) -> Any: + """Evaluates a `py(...)` expression.""" + # The method we used for obtaining the Python variables in scope only works in + # CPython (see `get_py_scope()`). + if sys.implementation.name != "cpython": + raise GuppyError( + "Compile-time `py(...)` expressions are only supported in CPython", node + ) + + try: + python_val = eval( # noqa: S307, PGH001 + ast.unparse(node.value), + None, + DummyEvalDict(ctx, node.value), + ) + except DummyEvalDict.GuppyVarUsedError as e: + raise GuppyError( + f"Guppy variable `{e.var}` cannot be accessed in a compile-time " + "`py(...)` expression", + e.node or node, + ) from None + except Exception as e: # noqa: BLE001 + # Remove the top frame pointing to the `eval` call from the stack trace + tb = e.__traceback__.tb_next if e.__traceback__ else None + raise GuppyError( + "Error occurred while evaluating Python expression:\n\n" + + "".join(traceback.format_exception(type(e), e, tb)), + node, + ) from e + return python_val + + def python_value_to_guppy_type( v: Any, node: ast.expr, globals: Globals ) -> GuppyType | None: @@ -917,5 +937,34 @@ def python_value_to_guppy_type( if any(ty is None for ty in tys): return None return TupleType(cast(list[GuppyType], tys)) + case list(): + return _python_list_to_guppy_type(v, node, globals) case _: return None + + +T = TypeVar("T") + + +def _python_list_to_guppy_type( + vs: list[T], node: ast.expr, globals: Globals +) -> ListType | None: + """Turns a Python list into a Guppy type. + + Returns `None` if the list contains different types or types that are not + representable in Guppy. + """ + if len(vs) == 0: + return ListType(ExistentialTypeVar.new("T", False)) + + # All the list elements must have a unifiable types + v, *rest = vs + el_ty = python_value_to_guppy_type(v, node, globals) + if el_ty is None: + return None + for v in rest: + ty = python_value_to_guppy_type(v, node, globals) + if ty is None or (subst := unify(ty, el_ty, {})) is None: + return None + el_ty = el_ty.substitute(subst) + return ListType(el_ty) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 7271a990..a010eec9 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -19,7 +19,7 @@ Inst, NoneType, TupleType, - type_to_row, + type_to_row, GuppyType, ListType, ) from guppylang.hugr import ops, val from guppylang.hugr.hugr import DFContainingNode, OutPortV, VNode @@ -135,7 +135,7 @@ def _if_true(self, cond: ast.expr, inputs: list[ast.Name]) -> Iterator[None]: self.dfg[name.id].port = cond_node.add_out_port(get_type(name)) def visit_Constant(self, node: ast.Constant) -> OutPortV: - if value := python_value_to_hugr(node.value): + if value := python_value_to_hugr(node.value, get_type(node)): const = self.graph.add_constant(value, get_type(node)).out_port(0) return self.graph.add_load_constant(const).out_port(0) raise InternalGuppyError("Unsupported constant expression in compiler") @@ -294,12 +294,12 @@ def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool: return False -def python_value_to_hugr(v: Any) -> val.Value | None: +def python_value_to_hugr(v: Any, exp_ty: GuppyType) -> val.Value | None: """Turns a Python value into a Hugr value. Returns None if the Python value cannot be represented in Guppy. """ - from guppylang.prelude._internal import bool_value, float_value, int_value + from guppylang.prelude._internal import bool_value, float_value, int_value, list_value match v: case bool(): @@ -309,9 +309,18 @@ def python_value_to_hugr(v: Any) -> val.Value | None: case float(): return float_value(v) case tuple(elts): - vs = [python_value_to_hugr(elt) for elt in elts] + assert isinstance(exp_ty, TupleType) + vs = [ + python_value_to_hugr(elt, ty) + for elt, ty in zip(elts, exp_ty.element_types) + ] if any(value is None for value in vs): return None return val.Tuple(vs=vs) + case list(elts): + assert isinstance(exp_ty, ListType) + return list_value( + [python_value_to_hugr(elt, exp_ty.element_type) for elt in elts] + ) case _: return None diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index 4baf1966..44d30b8d 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -1,5 +1,5 @@ import ast -from typing import Literal +from typing import Literal, Any from pydantic import BaseModel @@ -52,6 +52,13 @@ class ConstF64(BaseModel): value: float +class ListValue(BaseModel): + """Hugr representation of floats in the arithmetic extension.""" + + c: Literal["ListValue"] = "ListValue" + value: list[Any] + + def bool_value(b: bool) -> val.Value: """Returns the Hugr representation of a boolean value.""" return val.Sum(tag=int(b), value=val.Tuple(vs=[])) @@ -67,6 +74,11 @@ def float_value(f: float) -> val.Value: return val.ExtensionVal(c=(ConstF64(value=f),)) +def list_value(v: list[val.Value]) -> val.Value: + """Returns the Hugr representation of a list value.""" + return val.ExtensionVal(c=(ListValue(value=v),)) + + def logic_op(op_name: str, args: list[tys.TypeArg] | None = None) -> ops.OpType: """Utility method to create Hugr logic ops.""" return ops.CustomOp(extension="logic", op_name=op_name, args=args or []) diff --git a/tests/error/py_errors/list_different_tys.err b/tests/error/py_errors/list_different_tys.err new file mode 100644 index 00000000..dcb972d5 --- /dev/null +++ b/tests/error/py_errors/list_different_tys.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo() -> int: +6: return py([1, 1.0]) + ^^^^^^^^^^^^ +GuppyError: Python expression of type `` is not supported by Guppy diff --git a/tests/error/py_errors/list_different_tys.py b/tests/error/py_errors/list_different_tys.py new file mode 100644 index 00000000..80f468c9 --- /dev/null +++ b/tests/error/py_errors/list_different_tys.py @@ -0,0 +1,6 @@ +from guppylang.decorator import guppy + + +@guppy +def foo() -> int: + return py([1, 1.0]) diff --git a/tests/error/py_errors/list_empty.err b/tests/error/py_errors/list_empty.err new file mode 100644 index 00000000..748c4047 --- /dev/null +++ b/tests/error/py_errors/list_empty.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo() -> None: +6: xs = py([]) + ^^^^^^ +GuppyTypeError: Cannot infer type variable in expression of type `list[?T]` diff --git a/tests/error/py_errors/list_empty.py b/tests/error/py_errors/list_empty.py new file mode 100644 index 00000000..09889f52 --- /dev/null +++ b/tests/error/py_errors/list_empty.py @@ -0,0 +1,6 @@ +from guppylang.decorator import guppy + + +@guppy +def foo() -> None: + xs = py([]) diff --git a/tests/integration/test_py.py b/tests/integration/test_py.py index 85b0ec1f..5f56d44d 100644 --- a/tests/integration/test_py.py +++ b/tests/integration/test_py.py @@ -60,3 +60,36 @@ def foo() -> int: return x validate(foo) + + +def test_list_basic(validate): + @guppy + def foo() -> list[int]: + xs = py([1, 2, 3]) + return xs + + validate(foo) + + +def test_list_empty(validate): + @guppy + def foo() -> list[int]: + return py([]) + + validate(foo) + + +def test_list_empty_nested(validate): + @guppy + def foo() -> None: + xs: list[tuple[int, list[bool]]] = py([(42, [])]) + + validate(foo) + + +def test_list_empty_multiple(validate): + @guppy + def foo() -> None: + xs: tuple[list[int], list[bool]] = py([], []) + + validate(foo)