Skip to content

Commit

Permalink
feat: Add method to load pytket circuit without function stub (#712)
Browse files Browse the repository at this point in the history
Closes #670 

Both finding the call location and then having either a node or a span
for errors feels a bit hacky but I haven't been able to think of
something better so far and it seems to work.
  • Loading branch information
tatiana-s authored Dec 16, 2024
1 parent 3ad49ff commit ee1e3de
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 44 deletions.
52 changes: 50 additions & 2 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
RawFunctionDef,
)
from guppylang.definition.parameter import ConstVarDef, TypeVarDef
from guppylang.definition.pytket_circuits import RawPytketDef
from guppylang.definition.pytket_circuits import (
RawLoadPytketDef,
RawPytketDef,
)
from guppylang.definition.struct import RawStructDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import MissingModuleError, pretty_errors
Expand All @@ -47,7 +50,7 @@
get_calling_frame,
sphinx_running,
)
from guppylang.span import SourceMap
from guppylang.span import Loc, SourceMap, Span
from guppylang.tys.arg import Argument
from guppylang.tys.param import Parameter
from guppylang.tys.subst import Inst
Expand Down Expand Up @@ -501,6 +504,27 @@ def func(f: PyFunc) -> RawPytketDef:

return func

@pretty_errors
def load_pytket(
self, name: str, input_circuit: Any, module: GuppyModule | None = None
) -> RawLoadPytketDef:
"""Adds a pytket circuit function definition with implicit signature."""
err_msg = "Only pytket circuits can be passed to guppy.load_pytket"
try:
import pytket

if not isinstance(input_circuit, pytket.circuit.Circuit):
raise TypeError(err_msg) from None

except ImportError:
raise TypeError(err_msg) from None

mod = module or self.get_module()
span = _find_load_call(self._sources)
defn = RawLoadPytketDef(DefId.fresh(module), name, None, span, input_circuit)
mod.register_def(defn)
return defn


class _GuppyDummy:
"""A dummy class with the same interface as `@guppy` that is used during sphinx
Expand Down Expand Up @@ -586,3 +610,27 @@ def _parse_expr_string(ty_str: str, parse_err: str, sources: SourceMap) -> ast.e
node.col_offset = 0
node.end_col_offset = len(source_lines[info.lineno - 1]) - 1
return expr_ast


def _find_load_call(sources: SourceMap) -> Span | None:
"""Helper function to find location where pytket circuit was loaded.
Tries to define a source code span by inspecting the call stack.
"""
# Go back as first frame outside of compiler modules is 'pretty_errors_wrapped'.
if (caller_frame := get_calling_frame()) and (load_frame := caller_frame.f_back):
info = inspect.getframeinfo(load_frame)
filename = info.filename
lineno = info.lineno
sources.add_file(filename)
# If we don't support python <= 3.10, this can be done better with
# info.positions which gives you exact offsets.
# For now over approximate and make the span cover the entire line.
if load_module := inspect.getmodule(load_frame):
source_lines, _ = inspect.getsourcelines(load_module)
max_offset = len(source_lines[lineno - 1]) - 1

start = Loc(filename, lineno, 0)
end = Loc(filename, lineno, max_offset)
return Span(start, end)
return None
119 changes: 78 additions & 41 deletions guppylang/definition/pytket_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
)
from guppylang.definition.ty import TypeDef
from guppylang.definition.value import CallableDef, CallReturnWires, CompiledCallableDef
from guppylang.error import GuppyError
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.nodes import GlobalCall
from guppylang.span import SourceMap
from guppylang.span import SourceMap, Span, ToSpan
from guppylang.tys.builtin import bool_type
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import (
Expand Down Expand Up @@ -74,65 +74,73 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef":
func_ast, globals.with_python_scope(self.python_scope)
)

# Compare signatures.
# TODO: Allow arrays as arguments.
# Retrieve circuit signature and compare.
try:
import pytket

if isinstance(self.input_circuit, pytket.circuit.Circuit):
try:
import tket2 # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401

qubit = cast(TypeDef, globals["qubit"]).check_instantiate(
[], globals
)

circuit_signature = FunctionType(
[FuncInput(qubit, InputFlags.Inout)]
* self.input_circuit.n_qubits,
row_to_type([bool_type()] * self.input_circuit.n_bits),
)

if not (
circuit_signature.inputs == stub_signature.inputs
and circuit_signature.output == stub_signature.output
):
# TODO: Implement pretty-printing for signatures in order to add
# a note for expected vs. actual types.
raise GuppyError(PytketSignatureMismatch(func_ast, self.name))
except ImportError:
err = Tket2NotInstalled(func_ast)
err.add_sub_diagnostic(Tket2NotInstalled.InstallInstruction(None))
raise GuppyError(err) from None
except ImportError:
pass
circuit_signature = _signature_from_circuit(
self.input_circuit, globals, self.defined_at
)
if not (
circuit_signature.inputs == stub_signature.inputs
and circuit_signature.output == stub_signature.output
):
# TODO: Implement pretty-printing for signatures in order to add
# a note for expected vs. actual types.
raise GuppyError(PytketSignatureMismatch(func_ast, self.name))

return ParsedPytketDef(
self.id,
self.name,
func_ast,
stub_signature,
self.python_scope,
self.input_circuit,
)


@dataclass(frozen=True)
class RawLoadPytketDef(ParsableDef):
"""A raw definition for loading pytket circuits without explicit function stub.
Args:
id: The unique definition identifier.
name: The name of the circuit function.
defined_at: The AST node of the definition (here always None).
source_span: The source span where the circuit was loaded.
input_circuit: The user-provided pytket circuit.
"""

source_span: Span | None
input_circuit: Any

description: str = field(default="pytket circuit", init=False)

def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef":
"""Creates a function signature based on the user-provided circuit."""
circuit_signature = _signature_from_circuit(
self.input_circuit, globals, self.source_span
)

return ParsedPytketDef(
self.id,
self.name,
self.defined_at,
circuit_signature,
self.input_circuit,
)


@dataclass(frozen=True)
class ParsedPytketDef(CallableDef, CompilableDef):
"""A circuit definition with parsed and checked signature.
"""A circuit definition with signature.
Args:
id: The unique definition identifier.
name: The name of the function.
defined_at: The AST node where the function was defined.
defined_at: The AST node of the function stub, if there is one.
ty: The type of the function.
python_scope: The Python scope where the function was defined.
input_circuit: The user-provided pytket circuit.
"""

