Skip to content

Commit

Permalink
feat!: Hide lists and function tensors behind experimental flag (#501)
Browse files Browse the repository at this point in the history
Closes #437.

BREAKING CHANGE: Lists and function tensors are no longer available by
default. `guppylang.enable_experimental_features()` must be called
before compilation to enable them.

---------

Co-authored-by: Alan Lawrence <[email protected]>
  • Loading branch information
mark-koch and acl-cqc authored Sep 16, 2024
1 parent d9ba592 commit c867f48
Show file tree
Hide file tree
Showing 19 changed files with 232 additions and 2 deletions.
1 change: 1 addition & 0 deletions guppylang/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from guppylang.decorator import guppy
from guppylang.experimental import enable_experimental_features
from guppylang.module import GuppyModule
from guppylang.prelude import builtins, quantum
from guppylang.prelude.builtins import Bool, Float, Int, List, linst, py
Expand Down
2 changes: 2 additions & 0 deletions guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from guppylang.cfg.cfg import CFG
from guppylang.checker.core import Globals
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.experimental import check_lists_enabled
from guppylang.nodes import (
DesugaredGenerator,
DesugaredListComp,
Expand Down Expand Up @@ -304,6 +305,7 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name:
return make_var(tmp, node)

def visit_ListComp(self, node: ast.ListComp) -> ast.AST:
check_lists_enabled(node)
# Check for illegal expressions
illegals = find_nodes(is_illegal_in_list_comp, node)
if illegals:
Expand Down
5 changes: 5 additions & 0 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
GuppyTypeInferenceError,
InternalGuppyError,
)
from guppylang.experimental import check_function_tensors_enabled, check_lists_enabled
from guppylang.nodes import (
DesugaredGenerator,
DesugaredListComp,
Expand Down Expand Up @@ -214,6 +215,7 @@ def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]:
return node, subst

def visit_List(self, node: ast.List, ty: Type) -> tuple[ast.expr, Subst]:
check_lists_enabled(node)
if not is_list_type(ty) and not is_linst_type(ty):
return self._fail(ty, node)
el_ty = get_element_type(ty)
Expand Down Expand Up @@ -265,6 +267,7 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]:
if isinstance(func_ty, TupleType) and (
function_elements := parse_function_tensor(func_ty)
):
check_function_tensors_enabled(node.func)
if any(f.parametrized for f in function_elements):
raise GuppyTypeError(
"Polymorphic functions in tuples are not supported", node.func
Expand Down Expand Up @@ -435,6 +438,7 @@ def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]:
return node, TupleType([ty for _, ty in elems])

