Skip to content

Commit

Permalink
feat: Allow calling of methods (#440)
Browse files Browse the repository at this point in the history
Closes  #439.

* Add a new `PartialApply` node to create a closure capturing the `self`
argument of methods
* Calls of partial applies are simplified to direct calls
* Higher-order usage of methods is only allowed if `self` is not linear
  • Loading branch information
mark-koch authored Sep 2, 2024
1 parent 5fd5d7c commit 5a59da3
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 0 deletions.
26 changes: 26 additions & 0 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
IterNext,
LocalCall,
MakeIter,
PartialApply,
PlaceNode,
PyExpr,
TensorCall,
Expand Down Expand Up @@ -239,6 +240,12 @@ def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]:
if isinstance(defn, CallableDef):
return defn.check_call(node.args, ty, node, self.ctx)

# When calling a `PartialApply` node, we just move the args into this call
if isinstance(node.func, PartialApply):
node.args = [*node.func.args, *node.args]
node.func = node.func.func
return self.visit_Call(node, ty)

# Otherwise, it must be a function as a higher-order value - something
# whose type is either a FunctionType or a Tuple of FunctionTypes
if isinstance(func_ty, FunctionType):
Expand Down Expand Up @@ -371,6 +378,19 @@ def visit_Attribute(self, node: ast.Attribute) -> tuple[ast.expr, Type]:
# you loose access to all fields besides `a`).
expr = FieldAccessAndDrop(value=node.value, struct_ty=ty, field=field)
return with_loc(node, expr), field.ty
elif func := self.ctx.globals.get_instance_func(ty, node.attr):
name = with_type(
func.ty, with_loc(node, GlobalName(id=func.name, def_id=func.id))
)
# Make a closure by partially applying the `self` argument
# TODO: Try to infer some type args based on `self`
result_ty = FunctionType(
func.ty.inputs[1:],
func.ty.output,
func.ty.input_names[1:] if func.ty.input_names else None,
func.ty.params,
)
return with_loc(node, PartialApply(func=name, args=[node.value])), result_ty
raise GuppyTypeError(
f"Expression of type `{ty}` has no attribute `{node.attr}`",
# Unfortunately, `node.attr` doesn't contain source annotations, so we have
Expand Down Expand Up @@ -517,6 +537,12 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]:
if isinstance(defn, CallableDef):
return defn.synthesize_call(node.args, node, self.ctx)

# When calling a `PartialApply` node, we just move the args into this call
if isinstance(node.func, PartialApply):
node.args = [*node.func.args, *node.args]
node.func = node.func.func
return self.visit_Call(node)

# Otherwise, it must be a function as a higher-order value, or a tensor
if isinstance(ty, FunctionType):
args, return_ty, inst = synthesize_call(ty, node.args, node, self.ctx)
Expand Down
14 changes: 14 additions & 0 deletions guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
GlobalCall,
InoutReturnSentinel,
LocalCall,
PartialApply,
PlaceNode,
TensorCall,
)
Expand Down Expand Up @@ -199,6 +200,19 @@ def visit_TensorCall(self, node: TensorCall) -> None:
self.visit(arg)
self._reassign_inout_args(node.tensor_ty, node.args)

def visit_PartialApply(self, node: PartialApply) -> None:
self.visit(node.func)
for arg in node.args:
ty = get_type(arg)
if ty.linear:
raise GuppyError(
f"Capturing a value with linear type `{ty}` in a closure is not "
"allowed. Try calling the function directly instead of using it as "
"a higher-order value.",
node,
)
self.visit(arg)

def visit_FieldAccessAndDrop(self, node: FieldAccessAndDrop) -> None:
# A field access on a value that is not a place. This means the value can no
# longer be accessed after the field has been projected out. Thus, this is only
Expand Down
11 changes: 11 additions & 0 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
GlobalCall,
GlobalName,
LocalCall,
PartialApply,
PlaceNode,
ResultExpr,
TensorCall,
Expand Down Expand Up @@ -324,6 +325,16 @@ def visit_GlobalCall(self, node: GlobalCall) -> Wire:
def visit_Call(self, node: ast.Call) -> Wire:
raise InternalGuppyError("Node should have been removed during type checking.")

