diff --git a/guppylang/cfg/builder.py b/guppylang/cfg/builder.py index 5c1ec8bf..c781088f 100644 --- a/guppylang/cfg/builder.py +++ b/guppylang/cfg/builder.py @@ -343,20 +343,7 @@ def visit_ListComp(self, node: ast.ListComp) -> ast.AST: return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens)) def visit_Call(self, node: ast.Call) -> ast.AST: - # Parse compile-time evaluated `py(...)` expression - if isinstance(node.func, ast.Name) and node.func.id == "py": - match node.args: - case []: - raise GuppyError( - "Compile-time `py(...)` expression requires an argument", - node, - ) - case [arg]: - pass - case args: - arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load)) - return with_loc(node, PyExpr(value=arg)) - return self.generic_visit(node) + return is_py_expression(node) or self.generic_visit(node) def generic_visit(self, node: ast.AST) -> ast.AST: # Short-circuit expressions must be built using the `BranchBuilder`. However, we @@ -492,6 +479,31 @@ def is_functional_annotation(stmt: ast.stmt) -> bool: return False +def is_py_expression(node: ast.AST) -> PyExpr | None: + """Checks if the given node is a compile-time `py(...)` expression and turns it into + a `PyExpr` AST node. + + Otherwise, returns `None`. + """ + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id == "py" + ): + match node.args: + case []: + raise GuppyError( + "Compile-time `py(...)` expression requires an argument", + node, + ) + case [arg]: + pass + case args: + arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load)) + return with_loc(node, PyExpr(value=arg)) + return None + + def is_short_circuit_expr(node: ast.AST) -> bool: """Checks if an expression uses short-circuiting. diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index 142d6873..2a8a3305 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -270,6 +270,11 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None return defn return None + def with_python_scope(self, python_scope: PyScope) -> "Globals": + return Globals( + self.defs, self.names, self.impls, self.python_scope | python_scope + ) + def __or__(self, other: "Globals") -> "Globals": impls = { def_id: self.impls.get(def_id, {}) | other.impls.get(def_id, {}) diff --git a/guppylang/decorator.py b/guppylang/decorator.py index 80711e5d..645c5f06 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -3,7 +3,7 @@ from collections.abc import Callable, KeysView from dataclasses import dataclass, field from pathlib import Path -from types import FrameType, ModuleType +from types import ModuleType from typing import Any, TypeVar, overload import hugr.ext @@ -36,6 +36,7 @@ PyClass, PyFunc, find_guppy_module_in_py_module, + get_calling_frame, ) from guppylang.tys.subst import Inst from guppylang.tys.ty import NumericType @@ -233,8 +234,14 @@ def struct( ) implicit_module._instance_func_buffer = {} + # Extract Python scope from the frame that called `guppy.struct` + frame = get_calling_frame() + python_scope = frame.f_globals | frame.f_locals if frame else {} + def dec(cls: type, module: GuppyModule) -> RawStructDef: - defn = RawStructDef(DefId.fresh(module), cls.__name__, None, cls) + defn = RawStructDef( + DefId.fresh(module), cls.__name__, None, cls, python_scope + ) module.register_def(defn) module._register_buffered_instance_funcs(defn) # If we mistakenly initialised the method buffer of the implicit module @@ -450,7 +457,7 @@ def _parse_expr_string(ty_str: str, parse_err: str) -> ast.expr: # Try to annotate the type AST with source information. This requires us to # inspect the stack frame of the caller - if caller_frame := _get_calling_frame(): + if caller_frame := get_calling_frame(): info = inspect.getframeinfo(caller_frame) if caller_module := inspect.getmodule(caller_frame): source_lines, _ = inspect.getsourcelines(caller_module) @@ -463,16 +470,3 @@ def _parse_expr_string(ty_str: str, parse_err: str) -> ast.expr: node.lineno, node.col_offset = info.lineno, 0 node.end_col_offset = len(source_lines[info.lineno - 1]) return expr_ast - - -def _get_calling_frame() -> FrameType | None: - """Finds the first frame that called this function outside the current module.""" - frame = inspect.currentframe() - while frame: - module = inspect.getmodule(frame) - if module is None: - break - if module.__file__ != __file__: - return frame - frame = frame.f_back - return None diff --git a/guppylang/definition/declaration.py b/guppylang/definition/declaration.py index aef9714f..2352ca7b 100644 --- a/guppylang/definition/declaration.py +++ b/guppylang/definition/declaration.py @@ -7,7 +7,7 @@ from hugr.build.dfg import DefinitionBuilder, OpVar from guppylang.ast_util import AstNode, has_empty_body, with_loc -from guppylang.checker.core import Context, Globals +from guppylang.checker.core import Context, Globals, PyScope from guppylang.checker.expr_checker import check_call, synthesize_call from guppylang.checker.func_checker import check_signature from guppylang.compiler.core import CompiledGlobals, DFContainer @@ -29,18 +29,25 @@ class RawFunctionDecl(ParsableDef): """ python_func: PyFunc + python_scope: PyScope description: str = field(default="function", init=False) def parse(self, globals: Globals) -> "CheckedFunctionDecl": """Parses and checks the user-provided signature of the function.""" func_ast, docstring = parse_py_func(self.python_func) - ty = check_signature(func_ast, globals) + ty = check_signature(func_ast, globals.with_python_scope(self.python_scope)) if not has_empty_body(func_ast): raise GuppyError( "Body of function declaration must be empty", func_ast.body[0] ) return CheckedFunctionDecl( - self.id, self.name, func_ast, ty, self.python_func, docstring + self.id, + self.name, + func_ast, + ty, + self.python_func, + self.python_scope, + docstring, ) @@ -86,6 +93,7 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledFunctionDe self.defined_at, self.ty, self.python_func, + self.python_scope, self.docstring, node, ) diff --git a/guppylang/definition/function.py b/guppylang/definition/function.py index 6b82c75b..3920fe32 100644 --- a/guppylang/definition/function.py +++ b/guppylang/definition/function.py @@ -56,7 +56,7 @@ class RawFunctionDef(ParsableDef): def parse(self, globals: Globals) -> "ParsedFunctionDef": """Parses and checks the user-provided signature of the function.""" func_ast, docstring = parse_py_func(self.python_func) - ty = check_signature(func_ast, globals) + ty = check_signature(func_ast, globals.with_python_scope(self.python_scope)) if ty.parametrized: raise GuppyError( "Generic function definitions are not supported yet", func_ast @@ -92,7 +92,7 @@ class ParsedFunctionDef(CheckableDef, CallableDef): def check(self, globals: Globals) -> "CheckedFunctionDef": """Type checks the body of the function.""" # Add python variable scope to the globals - globals = globals | Globals({}, {}, {}, self.python_scope) + globals = globals.with_python_scope(self.python_scope) cfg = check_global_func_def(self.defined_at, self.ty, globals) return CheckedFunctionDef( self.id, diff --git a/guppylang/definition/struct.py b/guppylang/definition/struct.py index fcf1bb29..521b5549 100644 --- a/guppylang/definition/struct.py +++ b/guppylang/definition/struct.py @@ -8,7 +8,7 @@ from hugr import Wire, ops from guppylang.ast_util import AstNode, annotate_location -from guppylang.checker.core import Globals +from guppylang.checker.core import Globals, PyScope from guppylang.definition.common import ( CheckableDef, CompiledDef, @@ -52,6 +52,7 @@ class RawStructDef(TypeDef, ParsableDef): """A raw struct type definition that has not been parsed yet.""" python_class: type + python_scope: PyScope def __getitem__(self, item: Any) -> "RawStructDef": """Dummy implementation to enable subscripting in the Python runtime. @@ -131,7 +132,9 @@ def parse(self, globals: Globals) -> "ParsedStructDef": used_func_names[x], ) - return ParsedStructDef(self.id, self.name, cls_def, params, fields) + return ParsedStructDef( + self.id, self.name, cls_def, params, fields, self.python_scope + ) def check_instantiate( self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None @@ -146,12 +149,14 @@ class ParsedStructDef(TypeDef, CheckableDef): defined_at: ast.ClassDef params: Sequence[Parameter] fields: Sequence[UncheckedStructField] + python_scope: PyScope def check(self, globals: Globals) -> "CheckedStructDef": """Checks that all struct fields have valid types.""" # Before checking the fields, make sure that this definition is not recursive, # otherwise the code below would not terminate. # TODO: This is not ideal (see todo in `check_instantiate`) + globals = globals.with_python_scope(self.python_scope) check_not_recursive(self, globals) param_var_mapping = {p.name: p for p in self.params} diff --git a/guppylang/module.py b/guppylang/module.py index 73b97e1f..c4a2ced4 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -2,7 +2,7 @@ import sys from collections.abc import Callable, Mapping from pathlib import Path -from types import ModuleType +from types import FrameType, ModuleType from typing import Any from hugr import Hugr, ops @@ -10,6 +10,7 @@ from hugr.ext import Package import guppylang.compiler.hugr_extension +from guppylang import decorator from guppylang.checker.core import Globals, PyScope from guppylang.compiler.core import CompiledGlobals from guppylang.definition.common import ( @@ -213,7 +214,7 @@ def register_func_decl( self, f: PyFunc, instance: TypeDef | None = None ) -> RawFunctionDecl: """Registers a Python function declaration as belonging to this Guppy module.""" - decl = RawFunctionDecl(DefId.fresh(self), f.__name__, None, f) + decl = RawFunctionDecl(DefId.fresh(self), f.__name__, None, f, get_py_scope(f)) self.register_def(decl, instance) return decl @@ -353,7 +354,8 @@ def contains(self, name: str) -> bool: def get_py_scope(f: PyFunc) -> PyScope: - """Returns a mapping of all variables captured by a Python function. + """Returns a mapping of all variables captured by a Python function together with + the `f_locals` and `f_globals` of the frame that called this function. Note that this function only works in CPython. On other platforms, an empty dictionary is returned. @@ -361,8 +363,12 @@ def get_py_scope(f: PyFunc) -> PyScope: Relies on inspecting the `__globals__` and `__closure__` attributes of the function. See https://docs.python.org/3/reference/datamodel.html#special-read-only-attributes """ + # Get variables from the calling frame + frame = get_calling_frame() + frame_vars = frame.f_globals | frame.f_locals if frame else {} + if sys.implementation.name != "cpython": - return {} + return frame_vars if inspect.ismethod(f): f = f.__func__ @@ -379,7 +385,7 @@ def get_py_scope(f: PyFunc) -> PyScope: continue nonlocals[var] = value - return nonlocals | f.__globals__.copy() + return frame_vars | nonlocals | f.__globals__.copy() def find_guppy_module_in_py_module(module: ModuleType) -> GuppyModule: @@ -407,3 +413,16 @@ def find_guppy_module_in_py_module(module: ModuleType) -> GuppyModule: ) raise GuppyError(msg) return mods[0] + + +def get_calling_frame() -> FrameType | None: + """Finds the first frame that called this function outside the compiler modules.""" + frame = inspect.currentframe() + while frame: + module = inspect.getmodule(frame) + if module is None: + break + if module.__file__ != __file__ and module != decorator: + return frame + frame = frame.f_back + return None diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index 2a0ea7ee..6fdea55b 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -8,7 +8,6 @@ from guppylang.decorator import guppy from guppylang.definition.custom import DefaultCallChecker, NoopCompiler -from guppylang.error import GuppyError from guppylang.prelude._internal.checker import ( ArrayLenChecker, CallableChecker, @@ -57,13 +56,12 @@ L = guppy.type_var("L", linear=True) -def py(*_args: Any) -> Any: +def py(*args: Any) -> Any: """Function to tag compile-time evaluated Python expressions in a Guppy context. - This function throws an error when execute in a Python context. It is only intended - to be used inside Guppy functions. + This function acts like the identity when execute in a Python context. """ - raise GuppyError("`py` can only by used in a Guppy context") + return tuple(args) class _Owned: diff --git a/guppylang/tys/parsing.py b/guppylang/tys/parsing.py index 6cffe032..56985f47 100644 --- a/guppylang/tys/parsing.py +++ b/guppylang/tys/parsing.py @@ -6,7 +6,9 @@ set_location_from, shift_loc, ) -from guppylang.checker.core import Globals +from guppylang.cfg.builder import is_py_expression +from guppylang.checker.core import Context, Globals, Locals +from guppylang.checker.expr_checker import eval_py_expr from guppylang.definition.common import Definition from guppylang.definition.module import ModuleDef from guppylang.definition.parameter import ParamDef @@ -69,6 +71,19 @@ def arg_from_ast( nat_ty = NumericType(NumericType.Kind.Nat) return ConstArg(ConstValue(nat_ty, node.value)) + # Py-expressions can also be used to specify static numbers + if py_expr := is_py_expression(node): + v = eval_py_expr(py_expr, Context(globals, Locals({}))) + if isinstance(v, int): + nat_ty = NumericType(NumericType.Kind.Nat) + return ConstArg(ConstValue(nat_ty, v)) + else: + raise GuppyError( + f"Compile-time `py(...)` expression with type `{type(v)}` is not a " + "valid type argument", + node, + ) + # Finally, we also support delayed annotations in strings if isinstance(node, ast.Constant) and isinstance(node.value, str): node = _parse_delayed_annotation(node.value, node) diff --git a/tests/error/py_errors/invalid_type_arg.err b/tests/error/py_errors/invalid_type_arg.err new file mode 100644 index 00000000..a310bd6b --- /dev/null +++ b/tests/error/py_errors/invalid_type_arg.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:7 + +5: @compile_guppy +6: def foo(xs: array[int, py(1.0)]) -> None: + ^^^^^^^ +GuppyError: Compile-time `py(...)` expression with type `` is not a valid type argument diff --git a/tests/error/py_errors/invalid_type_arg.py b/tests/error/py_errors/invalid_type_arg.py new file mode 100644 index 00000000..d6993d4d --- /dev/null +++ b/tests/error/py_errors/invalid_type_arg.py @@ -0,0 +1,8 @@ +from guppylang import py +from guppylang.prelude.builtins import array +from tests.util import compile_guppy + + +@compile_guppy +def foo(xs: array[int, py(1.0)]) -> None: + pass diff --git a/tests/integration/test_py.py b/tests/integration/test_py.py index 492c5c85..79804178 100644 --- a/tests/integration/test_py.py +++ b/tests/integration/test_py.py @@ -6,7 +6,7 @@ from guppylang.decorator import guppy from guppylang.module import GuppyModule -from guppylang.prelude.builtins import py +from guppylang.prelude.builtins import py, array from guppylang.prelude import quantum from guppylang.prelude.quantum import qubit from tests.util import compile_guppy @@ -166,3 +166,23 @@ def foo(q: qubit) -> tuple[qubit, bool]: return py(circ)(q) validate(module.compile()) + + +def test_func_type_arg(validate): + module = GuppyModule("test") + n = 10 + + @guppy(module) + def foo(xs: array[int, py(n)]) -> array[int, py(n)]: + return xs + + @guppy.declare(module) + def bar(xs: array[int, py(n)]) -> array[int, py(n)]: ... + + @guppy.struct(module) + class Baz: + xs: array[int, py(n)] + + validate(module.compile()) + +