Skip to content

Commit

Permalink
feat: range() with single-argument (#452)
Browse files Browse the repository at this point in the history
Addresses #429 but maybe only partially.

* `xfail` some tests that look for single nodes in the compiled Hugr, as
this currently includes the whole stdlib. See
#470.
* `@typing.overload` the `@guppy` decorator to remove mypy error when
using `@guppy(module)`
* `parse_py_class` attempts `inspect.getsourcelines` when running
IPython, as this works for the stdlib
  • Loading branch information
acl-cqc authored Sep 9, 2024
1 parent 71340d2 commit d05f369
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 26 deletions.
8 changes: 7 additions & 1 deletion guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass, field
from pathlib import Path
from types import FrameType, ModuleType
from typing import Any, TypeVar
from typing import Any, TypeVar, overload

import hugr.ext
from hugr import ops
Expand Down Expand Up @@ -64,6 +64,12 @@ class _Guppy:
def __init__(self) -> None:
self._modules = {}

@overload
def __call__(self, arg: PyFunc) -> RawFunctionDef: ...

@overload
def __call__(self, arg: GuppyModule) -> FuncDefDecorator: ...

@pretty_errors
def __call__(self, arg: PyFunc | GuppyModule) -> FuncDefDecorator | RawFunctionDef:
"""Decorator to annotate Python functions as Guppy code.
Expand Down
38 changes: 19 additions & 19 deletions guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,30 +226,30 @@ def compile(self, args: list[Wire]) -> list[Wire]:

def parse_py_class(cls: type) -> ast.ClassDef:
"""Parses a Python class object into an AST."""
# We cannot use `inspect.getsourcelines` if we're running in IPython. See
# If we are running IPython, `inspect.getsourcelines` works only for builtins
# (guppy stdlib), but not for most/user-defined classes - see:
# - https://bugs.python.org/issue33826
# - https://github.com/ipython/ipython/issues/11249
# - https://github.com/wandb/weave/pull/1864
if is_running_ipython():
defn = find_ipython_def(cls.__name__)
if defn is None:
raise ValueError(f"Couldn't find source for class `{cls.__name__}`")
annotate_location(defn.node, defn.cell_source, f"<{defn.cell_name}>", 1)
if not isinstance(defn.node, ast.ClassDef):
raise GuppyError("Expected a class definition", defn.node)
return defn.node
else:
source_lines, line_offset = inspect.getsourcelines(cls)
source = "".join(source_lines) # Lines already have trailing \n's
source = textwrap.dedent(source)
cls_ast = ast.parse(source).body[0]
file = inspect.getsourcefile(cls)
if file is None:
raise GuppyError("Couldn't determine source file for class")
annotate_location(cls_ast, source, file, line_offset)
if not isinstance(cls_ast, ast.ClassDef):
raise GuppyError("Expected a class definition", cls_ast)
return cls_ast
if defn is not None:
annotate_location(defn.node, defn.cell_source, f"<{defn.cell_name}>", 1)
if not isinstance(defn.node, ast.ClassDef):
raise GuppyError("Expected a class definition", defn.node)
return defn.node
# else, fall through to handle builtins.
source_lines, line_offset = inspect.getsourcelines(cls)
source = "".join(source_lines) # Lines already have trailing \n's
source = textwrap.dedent(source)
cls_ast = ast.parse(source).body[0]
file = inspect.getsourcefile(cls)
if file is None:
raise GuppyError("Couldn't determine source file for class")
annotate_location(cls_ast, source, file, line_offset)
if not isinstance(cls_ast, ast.ClassDef):
raise GuppyError("Expected a class definition", cls_ast)
return cls_ast


def try_parse_generic_base(node: ast.expr) -> list[ast.expr] | None:
Expand Down
41 changes: 39 additions & 2 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,8 +875,45 @@ def print(x): ...
def property(x): ...


@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False)
def range(x): ...
@guppy.struct(builtins)
class Range:
stop: int

@guppy(builtins)
def __iter__(self: "Range") -> "RangeIter":
return RangeIter(0, self.stop) # type: ignore[call-arg]


@guppy.struct(builtins)
class RangeIter:
next: int
stop: int

@guppy(builtins)
def __iter__(self: "RangeIter") -> "RangeIter":
return self

@guppy(builtins)
def __hasnext__(self: "RangeIter") -> tuple[bool, "RangeIter"]:
return (self.next < self.stop, self)

@guppy(builtins)
def __next__(self: "RangeIter") -> tuple[int, "RangeIter"]:
# Fine not to check bounds while we can only be called from inside a `for` loop.
# if self.start >= self.stop:
# raise StopIteration
return (self.next, RangeIter(self.next + 1, self.stop)) # type: ignore[call-arg]

@guppy(builtins)
def __end__(self: "RangeIter") -> None:
pass


@guppy(builtins)
def range(stop: int) -> Range:
"""Limited version of python range().
Only a single argument (stop/limit) is supported."""
return Range(stop) # type: ignore[call-arg]


@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False)
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def main(xs: array[float, 42]) -> int:
validate(package)

hg = package.modules[0]
[val] = [data.op for node, data in hg.nodes() if isinstance(data.op, ops.Const)]
vals = [data.op for node, data in hg.nodes() if isinstance(data.op, ops.Const)]
if len(vals) > 1:
pytest.xfail(reason="hugr-includes-whole-stdlib")
[val] = vals
assert isinstance(val, ops.Const)
assert isinstance(val.val, IntVal)
assert val.val.v == 42
Expand Down
6 changes: 5 additions & 1 deletion tests/integration/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from hugr import ops

from guppylang.decorator import guppy
Expand Down Expand Up @@ -68,11 +69,14 @@ def test_func_def_name():
def func_name() -> None:
return

[def_op] = [
defs = [
data.op
for n, data in func_name.modules[0].nodes()
if isinstance(data.op, ops.FuncDefn)
]
if len(defs) > 1:
pytest.xfail(reason="hugr-includes-whole-stdlib")
[def_op] = defs
assert isinstance(def_op, ops.FuncDefn)
assert def_op.f_name == "func_name"

Expand Down
11 changes: 9 additions & 2 deletions tests/integration/test_extern.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from hugr import ops, val

from guppylang.decorator import guppy
Expand All @@ -17,7 +18,10 @@ def main() -> float:
validate(package)

hg = package.modules[0]
[c] = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)]
consts = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)]
if len(consts) > 1:
pytest.xfail(reason="hugr-includes-whole-stdlib")
[c] = consts
assert isinstance(c.val, val.Extension)
assert c.val.val["symbol"] == "ext"

Expand All @@ -35,7 +39,10 @@ def main() -> int:
validate(package)

hg = package.modules[0]
[c] = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)]
consts = [data.op for n, data in hg.nodes() if isinstance(data.op, ops.Const)]
if len(consts) > 1:
pytest.xfail(reason="hugr-includes-whole-stdlib")
[c] = consts
assert isinstance(c.val, val.Extension)
assert c.val.val["symbol"] == "foo"

Expand Down
26 changes: 26 additions & 0 deletions tests/integration/test_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from guppylang.decorator import guppy
from guppylang.prelude.builtins import nat, range
from guppylang.module import GuppyModule
from tests.util import compile_guppy

def test_range(validate, run_int_fn):
module = GuppyModule("test_range")

@guppy(module)
def main() -> int:
total = 0
for x in range(5):
total += x + 100 # Make the initial 0 obvious
return total

@guppy(module)
def negative() -> int:
total = 0
for x in range(-3):
total += 100 + x
return total

compiled = module.compile()
validate(compiled)
run_int_fn(compiled, expected=510)
run_int_fn(compiled, expected=0, fn_name="negative")

0 comments on commit d05f369

Please sign in to comment.