-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat!: Hide lists and function tensors behind experimental flag #501
Merged
Merged
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
from dataclasses import replace | ||
from typing import Any, NoReturn, cast | ||
|
||
import guppylang | ||
from guppylang.ast_util import ( | ||
AstNode, | ||
AstVisitor, | ||
|
@@ -78,6 +79,7 @@ | |
from guppylang.tys.arg import TypeArg | ||
from guppylang.tys.builtin import ( | ||
bool_type, | ||
check_lists_enabled, | ||
get_element_type, | ||
is_bool_type, | ||
is_linst_type, | ||
|
@@ -214,6 +216,7 @@ def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]: | |
return node, subst | ||
|
||
def visit_List(self, node: ast.List, ty: Type) -> tuple[ast.expr, Subst]: | ||
check_lists_enabled(node) | ||
if not is_list_type(ty) and not is_linst_type(ty): | ||
return self._fail(ty, node) | ||
el_ty = get_element_type(ty) | ||
|
@@ -265,6 +268,7 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: | |
if isinstance(func_ty, TupleType) and ( | ||
function_elements := parse_function_tensor(func_ty) | ||
): | ||
check_function_tensors_enabled(node.func) | ||
if any(f.parametrized for f in function_elements): | ||
raise GuppyTypeError( | ||
"Polymorphic functions in tuples are not supported", node.func | ||
|
@@ -435,6 +439,7 @@ def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]: | |
return node, TupleType([ty for _, ty in elems]) | ||
|
||
def visit_List(self, node: ast.List) -> tuple[ast.expr, Type]: | ||
check_lists_enabled(node) | ||
if len(node.elts) == 0: | ||
raise GuppyTypeInferenceError( | ||
"Cannot infer type variable in expression of type `list[?T]`", node | ||
|
@@ -602,6 +607,7 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: | |
elif isinstance(ty, TupleType) and ( | ||
function_elems := parse_function_tensor(ty) | ||
): | ||
check_function_tensors_enabled(node.func) | ||
if any(f.parametrized for f in function_elems): | ||
raise GuppyTypeError( | ||
"Polymorphic functions in tuples are not supported", node.func | ||
|
@@ -990,6 +996,15 @@ def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr: | |
return with_type(ty, node) | ||
|
||
|
||
def check_function_tensors_enabled(node: ast.expr | None = None) -> None: | ||
if not guppylang.experimental.EXPERIMENTAL_FEATURES_ENABLED: | ||
raise GuppyError( | ||
"Function tensors are an experimental feature. Call " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe |
||
"`guppylang.enable_experimental_features()` to enable them.", | ||
node, | ||
) | ||
|
||
|
||
def to_bool(node: ast.expr, node_ty: Type, ctx: Context) -> tuple[ast.expr, Type]: | ||
"""Tries to turn a node into a bool""" | ||
if is_bool_type(node_ty): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from types import TracebackType | ||
|
||
EXPERIMENTAL_FEATURES_ENABLED = False | ||
|
||
|
||
class enable_experimental_features: | ||
"""Enables experimental Guppy features. | ||
|
||
Can be used as a context manager to enable experimental features in a `with` block. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
global EXPERIMENTAL_FEATURES_ENABLED | ||
self.original = EXPERIMENTAL_FEATURES_ENABLED | ||
EXPERIMENTAL_FEATURES_ENABLED = True | ||
|
||
def __enter__(self) -> None: | ||
pass | ||
|
||
def __exit__( | ||
self, | ||
exc_type: type[BaseException] | None, | ||
exc_val: BaseException | None, | ||
exc_tb: TracebackType | None, | ||
) -> None: | ||
global EXPERIMENTAL_FEATURES_ENABLED | ||
EXPERIMENTAL_FEATURES_ENABLED = self.original | ||
|
||
|
||
class disable_experimental_features: | ||
"""Disables experimental Guppy features. | ||
|
||
Can be used as a context manager to enable experimental features in a `with` block. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
global EXPERIMENTAL_FEATURES_ENABLED | ||
self.original = EXPERIMENTAL_FEATURES_ENABLED | ||
EXPERIMENTAL_FEATURES_ENABLED = False | ||
|
||
def __enter__(self) -> None: | ||
pass | ||
|
||
def __exit__( | ||
self, | ||
exc_type: type[BaseException] | None, | ||
exc_val: BaseException | None, | ||
exc_tb: TracebackType | None, | ||
) -> None: | ||
global EXPERIMENTAL_FEATURES_ENABLED | ||
EXPERIMENTAL_FEATURES_ENABLED = self.original |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
import hugr.std.collections | ||
from hugr import tys as ht | ||
|
||
import guppylang | ||
from guppylang.ast_util import AstNode | ||
from guppylang.definition.common import DefId | ||
from guppylang.definition.ty import OpaqueTypeDef, TypeDef | ||
|
@@ -109,6 +110,7 @@ class _ListTypeDef(OpaqueTypeDef): | |
def check_instantiate( | ||
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None | ||
) -> OpaqueType: | ||
check_lists_enabled(loc) | ||
if len(args) == 1: | ||
[arg] = args | ||
if isinstance(arg, TypeArg) and arg.ty.linear: | ||
|
@@ -118,6 +120,30 @@ def check_instantiate( | |
return super().check_instantiate(args, globals, loc) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class _LinstTypeDef(OpaqueTypeDef): | ||
"""Type definition associated with the builtin `linst` type. | ||
|
||
We have a custom definition to disable usage of linsts unless experimental features | ||
are enabled. | ||
""" | ||
|
||
def check_instantiate( | ||
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None | ||
) -> OpaqueType: | ||
check_lists_enabled(loc) | ||
return super().check_instantiate(args, globals, loc) | ||
|
||
|
||
def check_lists_enabled(loc: AstNode | None = None) -> None: | ||
if not guppylang.experimental.EXPERIMENTAL_FEATURES_ENABLED: | ||
raise GuppyError( | ||
"Lists are an experimental feature and not fully supported yet. Call " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto*2 ( |
||
"`guppylang.enable_experimental_features()` to enable them.", | ||
loc, | ||
) | ||
|
||
|
||
def _list_to_hugr(args: Sequence[Argument]) -> ht.Type: | ||
# Type checker ensures that we get a single arg of kind type | ||
[arg] = args | ||
|
@@ -163,7 +189,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type: | |
float_type_def = _NumericTypeDef( | ||
DefId.fresh(), "float", None, NumericType(NumericType.Kind.Float) | ||
) | ||
linst_type_def = OpaqueTypeDef( | ||
linst_type_def = _LinstTypeDef( | ||
id=DefId.fresh(), | ||
name="linst", | ||
defined_at=None, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Guppy compilation failed. Error in file $FILE:19 | ||
|
||
17: @guppy(module) | ||
18: def main() -> tuple[int, int]: | ||
19: return (f, g)(1, 2) | ||
^^^^^^ | ||
GuppyError: Function tensors are an experimental feature. Call `guppylang.enable_experimental_features()` to enable them. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
|
||
module = GuppyModule("test") | ||
|
||
|
||
@guppy(module) | ||
def f(x: int) -> int: | ||
return x | ||
|
||
|
||
@guppy(module) | ||
def g(x: int) -> int: | ||
return x | ||
|
||
|
||
@guppy(module) | ||
def main() -> tuple[int, int]: | ||
return (f, g)(1, 2) | ||
|
||
|
||
module.compile() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Guppy compilation failed. Error in file $FILE:9 | ||
|
||
7: @guppy(module) | ||
8: def main(x: linst[int]) -> linst[int]: | ||
^^^^^^^^^^ | ||
GuppyError: Lists are an experimental feature and not fully supported yet. Call `guppylang.enable_experimental_features()` to enable them. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from guppylang.decorator import guppy | ||
from guppylang.prelude.builtins import linst | ||
from guppylang.module import GuppyModule | ||
|
||
module = GuppyModule("test") | ||
|
||
|
||
@guppy(module) | ||
def main(x: linst[int]) -> linst[int]: | ||
return x | ||
|
||
|
||
module.compile() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Guppy compilation failed. Error in file $FILE:9 | ||
|
||
7: @guppy(module) | ||
8: def main() -> None: | ||
9: [i for i in range(10)] | ||
^^^^^^^^^^^^^^^^^^^^^^ | ||
GuppyError: Lists are an experimental feature and not fully supported yet. Call `guppylang.enable_experimental_features()` to enable them. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
|
||
module = GuppyModule("test") | ||
|
||
|
||
@guppy(module) | ||
def main() -> None: | ||
[i for i in range(10)] | ||
|
||
|
||
module.compile() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Guppy compilation failed. Error in file $FILE:9 | ||
|
||
7: @guppy(module) | ||
8: def main() -> None: | ||
9: [1, 2, 3] | ||
^^^^^^^^^ | ||
GuppyError: Lists are an experimental feature and not fully supported yet. Call `guppylang.enable_experimental_features()` to enable them. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
|
||
module = GuppyModule("test") | ||
|
||
|
||
@guppy(module) | ||
def main() -> None: | ||
[1, 2, 3] | ||
|
||
|
||
module.compile() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Guppy compilation failed. Error in file $FILE:8 | ||
|
||
6: @guppy(module) | ||
7: def main(x: list[int]) -> list[int]: | ||
^^^^^^^^^ | ||
GuppyError: Lists are an experimental feature and not fully supported yet. Call `guppylang.enable_experimental_features()` to enable them. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from guppylang.decorator import guppy | ||
from guppylang.module import GuppyModule | ||
|
||
module = GuppyModule("test") | ||
|
||
|
||
@guppy(module) | ||
def main(x: list[int]) -> list[int]: | ||
return x | ||
|
||
|
||
module.compile() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pathlib | ||
import pytest | ||
|
||
from guppylang.experimental import disable_experimental_features | ||
from tests.error.util import run_error_test | ||
|
||
path = pathlib.Path(__file__).parent.resolve() / "experimental_errors" | ||
files = [ | ||
x | ||
for x in path.iterdir() | ||
if x.is_file() and x.suffix == ".py" and x.name != "__init__.py" | ||
] | ||
|
||
# Turn paths into strings, otherwise pytest doesn't display the names | ||
files = [str(f) for f in files] | ||
|
||
|
||
@pytest.mark.parametrize("file", files) | ||
def test_experimental_errors(file, capsys): | ||
with disable_experimental_features(): | ||
run_error_test(file, capsys) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I wonder if the two
check_xxx_enabled
should live together inexperimental.py