Skip to content

Commit

Permalink
feat: Allow py expressions in type arguments (#515)
Browse files Browse the repository at this point in the history
Closes #513

Function declarations and struct definitions now also store the
`python_scope` in which they were defined in oder to resolve
py-expressions ocurring in their signatures.

Also I needed to update the way the python scope of functions is
computed: The approach using `__closure__` doesn't capture variables
that are bound in the signature. Instead we now also inspect the calling
frame to extract the `f_globals` and `f_locals` from there. The same is
also used for structs since Python classes don't have `__closure__`
  • Loading branch information
mark-koch authored Oct 1, 2024
1 parent d4eb07d commit b4fae3f
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 49 deletions.
40 changes: 26 additions & 14 deletions guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,20 +343,7 @@ def visit_ListComp(self, node: ast.ListComp) -> ast.AST:
return with_loc(node, DesugaredListComp(elt=node.elt, generators=gens))

def visit_Call(self, node: ast.Call) -> ast.AST:
# Parse compile-time evaluated `py(...)` expression
if isinstance(node.func, ast.Name) and node.func.id == "py":
match node.args:
case []:
raise GuppyError(
"Compile-time `py(...)` expression requires an argument",
node,
)
case [arg]:
pass
case args:
arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load))
return with_loc(node, PyExpr(value=arg))
return self.generic_visit(node)
return is_py_expression(node) or self.generic_visit(node)

def generic_visit(self, node: ast.AST) -> ast.AST:
# Short-circuit expressions must be built using the `BranchBuilder`. However, we
Expand Down Expand Up @@ -492,6 +479,31 @@ def is_functional_annotation(stmt: ast.stmt) -> bool:
return False


def is_py_expression(node: ast.AST) -> PyExpr | None:
"""Checks if the given node is a compile-time `py(...)` expression and turns it into
a `PyExpr` AST node.
Otherwise, returns `None`.
"""
if (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Name)
and node.func.id == "py"
):
match node.args:
case []:
raise GuppyError(
"Compile-time `py(...)` expression requires an argument",
node,
)
case [arg]:
pass
case args:
arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load))
return with_loc(node, PyExpr(value=arg))
return None


def is_short_circuit_expr(node: ast.AST) -> bool:
"""Checks if an expression uses short-circuiting.
Expand Down
5 changes: 5 additions & 0 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None
return defn
return None

def with_python_scope(self, python_scope: PyScope) -> "Globals":
return Globals(
self.defs, self.names, self.impls, self.python_scope | python_scope
)

def __or__(self, other: "Globals") -> "Globals":
impls = {
def_id: self.impls.get(def_id, {}) | other.impls.get(def_id, {})
Expand Down
26 changes: 10 additions & 16 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Callable, KeysView
from dataclasses import dataclass, field
from pathlib import Path
from types import FrameType, ModuleType
from types import ModuleType
from typing import Any, TypeVar, overload

import hugr.ext
Expand Down Expand Up @@ -36,6 +36,7 @@
PyClass,
PyFunc,
find_guppy_module_in_py_module,
get_calling_frame,
)
from guppylang.tys.subst import Inst
from guppylang.tys.ty import NumericType
Expand Down Expand Up @@ -233,8 +234,14 @@ def struct(
)
implicit_module._instance_func_buffer = {}

# Extract Python scope from the frame that called `guppy.struct`
frame = get_calling_frame()
python_scope = frame.f_globals | frame.f_locals if frame else {}

def dec(cls: type, module: GuppyModule) -> RawStructDef:
defn = RawStructDef(DefId.fresh(module), cls.__name__, None, cls)
defn = RawStructDef(
DefId.fresh(module), cls.__name__, None, cls, python_scope
)
module.register_def(defn)
module._register_buffered_instance_funcs(defn)
# If we mistakenly initialised the method buffer of the implicit module
Expand Down Expand Up @@ -450,7 +457,7 @@ def _parse_expr_string(ty_str: str, parse_err: str) -> ast.expr:

# Try to annotate the type AST with source information. This requires us to
# inspect the stack frame of the caller
if caller_frame := _get_calling_frame():
if caller_frame := get_calling_frame():
info = inspect.getframeinfo(caller_frame)
if caller_module := inspect.getmodule(caller_frame):
source_lines, _ = inspect.getsourcelines(caller_module)
Expand All @@ -463,16 +470,3 @@ def _parse_expr_string(ty_str: str, parse_err: str) -> ast.expr:
node.lineno, node.col_offset = info.lineno, 0
node.end_col_offset = len(source_lines[info.lineno - 1])
return expr_ast


def _get_calling_frame() -> FrameType | None:
"""Finds the first frame that called this function outside the current module."""
frame = inspect.currentframe()
while frame:
module = inspect.getmodule(frame)
if module is None:
break
if module.__file__ != __file__:
return frame
frame = frame.f_back
return None
14 changes: 11 additions & 3 deletions guppylang/definition/declaration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from hugr.build.dfg import DefinitionBuilder, OpVar

from guppylang.ast_util import AstNode, has_empty_body, with_loc
from guppylang.checker.core import Context, Globals
from guppylang.checker.core import Context, Globals, PyScope
from guppylang.checker.expr_checker import check_call, synthesize_call
from guppylang.checker.func_checker import check_signature
from guppylang.compiler.core import CompiledGlobals, DFContainer
Expand All @@ -29,18 +29,25 @@ class RawFunctionDecl(ParsableDef):
"""

