Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Support implicit modules for all decorators and turn builtins into implicit module #476

Merged
merged 23 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3c3f157
feat!: Support implicit modules for all decorators and turn builtins …
mark-koch Sep 11, 2024
fcf6984
Lints
mark-koch Sep 11, 2024
359c1fc
Fix test
mark-koch Sep 11, 2024
2b21a44
Fix more tests
mark-koch Sep 11, 2024
d25dab1
Fix test
mark-koch Sep 12, 2024
6f6a564
Merge remote-tracking branch 'origin/main' into feat/implicit-decorators
mark-koch Sep 12, 2024
faa5b64
feat!: Add functions to quantum module and make quantum_functional in…
mark-koch Sep 13, 2024
d99d694
Fix tket tests
mark-koch Sep 13, 2024
f5cff19
Merge remote-tracking branch 'origin/main' into feat/quantum-module
mark-koch Sep 13, 2024
6d869e4
Merge remote-tracking branch 'origin/feat/quantum-module' into feat/i…
mark-koch Sep 13, 2024
37162f9
Fix tests
mark-koch Sep 13, 2024
9287b39
Fmt
mark-koch Sep 13, 2024
f24183d
Fix test
mark-koch Sep 13, 2024
133888f
tweak comment
acl-cqc Sep 16, 2024
3f04c01
quantum.py import only no_type_check from typing
acl-cqc Sep 16, 2024
fc29e84
fix: Fix implicit imports in notebooks (#495)
mark-koch Sep 16, 2024
2f70db7
Merge remote-tracking branch 'origin/main' into HEAD
acl-cqc Sep 16, 2024
17b5647
Merge 'origin/main' including 'feat/quantum-module' into HEAD
acl-cqc Sep 16, 2024
d4532ac
Add @overloads to help document _with_optional_module; use Decorator …
acl-cqc Sep 16, 2024
e3c4453
fix overloads, oops, and type-ignore, heh
acl-cqc Sep 16, 2024
15e5c48
Pull stuff out of loop in _get_python_caller
acl-cqc Sep 16, 2024
bafadfe
assert discarded buffer is empty
acl-cqc Sep 16, 2024
8f75828
make assert mypy-ok
acl-cqc Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 123 additions & 51 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,25 @@
from guppylang.definition.struct import RawStructDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import GuppyError, MissingModuleError, pretty_errors
from guppylang.module import GuppyModule, PyFunc, find_guppy_module_in_py_module
from guppylang.module import (
GuppyModule,
PyClass,
PyFunc,
find_guppy_module_in_py_module,
)
from guppylang.tys.subst import Inst
from guppylang.tys.ty import NumericType

FuncDefDecorator = Callable[[PyFunc], RawFunctionDef]
FuncDeclDecorator = Callable[[PyFunc], RawFunctionDecl]
CustomFuncDecorator = Callable[[PyFunc], RawCustomFunctionDef]
ClassDecorator = Callable[[type], type]
OpaqueTypeDecorator = Callable[[type], OpaqueTypeDef]
StructDecorator = Callable[[type], RawStructDef]
S = TypeVar("S")
T = TypeVar("T")
Decorator = Callable[[S], T]

FuncDefDecorator = Decorator[PyFunc, RawFunctionDef]
FuncDeclDecorator = Decorator[PyFunc, RawFunctionDecl]
CustomFuncDecorator = Decorator[PyFunc, RawCustomFunctionDef]
ClassDecorator = Decorator[PyClass, PyClass]
OpaqueTypeDecorator = Decorator[PyClass, OpaqueTypeDef]
StructDecorator = Decorator[PyClass, RawStructDef]


@dataclass(frozen=True)
Expand Down Expand Up @@ -79,21 +88,23 @@ def __call__(self, arg: PyFunc | GuppyModule) -> FuncDefDecorator | RawFunctionD
Optionally, the `GuppyModule` in which the function should be placed can
be passed to the decorator.
"""
if not isinstance(arg, GuppyModule):
# Decorator used without any arguments.
# We default to a module associated with the caller of the decorator.
f = arg
module = self.get_module()

def dec(f: Callable[..., Any], module: GuppyModule) -> RawFunctionDef:
return module.register_func_def(f)

if isinstance(arg, GuppyModule):
# Module passed.
def dec(f: Callable[..., Any]) -> RawFunctionDef:
return arg.register_func_def(f)
return self._with_optional_module(dec, arg)

return dec
def _with_optional_module(
self, dec: Callable[[S, GuppyModule], T], arg: S | GuppyModule
) -> Callable[[S], T] | T:
"""Helper function to define decorators that take an optional `GuppyModule`
argument but no other arguments.

raise ValueError(f"Invalid arguments to `@guppy` decorator: {arg}")
For example, we allow `@guppy(module)` but also `@guppy`.
"""
if isinstance(arg, GuppyModule):
return lambda s: dec(s, arg)
return dec(arg, self.get_module())

def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier:
"""Returns an identifier for the Python file/module that called the decorator.
Expand All @@ -104,89 +115,137 @@ def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier:
filename = inspect.getfile(fn)
module = inspect.getmodule(fn)
else:
for s in inspect.stack():
if s.filename != __file__:
filename = s.filename
module = inspect.getmodule(s.frame)
frame = inspect.currentframe()
while frame:
info = inspect.getframeinfo(frame)
if info and info.filename != __file__:
filename = info.filename
module = inspect.getmodule(frame)
# Skip frames from the `pretty_error` decorator
if module != guppylang.error:
break
frame = frame.f_back
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drive-by: inspect.stack() is inefficient, it suffices to get the current frame and walk back. Now that we make heavy use of implicit modues, this actually makes a difference

else:
raise GuppyError("Could not find a caller for the `@guppy` decorator")
module_path = Path(filename)
return ModuleIdentifier(
module_path, module.__name__ if module else module_path.name, module
)

def init_module(self, import_builtins: bool = True) -> None:
"""Manually initialises a Guppy module for the current Python file.

Calling this method is only required when trying to define an empty module or
a module that doesn't include the builtins.
"""
module_id = self._get_python_caller()
if module_id in self._modules:
msg = f"Module {module_id.name} is already initialised"
raise GuppyError(msg)
self._modules[module_id] = GuppyModule(module_id.name, import_builtins)

@pretty_errors
def extend_type(self, module: GuppyModule, defn: TypeDef) -> ClassDecorator:
def extend_type(
self, defn: TypeDef, module: GuppyModule | None = None
) -> ClassDecorator:
"""Decorator to add new instance functions to a type."""
module._instance_func_buffer = {}
mod = module or self.get_module()
mod._instance_func_buffer = {}

def dec(c: type) -> type:
module._register_buffered_instance_funcs(defn)
mod._register_buffered_instance_funcs(defn)
return c

return dec

@pretty_errors
def type(
self,
module: GuppyModule,
hugr_ty: ht.Type,
name: str = "",
linear: bool = False,
bound: ht.TypeBound | None = None,
module: GuppyModule | None = None,
) -> OpaqueTypeDecorator:
"""Decorator to annotate a class definitions as Guppy types.

Requires the static Hugr translation of the type. Additionally, the type can be
marked as linear. All `@guppy` annotated functions on the class are turned into
instance functions.
"""
module._instance_func_buffer = {}
mod = module or self.get_module()
mod._instance_func_buffer = {}

def dec(c: type) -> OpaqueTypeDef:
defn = OpaqueTypeDef(
DefId.fresh(module),
DefId.fresh(mod),
name or c.__name__,
None,
[],
linear,
lambda _: hugr_ty,
bound,
)
module.register_def(defn)
module._register_buffered_instance_funcs(defn)
mod.register_def(defn)
mod._register_buffered_instance_funcs(defn)
return defn

return dec

@pretty_errors
def struct(self, module: GuppyModule) -> StructDecorator:
@property
def struct(
self,
) -> Callable[[PyClass | GuppyModule], StructDecorator | RawStructDef]:
"""Decorator to define a new struct."""
module._instance_func_buffer = {}
# Note that this is a property. Thus, the code below is executed *before*
# the members of the decorated class are executed.
# At this point, we don't know if the user has called `@struct(module)` or
# just `@struct`. To be safe, we initialise the method buffer of the implicit
# module either way
caller_id = self._get_python_caller()
implicit_module_existed = caller_id in self._modules
implicit_module = self.get_module(
# But don't try to do implicit imports since we're not sure if this is
# actually an implicit module
resolve_implicit_imports=False
)
implicit_module._instance_func_buffer = {}

def dec(cls: type) -> RawStructDef:
def dec(cls: type, module: GuppyModule) -> RawStructDef:
defn = RawStructDef(DefId.fresh(module), cls.__name__, None, cls)
module.register_def(defn)
module._register_buffered_instance_funcs(defn)
# If we mistakenly initialised the method buffer of the implicit module
# we can just clear it here
if module != implicit_module:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is doing pointer equality on the module then?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, since GuppyModule doesn't implement __eq__

implicit_module._instance_func_buffer = None
if not implicit_module_existed:
self._modules.pop(caller_id)
return defn

return dec
def higher_dec(arg: GuppyModule | PyClass) -> StructDecorator | RawStructDef:
if isinstance(arg, GuppyModule):
arg._instance_func_buffer = {}
return self._with_optional_module(dec, arg)

return higher_dec

@pretty_errors
def type_var(self, module: GuppyModule, name: str, linear: bool = False) -> TypeVar:
def type_var(
self, name: str, linear: bool = False, module: GuppyModule | None = None
) -> TypeVar:
"""Creates a new type variable in a module."""
module = module or self.get_module()
defn = TypeVarDef(DefId.fresh(module), name, None, linear)
module.register_def(defn)
# Return an actual Python `TypeVar` so it can be used as an actual type in code
# that is executed by interpreter before handing it to Guppy.
return TypeVar(name)

@pretty_errors
def nat_var(self, module: GuppyModule, name: str) -> ConstVarDef:
def nat_var(self, name: str, module: GuppyModule | None = None) -> ConstVarDef:
"""Creates a new const nat variable in a module."""
module = module or self.get_module()
defn = ConstVarDef(
DefId.fresh(module), name, None, NumericType(NumericType.Kind.Nat)
)
Expand All @@ -196,18 +255,19 @@ def nat_var(self, module: GuppyModule, name: str) -> ConstVarDef:
@pretty_errors
def custom(
self,
module: GuppyModule,
compiler: CustomCallCompiler | None = None,
checker: CustomCallChecker | None = None,
higher_order_value: bool = True,
name: str = "",
module: GuppyModule | None = None,
) -> CustomFuncDecorator:
"""Decorator to add custom typing or compilation behaviour to function decls.

Optionally, usage of the function as a higher-order value can be disabled. In
that case, the function signature can be omitted if a custom call compiler is
provided.
"""
mod = module or self.get_module()

def dec(f: PyFunc) -> RawCustomFunctionDef:
func_ast, docstring = parse_py_func(f)
Expand All @@ -218,25 +278,25 @@ def dec(f: PyFunc) -> RawCustomFunctionDef:
)
call_checker = checker or DefaultCallChecker()
func = RawCustomFunctionDef(
DefId.fresh(module),
DefId.fresh(mod),
name or func_ast.name,
func_ast,
call_checker,
compiler or NotImplementedCallCompiler(),
higher_order_value,
)
module.register_def(func)
mod.register_def(func)
return func

