Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Hide lists and function tensors behind experimental flag #501

Merged
merged 3 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions guppylang/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from guppylang.decorator import guppy
from guppylang.experimental import enable_experimental_features
from guppylang.module import GuppyModule
from guppylang.prelude import builtins, quantum
from guppylang.prelude.builtins import Bool, Float, Int, List, linst, py
Expand Down
2 changes: 2 additions & 0 deletions guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
NestedFunctionDef,
PyExpr,
)
from guppylang.tys.builtin import check_lists_enabled
from guppylang.tys.ty import NoneType

# In order to build expressions, need an endless stream of unique temporary variables
Expand Down Expand Up @@ -304,6 +305,7 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name:
return make_var(tmp, node)

def visit_ListComp(self, node: ast.ListComp) -> ast.AST:
check_lists_enabled(node)
# Check for illegal expressions
illegals = find_nodes(is_illegal_in_list_comp, node)
if illegals:
Expand Down
15 changes: 15 additions & 0 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from dataclasses import replace
from typing import Any, NoReturn, cast

import guppylang
from guppylang.ast_util import (
AstNode,
AstVisitor,
Expand Down Expand Up @@ -78,6 +79,7 @@
from guppylang.tys.arg import TypeArg
from guppylang.tys.builtin import (
bool_type,
check_lists_enabled,
get_element_type,
is_bool_type,
is_linst_type,
Expand Down Expand Up @@ -214,6 +216,7 @@ def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]:
return node, subst

def visit_List(self, node: ast.List, ty: Type) -> tuple[ast.expr, Subst]:
check_lists_enabled(node)
if not is_list_type(ty) and not is_linst_type(ty):
return self._fail(ty, node)
el_ty = get_element_type(ty)
Expand Down Expand Up @@ -265,6 +268,7 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]:
if isinstance(func_ty, TupleType) and (
function_elements := parse_function_tensor(func_ty)
):
check_function_tensors_enabled(node.func)
if any(f.parametrized for f in function_elements):
raise GuppyTypeError(
"Polymorphic functions in tuples are not supported", node.func
Expand Down Expand Up @@ -435,6 +439,7 @@ def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]:
return node, TupleType([ty for _, ty in elems])