python_func: PyFunc
python_scope: PyScope
description: str = field(default="function", init=False)

def parse(self, globals: Globals) -> "CheckedFunctionDecl":
"""Parses and checks the user-provided signature of the function."""
func_ast, docstring = parse_py_func(self.python_func)
ty = check_signature(func_ast, globals)
ty = check_signature(func_ast, globals.with_python_scope(self.python_scope))
if not has_empty_body(func_ast):
raise GuppyError(
"Body of function declaration must be empty", func_ast.body[0]
)
return CheckedFunctionDecl(
self.id, self.name, func_ast, ty, self.python_func, docstring
self.id,
self.name,
func_ast,
ty,
self.python_func,
self.python_scope,
docstring,
)


Expand Down Expand Up @@ -86,6 +93,7 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledFunctionDe
self.defined_at,
self.ty,
self.python_func,
self.python_scope,
self.docstring,
node,
)
Expand Down
4 changes: 2 additions & 2 deletions guppylang/definition/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class RawFunctionDef(ParsableDef):
def parse(self, globals: Globals) -> "ParsedFunctionDef":
"""Parses and checks the user-provided signature of the function."""
func_ast, docstring = parse_py_func(self.python_func)
ty = check_signature(func_ast, globals)
ty = check_signature(func_ast, globals.with_python_scope(self.python_scope))
if ty.parametrized:
raise GuppyError(
"Generic function definitions are not supported yet", func_ast
Expand Down Expand Up @@ -92,7 +92,7 @@ class ParsedFunctionDef(CheckableDef, CallableDef):
def check(self, globals: Globals) -> "CheckedFunctionDef":
"""Type checks the body of the function."""
# Add python variable scope to the globals
globals = globals | Globals({}, {}, {}, self.python_scope)
globals = globals.with_python_scope(self.python_scope)
cfg = check_global_func_def(self.defined_at, self.ty, globals)
return CheckedFunctionDef(
self.id,
Expand Down
9 changes: 7 additions & 2 deletions guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from hugr import Wire, ops

from guppylang.ast_util import AstNode, annotate_location
from guppylang.checker.core import Globals
from guppylang.checker.core import Globals, PyScope
from guppylang.definition.common import (
CheckableDef,
CompiledDef,
Expand Down Expand Up @@ -52,6 +52,7 @@ class RawStructDef(TypeDef, ParsableDef):
"""A raw struct type definition that has not been parsed yet."""

python_class: type
python_scope: PyScope

def __getitem__(self, item: Any) -> "RawStructDef":
"""Dummy implementation to enable subscripting in the Python runtime.
Expand Down Expand Up @@ -131,7 +132,9 @@ def parse(self, globals: Globals) -> "ParsedStructDef":
used_func_names[x],
)

return ParsedStructDef(self.id, self.name, cls_def, params, fields)
return ParsedStructDef(
self.id, self.name, cls_def, params, fields, self.python_scope
)

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
Expand All @@ -146,12 +149,14 @@ class ParsedStructDef(TypeDef, CheckableDef):
defined_at: ast.ClassDef
params: Sequence[Parameter]
fields: Sequence[UncheckedStructField]
python_scope: PyScope

def check(self, globals: Globals) -> "CheckedStructDef":
"""Checks that all struct fields have valid types."""
# Before checking the fields, make sure that this definition is not recursive,
# otherwise the code below would not terminate.
# TODO: This is not ideal (see todo in `check_instantiate`)
globals = globals.with_python_scope(self.python_scope)
check_not_recursive(self, globals)

param_var_mapping = {p.name: p for p in self.params}
Expand Down
29 changes: 24 additions & 5 deletions guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import sys
from collections.abc import Callable, Mapping
from pathlib import Path
from types import ModuleType
from types import FrameType, ModuleType
from typing import Any

from hugr import Hugr, ops
from hugr.build.function import Module
from hugr.ext import Package

import guppylang.compiler.hugr_extension
from guppylang import decorator
from guppylang.checker.core import Globals, PyScope
from guppylang.compiler.core import CompiledGlobals
from guppylang.definition.common import (
Expand Down Expand Up @@ -213,7 +214,7 @@ def register_func_decl(
self, f: PyFunc, instance: TypeDef | None = None
) -> RawFunctionDecl:
"""Registers a Python function declaration as belonging to this Guppy module."""
decl = RawFunctionDecl(DefId.fresh(self), f.__name__, None, f)
decl = RawFunctionDecl(DefId.fresh(self), f.__name__, None, f, get_py_scope(f))
self.register_def(decl, instance)
return decl

Expand Down Expand Up @@ -353,16 +354,21 @@ def contains(self, name: str) -> bool:


def get_py_scope(f: PyFunc) -> PyScope:
"""Returns a mapping of all variables captured by a Python function.
"""Returns a mapping of all variables captured by a Python function together with
the `f_locals` and `f_globals` of the frame that called this function.
Note that this function only works in CPython. On other platforms, an empty
dictionary is returned.
Relies on inspecting the `__globals__` and `__closure__` attributes of the function.
See https://docs.python.org/3/reference/datamodel.html#special-read-only-attributes
"""
# Get variables from the calling frame
frame = get_calling_frame()
frame_vars = frame.f_globals | frame.f_locals if frame else {}

if sys.implementation.name != "cpython":
return {}
return frame_vars

if inspect.ismethod(f):
f = f.__func__
Expand All @@ -379,7 +385,7 @@ def get_py_scope(f: PyFunc) -> PyScope:
continue
nonlocals[var] = value

return nonlocals | f.__globals__.copy()
return frame_vars | nonlocals | f.__globals__.copy()


def find_guppy_module_in_py_module(module: ModuleType) -> GuppyModule:
Expand Down Expand Up @@ -407,3 +413,16 @@ def find_guppy_module_in_py_module(module: ModuleType) -> GuppyModule:
)
raise GuppyError(msg)
return mods[0]


def get_calling_frame() -> FrameType | None:
"""Finds the first frame that called this function outside the compiler modules."""
frame = inspect.currentframe()
while frame:
module = inspect.getmodule(frame)
if module is None:
break
if module.__file__ != __file__ and module != decorator:
return frame
frame = frame.f_back
return None
8 changes: 3 additions & 5 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from guppylang.decorator import guppy
from guppylang.definition.custom import DefaultCallChecker, NoopCompiler
from guppylang.error import GuppyError
from guppylang.prelude._internal.checker import (
ArrayLenChecker,
CallableChecker,
Expand Down Expand Up @@ -57,13 +56,12 @@
L = guppy.type_var("L", linear=True)


def py(*_args: Any) -> Any:
def py(*args: Any) -> Any:
"""Function to tag compile-time evaluated Python expressions in a Guppy context.
This function throws an error when execute in a Python context. It is only intended
to be used inside Guppy functions.
This function acts like the identity when execute in a Python context.
"""
raise GuppyError("`py` can only by used in a Guppy context")
return tuple(args)


class _Owned:
Expand Down
17 changes: 16 additions & 1 deletion guppylang/tys/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
set_location_from,
shift_loc,
)
from guppylang.checker.core import Globals
from guppylang.cfg.builder import is_py_expression
from guppylang.checker.core import Context, Globals, Locals
from guppylang.checker.expr_checker import eval_py_expr
from guppylang.definition.common import Definition
from guppylang.definition.module import ModuleDef
from guppylang.definition.parameter import ParamDef
Expand Down Expand Up @@ -69,6 +71,19 @@ def arg_from_ast(
nat_ty = NumericType(NumericType.Kind.Nat)
return ConstArg(ConstValue(nat_ty, node.value))

# Py-expressions can also be used to specify static numbers
if py_expr := is_py_expression(node):
v = eval_py_expr(py_expr, Context(globals, Locals({})))
if isinstance(v, int):
nat_ty = NumericType(NumericType.Kind.Nat)
return ConstArg(ConstValue(nat_ty, v))
else:
raise GuppyError(
f"Compile-time `py(...)` expression with type `{type(v)}` is not a "
"valid type argument",
node,
)

# Finally, we also support delayed annotations in strings
if isinstance(node, ast.Constant) and isinstance(node.value, str):
node = _parse_delayed_annotation(node.value, node)
Expand Down
6 changes: 6 additions & 0 deletions tests/error/py_errors/invalid_type_arg.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Guppy compilation failed. Error in file $FILE:7

5: @compile_guppy
6: def foo(xs: array[int, py(1.0)]) -> None:
^^^^^^^
GuppyError: Compile-time `py(...)` expression with type `<class 'float'>` is not a valid type argument
8 changes: 8 additions & 0 deletions tests/error/py_errors/invalid_type_arg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from guppylang import py
from guppylang.prelude.builtins import array
from tests.util import compile_guppy


@compile_guppy
def foo(xs: array[int, py(1.0)]) -> None:
pass
Loading

0 comments on commit b4fae3f

Please sign in to comment.