Skip to content

Commit

Permalink
feat: Allow lists in py expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Jan 18, 2024
1 parent e0761ff commit 6de5059
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 35 deletions.
107 changes: 78 additions & 29 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import sys
import traceback
from contextlib import suppress
from typing import Any, NoReturn, cast
from typing import Any, NoReturn, cast, TypeVar

from guppylang.ast_util import (
AstNode,
Expand Down Expand Up @@ -230,6 +230,21 @@ def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]:
else:
raise GuppyTypeError(f"Expected function type, got `{func_ty}`", node.func)

def visit_PyExpr(self, node: PyExpr, ty: GuppyType) -> tuple[ast.expr, Subst]:
python_val = eval_py_expr(node, self.ctx)
if act := python_value_to_guppy_type(python_val, node, self.ctx.globals):
subst = unify(ty, act, {})
if subst is None:
self._fail(ty, act, node)
act = act.substitute(subst)
subst = {x: s for x, s in subst.items() if x in ty.unsolved_vars}
return with_type(act, with_loc(node, ast.Constant(value=python_val))), subst

raise GuppyError(
f"Python expression of type `{type(python_val)}` is not supported by Guppy",
node,
)

def generic_visit(self, node: ast.expr, ty: GuppyType) -> tuple[ast.expr, Subst]:
# Try to synthesize and then check if we can unify it with the given type
node, synth = self._synthesize(node, allow_free_vars=False)
Expand Down Expand Up @@ -497,34 +512,7 @@ def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, GuppyType]:
)

def visit_PyExpr(self, node: PyExpr) -> tuple[ast.expr, GuppyType]:
# The method we used for obtaining the Python variables in scope only works in
# CPython (see `get_py_scope()`).
if sys.implementation.name != "cpython":
raise GuppyError(
"Compile-time `py(...)` expressions are only supported in CPython", node
)

try:
python_val = eval( # noqa: S307, PGH001
ast.unparse(node.value),
None,
DummyEvalDict(self.ctx, node.value),
)
except DummyEvalDict.GuppyVarUsedError as e:
raise GuppyError(
f"Guppy variable `{e.var}` cannot be accessed in a compile-time "
"`py(...)` expression",
e.node or node,
) from None
except Exception as e: # noqa: BLE001
# Remove the top frame pointing to the `eval` call from the stack trace
tb = e.__traceback__.tb_next if e.__traceback__ else None
raise GuppyError(
"Error occurred while evaluating Python expression:\n\n"
+ "".join(traceback.format_exception(type(e), e, tb)),
node,
) from e

python_val = eval_py_expr(node, self.ctx)
if ty := python_value_to_guppy_type(python_val, node, self.ctx.globals):
return with_loc(node, ast.Constant(value=python_val)), ty

Expand Down Expand Up @@ -898,6 +886,38 @@ def check_linear_use_from_outer_scope(expr: ast.expr, locals: Locals) -> None:
return node, elt_ty


def eval_py_expr(node: PyExpr, ctx: Context) -> Any:
"""Evaluates a `py(...)` expression."""
# The method we used for obtaining the Python variables in scope only works in
# CPython (see `get_py_scope()`).
if sys.implementation.name != "cpython":
raise GuppyError(
"Compile-time `py(...)` expressions are only supported in CPython", node
)

try:
python_val = eval( # noqa: S307, PGH001
ast.unparse(node.value),
None,
DummyEvalDict(ctx, node.value),
)
except DummyEvalDict.GuppyVarUsedError as e:
raise GuppyError(
f"Guppy variable `{e.var}` cannot be accessed in a compile-time "
"`py(...)` expression",
e.node or node,
) from None
except Exception as e: # noqa: BLE001
# Remove the top frame pointing to the `eval` call from the stack trace
tb = e.__traceback__.tb_next if e.__traceback__ else None
raise GuppyError(
"Error occurred while evaluating Python expression:\n\n"
+ "".join(traceback.format_exception(type(e), e, tb)),
node,
) from e
return python_val


def python_value_to_guppy_type(
v: Any, node: ast.expr, globals: Globals
) -> GuppyType | None:
Expand All @@ -917,5 +937,34 @@ def python_value_to_guppy_type(
if any(ty is None for ty in tys):
return None
return TupleType(cast(list[GuppyType], tys))
case list():
return _python_list_to_guppy_type(v, node, globals)
case _:
return None


T = TypeVar("T")


def _python_list_to_guppy_type(
vs: list[T], node: ast.expr, globals: Globals
) -> ListType | None:
"""Turns a Python list into a Guppy type.
Returns `None` if the list contains different types or types that are not
representable in Guppy.
"""
if len(vs) == 0:
return ListType(ExistentialTypeVar.new("T", False))

# All the list elements must have a unifiable types
v, *rest = vs
el_ty = python_value_to_guppy_type(v, node, globals)
if el_ty is None:
return None
for v in rest:
ty = python_value_to_guppy_type(v, node, globals)
if ty is None or (subst := unify(ty, el_ty, {})) is None:
return None
el_ty = el_ty.substitute(subst)
return ListType(el_ty)
19 changes: 14 additions & 5 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Inst,
NoneType,
TupleType,
type_to_row,
type_to_row, GuppyType, ListType,
)
from guppylang.hugr import ops, val
from guppylang.hugr.hugr import DFContainingNode, OutPortV, VNode
Expand Down Expand Up @@ -135,7 +135,7 @@ def _if_true(self, cond: ast.expr, inputs: list[ast.Name]) -> Iterator[None]:
self.dfg[name.id].port = cond_node.add_out_port(get_type(name))

