Skip to content

Commit

Permalink
feat: Attribute access to instance methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Feb 2, 2024
1 parent 8c9e4d2 commit 6c4b6b7
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 5 deletions.
30 changes: 30 additions & 0 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
LocalCall,
LocalName,
MakeIter,
PartialApply,
PyExpr,
TypeApply,
)
Expand Down Expand Up @@ -214,6 +215,12 @@ def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]:
)
node.func, func_ty = self._synthesize(node.func, allow_free_vars=False)

# When calling a `PartialApply` node, we just move the args into this call
if isinstance(node.func, PartialApply):
node.args = [*node.func.args, *node.args]
node.func = node.func.func
return self.visit_Call(node, ty)

# First handle direct calls of user-defined functions and extension functions
if isinstance(node.func, GlobalName) and isinstance(
node.func.value, CallableVariable
Expand Down Expand Up @@ -445,6 +452,12 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]:
raise GuppyError("Keyword arguments are not supported", node.keywords[0])
node.func, ty = self.synthesize(node.func)

# When calling a `PartialApply` node, we just move the args into this call
if isinstance(node.func, PartialApply):
node.args = [*node.func.args, *node.args]
node.func = node.func.func
return self.visit_Call(node)

# First handle direct calls of user-defined functions and extension functions
if isinstance(node.func, GlobalName) and isinstance(
node.func.value, CallableVariable
Expand All @@ -461,6 +474,23 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]:
else:
raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func)

def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, GuppyType]:
node.value, ty = self.synthesize(node.value)
func = self.ctx.globals.get_instance_func(ty, node.attr)
if func is None:
raise GuppyTypeError(
f"Expression of type `{ty}` does not have a method named `{node.attr}`"
)
# TODO: Infer type args
name = with_loc(node, with_type(func.ty, GlobalName(id=func.name, value=func)))
partial_ty = FunctionType(
func.ty.args[1:],
func.ty.returns,
func.ty.arg_names[1:] if func.ty.arg_names else None,
func.ty.quantified,
)
return with_loc(node, PartialApply(func=name, args=[node.value])), partial_ty

def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, GuppyType]:
node.value, ty = self.synthesize(node.value)
exp_sig = FunctionType([ty], ExistentialTypeVar.new("Iter", False))
Expand Down
6 changes: 6 additions & 0 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
GlobalName,
LocalCall,
LocalName,
PartialApply,
TypeApply,
)

Expand Down Expand Up @@ -191,6 +192,11 @@ def visit_GlobalCall(self, node: GlobalCall) -> OutPortV:
def visit_Call(self, node: ast.Call) -> OutPortV:
raise InternalGuppyError("Node should have been removed during type checking.")

def visit_PartialApply(self, node: PartialApply) -> OutPortV:
func = self.visit(node.func)
args = [self.visit(arg) for arg in node.args]
return self.graph.add_partial(func, args).out_port(0)

def visit_TypeApply(self, node: TypeApply) -> OutPortV:
func = self.visit(node.value)
assert isinstance(func.ty, FunctionType)
Expand Down
9 changes: 6 additions & 3 deletions guppylang/custom.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
from abc import ABC, abstractmethod

from guppylang.ast_util import AstNode, get_type, with_loc, with_type
from guppylang.ast_util import AstNode, with_loc, with_type
from guppylang.checker.core import Context, Globals
from guppylang.checker.expr_checker import check_call, synthesize_call
from guppylang.checker.func_checker import check_signature
Expand Down Expand Up @@ -105,7 +105,7 @@ def compile_call(
globals: CompiledGlobals,
node: AstNode,
) -> list[OutPortV]:
self.call_compiler._setup(type_args, dfg, graph, globals, node)
self.call_compiler._setup(type_args, dfg, graph, globals, self, node)
return self.call_compiler.compile(args)

def load(
Expand Down Expand Up @@ -190,6 +190,7 @@ class CustomCallCompiler(ABC):
dfg: DFContainer
graph: Hugr
globals: CompiledGlobals
func: CustomFunction
node: AstNode

def _setup(
Expand All @@ -198,12 +199,14 @@ def _setup(
dfg: DFContainer,
graph: Hugr,
globals: CompiledGlobals,
func: CustomFunction,
node: AstNode,
) -> None:
self.type_args = type_args
self.dfg = dfg
self.graph = graph
self.globals = globals
self.func = func
self.node = node

@abstractmethod
Expand Down Expand Up @@ -242,7 +245,7 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]:
node = self.graph.add_node(
self.op.model_copy(), inputs=args, parent=self.dfg.node
)
return_ty = get_type(self.node)
return_ty = self.func.ty.instantiate(self.type_args).returns
return [node.add_out_port(ty) for ty in type_to_row(return_ty)]


Expand Down
4 changes: 2 additions & 2 deletions guppylang/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class OutPortCF(OutPort):
TypeList = list[GuppyType]


@dataclass
@dataclass(eq=False)
class Node(ABC):
"""Base class for a node in the graph.
Expand Down Expand Up @@ -121,7 +121,7 @@ def out_ports(self) -> Iterator[OutPort]:
return (self.out_port(i) for i in range(self.num_out_ports))


@dataclass
@dataclass(eq=False)
class VNode(Node):
"""A node with typed value ports."""

Expand Down
5 changes: 5 additions & 0 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class GlobalCall(ast.expr):
)


class PartialApply(ast.expr):
func: ast.expr
args: list[ast.expr]


class TypeApply(ast.expr):
value: ast.expr
tys: Sequence[GuppyType]
Expand Down
7 changes: 7 additions & 0 deletions tests/integration/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,10 @@ def bar(x: int) -> int:
return foo(x)

validate(module.compile())


def test_instance_call(validate):
@compile_guppy
def foo(x: int) -> int:
return x.__add__(2)

6 changes: 6 additions & 0 deletions tests/integration/test_higher_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def baz(y: int) -> None:
validate(module.compile())


def test_instance_func(validate):
@compile_guppy
def foo(x: int) -> Callable[[int], int]:
return x.__add__


def test_nested(validate):
@compile_guppy
def foo(x: int) -> Callable[[int], bool]:
Expand Down

0 comments on commit 6c4b6b7

Please sign in to comment.