return dec

def hugr_op(
self,
module: GuppyModule,
op: Callable[[ht.FunctionType, Inst], ops.DataflowOp],
checker: CustomCallChecker | None = None,
higher_order_value: bool = True,
name: str = "",
module: GuppyModule | None = None,
) -> CustomFuncDecorator:
"""Decorator to annotate function declarations as HUGR ops.

Expand All @@ -249,34 +309,42 @@ def hugr_op(
value.
name: The name of the function.
"""
return self.custom(module, OpCompiler(op), checker, higher_order_value, name)
return self.custom(OpCompiler(op), checker, higher_order_value, name, module)

def declare(self, module: GuppyModule) -> FuncDeclDecorator:
@overload
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

def declare(self, arg: GuppyModule) -> RawFunctionDecl: ...

@overload
def declare(self, arg: PyFunc) -> FuncDeclDecorator: ...

def declare(self, arg: GuppyModule | PyFunc) -> FuncDeclDecorator | RawFunctionDecl:
"""Decorator to declare functions"""

def dec(f: Callable[..., Any]) -> RawFunctionDecl:
def dec(f: Callable[..., Any], module: GuppyModule) -> RawFunctionDecl:
return module.register_func_decl(f)

return dec
return self._with_optional_module(dec, arg)

def constant(
self, module: GuppyModule, name: str, ty: str, value: hv.Value
self, name: str, ty: str, value: hv.Value, module: GuppyModule | None = None
) -> RawConstDef:
"""Adds a constant to a module, backed by a `hugr.val.Value`."""
module = module or self.get_module()
type_ast = _parse_expr_string(ty, f"Not a valid Guppy type: `{ty}`")
defn = RawConstDef(DefId.fresh(module), name, None, type_ast, value)
module.register_def(defn)
return defn

