diff --git a/qualtran/_infra/adjoint_test.py b/qualtran/_infra/adjoint_test.py index 1b8567cd8..828214b19 100644 --- a/qualtran/_infra/adjoint_test.py +++ b/qualtran/_infra/adjoint_test.py @@ -11,25 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import cached_property -from typing import cast, Dict, TYPE_CHECKING +from typing import cast import pytest import sympy import qualtran.testing as qlt_testing -from qualtran import Adjoint, Bloq, CompositeBloq, Side, Signature +from qualtran import Adjoint, CompositeBloq, Side from qualtran._infra.adjoint import _adjoint_cbloq from qualtran.bloqs.basic_gates import CNOT, CSwap, ZeroState from qualtran.bloqs.for_testing.atom import TestAtom from qualtran.bloqs.for_testing.with_call_graph import TestBloqWithCallGraph from qualtran.bloqs.for_testing.with_decomposition import TestParallelCombo, TestSerialCombo -from qualtran.cirq_interop.t_complexity_protocol import TComplexity from qualtran.drawing import LarrowTextBox, RarrowTextBox, Text -if TYPE_CHECKING: - from qualtran import BloqBuilder, SoquetT - def test_serial_combo_adjoint(): # The normal decomposition is three `TestAtom` tagged atom{0,1,2}. @@ -168,37 +163,6 @@ def test_wire_symbol(): assert isinstance(adj_ws, RarrowTextBox) -class TAcceptsAdjoint(TestAtom): - def _t_complexity_(self, adjoint: bool = False) -> TComplexity: - return TComplexity(t=2 if adjoint else 1) - - -class TDoesNotAcceptAdjoint(TestAtom): - def _t_complexity_(self) -> TComplexity: - return TComplexity(t=3) - - -class DecomposesIntoTAcceptsAdjoint(Bloq): - @cached_property - def signature(self) -> Signature: - return Signature.build(q=1) - - def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']: - soqs = bb.add_d(TAcceptsAdjoint(), **soqs) - return soqs - - -def test_t_complexity(): - assert TAcceptsAdjoint().t_complexity().t == 1 - assert Adjoint(TAcceptsAdjoint()).t_complexity().t == 2 - - assert DecomposesIntoTAcceptsAdjoint().t_complexity().t == 1 - assert Adjoint(DecomposesIntoTAcceptsAdjoint()).t_complexity().t == 2 - - assert TDoesNotAcceptAdjoint().t_complexity().t == 3 - assert Adjoint(TDoesNotAcceptAdjoint()).t_complexity().t == 3 - - @pytest.mark.notebook def test_notebook(): qlt_testing.execute_notebook('../Adjoint') diff --git a/qualtran/bloqs/arithmetic/controlled_addition.py b/qualtran/bloqs/arithmetic/controlled_addition.py index 49014ec6a..aa3f744a6 100644 --- a/qualtran/bloqs/arithmetic/controlled_addition.py +++ b/qualtran/bloqs/arithmetic/controlled_addition.py @@ -35,7 +35,6 @@ from qualtran.bloqs.arithmetic.addition import Add from qualtran.bloqs.bookkeeping import Cast from qualtran.bloqs.mcmt.and_bloq import And -from qualtran.cirq_interop.t_complexity_protocol import TComplexity from qualtran.resource_counting.generalizers import ignore_split_join from qualtran.simulation.classical_sim import add_ints @@ -156,12 +155,6 @@ def build_composite_bloq( ctrl = bb.join(np.array([ctrl_q])) return {'ctrl': ctrl, 'a': a, 'b': b} - def _t_complexity_(self): - n = self.b_dtype.bitsize - num_and = self.a_dtype.bitsize + self.b_dtype.bitsize - 1 - num_clifford = 33 * (n - 2) + 43 - return TComplexity(t=4 * num_and, clifford=num_clifford) - def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: return { (And(self.cv, 1), self.a_dtype.bitsize), diff --git a/qualtran/bloqs/basic_gates/hadamard.py b/qualtran/bloqs/basic_gates/hadamard.py index c3fd0d791..eca9c99e8 100644 --- a/qualtran/bloqs/basic_gates/hadamard.py +++ b/qualtran/bloqs/basic_gates/hadamard.py @@ -184,7 +184,7 @@ def _t_complexity_(self) -> 'TComplexity': def my_static_costs(self, cost_key: 'CostKey'): from qualtran.resource_counting import GateCounts, QECGatesCost - if cost_key == QECGatesCost(): + if isinstance(cost_key, QECGatesCost): # This is based on the decomposition provided by `cirq.decompose_multi_controlled_rotation` # which uses three cirq.MatrixGate's to do a controlled version of any single-qubit gate. # The first MatrixGate happens to be a clifford, Hadamard operation in this case. diff --git a/qualtran/bloqs/basic_gates/rotation_test.py b/qualtran/bloqs/basic_gates/rotation_test.py index dbfd317fc..a13ad9856 100644 --- a/qualtran/bloqs/basic_gates/rotation_test.py +++ b/qualtran/bloqs/basic_gates/rotation_test.py @@ -20,13 +20,28 @@ from qualtran._infra.gate_with_registers import get_named_qubits from qualtran.bloqs.basic_gates import CZPowGate, Rx, Ry, Rz, XPowGate, YPowGate, ZPowGate from qualtran.bloqs.basic_gates.rotation import _rx, _ry, _rz +from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost +from qualtran.resource_counting.classify_bloqs import bloq_is_rotation, bloq_is_t_like def test_rotation_gates(): angle = np.pi / 4.0 - tcount = 52 - assert Rx(angle).t_complexity().t_incl_rotations() == tcount - assert Ry(angle).t_complexity().t_incl_rotations() == tcount + # In prior versions of the library, only Rz(pi/4) would simplify to a T gate in gate counts. + # The others would report the synthesis cost for an arbitrary angle, which was reported as + # 52 T-gates. + assert not bloq_is_rotation(Rx(angle)) + assert not bloq_is_rotation(Ry(angle)) + assert not bloq_is_rotation(Rz(angle)) + assert bloq_is_t_like(Rx(angle)) + assert bloq_is_t_like(Ry(angle)) + assert bloq_is_t_like(Rz(angle)) + + assert get_cost_value(Rx(angle), QECGatesCost()) == GateCounts(t=1) + assert get_cost_value(Ry(angle), QECGatesCost()) == GateCounts(t=1) + assert get_cost_value(Rz(angle), QECGatesCost()) == GateCounts(t=1) + + assert Rx(angle).t_complexity().t_incl_rotations() == 1 + assert Ry(angle).t_complexity().t_incl_rotations() == 1 assert Rz(angle).t_complexity().t_incl_rotations() == 1 diff --git a/qualtran/bloqs/for_testing/atom.py b/qualtran/bloqs/for_testing/atom.py index 7c04867b6..a52273dee 100644 --- a/qualtran/bloqs/for_testing/atom.py +++ b/qualtran/bloqs/for_testing/atom.py @@ -68,7 +68,7 @@ def my_tensors( ] def my_static_costs(self, cost_key: 'CostKey'): - if cost_key == QECGatesCost(): + if isinstance(cost_key, QECGatesCost): return GateCounts(t=100) return NotImplemented diff --git a/qualtran/cirq_interop/_cirq_to_bloq.py b/qualtran/cirq_interop/_cirq_to_bloq.py index 01192717e..078051bb1 100644 --- a/qualtran/cirq_interop/_cirq_to_bloq.py +++ b/qualtran/cirq_interop/_cirq_to_bloq.py @@ -48,6 +48,7 @@ ) from qualtran.cirq_interop._interop_qubit_manager import InteropQubitManager from qualtran.cirq_interop.t_complexity_protocol import _from_directly_countable_cirq, TComplexity +from qualtran.resource_counting import CostKey, GateCounts, QECGatesCost if TYPE_CHECKING: import quimb.tensor as qtn @@ -108,12 +109,6 @@ def my_tensors( self.cirq_gate, self.signature, incoming=incoming, outgoing=outgoing ) - def _t_complexity_(self) -> 'TComplexity': - t_count = _from_directly_countable_cirq(self.cirq_gate) - if t_count is None: - raise ValueError(f"Cirq gate must be directly countable, not {self.cirq_gate}") - return t_count - def as_cirq_op( self, qubit_manager: 'cirq.QubitManager', **in_quregs: 'CirqQuregT' ) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]: @@ -153,6 +148,19 @@ def pretty_name(self) -> str: def cirq_gate(self) -> cirq.Gate: return self.gate + def _t_complexity_(self) -> 'TComplexity': + t_count = _from_directly_countable_cirq(self.cirq_gate) + if t_count is None: + raise ValueError(f"Cirq gate must be directly countable, not {self.cirq_gate}") + return t_count + + def my_static_costs(self, cost_key: 'CostKey'): + if isinstance(cost_key, QECGatesCost): + t_count = _from_directly_countable_cirq(self.cirq_gate) + if t_count is None: + raise ValueError(f"Cirq gate must be directly countable, not {self.cirq_gate}") + return GateCounts(t=t_count.t, rotation=t_count.rotations, clifford=t_count.clifford) + def _cirq_wire_symbol_to_qualtran_wire_symbol(symbol: str, side: Side) -> 'WireSymbol': from qualtran.drawing import Circle, directional_text_box, ModPlus diff --git a/qualtran/cirq_interop/t_complexity_protocol.py b/qualtran/cirq_interop/t_complexity_protocol.py index 6e96057fd..7f7c4748b 100644 --- a/qualtran/cirq_interop/t_complexity_protocol.py +++ b/qualtran/cirq_interop/t_complexity_protocol.py @@ -61,6 +61,9 @@ def __mul__(self, other: int) -> 'TComplexity': def __rmul__(self, other: int) -> 'TComplexity': return self.__mul__(other) + def asdict(self): + return {'t': self.t, 'rotations': self.rotations, 'clifford': self.clifford} + def __str__(self) -> str: return ( f'T-count: {self.t:g}\n' @@ -236,6 +239,9 @@ def _t_complexity_for_bloq(bloq: Bloq) -> Optional[TComplexity]: return _t_complexity_from_strategies(bloq, strategies) +USE_NEW_GATE_COUNTING_FLAG = True + + def t_complexity(bloq: Bloq) -> TComplexity: """Returns the TComplexity of a bloq. @@ -248,6 +254,11 @@ def t_complexity(bloq: Bloq) -> TComplexity: Raises: TypeError: if none of the strategies can derive the t complexity. """ + if USE_NEW_GATE_COUNTING_FLAG: + from qualtran.resource_counting import get_cost_value, QECGatesCost + + return get_cost_value(bloq, QECGatesCost(legacy_shims=True)).to_legacy_t_complexity() + ret = _t_complexity_for_bloq(bloq) if ret is None: raise TypeError( diff --git a/qualtran/cirq_interop/t_complexity_protocol_test.py b/qualtran/cirq_interop/t_complexity_protocol_test.py index 1b54a320d..2572b85ec 100644 --- a/qualtran/cirq_interop/t_complexity_protocol_test.py +++ b/qualtran/cirq_interop/t_complexity_protocol_test.py @@ -17,7 +17,7 @@ import pytest from attrs import frozen -from qualtran import Bloq, GateWithRegisters, Signature +from qualtran import Bloq, DecomposeNotImplementedError, GateWithRegisters, Signature from qualtran._infra.gate_with_registers import get_named_qubits from qualtran.bloqs.basic_gates import CHadamard, GlobalPhase from qualtran.bloqs.mcmt.and_bloq import And @@ -85,7 +85,7 @@ def test_t_complexity_for_bloq_via_build_call_graph(): def test_t_complexity_for_bloq_does_not_support(): - with pytest.raises(TypeError): + with pytest.raises(DecomposeNotImplementedError): _ = t_complexity(DoesNotSupportTComplexityBloq()) @@ -213,45 +213,6 @@ def test_tagged_operations(): ) -def test_cache_clear(): - class Cachable1(Bloq): - def __init__(self) -> None: - self.num_calls = 0 - - @property - def signature(self) -> 'Signature': - return Signature([]) - - def _t_complexity_(self) -> TComplexity: - self.num_calls += 1 - return TComplexity(clifford=1) - - def __hash__(self): - # Manufacture a hash collision - return hash(2) - - class Cachable2(Cachable1): - def _t_complexity_(self) -> TComplexity: - self.num_calls += 1 - return TComplexity() - - def __hash__(self): - # Manufacture a hash collision - return hash(2) - - assert t_complexity(Cachable1()) == TComplexity(clifford=1) - # Using a global cache will result in a failure of this test since `cirq.X` has - # `T-complexity(clifford=1)` but we explicitly return `TComplexity()` for IsCachable - # operation; for which the hash would be equivalent to the hash of its subgate i.e. `cirq.X`. - # TODO: t_complexity protocol will be refactored (#735) - t_complexity.cache_clear() # type: ignore[attr-defined] - op = Cachable2() - assert t_complexity(op) == TComplexity() - assert t_complexity(op) == TComplexity() - assert op.num_calls == 1 - t_complexity.cache_clear() # type: ignore[attr-defined] - - @pytest.mark.notebook def test_notebook(): execute_notebook('t_complexity') diff --git a/qualtran/drawing/bloq_counts_graph.py b/qualtran/drawing/bloq_counts_graph.py index dbadfa84e..09bb97c0c 100644 --- a/qualtran/drawing/bloq_counts_graph.py +++ b/qualtran/drawing/bloq_counts_graph.py @@ -241,6 +241,7 @@ def format_qec_gates_cost(cls, val: 'GateCounts', agg: Optional[str] = None) -> 'and_bloq': 'Ands', 'clifford': 'Cliffords', 'rotation': 'Rotations', + 'rotations': 'Rotations', 'measurement': 'Measurements', } counts_dict: Mapping[str, SymbolicInt] @@ -252,6 +253,8 @@ def format_qec_gates_cost(cls, val: 'GateCounts', agg: Optional[str] = None) -> counts_dict = val.total_t_and_ccz_count() elif agg == 'beverland': counts_dict = val.total_beverland_count() + elif agg == 'legacy': + counts_dict = val.to_legacy_t_complexity().asdict() else: raise ValueError(f"Unknown aggregation mode {agg}.") diff --git a/qualtran/resource_counting/_bloq_counts.py b/qualtran/resource_counting/_bloq_counts.py index 95cf98d69..bc699b96b 100644 --- a/qualtran/resource_counting/_bloq_counts.py +++ b/qualtran/resource_counting/_bloq_counts.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import warnings from collections import defaultdict -from typing import Callable, Dict, Sequence, Tuple, TYPE_CHECKING +from typing import Callable, cast, Dict, Sequence, Tuple, TYPE_CHECKING import attrs import networkx as nx @@ -33,6 +34,7 @@ if TYPE_CHECKING: from qualtran import Bloq + from qualtran.cirq_interop.t_complexity_protocol import TComplexity logger = logging.getLogger(__name__) @@ -209,6 +211,38 @@ def total_t_and_ccz_count(self, ts_per_rotation: int = 11) -> Dict[str, Symbolic n_t = self.t + ts_per_rotation * self.rotation return {'n_t': n_t, 'n_ccz': n_ccz} + def to_legacy_t_complexity( + self, + ts_per_toffoli: int = 4, + ts_per_cswap: int = 7, + ts_per_and_bloq: int = 4, + cliffords_per_and_bloq: int = 9, + cliffords_per_cswap: int = 10, + ) -> 'TComplexity': + """Return a legacy `TComplexity` object. + + This coalesces all the gate types into t, rotations, and clifford fields. The conversion + factors can be tweaked using the arguments to this method. + + The argument `cliffords_per_and_bloq` sets the base number of clifford gates to + add per `self.and_bloq`. To fully match the exact legacy `t_complexity` numbers, you + must enable `QECGatesCost(legacy_shims=True)`, which will enable a shim that directly + adds on clifford counts for the X-gates used to invert the And control lines. + """ + from qualtran.cirq_interop.t_complexity_protocol import TComplexity + + return TComplexity( + t=self.t + + ts_per_toffoli * self.toffoli + + ts_per_cswap * self.cswap + + ts_per_and_bloq * self.and_bloq, + rotations=cast(int, self.rotation), + clifford=self.clifford + + self.measurement + + cliffords_per_and_bloq * self.and_bloq + + cliffords_per_cswap * self.cswap, + ) + def total_beverland_count(self) -> Dict[str, SymbolicInt]: r"""Counts used by Beverland. et. al. using notation from the reference. @@ -235,18 +269,36 @@ def total_beverland_count(self) -> Dict[str, SymbolicInt]: } -@frozen +@frozen(kw_only=True) class QECGatesCost(CostKey[GateCounts]): """Counts specifically for 'expensive' gates in a surface code error correction scheme. The cost value type for this CostKey is `GateCounts`. + + Args: + legacy_shims: If enabled, modify the counting logic to match the peculiarities of + the legacy `t_complexity` protocol. """ + legacy_shims: bool = False + def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts]) -> GateCounts: from qualtran.bloqs.basic_gates import GlobalPhase, Identity, Toffoli, TwoBitCSwap from qualtran.bloqs.basic_gates._shims import Measure from qualtran.bloqs.bookkeeping._bookkeeping_bloq import _BookkeepingBloq - from qualtran.bloqs.mcmt.and_bloq import And + from qualtran.bloqs.mcmt import And, MultiTargetCNOT + + if self.legacy_shims: + legacy_val = bloq._t_complexity_() + if legacy_val is not NotImplemented: + warnings.warn( + "Please migrate explicit cost annotations to the general " + "`Bloq.my_static_costs` method override.", + DeprecationWarning, + ) + return GateCounts( + t=legacy_val.t, clifford=legacy_val.clifford, rotation=legacy_val.rotations + ) # T gates if bloq_is_t_like(bloq): @@ -262,14 +314,35 @@ def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts]) # 'And' bloqs if isinstance(bloq, And): + # To match the legacy `t_complexity` protocol, we can hack in the explicit + # counts for the clifford operations used to invert the control bit. + # Note: we *only* add in the clifford operations that correspond to correctly + # setting the control line. The other clifford operations inherent in compiling + # an And gate to the gateset considered by the legacy `t_complexity` protocol can be + # simply added in as part of `GateCounts.to_legacy_t_complexity()` + n_inverted_controls = (bloq.cv1 == 0) + int(bloq.cv2 == 0) if bloq.uncompute: - return GateCounts(measurement=1, clifford=1) - return GateCounts(and_bloq=1) + if self.legacy_shims: + return GateCounts(clifford=3 + 2 * n_inverted_controls, measurement=1) + else: + return GateCounts(measurement=1, clifford=1) + + if self.legacy_shims: + return GateCounts(and_bloq=1, clifford=2 * n_inverted_controls) + else: + return GateCounts(and_bloq=1) # CSwaps aka Fredkin if isinstance(bloq, TwoBitCSwap): return GateCounts(cswap=1) + if isinstance(bloq, MultiTargetCNOT): + if self.legacy_shims: + # Legacy mode: don't treat this as one clifford. Use its decomposition. + pass # fall through + else: + return GateCounts(clifford=1) + # Cliffords if bloq_is_clifford(bloq): return GateCounts(clifford=1) diff --git a/qualtran/resource_counting/_bloq_counts_test.py b/qualtran/resource_counting/_bloq_counts_test.py index 6688459a6..8c2328c5c 100644 --- a/qualtran/resource_counting/_bloq_counts_test.py +++ b/qualtran/resource_counting/_bloq_counts_test.py @@ -18,6 +18,8 @@ from qualtran.bloqs.basic_gates import Hadamard, TGate, Toffoli from qualtran.bloqs.basic_gates._shims import Measure from qualtran.bloqs.for_testing.costing import make_example_costing_bloqs +from qualtran.bloqs.mcmt import MultiTargetCNOT +from qualtran.cirq_interop.t_complexity_protocol import TComplexity from qualtran.resource_counting import BloqCount, GateCounts, get_cost_value, QECGatesCost @@ -91,5 +93,17 @@ def test_qec_gates_cost(): [mcmt.MultiControlX(cvs=(1, 1, 1)), GateCounts(and_bloq=2, measurement=2, clifford=3)], ], ) -def test_algorithm_summary_counts(bloq, counts): +def test_get_cost_value_qec_gates_cost(bloq, counts): assert get_cost_value(bloq, QECGatesCost()) == counts + + +def test_count_multi_target_cnot(): + b = MultiTargetCNOT(bitsize=12) + + # MultiTargetCNOT can be done in one clifford cycle on the surface code. + assert get_cost_value(b, QECGatesCost()) == GateCounts(clifford=1) + + # And/or we could respect its decomposition. + # TODO: https://github.com/quantumlib/Qualtran/issues/1318 + assert get_cost_value(b, QECGatesCost(legacy_shims=True)) == GateCounts(clifford=23) + assert b.t_complexity() == TComplexity(clifford=23) diff --git a/qualtran/resource_counting/classify_bloqs.py b/qualtran/resource_counting/classify_bloqs.py index 60ab826e6..eae56e52e 100644 --- a/qualtran/resource_counting/classify_bloqs.py +++ b/qualtran/resource_counting/classify_bloqs.py @@ -182,26 +182,12 @@ def bloq_is_clifford(b: Bloq) -> bool: ) from qualtran.bloqs.basic_gates.rotation import Rx, Ry, Rz, XPowGate, YPowGate, ZPowGate from qualtran.bloqs.bookkeeping import ArbitraryClifford - from qualtran.bloqs.mcmt.multi_target_cnot import MultiTargetCNOT if isinstance(b, Adjoint): b = b.subbloq if isinstance( - b, - ( - TwoBitSwap, - Hadamard, - XGate, - ZGate, - YGate, - ArbitraryClifford, - CNOT, - MultiTargetCNOT, - CYGate, - CZ, - SGate, - ), + b, (TwoBitSwap, Hadamard, XGate, ZGate, YGate, ArbitraryClifford, CNOT, CYGate, CZ, SGate) ): return True diff --git a/qualtran/serialization/bloq_test.py b/qualtran/serialization/bloq_test.py index fffd55bed..5964908ee 100644 --- a/qualtran/serialization/bloq_test.py +++ b/qualtran/serialization/bloq_test.py @@ -28,6 +28,7 @@ from qualtran.cirq_interop._cirq_to_bloq_test import TestCNOT as TestCNOTCirq from qualtran.cirq_interop.t_complexity_protocol import TComplexity from qualtran.protos import registers_pb2 +from qualtran.resource_counting import CostKey, GateCounts, QECGatesCost from qualtran.serialization import bloq as bloq_serialization from qualtran.serialization import resolver_dict from qualtran.serialization.bloq import arg_from_proto @@ -98,6 +99,10 @@ def signature(self) -> 'Signature': def _t_complexity_(self) -> TComplexity: return TComplexity(t=7 * self.bitsize, clifford=10 * self.bitsize) + def my_static_costs(self, cost_key: 'CostKey'): + if isinstance(cost_key, QECGatesCost): + return GateCounts(t=7 * self.bitsize, clifford=10 * self.bitsize) + @dataclasses.dataclass(frozen=True) class TestTwoCSwap(Bloq): @@ -130,6 +135,7 @@ def test_cbloq_to_proto_test_two_cswap(): assert TestCSwap(bitsize) in bloq_serialization.bloqs_from_proto(cswap_proto_lib) cswap_proto = bloq_serialization.bloqs_to_proto(TestCSwap(100)).table[0].bloq + assert TestCSwap(100).t_complexity().t == 7 * 100 cbloq = TestTwoCSwap(100).decompose_bloq() proto_lib = bloq_serialization.bloqs_to_proto(cbloq) assert len(proto_lib.table) == 2