Skip to content

Commit

Permalink
fix: Fix printing of generic function parameters (#516)
Browse files Browse the repository at this point in the history
Fixes #482
  • Loading branch information
mark-koch authored Oct 1, 2024
1 parent dd669c1 commit 5c18ef6
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
4 changes: 2 additions & 2 deletions guppylang/tys/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def to_existential(self) -> tuple[Argument, ExistentialVar]:
var = ExistentialTypeVar.fresh(self.name, self.can_be_linear)
return TypeArg(var), var

def to_bound(self, idx: int | None = None) -> Argument:
def to_bound(self, idx: int | None = None) -> TypeArg:
"""Creates a bound variable with a given index that can be instantiated for this
parameter.
"""
Expand Down Expand Up @@ -169,7 +169,7 @@ def to_existential(self) -> tuple[Argument, ExistentialVar]:
var = ExistentialConstVar.fresh(self.name, self.ty)
return ConstArg(var), var

def to_bound(self, idx: int | None = None) -> Argument:
def to_bound(self, idx: int | None = None) -> ConstArg:
"""Creates a bound variable with a given index that can be instantiated for this
parameter.
"""
Expand Down
4 changes: 2 additions & 2 deletions guppylang/tys/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def _visit_NumericType(self, ty: NumericType, inside_row: bool) -> str:
@_visit.register
def _visit_TypeParam(self, param: TypeParam, inside_row: bool) -> str:
# TODO: Print linearity?
return self.bound_names[-param.idx - 1]
return self.bound_names[param.idx]

@_visit.register
def _visit_ConstParam(self, param: ConstParam, inside_row: bool) -> str:
kind = self._visit(param.ty, True)
name = self.bound_names[-param.idx - 1]
name = self.bound_names[param.idx]
return f"{name}: {kind}"

@_visit.register
Expand Down
2 changes: 1 addition & 1 deletion tests/error/array_errors/linear_len.err
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:14
13: def main(qs: array[qubit, 42]) -> int:
14: return len(qs)
^^^^^^^
GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall n, T: nat. array[T, n] -> int` with linear type `qubit`
GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall T, n: nat. array[T, n] -> int` with linear type `qubit`
21 changes: 21 additions & 0 deletions tests/test_type_printing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from guppylang.tys.builtin import array_type_def
from guppylang.tys.param import ConstParam, TypeParam
from guppylang.tys.ty import (
FunctionType,
FuncInput,
NumericType,
OpaqueType,
InputFlags,
)


def test_generic_function_type():
ty_param = TypeParam(0, "T", can_be_linear=True)
len_param = ConstParam(1, "n", NumericType(NumericType.Kind.Nat))
array_ty = OpaqueType([ty_param.to_bound(0), len_param.to_bound(1)], array_type_def)
ty = FunctionType(
params=[ty_param, len_param],
inputs=[FuncInput(array_ty, InputFlags.Inout)],
output=ty_param.to_bound(0).ty,
)
assert str(ty) == "forall T, n: nat. array[T, n] -> T"

0 comments on commit 5c18ef6

Please sign in to comment.