-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat!: Support implicit modules for all decorators and turn builtins into implicit module #476
Changes from 6 commits
3c3f157
fcf6984
359c1fc
2b21a44
d25dab1
6f6a564
faa5b64
d99d694
f5cff19
6d869e4
37162f9
9287b39
f24183d
133888f
3f04c01
fc29e84
2f70db7
17b5647
d4532ac
e3c4453
15e5c48
bafadfe
8f75828
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,16 +30,25 @@ | |
from guppylang.definition.struct import RawStructDef | ||
from guppylang.definition.ty import OpaqueTypeDef, TypeDef | ||
from guppylang.error import GuppyError, MissingModuleError, pretty_errors | ||
from guppylang.module import GuppyModule, PyFunc, find_guppy_module_in_py_module | ||
from guppylang.module import ( | ||
GuppyModule, | ||
PyClass, | ||
PyFunc, | ||
find_guppy_module_in_py_module, | ||
) | ||
from guppylang.tys.subst import Inst | ||
from guppylang.tys.ty import NumericType | ||
|
||
FuncDefDecorator = Callable[[PyFunc], RawFunctionDef] | ||
FuncDeclDecorator = Callable[[PyFunc], RawFunctionDecl] | ||
CustomFuncDecorator = Callable[[PyFunc], RawCustomFunctionDef] | ||
ClassDecorator = Callable[[type], type] | ||
OpaqueTypeDecorator = Callable[[type], OpaqueTypeDef] | ||
StructDecorator = Callable[[type], RawStructDef] | ||
S = TypeVar("S") | ||
T = TypeVar("T") | ||
Decorator = Callable[[S], T] | ||
|
||
FuncDefDecorator = Decorator[PyFunc, RawFunctionDef] | ||
FuncDeclDecorator = Decorator[PyFunc, RawFunctionDecl] | ||
CustomFuncDecorator = Decorator[PyFunc, RawCustomFunctionDef] | ||
ClassDecorator = Decorator[PyClass, PyClass] | ||
OpaqueTypeDecorator = Decorator[PyClass, OpaqueTypeDef] | ||
StructDecorator = Decorator[PyClass, RawStructDef] | ||
|
||
|
||
@dataclass(frozen=True) | ||
|
@@ -79,21 +88,23 @@ def __call__(self, arg: PyFunc | GuppyModule) -> FuncDefDecorator | RawFunctionD | |
Optionally, the `GuppyModule` in which the function should be placed can | ||
be passed to the decorator. | ||
""" | ||
if not isinstance(arg, GuppyModule): | ||
# Decorator used without any arguments. | ||
# We default to a module associated with the caller of the decorator. | ||
f = arg | ||
module = self.get_module() | ||
|
||
def dec(f: Callable[..., Any], module: GuppyModule) -> RawFunctionDef: | ||
return module.register_func_def(f) | ||
|
||
if isinstance(arg, GuppyModule): | ||
# Module passed. | ||
def dec(f: Callable[..., Any]) -> RawFunctionDef: | ||
return arg.register_func_def(f) | ||
return self._with_optional_module(dec, arg) | ||
|
||
return dec | ||
def _with_optional_module( | ||
self, dec: Callable[[S, GuppyModule], T], arg: S | GuppyModule | ||
) -> Callable[[S], T] | T: | ||
"""Helper function to define decorators that take an optional `GuppyModule` | ||
argument but no other arguments. | ||
|
||
raise ValueError(f"Invalid arguments to `@guppy` decorator: {arg}") | ||
For example, we allow `@guppy(module)` but also `@guppy`. | ||
""" | ||
if isinstance(arg, GuppyModule): | ||
return lambda s: dec(s, arg) | ||
return dec(arg, self.get_module()) | ||
|
||
def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier: | ||
"""Returns an identifier for the Python file/module that called the decorator. | ||
|
@@ -104,89 +115,137 @@ def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier: | |
filename = inspect.getfile(fn) | ||
module = inspect.getmodule(fn) | ||
else: | ||
for s in inspect.stack(): | ||
if s.filename != __file__: | ||
filename = s.filename | ||
module = inspect.getmodule(s.frame) | ||
frame = inspect.currentframe() | ||
while frame: | ||
info = inspect.getframeinfo(frame) | ||
if info and info.filename != __file__: | ||
filename = info.filename | ||
module = inspect.getmodule(frame) | ||
# Skip frames from the `pretty_error` decorator | ||
if module != guppylang.error: | ||
break | ||
frame = frame.f_back | ||
else: | ||
raise GuppyError("Could not find a caller for the `@guppy` decorator") | ||
module_path = Path(filename) | ||
return ModuleIdentifier( | ||
module_path, module.__name__ if module else module_path.name, module | ||
) | ||
|
||
def init_module(self, import_builtins: bool = True) -> None: | ||
"""Manually initialises a Guppy module for the current Python file. | ||
|
||
Calling this method is only required when trying to define an empty module or | ||
a module that doesn't include the builtins. | ||
""" | ||
module_id = self._get_python_caller() | ||
if module_id in self._modules: | ||
msg = f"Module {module_id.name} is already initialised" | ||
raise GuppyError(msg) | ||
self._modules[module_id] = GuppyModule(module_id.name, import_builtins) | ||
|
||
@pretty_errors | ||
def extend_type(self, module: GuppyModule, defn: TypeDef) -> ClassDecorator: | ||
def extend_type( | ||
self, defn: TypeDef, module: GuppyModule | None = None | ||
) -> ClassDecorator: | ||
"""Decorator to add new instance functions to a type.""" | ||
module._instance_func_buffer = {} | ||
mod = module or self.get_module() | ||
mod._instance_func_buffer = {} | ||
|
||
def dec(c: type) -> type: | ||
module._register_buffered_instance_funcs(defn) | ||
mod._register_buffered_instance_funcs(defn) | ||
return c | ||
|
||
return dec | ||
|
||
@pretty_errors | ||
def type( | ||
self, | ||
module: GuppyModule, | ||
hugr_ty: ht.Type, | ||
name: str = "", | ||
linear: bool = False, | ||
bound: ht.TypeBound | None = None, | ||
module: GuppyModule | None = None, | ||
) -> OpaqueTypeDecorator: | ||
"""Decorator to annotate a class definitions as Guppy types. | ||
|
||
Requires the static Hugr translation of the type. Additionally, the type can be | ||
marked as linear. All `@guppy` annotated functions on the class are turned into | ||
instance functions. | ||
""" | ||
module._instance_func_buffer = {} | ||
mod = module or self.get_module() | ||
mod._instance_func_buffer = {} | ||
|
||
def dec(c: type) -> OpaqueTypeDef: | ||
defn = OpaqueTypeDef( | ||
DefId.fresh(module), | ||
DefId.fresh(mod), | ||
name or c.__name__, | ||
None, | ||
[], | ||
linear, | ||
lambda _: hugr_ty, | ||
bound, | ||
) | ||
module.register_def(defn) | ||
module._register_buffered_instance_funcs(defn) | ||
mod.register_def(defn) | ||
mod._register_buffered_instance_funcs(defn) | ||
return defn | ||
|
||
return dec | ||
|
||
@pretty_errors | ||
def struct(self, module: GuppyModule) -> StructDecorator: | ||
@property | ||
def struct( | ||
self, | ||
) -> Callable[[PyClass | GuppyModule], StructDecorator | RawStructDef]: | ||
"""Decorator to define a new struct.""" | ||
module._instance_func_buffer = {} | ||
# Note that this is a property. Thus, the code below is executed *before* | ||
# the members of the decorated class are executed. | ||
# At this point, we don't know if the user has called `@struct(module)` or | ||
# just `@struct`. To be safe, we initialise the method buffer of the implicit | ||
# module either way | ||
caller_id = self._get_python_caller() | ||
implicit_module_existed = caller_id in self._modules | ||
implicit_module = self.get_module( | ||
# But don't try to do implicit imports since we're not sure if this is | ||
# actually an implicit module | ||
resolve_implicit_imports=False | ||
) | ||
implicit_module._instance_func_buffer = {} | ||
|
||
def dec(cls: type) -> RawStructDef: | ||
def dec(cls: type, module: GuppyModule) -> RawStructDef: | ||
defn = RawStructDef(DefId.fresh(module), cls.__name__, None, cls) | ||
module.register_def(defn) | ||
module._register_buffered_instance_funcs(defn) | ||
# If we mistakenly initialised the method buffer of the implicit module | ||
# we can just clear it here | ||
if module != implicit_module: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is doing pointer equality on the module then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, since |
||
implicit_module._instance_func_buffer = None | ||
if not implicit_module_existed: | ||
self._modules.pop(caller_id) | ||
return defn | ||
|
||
return dec | ||
def higher_dec(arg: GuppyModule | PyClass) -> StructDecorator | RawStructDef: | ||
if isinstance(arg, GuppyModule): | ||
arg._instance_func_buffer = {} | ||
return self._with_optional_module(dec, arg) | ||
|
||
return higher_dec | ||
|
||
@pretty_errors | ||
def type_var(self, module: GuppyModule, name: str, linear: bool = False) -> TypeVar: | ||
def type_var( | ||
self, name: str, linear: bool = False, module: GuppyModule | None = None | ||
) -> TypeVar: | ||
"""Creates a new type variable in a module.""" | ||
module = module or self.get_module() | ||
defn = TypeVarDef(DefId.fresh(module), name, None, linear) | ||
module.register_def(defn) | ||
# Return an actual Python `TypeVar` so it can be used as an actual type in code | ||
# that is executed by interpreter before handing it to Guppy. | ||
return TypeVar(name) | ||
|
||
@pretty_errors | ||
def nat_var(self, module: GuppyModule, name: str) -> ConstVarDef: | ||
def nat_var(self, name: str, module: GuppyModule | None = None) -> ConstVarDef: | ||
"""Creates a new const nat variable in a module.""" | ||
module = module or self.get_module() | ||
defn = ConstVarDef( | ||
DefId.fresh(module), name, None, NumericType(NumericType.Kind.Nat) | ||
) | ||
|
@@ -196,18 +255,19 @@ def nat_var(self, module: GuppyModule, name: str) -> ConstVarDef: | |
@pretty_errors | ||
def custom( | ||
self, | ||
module: GuppyModule, | ||
compiler: CustomCallCompiler | None = None, | ||
checker: CustomCallChecker | None = None, | ||
higher_order_value: bool = True, | ||
name: str = "", | ||
module: GuppyModule | None = None, | ||
) -> CustomFuncDecorator: | ||
"""Decorator to add custom typing or compilation behaviour to function decls. | ||
|
||
Optionally, usage of the function as a higher-order value can be disabled. In | ||
that case, the function signature can be omitted if a custom call compiler is | ||
provided. | ||
""" | ||
mod = module or self.get_module() | ||
|
||
def dec(f: PyFunc) -> RawCustomFunctionDef: | ||
func_ast, docstring = parse_py_func(f) | ||
|
@@ -218,25 +278,25 @@ def dec(f: PyFunc) -> RawCustomFunctionDef: | |
) | ||
call_checker = checker or DefaultCallChecker() | ||
func = RawCustomFunctionDef( | ||
DefId.fresh(module), | ||
DefId.fresh(mod), | ||
name or func_ast.name, | ||
func_ast, | ||
call_checker, | ||
compiler or NotImplementedCallCompiler(), | ||
higher_order_value, | ||
) | ||
module.register_def(func) | ||
mod.register_def(func) | ||
return func | ||
|
||
return dec | ||
|
||
def hugr_op( | ||
self, | ||
module: GuppyModule, | ||
op: Callable[[ht.FunctionType, Inst], ops.DataflowOp], | ||
checker: CustomCallChecker | None = None, | ||
higher_order_value: bool = True, | ||
name: str = "", | ||
module: GuppyModule | None = None, | ||
) -> CustomFuncDecorator: | ||
"""Decorator to annotate function declarations as HUGR ops. | ||
|
||
|
@@ -249,34 +309,42 @@ def hugr_op( | |
value. | ||
name: The name of the function. | ||
""" | ||
return self.custom(module, OpCompiler(op), checker, higher_order_value, name) | ||
return self.custom(OpCompiler(op), checker, higher_order_value, name, module) | ||
|
||
def declare(self, module: GuppyModule) -> FuncDeclDecorator: | ||
@overload | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
def declare(self, arg: GuppyModule) -> RawFunctionDecl: ... | ||
|
||
@overload | ||
def declare(self, arg: PyFunc) -> FuncDeclDecorator: ... | ||
|
||
def declare(self, arg: GuppyModule | PyFunc) -> FuncDeclDecorator | RawFunctionDecl: | ||
"""Decorator to declare functions""" | ||
|
||
def dec(f: Callable[..., Any]) -> RawFunctionDecl: | ||
def dec(f: Callable[..., Any], module: GuppyModule) -> RawFunctionDecl: | ||
return module.register_func_decl(f) | ||
|
||
return dec | ||
return self._with_optional_module(dec, arg) | ||
|
||
def constant( | ||
self, module: GuppyModule, name: str, ty: str, value: hv.Value | ||
self, name: str, ty: str, value: hv.Value, module: GuppyModule | None = None | ||
) -> RawConstDef: | ||
"""Adds a constant to a module, backed by a `hugr.val.Value`.""" | ||
module = module or self.get_module() | ||
type_ast = _parse_expr_string(ty, f"Not a valid Guppy type: `{ty}`") | ||
defn = RawConstDef(DefId.fresh(module), name, None, type_ast, value) | ||
module.register_def(defn) | ||
return defn | ||
|
||
def extern( | ||
self, | ||
module: GuppyModule, | ||
name: str, | ||
ty: str, | ||
symbol: str | None = None, | ||
constant: bool = True, | ||
module: GuppyModule | None = None, | ||
) -> RawExternDef: | ||
"""Adds an extern symbol to a module.""" | ||
module = module or self.get_module() | ||
type_ast = _parse_expr_string(ty, f"Not a valid Guppy type: `{ty}`") | ||
defn = RawExternDef( | ||
DefId.fresh(module), name, None, symbol or name, constant, type_ast | ||
|
@@ -291,19 +359,23 @@ def load(self, m: ModuleType | GuppyModule) -> None: | |
module = self._modules[caller] | ||
module.load_all(m) | ||
|
||
def get_module(self, id: ModuleIdentifier | None = None) -> GuppyModule: | ||
def get_module( | ||
self, id: ModuleIdentifier | None = None, resolve_implicit_imports: bool = True | ||
) -> GuppyModule: | ||
"""Returns the local GuppyModule.""" | ||
if id is None: | ||
id = self._get_python_caller() | ||
if id not in self._modules: | ||
self._modules[id] = GuppyModule(id.name.split(".")[-1]) | ||
module = self._modules[id] | ||
# Update implicit imports | ||
if id.module: | ||
if resolve_implicit_imports and id.module: | ||
defs: dict[str, Definition | ModuleType] = {} | ||
for x, value in id.module.__dict__.items(): | ||
if isinstance(value, Definition) and value.id.module != module: | ||
defs[x] = value | ||
if isinstance(value, Definition): | ||
other_module = value.id.module | ||
if other_module and other_module != module: | ||
defs[x] = value | ||
elif isinstance(value, ModuleType): | ||
try: | ||
other_module = find_guppy_module_in_py_module(value) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drive-by:
inspect.stack()
is inefficient, it suffices to get the current frame and walk back. Now that we make heavy use of implicit modues, this actually makes a difference