Skip to content

Commit

Permalink
feat: Use cell name instead of file for notebook errors (#382)
Browse files Browse the repository at this point in the history
This makes the compiler output for notebooks deterministic. Closes #381.

Builds on top of the solution in #374, generalising from class to
function definitions.
  • Loading branch information
mark-koch authored Aug 13, 2024
1 parent 23b2a15 commit d542601
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 39 deletions.
13 changes: 12 additions & 1 deletion guppylang/definition/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from guppylang.definition.value import CallableDef, CompiledCallableDef
from guppylang.error import GuppyError
from guppylang.hugr_builder.hugr import DFContainingVNode, Hugr, Node, OutPortV
from guppylang.ipython_inspect import find_ipython_def, is_running_ipython
from guppylang.nodes import GlobalCall
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import FunctionType, Type, type_to_row
Expand Down Expand Up @@ -175,7 +176,17 @@ def parse_py_func(f: PyFunc) -> tuple[ast.FunctionDef, str | None]:
source = "".join(source_lines) # Lines already have trailing \n's
source = textwrap.dedent(source)
func_ast = ast.parse(source).body[0]
file = inspect.getsourcefile(f)
# In Jupyter notebooks, we shouldn't use `inspect.getsourcefile(f)` since it would
# only give us a dummy temporary file
file: str | None
if is_running_ipython():
file = "<In [?]>"
if isinstance(func_ast, ast.FunctionDef):
defn = find_ipython_def(func_ast.name)
if defn is not None:
file = f"<{defn.cell_name}>"
else:
file = inspect.getsourcefile(f)
if file is None:
raise GuppyError("Couldn't determine source file for function")
annotate_location(func_ast, source, file, line_offset)
Expand Down
47 changes: 9 additions & 38 deletions guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import Any, cast
from typing import Any

from guppylang.ast_util import AstNode, annotate_location
from guppylang.checker.core import Globals
Expand All @@ -24,6 +24,7 @@
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.hugr_builder.hugr import OutPortV
from guppylang.ipython_inspect import find_ipython_def, is_running_ipython
from guppylang.tys.arg import Argument
from guppylang.tys.param import Parameter, check_all_args
from guppylang.tys.parsing import type_from_ast
Expand Down Expand Up @@ -223,50 +224,20 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]:
return [constructor_def]


def is_running_ipython() -> bool:
"""Checks if we are currently running in IPython"""
try:
return get_ipython() is not None # type: ignore[name-defined]
except NameError:
return False


def get_ipython_cell_sources() -> list[str]:
"""Returns the source code of all cells in the running IPython session.
See https://github.com/wandb/weave/pull/1864
"""
shell = get_ipython() # type: ignore[name-defined] # noqa: F821
if not hasattr(shell, "user_ns"):
raise AttributeError("Cannot access user namespace")
cells = cast(list[str], shell.user_ns["In"])
# First cell is always empty
return cells[1:]


def parse_py_class(cls: type) -> ast.ClassDef:
"""Parses a Python class object into an AST."""
# We cannot use `inspect.getsourcelines` if we're running in IPython. See
# - https://bugs.python.org/issue33826
# - https://github.com/ipython/ipython/issues/11249
# - https://github.com/wandb/weave/pull/1864
if is_running_ipython():
cell_sources = get_ipython_cell_sources()
# Search cells in reverse order to find the most recent version of the class
for i, cell_source in enumerate(reversed(cell_sources)):
try:
cell_ast = ast.parse(cell_source)
except SyntaxError:
continue
# Search body in reverse order to find the most recent version of the class
for node in reversed(cell_ast.body):
if getattr(node, "name", None) == cls.__name__:
cell_name = f"<In [{len(cell_sources) - i}]>"
annotate_location(node, cell_source, cell_name, 1)
if not isinstance(node, ast.ClassDef):
raise GuppyError("Expected a class definition", node)
return node
raise ValueError(f"Couldn't find source for class `{cls.__name__}`")
defn = find_ipython_def(cls.__name__)
if defn is None:
raise ValueError(f"Couldn't find source for class `{cls.__name__}`")
annotate_location(defn.node, defn.cell_source, f"<{defn.cell_name}>", 1)
if not isinstance(defn.node, ast.ClassDef):
raise GuppyError("Expected a class definition", defn.node)
return defn.node
else:
source_lines, line_offset = inspect.getsourcelines(cls)
source = "".join(source_lines) # Lines already have trailing \n's
Expand Down
56 changes: 56 additions & 0 deletions guppylang/ipython_inspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Tools for inspecting source code when running in IPython."""

import ast
from typing import NamedTuple, cast


def is_running_ipython() -> bool:
"""Checks if we are currently running in IPython"""
try:
return get_ipython() is not None # type: ignore[name-defined]
except NameError:
return False


def get_ipython_cell_sources() -> list[str]:
"""Returns the source code of all cells in the running IPython session.
See https://github.com/wandb/weave/pull/1864
"""
shell = get_ipython() # type: ignore[name-defined] # noqa: F821
if not hasattr(shell, "user_ns"):
raise AttributeError("Cannot access user namespace")
cells = cast(list[str], shell.user_ns["In"])
# First cell is always empty
return cells[1:]


class IPythonDef(NamedTuple):
"""AST of a definition in IPython together with the definition cell name."""

node: ast.FunctionDef | ast.ClassDef
cell_name: str
cell_source: str


def find_ipython_def(name: str) -> IPythonDef | None:
"""Tries to find a definition matching a given name in the current IPython session.
Note that this only finds *top-level* function or class definitions. Nested
definitions are not detected.
See https://github.com/wandb/weave/pull/1864
"""
cell_sources = get_ipython_cell_sources()
# Search cells in reverse order to find the most recent version of the definition
for i, cell_source in enumerate(reversed(cell_sources)):
try:
cell_ast = ast.parse(cell_source)
except SyntaxError:
continue
# Search body in reverse order to find the most recent version of the class
for node in reversed(cell_ast.body):
if isinstance(node, ast.FunctionDef | ast.ClassDef) and node.name == name:
cell_name = f"In [{len(cell_sources) - i}]"
return IPythonDef(node, cell_name, cell_source)
return None

0 comments on commit d542601

Please sign in to comment.