def visit_List(self, node: ast.List) -> tuple[ast.expr, Type]:
check_lists_enabled(node)
if len(node.elts) == 0:
raise GuppyTypeInferenceError(
"Cannot infer type variable in expression of type `list[?T]`", node
Expand Down Expand Up @@ -602,6 +606,7 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:
elif isinstance(ty, TupleType) and (
function_elems := parse_function_tensor(ty)
):
check_function_tensors_enabled(node.func)
if any(f.parametrized for f in function_elems):
raise GuppyTypeError(
"Polymorphic functions in tuples are not supported", node.func
Expand Down
73 changes: 73 additions & 0 deletions guppylang/experimental.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from ast import expr
from types import TracebackType

from guppylang.ast_util import AstNode
from guppylang.error import GuppyError

EXPERIMENTAL_FEATURES_ENABLED = False


class enable_experimental_features:
"""Enables experimental Guppy features.
Can be used as a context manager to enable experimental features in a `with` block.
"""

def __init__(self) -> None:
global EXPERIMENTAL_FEATURES_ENABLED
self.original = EXPERIMENTAL_FEATURES_ENABLED
EXPERIMENTAL_FEATURES_ENABLED = True

def __enter__(self) -> None:
pass

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
global EXPERIMENTAL_FEATURES_ENABLED
EXPERIMENTAL_FEATURES_ENABLED = self.original


class disable_experimental_features:
"""Disables experimental Guppy features.
Can be used as a context manager to enable experimental features in a `with` block.
"""

def __init__(self) -> None:
global EXPERIMENTAL_FEATURES_ENABLED
self.original = EXPERIMENTAL_FEATURES_ENABLED
EXPERIMENTAL_FEATURES_ENABLED = False

def __enter__(self) -> None:
pass

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
global EXPERIMENTAL_FEATURES_ENABLED
EXPERIMENTAL_FEATURES_ENABLED = self.original


def check_function_tensors_enabled(node: expr | None = None) -> None:
if not EXPERIMENTAL_FEATURES_ENABLED:
raise GuppyError(
"Function tensors are an experimental feature. Use "
"`guppylang.enable_experimental_features()` to enable them.",
node,
)


def check_lists_enabled(loc: AstNode | None = None) -> None:
if not EXPERIMENTAL_FEATURES_ENABLED:
raise GuppyError(
"Lists are an experimental feature and not fully supported yet. Use "
"`guppylang.enable_experimental_features()` to enable them.",
loc,
)
5 changes: 4 additions & 1 deletion guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from guppylang.definition.struct import CheckedStructDef
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError, pretty_errors
from guppylang.experimental import enable_experimental_features

PyClass = type
PyFunc = Callable[..., Any]
Expand Down Expand Up @@ -89,7 +90,9 @@ def __init__(self, name: str, import_builtins: bool = True):
if import_builtins:
import guppylang.prelude.builtins as builtins

self.load_all(builtins)
# Std lib is allowed to use experimental features
with enable_experimental_features():
self.load_all(builtins)

def load(
self,
Expand Down
19 changes: 18 additions & 1 deletion guppylang/tys/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from guppylang.definition.common import DefId
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.experimental import check_lists_enabled
from guppylang.tys.arg import Argument, ConstArg, TypeArg
from guppylang.tys.const import ConstValue
from guppylang.tys.param import ConstParam, TypeParam
Expand Down Expand Up @@ -109,6 +110,7 @@ class _ListTypeDef(OpaqueTypeDef):
def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
) -> OpaqueType:
check_lists_enabled(loc)
if len(args) == 1:
[arg] = args
if isinstance(arg, TypeArg) and arg.ty.linear:
Expand All @@ -118,6 +120,21 @@ def check_instantiate(
return super().check_instantiate(args, globals, loc)


@dataclass(frozen=True)
class _LinstTypeDef(OpaqueTypeDef):
"""Type definition associated with the builtin `linst` type.
We have a custom definition to disable usage of linsts unless experimental features
are enabled.
"""

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
) -> OpaqueType:
check_lists_enabled(loc)
return super().check_instantiate(args, globals, loc)


def _list_to_hugr(args: Sequence[Argument]) -> ht.Type:
# Type checker ensures that we get a single arg of kind type
[arg] = args
Expand Down Expand Up @@ -163,7 +180,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type:
float_type_def = _NumericTypeDef(
DefId.fresh(), "float", None, NumericType(NumericType.Kind.Float)
)
linst_type_def = OpaqueTypeDef(
linst_type_def = _LinstTypeDef(
id=DefId.fresh(),
name="linst",
defined_at=None,
Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import argparse
from pathlib import Path

import guppylang

guppylang.enable_experimental_features()


def pytest_addoption(parser):
def dir_path(s):
Expand Down
Empty file.
7 changes: 7 additions & 0 deletions tests/error/experimental_errors/function_tensor.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:19

17: @guppy(module)
18: def main() -> tuple[int, int]:
19: return (f, g)(1, 2)
^^^^^^
GuppyError: Function tensors are an experimental feature. Use `guppylang.enable_experimental_features()` to enable them.
22 changes: 22 additions & 0 deletions tests/error/experimental_errors/function_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule

module = GuppyModule("test")


@guppy(module)
def f(x: int) -> int:
return x


@guppy(module)
def g(x: int) -> int:
return x


@guppy(module)
def main() -> tuple[int, int]:
return (f, g)(1, 2)


module.compile()
6 changes: 6 additions & 0 deletions tests/error/experimental_errors/linst_type.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:9

7: @guppy(module)
8: def main(x: linst[int]) -> linst[int]:
^^^^^^^^^^
GuppyError: Lists are an experimental feature and not fully supported yet. Use `guppylang.enable_experimental_features()` to enable them.
13 changes: 13 additions & 0 deletions tests/error/experimental_errors/linst_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from guppylang.decorator import guppy
from guppylang.prelude.builtins import linst
from guppylang.module import GuppyModule

module = GuppyModule("test")


@guppy(module)
def main(x: linst[int]) -> linst[int]:
return x


module.compile()
7 changes: 7 additions & 0 deletions tests/error/experimental_errors/list_comprehension.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:9

7: @guppy(module)
8: def main() -> None:
9: [i for i in range(10)]
^^^^^^^^^^^^^^^^^^^^^^
GuppyError: Lists are an experimental feature and not fully supported yet. Use `guppylang.enable_experimental_features()` to enable them.
12 changes: 12 additions & 0 deletions tests/error/experimental_errors/list_comprehension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule

module = GuppyModule("test")


@guppy(module)
def main() -> None:
[i for i in range(10)]


module.compile()
7 changes: 7 additions & 0 deletions tests/error/experimental_errors/list_literal.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:9

7: @guppy(module)
8: def main() -> None:
9: [1, 2, 3]
^^^^^^^^^
GuppyError: Lists are an experimental feature and not fully supported yet. Use `guppylang.enable_experimental_features()` to enable them.
12 changes: 12 additions & 0 deletions tests/error/experimental_errors/list_literal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule

module = GuppyModule("test")


@guppy(module)
def main() -> None:
[1, 2, 3]


module.compile()
6 changes: 6 additions & 0 deletions tests/error/experimental_errors/list_type.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:8

6: @guppy(module)
7: def main(x: list[int]) -> list[int]:
^^^^^^^^^
GuppyError: Lists are an experimental feature and not fully supported yet. Use `guppylang.enable_experimental_features()` to enable them.
12 changes: 12 additions & 0 deletions tests/error/experimental_errors/list_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule

module = GuppyModule("test")


@guppy(module)
def main(x: list[int]) -> list[int]:
return x


module.compile()
21 changes: 21 additions & 0 deletions tests/error/test_experimental_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pathlib
import pytest

from guppylang.experimental import disable_experimental_features
from tests.error.util import run_error_test

path = pathlib.Path(__file__).parent.resolve() / "experimental_errors"
files = [
x
for x in path.iterdir()
if x.is_file() and x.suffix == ".py" and x.name != "__init__.py"
]

# Turn paths into strings, otherwise pytest doesn't display the names
files = [str(f) for f in files]


@pytest.mark.parametrize("file", files)
def test_experimental_errors(file, capsys):
with disable_experimental_features():
run_error_test(file, capsys)

0 comments on commit c867f48

Please sign in to comment.