diff --git a/guppylang/prelude/_internal/compiler/quantum.py b/guppylang/prelude/_internal/compiler/quantum.py index efd87f37..7adea95e 100644 --- a/guppylang/prelude/_internal/compiler/quantum.py +++ b/guppylang/prelude/_internal/compiler/quantum.py @@ -70,7 +70,7 @@ def __init__(self, opname: str): def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: from guppylang.prelude._internal.util import quantum_op - [q, angle] = args + [*qs, angle] = args [halfturns] = self.builder.add_op(ops.UnpackTuple([FLOAT_T]), angle) [mb_rotation] = self.builder.add_op(from_halfturns(), halfturns) @@ -81,11 +81,14 @@ def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: with conditional.add_case(1) as case: case.set_outputs(*case.inputs()) - q = self.builder.add_op( + qs = self.builder.add_op( quantum_op(self.opname)( - ht.FunctionType([ht.Qubit, ROTATION_T], [ht.Qubit]), [] + ht.FunctionType( + [ht.Qubit for _ in qs] + [ROTATION_T], [ht.Qubit for _ in qs] + ), + [], ), - q, + *qs, conditional, ) - return CallReturnWires(regular_returns=[], inout_returns=[q]) + return CallReturnWires(regular_returns=[], inout_returns=list(qs)) diff --git a/guppylang/prelude/quantum.py b/guppylang/prelude/quantum.py index 3c97a0d3..d1062cb7 100644 --- a/guppylang/prelude/quantum.py +++ b/guppylang/prelude/quantum.py @@ -35,6 +35,10 @@ def h(q: qubit) -> None: ... def cz(control: qubit, target: qubit) -> None: ... +@guppy.hugr_op(quantum_op("CY")) +def cy(control: qubit, target: qubit) -> None: ... + + @guppy.hugr_op(quantum_op("CX")) def cx(control: qubit, target: qubit) -> None: ... @@ -79,6 +83,18 @@ def rz(q: qubit, angle: angle) -> None: ... def rx(q: qubit, angle: angle) -> None: ... +@guppy.custom(RotationCompiler("Ry")) +def ry(q: qubit, angle: angle) -> None: ... + + +@guppy.custom(RotationCompiler("CRz")) +def crz(control: qubit, target: qubit, angle: angle) -> None: ... + + +@guppy.hugr_op(quantum_op("Toffoli")) +def toffoli(control1: qubit, control2: qubit, target: qubit) -> None: ... + + @guppy.hugr_op(quantum_op("QAlloc")) def dirty_qubit() -> qubit: ... diff --git a/guppylang/prelude/quantum_functional.py b/guppylang/prelude/quantum_functional.py index 92b95f1d..96aea683 100644 --- a/guppylang/prelude/quantum_functional.py +++ b/guppylang/prelude/quantum_functional.py @@ -39,6 +39,13 @@ def cx(control: qubit @ owned, target: qubit @ owned) -> tuple[qubit, qubit]: return control, target +@guppy(quantum_functional) +@no_type_check +def cy(control: qubit @ owned, target: qubit @ owned) -> tuple[qubit, qubit]: + quantum.cy(control, target) + return control, target + + @guppy(quantum_functional) @no_type_check def t(q: qubit @ owned) -> qubit: @@ -109,6 +116,31 @@ def rx(q: qubit @ owned, angle: angle) -> qubit: return q +@guppy(quantum_functional) +@no_type_check +def ry(q: qubit @ owned, angle: angle) -> qubit: + quantum.ry(q, angle) + return q + + +@guppy(quantum_functional) +@no_type_check +def crz( + control: qubit @ owned, target: qubit @ owned, angle: angle +) -> tuple[qubit, qubit]: + quantum.crz(control, target, angle) + return control, target + + +@guppy(quantum_functional) +@no_type_check +def toffoli( + control1: qubit @ owned, control2: qubit @ owned, target: qubit @ owned +) -> tuple[qubit, qubit, qubit]: + quantum.toffoli(control1, control2, target) + return control1, control2, target + + @guppy(quantum_functional) @no_type_check def phased_x(q: qubit @ owned, angle1: angle, angle2: angle) -> qubit: diff --git a/tests/integration/test_quantum.py b/tests/integration/test_quantum.py index 643036af..900f5d77 100644 --- a/tests/integration/test_quantum.py +++ b/tests/integration/test_quantum.py @@ -18,6 +18,7 @@ ) from guppylang.prelude.quantum_functional import ( cx, + cy, cz, h, t, @@ -30,7 +31,10 @@ zz_max, phased_x, rx, + ry, rz, + crz, + toffoli, zz_phase, reset, quantum_functional, @@ -85,6 +89,7 @@ def test_2qb_op(validate): @compile_quantum_guppy def test(q1: qubit @owned, q2: qubit @owned) -> tuple[qubit, qubit]: q1, q2 = cx(q1, q2) + q1, q2 = cy(q1, q2) q1, q2 = cz(q1, q2) q1, q2 = zz_max(q1, q2) return (q1, q2) @@ -92,6 +97,15 @@ def test(q1: qubit @owned, q2: qubit @owned) -> tuple[qubit, qubit]: validate(test) +def test_3qb_op(validate): + @compile_quantum_guppy + def test(q1: qubit @owned, q2: qubit @owned, q3: qubit @owned) -> tuple[qubit, qubit, qubit]: + q1, q2, q3 = toffoli(q1, q2, q3) + return (q1, q2, q3) + + validate(test) + + def test_measure_ops(validate): """Compile various measurement-related operations.""" @@ -110,9 +124,11 @@ def test_parametric(validate): """Compile various parametric operations.""" @compile_quantum_guppy - def test(q1: qubit @owned, q2: qubit @owned, a1: angle, a2: angle) -> tuple[qubit, qubit]: + def test(q1: qubit @owned, q2: qubit @owned, a1: angle, a2: angle, a3: angle) -> tuple[qubit, qubit]: q1 = rx(q1, a1) - q2 = rz(q2, a2) + q1 = ry(q1, a1) + q2 = rz(q2, a3) q1 = phased_x(q1, a1, a2) q1, q2 = zz_phase(q1, q2, a1) + q1, q2 = crz(q1, q2, a3) return (q1, q2)