From ee1e3defb941e9da29f462afc4c16f3e97147828 Mon Sep 17 00:00:00 2001 From: tatiana-s Date: Mon, 16 Dec 2024 09:30:47 +0000 Subject: [PATCH] feat: Add method to load pytket circuit without function stub (#712) Closes #670 Both finding the call location and then having either a node or a span for errors feels a bit hacky but I haven't been able to think of something better so far and it seems to work. --- guppylang/decorator.py | 52 +++++++- guppylang/definition/pytket_circuits.py | 119 ++++++++++++------ .../py_errors/load_tket2_not_installed.err | 11 ++ .../py_errors/load_tket2_not_installed.py | 21 ++++ tests/error/test_py_errors.py | 9 ++ tests/integration/test_pytket_circuits.py | 26 +++- 6 files changed, 194 insertions(+), 44 deletions(-) create mode 100644 tests/error/py_errors/load_tket2_not_installed.err create mode 100644 tests/error/py_errors/load_tket2_not_installed.py diff --git a/guppylang/decorator.py b/guppylang/decorator.py index c2a31e25..6d701b5a 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -30,7 +30,10 @@ RawFunctionDef, ) from guppylang.definition.parameter import ConstVarDef, TypeVarDef -from guppylang.definition.pytket_circuits import RawPytketDef +from guppylang.definition.pytket_circuits import ( + RawLoadPytketDef, + RawPytketDef, +) from guppylang.definition.struct import RawStructDef from guppylang.definition.ty import OpaqueTypeDef, TypeDef from guppylang.error import MissingModuleError, pretty_errors @@ -47,7 +50,7 @@ get_calling_frame, sphinx_running, ) -from guppylang.span import SourceMap +from guppylang.span import Loc, SourceMap, Span from guppylang.tys.arg import Argument from guppylang.tys.param import Parameter from guppylang.tys.subst import Inst @@ -501,6 +504,27 @@ def func(f: PyFunc) -> RawPytketDef: return func + @pretty_errors + def load_pytket( + self, name: str, input_circuit: Any, module: GuppyModule | None = None + ) -> RawLoadPytketDef: + """Adds a pytket circuit function definition with implicit signature.""" + err_msg = "Only pytket circuits can be passed to guppy.load_pytket" + try: + import pytket + + if not isinstance(input_circuit, pytket.circuit.Circuit): + raise TypeError(err_msg) from None + + except ImportError: + raise TypeError(err_msg) from None + + mod = module or self.get_module() + span = _find_load_call(self._sources) + defn = RawLoadPytketDef(DefId.fresh(module), name, None, span, input_circuit) + mod.register_def(defn) + return defn + class _GuppyDummy: """A dummy class with the same interface as `@guppy` that is used during sphinx @@ -586,3 +610,27 @@ def _parse_expr_string(ty_str: str, parse_err: str, sources: SourceMap) -> ast.e node.col_offset = 0 node.end_col_offset = len(source_lines[info.lineno - 1]) - 1 return expr_ast + + +def _find_load_call(sources: SourceMap) -> Span | None: + """Helper function to find location where pytket circuit was loaded. + + Tries to define a source code span by inspecting the call stack. + """ + # Go back as first frame outside of compiler modules is 'pretty_errors_wrapped'. + if (caller_frame := get_calling_frame()) and (load_frame := caller_frame.f_back): + info = inspect.getframeinfo(load_frame) + filename = info.filename + lineno = info.lineno + sources.add_file(filename) + # If we don't support python <= 3.10, this can be done better with + # info.positions which gives you exact offsets. + # For now over approximate and make the span cover the entire line. + if load_module := inspect.getmodule(load_frame): + source_lines, _ = inspect.getsourcelines(load_module) + max_offset = len(source_lines[lineno - 1]) - 1 + + start = Loc(filename, lineno, 0) + end = Loc(filename, lineno, max_offset) + return Span(start, end) + return None diff --git a/guppylang/definition/pytket_circuits.py b/guppylang/definition/pytket_circuits.py index caaf5b02..557b4231 100644 --- a/guppylang/definition/pytket_circuits.py +++ b/guppylang/definition/pytket_circuits.py @@ -28,9 +28,9 @@ ) from guppylang.definition.ty import TypeDef from guppylang.definition.value import CallableDef, CallReturnWires, CompiledCallableDef -from guppylang.error import GuppyError +from guppylang.error import GuppyError, InternalGuppyError from guppylang.nodes import GlobalCall -from guppylang.span import SourceMap +from guppylang.span import SourceMap, Span, ToSpan from guppylang.tys.builtin import bool_type from guppylang.tys.subst import Inst, Subst from guppylang.tys.ty import ( @@ -74,65 +74,73 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": func_ast, globals.with_python_scope(self.python_scope) ) + # Compare signatures. # TODO: Allow arrays as arguments. - # Retrieve circuit signature and compare. - try: - import pytket - - if isinstance(self.input_circuit, pytket.circuit.Circuit): - try: - import tket2 # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401 - - qubit = cast(TypeDef, globals["qubit"]).check_instantiate( - [], globals - ) - - circuit_signature = FunctionType( - [FuncInput(qubit, InputFlags.Inout)] - * self.input_circuit.n_qubits, - row_to_type([bool_type()] * self.input_circuit.n_bits), - ) - - if not ( - circuit_signature.inputs == stub_signature.inputs - and circuit_signature.output == stub_signature.output - ): - # TODO: Implement pretty-printing for signatures in order to add - # a note for expected vs. actual types. - raise GuppyError(PytketSignatureMismatch(func_ast, self.name)) - except ImportError: - err = Tket2NotInstalled(func_ast) - err.add_sub_diagnostic(Tket2NotInstalled.InstallInstruction(None)) - raise GuppyError(err) from None - except ImportError: - pass + circuit_signature = _signature_from_circuit( + self.input_circuit, globals, self.defined_at + ) + if not ( + circuit_signature.inputs == stub_signature.inputs + and circuit_signature.output == stub_signature.output + ): + # TODO: Implement pretty-printing for signatures in order to add + # a note for expected vs. actual types. + raise GuppyError(PytketSignatureMismatch(func_ast, self.name)) return ParsedPytketDef( self.id, self.name, func_ast, stub_signature, - self.python_scope, + self.input_circuit, + ) + + +@dataclass(frozen=True) +class RawLoadPytketDef(ParsableDef): + """A raw definition for loading pytket circuits without explicit function stub. + + Args: + id: The unique definition identifier. + name: The name of the circuit function. + defined_at: The AST node of the definition (here always None). + source_span: The source span where the circuit was loaded. + input_circuit: The user-provided pytket circuit. + """ + + source_span: Span | None + input_circuit: Any + + description: str = field(default="pytket circuit", init=False) + + def parse(self, globals: Globals, sources: SourceMap) -> "ParsedPytketDef": + """Creates a function signature based on the user-provided circuit.""" + circuit_signature = _signature_from_circuit( + self.input_circuit, globals, self.source_span + ) + + return ParsedPytketDef( + self.id, + self.name, + self.defined_at, + circuit_signature, self.input_circuit, ) @dataclass(frozen=True) class ParsedPytketDef(CallableDef, CompilableDef): - """A circuit definition with parsed and checked signature. + """A circuit definition with signature. Args: id: The unique definition identifier. name: The name of the function. - defined_at: The AST node where the function was defined. + defined_at: The AST node of the function stub, if there is one. ty: The type of the function. - python_scope: The Python scope where the function was defined. input_circuit: The user-provided pytket circuit. """ - defined_at: ast.FunctionDef ty: FunctionType - python_scope: PyScope input_circuit: Any description: str = field(default="pytket circuit", init=False) @@ -181,7 +189,6 @@ def compile_outer(self, module: DefinitionBuilder[OpVar]) -> "CompiledPytketDef" self.name, self.defined_at, self.ty, - self.python_scope, self.input_circuit, outer_func, ) @@ -214,7 +221,6 @@ class CompiledPytketDef(ParsedPytketDef, CompiledCallableDef): name: The name of the function. defined_at: The AST node where the function was defined. ty: The type of the function. - python_scope: The Python scope where the function was defined. input_circuit: The user-provided pytket circuit. func_df: The Hugr function definition. """ @@ -243,3 +249,34 @@ def compile_call( """Compiles a call to the function.""" # Use implementation from function definition. return compile_call(args, type_args, dfg, self.ty, self.func_def) + + +def _signature_from_circuit( + input_circuit: Any, globals: Globals, defined_at: ToSpan | None +) -> FunctionType: + """Helper function for inferring a function signature from a pytket circuit.""" + try: + import pytket + + if isinstance(input_circuit, pytket.circuit.Circuit): + try: + import tket2 # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401 + + qubit = cast(TypeDef, globals["qubit"]).check_instantiate([], globals) + + circuit_signature = FunctionType( + [FuncInput(qubit, InputFlags.Inout)] * input_circuit.n_qubits, + row_to_type([bool_type()] * input_circuit.n_bits), + ) + except ImportError: + err = Tket2NotInstalled(defined_at) + err.add_sub_diagnostic(Tket2NotInstalled.InstallInstruction(None)) + raise GuppyError(err) from None + else: + pass + except ImportError: + raise InternalGuppyError( + "Pytket error should have been caught earlier" + ) from None + else: + return circuit_signature diff --git a/tests/error/py_errors/load_tket2_not_installed.err b/tests/error/py_errors/load_tket2_not_installed.err new file mode 100644 index 00000000..244f14c5 --- /dev/null +++ b/tests/error/py_errors/load_tket2_not_installed.err @@ -0,0 +1,11 @@ +Error: Tket2 not installed (at $FILE:14:0) + | +12 | module.load(qubit) +13 | +14 | guppy.load_pytket("guppy_circ", circ, module) + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Experimental pytket compatibility requires `tket2` to be + | installed + +Help: Install tket2: `pip install tket2` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/py_errors/load_tket2_not_installed.py b/tests/error/py_errors/load_tket2_not_installed.py new file mode 100644 index 00000000..5630ad4f --- /dev/null +++ b/tests/error/py_errors/load_tket2_not_installed.py @@ -0,0 +1,21 @@ +from pytket import Circuit + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.quantum import qubit + +circ = Circuit(2) +circ.X(0) +circ.Y(1) + +module = GuppyModule("test") +module.load(qubit) + +guppy.load_pytket("guppy_circ", circ, module) + +@guppy(module) +def foo(q: qubit) -> None: + guppy_circ(q) + + +module.compile() \ No newline at end of file diff --git a/tests/error/test_py_errors.py b/tests/error/test_py_errors.py index bdb6924c..1f3889a3 100644 --- a/tests/error/test_py_errors.py +++ b/tests/error/test_py_errors.py @@ -16,6 +16,7 @@ if x.is_file() and x.suffix == ".py" and x.name not in ("__init__.py", "tket2_not_installed.py") + and x.name not in ("__init__.py", "load_tket2_not_installed.py") ] # Turn paths into strings, otherwise pytest doesn't display the names @@ -34,3 +35,11 @@ def test_tket2_not_installed(capsys, snapshot): pathlib.Path(__file__).parent.resolve() / "py_errors" / "tket2_not_installed.py" ) run_error_test(str(path), capsys, snapshot) + + +@pytest.mark.skipif(tket2_installed, reason="tket2 is installed") +def test_load_tket2_not_installed(capsys, snapshot): + path = ( + pathlib.Path(__file__).parent.resolve() / "py_errors" / "load_tket2_not_installed.py" + ) + run_error_test(str(path), capsys, snapshot) diff --git a/tests/integration/test_pytket_circuits.py b/tests/integration/test_pytket_circuits.py index dc3e1aa0..72322f16 100644 --- a/tests/integration/test_pytket_circuits.py +++ b/tests/integration/test_pytket_circuits.py @@ -118,7 +118,6 @@ def foo(q: qubit) -> bool: @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") -@pytest.mark.skip("Not implemented") def test_load_circuit(validate): from pytket import Circuit @@ -134,4 +133,29 @@ def test_load_circuit(validate): def foo(q: qubit) -> None: guppy_circ(q) + validate(module.compile()) + + +@pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") +def test_load_circuits(validate): + from pytket import Circuit + + circ1 = Circuit(1) + circ1.H(0) + + circ2 = Circuit(2) + circ2.CX(0, 1) + circ2.measure_all() + + module = GuppyModule("test") + module.load_all(quantum) + + guppy.load_pytket("guppy_circ1", circ1, module) + guppy.load_pytket("guppy_circ2", circ2, module) + + @guppy(module) + def foo(q1: qubit, q2: qubit, q3: qubit) -> tuple[bool, bool]: + guppy_circ1(q1) + return guppy_circ2(q2, q3) + validate(module.compile()) \ No newline at end of file