defined_at: ast.FunctionDef
ty: FunctionType
python_scope: PyScope
input_circuit: Any

description: str = field(default="pytket circuit", init=False)
Expand Down Expand Up @@ -181,7 +189,6 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef"
self.name,
self.defined_at,
self.ty,
self.python_scope,
self.input_circuit,
outer_func,
)
Expand Down Expand Up @@ -214,7 +221,6 @@ class CompiledPytketDef(ParsedPytketDef, CompiledCallableDef):
name: The name of the function.
defined_at: The AST node where the function was defined.
ty: The type of the function.
python_scope: The Python scope where the function was defined.
input_circuit: The user-provided pytket circuit.
func_df: The Hugr function definition.
"""
Expand Down Expand Up @@ -243,3 +249,34 @@ def compile_call(
"""Compiles a call to the function."""
# Use implementation from function definition.
return compile_call(args, type_args, dfg, self.ty, self.func_def)


def _signature_from_circuit(
input_circuit: Any, globals: Globals, defined_at: ToSpan | None
) -> FunctionType:
"""Helper function for inferring a function signature from a pytket circuit."""
try:
import pytket

if isinstance(input_circuit, pytket.circuit.Circuit):
try:
import tket2 # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401

qubit = cast(TypeDef, globals["qubit"]).check_instantiate([], globals)

circuit_signature = FunctionType(
[FuncInput(qubit, InputFlags.Inout)] * input_circuit.n_qubits,
row_to_type([bool_type()] * input_circuit.n_bits),
)
except ImportError:
err = Tket2NotInstalled(defined_at)
err.add_sub_diagnostic(Tket2NotInstalled.InstallInstruction(None))
raise GuppyError(err) from None
else:
pass
except ImportError:
raise InternalGuppyError(
"Pytket error should have been caught earlier"
) from None
else:
return circuit_signature
11 changes: 11 additions & 0 deletions tests/error/py_errors/load_tket2_not_installed.err
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Error: Tket2 not installed (at $FILE:14:0)
|
12 | module.load(qubit)
13 |
14 | guppy.load_pytket("guppy_circ", circ, module)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Experimental pytket compatibility requires `tket2` to be
| installed

Help: Install tket2: `pip install tket2`

Guppy compilation failed due to 1 previous error
21 changes: 21 additions & 0 deletions tests/error/py_errors/load_tket2_not_installed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pytket import Circuit

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.quantum import qubit

circ = Circuit(2)
circ.X(0)
circ.Y(1)

module = GuppyModule("test")
module.load(qubit)

guppy.load_pytket("guppy_circ", circ, module)

@guppy(module)
def foo(q: qubit) -> None:
guppy_circ(q)


module.compile()
9 changes: 9 additions & 0 deletions tests/error/test_py_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
if x.is_file()
and x.suffix == ".py"
and x.name not in ("__init__.py", "tket2_not_installed.py")
and x.name not in ("__init__.py", "load_tket2_not_installed.py")
]

# Turn paths into strings, otherwise pytest doesn't display the names
Expand All @@ -34,3 +35,11 @@ def test_tket2_not_installed(capsys, snapshot):
pathlib.Path(__file__).parent.resolve() / "py_errors" / "tket2_not_installed.py"
)
run_error_test(str(path), capsys, snapshot)


@pytest.mark.skipif(tket2_installed, reason="tket2 is installed")
def test_load_tket2_not_installed(capsys, snapshot):
path = (
pathlib.Path(__file__).parent.resolve() / "py_errors" / "load_tket2_not_installed.py"
)
run_error_test(str(path), capsys, snapshot)
26 changes: 25 additions & 1 deletion tests/integration/test_pytket_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def foo(q: qubit) -> bool:


@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed")
@pytest.mark.skip("Not implemented")
def test_load_circuit(validate):
from pytket import Circuit

Expand All @@ -134,4 +133,29 @@ def test_load_circuit(validate):
def foo(q: qubit) -> None:
guppy_circ(q)

validate(module.compile())


@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed")
def test_load_circuits(validate):
from pytket import Circuit

circ1 = Circuit(1)
circ1.H(0)

circ2 = Circuit(2)
circ2.CX(0, 1)
circ2.measure_all()

module = GuppyModule("test")
module.load_all(quantum)

guppy.load_pytket("guppy_circ1", circ1, module)
guppy.load_pytket("guppy_circ2", circ2, module)

@guppy(module)
def foo(q1: qubit, q2: qubit, q3: qubit) -> tuple[bool, bool]:
guppy_circ1(q1)
return guppy_circ2(q2, q3)

validate(module.compile())

0 comments on commit ee1e3de

Please sign in to comment.