Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow lists in py expressions #113

Merged
merged 5 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 76 additions & 28 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,21 @@ def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]:
else:
raise GuppyTypeError(f"Expected function type, got `{func_ty}`", node.func)

def visit_PyExpr(self, node: PyExpr, ty: GuppyType) -> tuple[ast.expr, Subst]:
python_val = eval_py_expr(node, self.ctx)
if act := python_value_to_guppy_type(python_val, node, self.ctx.globals):
subst = unify(ty, act, {})
if subst is None:
self._fail(ty, act, node)
act = act.substitute(subst)
subst = {x: s for x, s in subst.items() if x in ty.unsolved_vars}
return with_type(act, with_loc(node, ast.Constant(value=python_val))), subst

raise GuppyError(
f"Python expression of type `{type(python_val)}` is not supported by Guppy",
node,
)

def generic_visit(self, node: ast.expr, ty: GuppyType) -> tuple[ast.expr, Subst]:
# Try to synthesize and then check if we can unify it with the given type
node, synth = self._synthesize(node, allow_free_vars=False)
Expand Down Expand Up @@ -498,34 +513,7 @@ def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, GuppyType]:
)

def visit_PyExpr(self, node: PyExpr) -> tuple[ast.expr, GuppyType]:
# The method we used for obtaining the Python variables in scope only works in
# CPython (see `get_py_scope()`).
if sys.implementation.name != "cpython":
raise GuppyError(
"Compile-time `py(...)` expressions are only supported in CPython", node
)

try:
python_val = eval( # noqa: S307, PGH001
ast.unparse(node.value),
None,
DummyEvalDict(self.ctx, node.value),
)
except DummyEvalDict.GuppyVarUsedError as e:
raise GuppyError(
f"Guppy variable `{e.var}` cannot be accessed in a compile-time "
"`py(...)` expression",
e.node or node,
) from None
except Exception as e: # noqa: BLE001
# Remove the top frame pointing to the `eval` call from the stack trace
tb = e.__traceback__.tb_next if e.__traceback__ else None
raise GuppyError(
"Error occurred while evaluating Python expression:\n\n"
+ "".join(traceback.format_exception(type(e), e, tb)),
node,
) from e

python_val = eval_py_expr(node, self.ctx)
if ty := python_value_to_guppy_type(python_val, node, self.ctx.globals):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we're converting python_val to a type. It's an arbitrary python value right? Just something inside a py(...)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'm looking up how python match works...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we're checking that the Python expression inside py(...) evaluates to something that is valid in Guppy and compute the corresponding Guppy type

return with_loc(node, ast.Constant(value=python_val)), ty

Expand Down Expand Up @@ -899,6 +887,38 @@ def check_linear_use_from_outer_scope(expr: ast.expr, locals: Locals) -> None:
return node, elt_ty


def eval_py_expr(node: PyExpr, ctx: Context) -> Any:
"""Evaluates a `py(...)` expression."""
# The method we used for obtaining the Python variables in scope only works in
# CPython (see `get_py_scope()`).
if sys.implementation.name != "cpython":
raise GuppyError(
"Compile-time `py(...)` expressions are only supported in CPython", node
)

try:
python_val = eval( # noqa: S307, PGH001
ast.unparse(node.value),
None,
DummyEvalDict(ctx, node.value),
)
except DummyEvalDict.GuppyVarUsedError as e:
raise GuppyError(
f"Guppy variable `{e.var}` cannot be accessed in a compile-time "
"`py(...)` expression",
e.node or node,
) from None
except Exception as e: # noqa: BLE001
# Remove the top frame pointing to the `eval` call from the stack trace
tb = e.__traceback__.tb_next if e.__traceback__ else None
raise GuppyError(
"Error occurred while evaluating Python expression:\n\n"
+ "".join(traceback.format_exception(type(e), e, tb)),
node,
) from e
return python_val


