Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Attribute access to instance methods #146

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
LocalCall,
LocalName,
MakeIter,
PartialApply,
PyExpr,
TypeApply,
)
Expand Down Expand Up @@ -214,6 +215,12 @@ def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]:
)
node.func, func_ty = self._synthesize(node.func, allow_free_vars=False)

# When calling a `PartialApply` node, we just move the args into this call
if isinstance(node.func, PartialApply):
node.args = [*node.func.args, *node.args]
node.func = node.func.func
return self.visit_Call(node, ty)

# First handle direct calls of user-defined functions and extension functions
if isinstance(node.func, GlobalName) and isinstance(
node.func.value, CallableVariable
Expand Down Expand Up @@ -445,6 +452,12 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]:
raise GuppyError("Keyword arguments are not supported", node.keywords[0])
node.func, ty = self.synthesize(node.func)

# When calling a `PartialApply` node, we just move the args into this call
if isinstance(node.func, PartialApply):
node.args = [*node.func.args, *node.args]
node.func = node.func.func
return self.visit_Call(node)

# First handle direct calls of user-defined functions and extension functions
if isinstance(node.func, GlobalName) and isinstance(
node.func.value, CallableVariable
Expand All @@ -461,6 +474,23 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]:
else:
raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func)

def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, GuppyType]:
node.value, ty = self.synthesize(node.value)
func = self.ctx.globals.get_instance_func(ty, node.attr)
if func is None:
raise GuppyTypeError(
f"Expression of type `{ty}` does not have a method named `{node.attr}`"
)
# TODO: Infer type args
name = with_loc(node, with_type(func.ty, GlobalName(id=func.name, value=func)))
partial_ty = FunctionType(
func.ty.args[1:],
func.ty.returns,
func.ty.arg_names[1:] if func.ty.arg_names else None,
func.ty.quantified,
)
return with_loc(node, PartialApply(func=name, args=[node.value])), partial_ty

def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, GuppyType]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType([ty], ExistentialTypeVar.new("Iter", False))
Expand Down
6 changes: 6 additions & 0 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
GlobalName,
LocalCall,
LocalName,
PartialApply,
TypeApply,
)

Expand Down Expand Up @@ -191,6 +192,11 @@ def visit_GlobalCall(self, node: GlobalCall) -> OutPortV:
def visit_Call(self, node: ast.Call) -> OutPortV:
raise InternalGuppyError("Node should have been removed during type checking.")

def visit_PartialApply(self, node: PartialApply) -> OutPortV:
func = self.visit(node.func)
args = [self.visit(arg) for arg in node.args]
return self.graph.add_partial(func, args).out_port(0)

def visit_TypeApply(self, node: TypeApply) -> OutPortV:
func = self.visit(node.value)
assert isinstance(func.ty, FunctionType)
Expand Down
9 changes: 6 additions & 3 deletions guppylang/custom.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
from abc import ABC, abstractmethod

from guppylang.ast_util import AstNode, get_type, with_loc, with_type
from guppylang.ast_util import AstNode, with_loc, with_type
from guppylang.checker.core import Context, Globals
from guppylang.checker.expr_checker import check_call, synthesize_call
from guppylang.checker.func_checker import check_signature
Expand Down Expand Up @@ -105,7 +105,7 @@ def compile_call(
globals: CompiledGlobals,
node: AstNode,
) -> list[OutPortV]:
self.call_compiler._setup(type_args, dfg, graph, globals, node)
self.call_compiler._setup(type_args, dfg, graph, globals, self, node)
return self.call_compiler.compile(args)

def load(
Expand Down Expand Up @@ -190,6 +190,7 @@ class CustomCallCompiler(ABC):
dfg: DFContainer
graph: Hugr
globals: CompiledGlobals
func: CustomFunction
node: AstNode

def _setup(
Expand All @@ -198,12 +199,14 @@ def _setup(
dfg: DFContainer,
graph: Hugr,
globals: CompiledGlobals,
func: CustomFunction,
node: AstNode,
) -> None:
self.type_args = type_args
self.dfg = dfg
self.graph = graph
self.globals = globals
self.func = func
self.node = node

@abstractmethod
Expand Down Expand Up @@ -242,7 +245,7 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]:
node = self.graph.add_node(
self.op.model_copy(), inputs=args, parent=self.dfg.node
)
return_ty = get_type(self.node)
return_ty = self.func.ty.instantiate(self.type_args).returns
return [node.add_out_port(ty) for ty in type_to_row(return_ty)]


Expand Down
4 changes: 2 additions & 2 deletions guppylang/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class OutPortCF(OutPort):
TypeList = list[GuppyType]


@dataclass
@dataclass(eq=False)
class Node(ABC):
"""Base class for a node in the graph.

Expand Down Expand Up @@ -121,7 +121,7 @@ def out_ports(self) -> Iterator[OutPort]:
return (self.out_port(i) for i in range(self.num_out_ports))


@dataclass
@dataclass(eq=False)
class VNode(Node):
"""A node with typed value ports."""

Expand Down
5 changes: 5 additions & 0 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class GlobalCall(ast.expr):
)


class PartialApply(ast.expr):
func: ast.expr
args: list[ast.expr]


class TypeApply(ast.expr):
value: ast.expr
tys: Sequence[GuppyType]
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,10 @@ def bar(x: int) -> int:
return foo(x)

validate(module.compile())


def test_instance_call(validate):
@compile_guppy
def foo(x: int) -> int:
return x.__add__(2)

6 changes: 6 additions & 0 deletions tests/integration/test_higher_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def baz(y: int) -> None:
validate(module.compile())


def test_instance_func(validate):
@compile_guppy
def foo(x: int) -> Callable[[int], int]:
return x.__add__


def test_nested(validate):
@compile_guppy
def foo(x: int) -> Callable[[int], bool]:
Expand Down
Loading