From 6c4b6b7a8ce55b441ee25a3551f8a9dcf6c64831 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Fri, 2 Feb 2024 12:36:34 +0000 Subject: [PATCH] feat: Attribute access to instance methods --- guppylang/checker/expr_checker.py | 30 ++++++++++++++++++++++++++ guppylang/compiler/expr_compiler.py | 6 ++++++ guppylang/custom.py | 9 +++++--- guppylang/hugr/hugr.py | 4 ++-- guppylang/nodes.py | 5 +++++ tests/integration/test_call.py | 7 ++++++ tests/integration/test_higher_order.py | 6 ++++++ 7 files changed, 62 insertions(+), 5 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 36570037..4f5fd6e9 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -73,6 +73,7 @@ LocalCall, LocalName, MakeIter, + PartialApply, PyExpr, TypeApply, ) @@ -214,6 +215,12 @@ def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]: ) node.func, func_ty = self._synthesize(node.func, allow_free_vars=False) + # When calling a `PartialApply` node, we just move the args into this call + if isinstance(node.func, PartialApply): + node.args = [*node.func.args, *node.args] + node.func = node.func.func + return self.visit_Call(node, ty) + # First handle direct calls of user-defined functions and extension functions if isinstance(node.func, GlobalName) and isinstance( node.func.value, CallableVariable @@ -445,6 +452,12 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: raise GuppyError("Keyword arguments are not supported", node.keywords[0]) node.func, ty = self.synthesize(node.func) + # When calling a `PartialApply` node, we just move the args into this call + if isinstance(node.func, PartialApply): + node.args = [*node.func.args, *node.args] + node.func = node.func.func + return self.visit_Call(node) + # First handle direct calls of user-defined functions and extension functions if isinstance(node.func, GlobalName) and isinstance( node.func.value, CallableVariable @@ -461,6 +474,23 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: else: raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func) + def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, GuppyType]: + node.value, ty = self.synthesize(node.value) + func = self.ctx.globals.get_instance_func(ty, node.attr) + if func is None: + raise GuppyTypeError( + f"Expression of type `{ty}` does not have a method named `{node.attr}`" + ) + # TODO: Infer type args + name = with_loc(node, with_type(func.ty, GlobalName(id=func.name, value=func))) + partial_ty = FunctionType( + func.ty.args[1:], + func.ty.returns, + func.ty.arg_names[1:] if func.ty.arg_names else None, + func.ty.quantified, + ) + return with_loc(node, PartialApply(func=name, args=[node.value])), partial_ty + def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, GuppyType]: node.value, ty = self.synthesize(node.value) exp_sig = FunctionType([ty], ExistentialTypeVar.new("Iter", False)) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 8d05fbed..8369ac40 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -33,6 +33,7 @@ GlobalName, LocalCall, LocalName, + PartialApply, TypeApply, ) @@ -191,6 +192,11 @@ def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: def visit_Call(self, node: ast.Call) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") + def visit_PartialApply(self, node: PartialApply) -> OutPortV: + func = self.visit(node.func) + args = [self.visit(arg) for arg in node.args] + return self.graph.add_partial(func, args).out_port(0) + def visit_TypeApply(self, node: TypeApply) -> OutPortV: func = self.visit(node.value) assert isinstance(func.ty, FunctionType) diff --git a/guppylang/custom.py b/guppylang/custom.py index e1142448..44c88f2a 100644 --- a/guppylang/custom.py +++ b/guppylang/custom.py @@ -1,7 +1,7 @@ import ast from abc import ABC, abstractmethod -from guppylang.ast_util import AstNode, get_type, with_loc, with_type +from guppylang.ast_util import AstNode, with_loc, with_type from guppylang.checker.core import Context, Globals from guppylang.checker.expr_checker import check_call, synthesize_call from guppylang.checker.func_checker import check_signature @@ -105,7 +105,7 @@ def compile_call( globals: CompiledGlobals, node: AstNode, ) -> list[OutPortV]: - self.call_compiler._setup(type_args, dfg, graph, globals, node) + self.call_compiler._setup(type_args, dfg, graph, globals, self, node) return self.call_compiler.compile(args) def load( @@ -190,6 +190,7 @@ class CustomCallCompiler(ABC): dfg: DFContainer graph: Hugr globals: CompiledGlobals + func: CustomFunction node: AstNode def _setup( @@ -198,12 +199,14 @@ def _setup( dfg: DFContainer, graph: Hugr, globals: CompiledGlobals, + func: CustomFunction, node: AstNode, ) -> None: self.type_args = type_args self.dfg = dfg self.graph = graph self.globals = globals + self.func = func self.node = node @abstractmethod @@ -242,7 +245,7 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: node = self.graph.add_node( self.op.model_copy(), inputs=args, parent=self.dfg.node ) - return_ty = get_type(self.node) + return_ty = self.func.ty.instantiate(self.type_args).returns return [node.add_out_port(ty) for ty in type_to_row(return_ty)] diff --git a/guppylang/hugr/hugr.py b/guppylang/hugr/hugr.py index 01a6cd30..113cf1e4 100644 --- a/guppylang/hugr/hugr.py +++ b/guppylang/hugr/hugr.py @@ -73,7 +73,7 @@ class OutPortCF(OutPort): TypeList = list[GuppyType] -@dataclass +@dataclass(eq=False) class Node(ABC): """Base class for a node in the graph. @@ -121,7 +121,7 @@ def out_ports(self) -> Iterator[OutPort]: return (self.out_port(i) for i in range(self.num_out_ports)) -@dataclass +@dataclass(eq=False) class VNode(Node): """A node with typed value ports.""" diff --git a/guppylang/nodes.py b/guppylang/nodes.py index 908fe55f..38ff6c2f 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -50,6 +50,11 @@ class GlobalCall(ast.expr): ) +class PartialApply(ast.expr): + func: ast.expr + args: list[ast.expr] + + class TypeApply(ast.expr): value: ast.expr tys: Sequence[GuppyType] diff --git a/tests/integration/test_call.py b/tests/integration/test_call.py index 01a520c9..ae7f3be3 100644 --- a/tests/integration/test_call.py +++ b/tests/integration/test_call.py @@ -51,3 +51,10 @@ def bar(x: int) -> int: return foo(x) validate(module.compile()) + + +def test_instance_call(validate): + @compile_guppy + def foo(x: int) -> int: + return x.__add__(2) + diff --git a/tests/integration/test_higher_order.py b/tests/integration/test_higher_order.py index 94de2b31..c0182623 100644 --- a/tests/integration/test_higher_order.py +++ b/tests/integration/test_higher_order.py @@ -55,6 +55,12 @@ def baz(y: int) -> None: validate(module.compile()) +def test_instance_func(validate): + @compile_guppy + def foo(x: int) -> Callable[[int], int]: + return x.__add__ + + def test_nested(validate): @compile_guppy def foo(x: int) -> Callable[[int], bool]: