Skip to content

Commit

Permalink
fix: Consider type when deciding whether to pack up returns (#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch authored May 14, 2024
1 parent 324f2ee commit 4f24a07
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
12 changes: 8 additions & 4 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,14 @@ def visit_List(self, node: ast.List) -> OutPortV:
ops.DummyOp(name="MakeList"), inputs=[self.visit(e) for e in node.elts]
).add_out_port(get_type(node))

def _pack_returns(self, returns: list[OutPortV]) -> OutPortV:
def _pack_returns(self, returns: list[OutPortV], return_ty: Type) -> OutPortV:
"""Groups function return values into a tuple"""
if len(returns) != 1:
if isinstance(return_ty, TupleType | NoneType) and not return_ty.preserve:
assert len(returns) == (
len(return_ty.element_types) if isinstance(return_ty, TupleType) else 0
)
return self.graph.add_make_tuple(inputs=returns).out_port(0)
assert len(returns) == 1
return returns[0]

def visit_LocalCall(self, node: LocalCall) -> OutPortV:
Expand All @@ -177,7 +181,7 @@ def visit_LocalCall(self, node: LocalCall) -> OutPortV:
args = [self.visit(arg) for arg in node.args]
call = self.graph.add_indirect_call(func, args)
rets = [call.out_port(i) for i in range(len(type_to_row(func.ty.output)))]
return self._pack_returns(rets)
return self._pack_returns(rets, func.ty.output)

def visit_GlobalCall(self, node: GlobalCall) -> OutPortV:
func = self.globals[node.def_id]
Expand All @@ -187,7 +191,7 @@ def visit_GlobalCall(self, node: GlobalCall) -> OutPortV:
rets = func.compile_call(
args, list(node.type_args), self.dfg, self.graph, self.globals, node
)
return self._pack_returns(rets)
return self._pack_returns(rets, func.ty.output)

def visit_Call(self, node: ast.Call) -> OutPortV:
raise InternalGuppyError("Node should have been removed during type checking.")
Expand Down
15 changes: 15 additions & 0 deletions tests/integration/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,18 @@ def bar(x: int) -> int:
return foo(x)

validate(module.compile())


def test_unary_tuple(validate):
module = GuppyModule("module")

@guppy(module)
def foo(x: int) -> tuple[int]:
return x,

@guppy(module)
def bar(x: int) -> int:
y, = foo(x)
return y

validate(module.compile())

0 comments on commit 4f24a07

Please sign in to comment.