def extern(
self,
module: GuppyModule,
name: str,
ty: str,
symbol: str | None = None,
constant: bool = True,
module: GuppyModule | None = None,
) -> RawExternDef:
"""Adds an extern symbol to a module."""
module = module or self.get_module()
type_ast = _parse_expr_string(ty, f"Not a valid Guppy type: `{ty}`")
defn = RawExternDef(
DefId.fresh(module), name, None, symbol or name, constant, type_ast
Expand All @@ -291,19 +359,23 @@ def load(self, m: ModuleType | GuppyModule) -> None:
module = self._modules[caller]
module.load_all(m)

def get_module(self, id: ModuleIdentifier | None = None) -> GuppyModule:
def get_module(
self, id: ModuleIdentifier | None = None, resolve_implicit_imports: bool = True
) -> GuppyModule:
"""Returns the local GuppyModule."""
if id is None:
id = self._get_python_caller()
if id not in self._modules:
self._modules[id] = GuppyModule(id.name.split(".")[-1])
module = self._modules[id]
# Update implicit imports
if id.module:
if resolve_implicit_imports and id.module:
defs: dict[str, Definition | ModuleType] = {}
for x, value in id.module.__dict__.items():
if isinstance(value, Definition) and value.id.module != module:
defs[x] = value
if isinstance(value, Definition):
other_module = value.id.module
if other_module and other_module != module:
defs[x] = value
elif isinstance(value, ModuleType):
try:
other_module = find_guppy_module_in_py_module(value)
Expand Down
1 change: 1 addition & 0 deletions guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError, pretty_errors

PyClass = type
PyFunc = Callable[..., Any]
PyFuncDefOrDecl = tuple[bool, PyFunc]

Expand Down
Loading
Loading