def python_value_to_guppy_type(
v: Any, node: ast.expr, globals: Globals
) -> GuppyType | None:
Expand All @@ -918,6 +938,8 @@ def python_value_to_guppy_type(
if any(ty is None for ty in tys):
return None
return TupleType(cast(list[GuppyType], tys))
case list():
return _python_list_to_guppy_type(v, node, globals)
case _:
# Pytket conversion is an optional feature
try:
Expand All @@ -942,3 +964,29 @@ def python_value_to_guppy_type(
except ImportError:
pass
return None


def _python_list_to_guppy_type(
vs: list[Any], node: ast.expr, globals: Globals
) -> ListType | None:
"""Turns a Python list into a Guppy type.

Returns `None` if the list contains different types or types that are not
representable in Guppy.
"""
if len(vs) == 0:
return ListType(ExistentialTypeVar.new("T", False))

# All the list elements must have a unifiable types
v, *rest = vs
el_ty = python_value_to_guppy_type(v, node, globals)
if el_ty is None:
return None
for v in rest:
ty = python_value_to_guppy_type(v, node, globals)
if ty is None:
return None
if (subst := unify(ty, el_ty, {})) is None:
raise GuppyError("Python list contains elements with different types", node)
el_ty = el_ty.substitute(subst)
return ListType(el_ty)
24 changes: 20 additions & 4 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
BoolType,
BoundTypeVar,
FunctionType,
GuppyType,
Inst,
ListType,
NoneType,
TupleType,
type_to_row,
Expand Down Expand Up @@ -136,7 +138,7 @@ def _if_true(self, cond: ast.expr, inputs: list[ast.Name]) -> Iterator[None]:
self.dfg[name.id].port = cond_node.add_out_port(get_type(name))

def visit_Constant(self, node: ast.Constant) -> OutPortV:
if value := python_value_to_hugr(node.value):
if value := python_value_to_hugr(node.value, get_type(node)):
const = self.graph.add_constant(value, get_type(node)).out_port(0)
return self.graph.add_load_constant(const).out_port(0)
raise InternalGuppyError("Unsupported constant expression in compiler")
Expand Down Expand Up @@ -295,12 +297,17 @@ def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool:
return False


def python_value_to_hugr(v: Any) -> val.Value | None:
def python_value_to_hugr(v: Any, exp_ty: GuppyType) -> val.Value | None:
"""Turns a Python value into a Hugr value.

Returns None if the Python value cannot be represented in Guppy.
"""
from guppylang.prelude._internal import bool_value, float_value, int_value
from guppylang.prelude._internal import (
bool_value,
float_value,
int_value,
list_value,
)

match v:
case bool():
Expand All @@ -310,10 +317,19 @@ def python_value_to_hugr(v: Any) -> val.Value | None:
case float():
return float_value(v)
case tuple(elts):
vs = [python_value_to_hugr(elt) for elt in elts]
assert isinstance(exp_ty, TupleType)
vs = [
python_value_to_hugr(elt, ty)
for elt, ty in zip(elts, exp_ty.element_types)
]
if any(value is None for value in vs):
return None
return val.Tuple(vs=vs)
case list(elts):
assert isinstance(exp_ty, ListType)
return list_value(
[python_value_to_hugr(elt, exp_ty.element_type) for elt in elts]
)
case _:
# Pytket conversion is an optional feature
try:
Expand Down
14 changes: 13 additions & 1 deletion guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ast
from typing import Literal
from typing import Any, Literal

from pydantic import BaseModel

Expand Down Expand Up @@ -52,6 +52,13 @@ class ConstF64(BaseModel):
value: float


class ListValue(BaseModel):
"""Hugr representation of floats in the arithmetic extension."""

c: Literal["ListValue"] = "ListValue"
value: list[Any]


def bool_value(b: bool) -> val.Value:
"""Returns the Hugr representation of a boolean value."""
return val.Sum(tag=int(b), value=val.Tuple(vs=[]))
Expand All @@ -67,6 +74,11 @@ def float_value(f: float) -> val.Value:
return val.ExtensionVal(c=(ConstF64(value=f),))


def list_value(v: list[val.Value]) -> val.Value:
"""Returns the Hugr representation of a list value."""
return val.ExtensionVal(c=(ListValue(value=v),))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the , do here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(x,) is a unary tuple. We need this because of some serialisation shenanigans...



def logic_op(op_name: str, args: list[tys.TypeArg] | None = None) -> ops.OpType:
"""Utility method to create Hugr logic ops."""
return ops.CustomOp(extension="logic", op_name=op_name, args=args or [])
Expand Down
7 changes: 7 additions & 0 deletions tests/error/py_errors/list_different_tys.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:6

4: @compile_guppy
5: def foo() -> int:
6: return py([1, 1.0])
^^^^^^^^^^^^
GuppyError: Python list contains elements with different types
6 changes: 6 additions & 0 deletions tests/error/py_errors/list_different_tys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from tests.util import compile_guppy


@compile_guppy
def foo() -> int:
return py([1, 1.0])
7 changes: 7 additions & 0 deletions tests/error/py_errors/list_empty.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:6

4: @compile_guppy
5: def foo() -> None:
6: xs = py([])
^^^^^^
GuppyTypeError: Cannot infer type variable in expression of type `list[?T]`
6 changes: 6 additions & 0 deletions tests/error/py_errors/list_empty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from tests.util import compile_guppy


@compile_guppy
def foo() -> None:
xs = py([])
33 changes: 33 additions & 0 deletions tests/integration/test_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,39 @@ def foo() -> int:
validate(foo)


def test_list_basic(validate):
@compile_guppy
def foo() -> list[int]:
xs = py([1, 2, 3])
croyzor marked this conversation as resolved.
Show resolved Hide resolved
return xs

validate(foo)


def test_list_empty(validate):
@compile_guppy
def foo() -> list[int]:
return py([])

validate(foo)


def test_list_empty_nested(validate):
@compile_guppy
def foo() -> None:
xs: list[tuple[int, list[bool]]] = py([(42, [])])

validate(foo)


def test_list_empty_multiple(validate):
@compile_guppy
def foo() -> None:
xs: tuple[list[int], list[bool]] = py([], [])

validate(foo)


@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed")
def test_pytket_single_qubit(validate):
from pytket import Circuit
Expand Down
Loading