def visit_List(self, node: ast.List) -> tuple[ast.expr, Type]:
check_lists_enabled(node)
if len(node.elts) == 0:
raise GuppyTypeInferenceError(
"Cannot infer type variable in expression of type `list[?T]`", node
Expand Down Expand Up @@ -602,6 +607,7 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:
elif isinstance(ty, TupleType) and (
function_elems := parse_function_tensor(ty)
):
check_function_tensors_enabled(node.func)
if any(f.parametrized for f in function_elems):
raise GuppyTypeError(
"Polymorphic functions in tuples are not supported", node.func
Expand Down Expand Up @@ -990,6 +996,15 @@ def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr:
return with_type(ty, node)


def check_function_tensors_enabled(node: ast.expr | None = None) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I wonder if the two check_xxx_enabled should live together in experimental.py

if not guppylang.experimental.EXPERIMENTAL_FEATURES_ENABLED:
raise GuppyError(
"Function tensors are an experimental feature. Call "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe Use would be better than Call ?

"`guppylang.enable_experimental_features()` to enable them.",
node,
)


def to_bool(node: ast.expr, node_ty: Type, ctx: Context) -> tuple[ast.expr, Type]:
"""Tries to turn a node into a bool"""
if is_bool_type(node_ty):
Expand Down
51 changes: 51 additions & 0 deletions guppylang/experimental.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from types import TracebackType

EXPERIMENTAL_FEATURES_ENABLED = False


class enable_experimental_features:
"""Enables experimental Guppy features.

Can be used as a context manager to enable experimental features in a `with` block.
"""

def __init__(self) -> None:
global EXPERIMENTAL_FEATURES_ENABLED
self.original = EXPERIMENTAL_FEATURES_ENABLED
EXPERIMENTAL_FEATURES_ENABLED = True

def __enter__(self) -> None:
pass

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
global EXPERIMENTAL_FEATURES_ENABLED
EXPERIMENTAL_FEATURES_ENABLED = self.original


class disable_experimental_features:
"""Disables experimental Guppy features.

Can be used as a context manager to enable experimental features in a `with` block.
"""

def __init__(self) -> None:
global EXPERIMENTAL_FEATURES_ENABLED
self.original = EXPERIMENTAL_FEATURES_ENABLED
EXPERIMENTAL_FEATURES_ENABLED = False

def __enter__(self) -> None:
pass

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
global EXPERIMENTAL_FEATURES_ENABLED
EXPERIMENTAL_FEATURES_ENABLED = self.original
5 changes: 4 additions & 1 deletion guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from guppylang.definition.struct import CheckedStructDef
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError, pretty_errors
from guppylang.experimental import enable_experimental_features

PyFunc = Callable[..., Any]
PyFuncDefOrDecl = tuple[bool, PyFunc]
Expand Down Expand Up @@ -88,7 +89,9 @@ def __init__(self, name: str, import_builtins: bool = True):
if import_builtins:
import guppylang.prelude.builtins as builtins

self.load_all(builtins)
# Std lib is allowed to use experimental features
with enable_experimental_features():
self.load_all(builtins)

def load(
self,
Expand Down
28 changes: 27 additions & 1 deletion guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import hugr.std.collections
from hugr import tys as ht

import guppylang
from guppylang.ast_util import AstNode
from guppylang.definition.common import DefId
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
Expand Down Expand Up @@ -109,6 +110,7 @@ class _ListTypeDef(OpaqueTypeDef):
def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
) -> OpaqueType:
check_lists_enabled(loc)
if len(args) == 1:
[arg] = args
if isinstance(arg, TypeArg) and arg.ty.linear:
Expand All @@ -118,6 +120,30 @@ def check_instantiate(
return super().check_instantiate(args, globals, loc)


@dataclass(frozen=True)
class _LinstTypeDef(OpaqueTypeDef):
"""Type definition associated with the builtin `linst` type.

We have a custom definition to disable usage of linsts unless experimental features
are enabled.
"""

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
) -> OpaqueType:
check_lists_enabled(loc)
return super().check_instantiate(args, globals, loc)


def check_lists_enabled(loc: AstNode | None = None) -> None:
if not guppylang.experimental.EXPERIMENTAL_FEATURES_ENABLED:
raise GuppyError(
"Lists are an experimental feature and not fully supported yet. Call "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto*2 (Call -> Use, move into experimental.py ?)

"`guppylang.enable_experimental_features()` to enable them.",
loc,
)


def _list_to_hugr(args: Sequence[Argument]) -> ht.Type:
# Type checker ensures that we get a single arg of kind type
[arg] = args
Expand Down Expand Up @@ -163,7 +189,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type:
float_type_def = _NumericTypeDef(
DefId.fresh(), "float", None, NumericType(NumericType.Kind.Float)
)
linst_type_def = OpaqueTypeDef(
linst_type_def = _LinstTypeDef(
id=DefId.fresh(),
name="linst",
defined_at=None,
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import argparse
from pathlib import Path

import guppylang

guppylang.enable_experimental_features()


def pytest_addoption(parser):
def dir_path(s):
Expand Down
Empty file.
7 changes: 7 additions & 0 deletions tests/error/experimental_errors/function_tensor.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:19

17: @guppy(module)
18: def main() -> tuple[int, int]:
19: return (f, g)(1, 2)
^^^^^^
GuppyError: Function tensors are an experimental feature. Call `guppylang.enable_experimental_features()` to enable them.
22 changes: 22 additions & 0 deletions tests/error/experimental_errors/function_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule

module = GuppyModule("test")


@guppy(module)
def f(x: int) -> int:
return x


@guppy(module)
def g(x: int) -> int:
return x


@guppy(module)
def main() -> tuple[int, int]:
return (f, g)(1, 2)


module.compile()
6 changes: 6 additions & 0 deletions tests/error/experimental_errors/linst_type.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:9

7: @guppy(module)
8: def main(x: linst[int]) -> linst[int]:
^^^^^^^^^^
GuppyError: Lists are an experimental feature and not fully supported yet. Call `guppylang.enable_experimental_features()` to enable them.
13 changes: 13 additions & 0 deletions tests/error/experimental_errors/linst_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from guppylang.decorator import guppy
from guppylang.prelude.builtins import linst
from guppylang.module import GuppyModule

module = GuppyModule("test")


@guppy(module)
def main(x: linst[int]) -> linst[int]:
return x


module.compile()
7 changes: 7 additions & 0 deletions tests/error/experimental_errors/list_comprehension.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:9

7: @guppy(module)
8: def main() -> None:
9: [i for i in range(10)]
^^^^^^^^^^^^^^^^^^^^^^
GuppyError: Lists are an experimental feature and not fully supported yet. Call `guppylang.enable_experimental_features()` to enable them.
12 changes: 12 additions & 0 deletions tests/error/experimental_errors/list_comprehension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule

module = GuppyModule("test")


@guppy(module)
def main() -> None:
[i for i in range(10)]


module.compile()
7 changes: 7 additions & 0 deletions tests/error/experimental_errors/list_literal.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:9

7: @guppy(module)
8: def main() -> None:
9: [1, 2, 3]
^^^^^^^^^
GuppyError: Lists are an experimental feature and not fully supported yet. Call `guppylang.enable_experimental_features()` to enable them.
12 changes: 12 additions & 0 deletions tests/error/experimental_errors/list_literal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule

module = GuppyModule("test")


@guppy(module)
def main() -> None:
[1, 2, 3]


module.compile()
6 changes: 6 additions & 0 deletions tests/error/experimental_errors/list_type.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:8

6: @guppy(module)
7: def main(x: list[int]) -> list[int]:
^^^^^^^^^
GuppyError: Lists are an experimental feature and not fully supported yet. Call `guppylang.enable_experimental_features()` to enable them.
12 changes: 12 additions & 0 deletions tests/error/experimental_errors/list_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule

module = GuppyModule("test")


@guppy(module)
def main(x: list[int]) -> list[int]:
return x


module.compile()
21 changes: 21 additions & 0 deletions tests/error/test_experimental_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pathlib
import pytest

from guppylang.experimental import disable_experimental_features
from tests.error.util import run_error_test

path = pathlib.Path(__file__).parent.resolve() / "experimental_errors"
files = [
x
for x in path.iterdir()
if x.is_file() and x.suffix == ".py" and x.name != "__init__.py"
]

# Turn paths into strings, otherwise pytest doesn't display the names
files = [str(f) for f in files]


@pytest.mark.parametrize("file", files)
def test_experimental_errors(file, capsys):
with disable_experimental_features():
run_error_test(file, capsys)
Loading