def visit_Constant(self, node: ast.Constant) -> OutPortV:
if value := python_value_to_hugr(node.value):
if value := python_value_to_hugr(node.value, get_type(node)):
const = self.graph.add_constant(value, get_type(node)).out_port(0)
return self.graph.add_load_constant(const).out_port(0)
raise InternalGuppyError("Unsupported constant expression in compiler")
Expand Down Expand Up @@ -294,12 +294,12 @@ def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool:
return False


def python_value_to_hugr(v: Any) -> val.Value | None:
def python_value_to_hugr(v: Any, exp_ty: GuppyType) -> val.Value | None:
"""Turns a Python value into a Hugr value.
Returns None if the Python value cannot be represented in Guppy.
"""
from guppylang.prelude._internal import bool_value, float_value, int_value
from guppylang.prelude._internal import bool_value, float_value, int_value, list_value

match v:
case bool():
Expand All @@ -309,9 +309,18 @@ def python_value_to_hugr(v: Any) -> val.Value | None:
case float():
return float_value(v)
case tuple(elts):
vs = [python_value_to_hugr(elt) for elt in elts]
assert isinstance(exp_ty, TupleType)
vs = [
python_value_to_hugr(elt, ty)
for elt, ty in zip(elts, exp_ty.element_types)
]
if any(value is None for value in vs):
return None
return val.Tuple(vs=vs)
case list(elts):
assert isinstance(exp_ty, ListType)
return list_value(
[python_value_to_hugr(elt, exp_ty.element_type) for elt in elts]
)
case _:
return None
14 changes: 13 additions & 1 deletion guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ast
from typing import Literal
from typing import Literal, Any

from pydantic import BaseModel

Expand Down Expand Up @@ -52,6 +52,13 @@ class ConstF64(BaseModel):
value: float


class ListValue(BaseModel):
"""Hugr representation of floats in the arithmetic extension."""

c: Literal["ListValue"] = "ListValue"
value: list[Any]


def bool_value(b: bool) -> val.Value:
"""Returns the Hugr representation of a boolean value."""
return val.Sum(tag=int(b), value=val.Tuple(vs=[]))
Expand All @@ -67,6 +74,11 @@ def float_value(f: float) -> val.Value:
return val.ExtensionVal(c=(ConstF64(value=f),))


def list_value(v: list[val.Value]) -> val.Value:
"""Returns the Hugr representation of a list value."""
return val.ExtensionVal(c=(ListValue(value=v),))


def logic_op(op_name: str, args: list[tys.TypeArg] | None = None) -> ops.OpType:
"""Utility method to create Hugr logic ops."""
return ops.CustomOp(extension="logic", op_name=op_name, args=args or [])
Expand Down
7 changes: 7 additions & 0 deletions tests/error/py_errors/list_different_tys.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:6

4: @guppy
5: def foo() -> int:
6: return py([1, 1.0])
^^^^^^^^^^^^
GuppyError: Python expression of type `<class 'list'>` is not supported by Guppy
6 changes: 6 additions & 0 deletions tests/error/py_errors/list_different_tys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from guppylang.decorator import guppy


@guppy
def foo() -> int:
return py([1, 1.0])
7 changes: 7 additions & 0 deletions tests/error/py_errors/list_empty.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:6

4: @guppy
5: def foo() -> None:
6: xs = py([])
^^^^^^
GuppyTypeError: Cannot infer type variable in expression of type `list[?T]`
6 changes: 6 additions & 0 deletions tests/error/py_errors/list_empty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from guppylang.decorator import guppy


@guppy
def foo() -> None:
xs = py([])
33 changes: 33 additions & 0 deletions tests/integration/test_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,36 @@ def foo() -> int:
return x

validate(foo)


def test_list_basic(validate):
@guppy
def foo() -> list[int]:
xs = py([1, 2, 3])
return xs

validate(foo)


def test_list_empty(validate):
@guppy
def foo() -> list[int]:
return py([])

validate(foo)


def test_list_empty_nested(validate):
@guppy
def foo() -> None:
xs: list[tuple[int, list[bool]]] = py([(42, [])])

validate(foo)


def test_list_empty_multiple(validate):
@guppy
def foo() -> None:
xs: tuple[list[int], list[bool]] = py([], [])

validate(foo)

0 comments on commit 6de5059

Please sign in to comment.