def visit_PartialApply(self, node: PartialApply) -> Wire:
from guppylang.compiler.func_compiler import make_partial_op

func_ty = get_type(node.func)
assert isinstance(func_ty, FunctionType)
op = make_partial_op(func_ty, [get_type(arg) for arg in node.args])
return self.builder.add_op(
op, self.visit(node.func), *(self.visit(arg) for arg in node.args)
)

def visit_TypeApply(self, node: TypeApply) -> Wire:
# For now, we can only TypeApply global FunctionDefs/Decls.
if not isinstance(node.value, GlobalName):
Expand Down
16 changes: 16 additions & 0 deletions guppylang/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,22 @@ class TypeApply(ast.expr):
)


class PartialApply(ast.expr):
"""A partial function application.
This node is emitted when methods are loaded as values, since this requires
partially applying the `self` argument.
"""

func: ast.expr
args: list[ast.expr]

_fields = (
"func",
"args",
)


class FieldAccessAndDrop(ast.expr):
"""A field access on a struct, dropping all the remaining other fields."""

Expand Down
7 changes: 7 additions & 0 deletions tests/error/linear_errors/method_capture.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Guppy compilation failed. Error in file $FILE:22

20: @guppy(module)
21: def foo(s: Struct) -> Struct:
22: f = s.foo
^^^^^
GuppyError: Capturing a value with linear type `Struct` in a closure is not allowed. Try calling the function directly instead of using it as a higher-order value.
26 changes: 26 additions & 0 deletions tests/error/linear_errors/method_capture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import guppylang.prelude.quantum as quantum
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.prelude.quantum import qubit


module = GuppyModule("test")
module.load(quantum)


@guppy.struct(module)
class Struct:
q: qubit

@guppy(module)
def foo(self: "Struct") -> "Struct":
return self


@guppy(module)
def foo(s: Struct) -> Struct:
f = s.foo
return f()


module.compile()
11 changes: 11 additions & 0 deletions tests/integration/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,14 @@ def bar(x: int) -> int:
return y

validate(module.compile())


def test_method_call(validate):
module = GuppyModule("module")

@guppy(module)
def foo(x: int) -> int:
return x.__add__(2)

validate(module.compile())

11 changes: 11 additions & 0 deletions tests/integration/test_higher_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ def baz(y: int) -> None:
validate(module.compile())


def test_method(validate):
module = GuppyModule("module")

@guppy(module)
def foo(x: int) -> tuple[int, Callable[[int], int]]:
f = x.__add__
return f(1), f

validate(module.compile())


def test_nested(validate):
@compile_guppy
def foo(x: int) -> Callable[[int], bool]:
Expand Down
27 changes: 27 additions & 0 deletions tests/integration/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,33 @@ def main(a: StructA[StructA[float]], b: StructB[bool, int], c: StructC) -> None:
validate(module.compile())


def test_methods(validate):
module = GuppyModule("module")

@guppy.struct(module)
class StructA:
x: int

@guppy(module)
def foo(self: "StructA", y: int) -> int:
return 2 * self.x + y

@guppy.struct(module)
class StructB:
x: int
y: float

@guppy(module)
def bar(self: "StructB", a: StructA) -> float:
return a.foo(self.x) + self.y

@guppy(module)
def main(a: StructA, b: StructB) -> tuple[int, float]:
return a.foo(1), b.bar(a)

validate(module.compile())


def test_higher_order(validate):
module = GuppyModule("module")
T = guppy.type_var(module, "T")
Expand Down

0 comments on commit 5a59da3

Please sign in to comment.