Skip to content

Commit

Permalink
[QECGatesCost] Port everything to QECGatesCost (#1359)
Browse files Browse the repository at this point in the history
* Port everything to QECGatesCost

* test name

* link issue
  • Loading branch information
mpharrigan authored Aug 30, 2024
1 parent 8fc12b2 commit a36c50f
Show file tree
Hide file tree
Showing 13 changed files with 154 additions and 119 deletions.
40 changes: 2 additions & 38 deletions qualtran/_infra/adjoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import cast, Dict, TYPE_CHECKING
from typing import cast

import pytest
import sympy

import qualtran.testing as qlt_testing
from qualtran import Adjoint, Bloq, CompositeBloq, Side, Signature
from qualtran import Adjoint, CompositeBloq, Side
from qualtran._infra.adjoint import _adjoint_cbloq
from qualtran.bloqs.basic_gates import CNOT, CSwap, ZeroState
from qualtran.bloqs.for_testing.atom import TestAtom
from qualtran.bloqs.for_testing.with_call_graph import TestBloqWithCallGraph
from qualtran.bloqs.for_testing.with_decomposition import TestParallelCombo, TestSerialCombo
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.drawing import LarrowTextBox, RarrowTextBox, Text

if TYPE_CHECKING:
from qualtran import BloqBuilder, SoquetT


def test_serial_combo_adjoint():
# The normal decomposition is three `TestAtom` tagged atom{0,1,2}.
Expand Down Expand Up @@ -168,37 +163,6 @@ def test_wire_symbol():
assert isinstance(adj_ws, RarrowTextBox)


class TAcceptsAdjoint(TestAtom):
def _t_complexity_(self, adjoint: bool = False) -> TComplexity:
return TComplexity(t=2 if adjoint else 1)


class TDoesNotAcceptAdjoint(TestAtom):
def _t_complexity_(self) -> TComplexity:
return TComplexity(t=3)


class DecomposesIntoTAcceptsAdjoint(Bloq):
@cached_property
def signature(self) -> Signature:
return Signature.build(q=1)

def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
soqs = bb.add_d(TAcceptsAdjoint(), **soqs)
return soqs


def test_t_complexity():
assert TAcceptsAdjoint().t_complexity().t == 1
assert Adjoint(TAcceptsAdjoint()).t_complexity().t == 2

assert DecomposesIntoTAcceptsAdjoint().t_complexity().t == 1
assert Adjoint(DecomposesIntoTAcceptsAdjoint()).t_complexity().t == 2

assert TDoesNotAcceptAdjoint().t_complexity().t == 3
assert Adjoint(TDoesNotAcceptAdjoint()).t_complexity().t == 3


@pytest.mark.notebook
def test_notebook():
qlt_testing.execute_notebook('../Adjoint')
7 changes: 0 additions & 7 deletions qualtran/bloqs/arithmetic/controlled_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from qualtran.bloqs.arithmetic.addition import Add
from qualtran.bloqs.bookkeeping import Cast
from qualtran.bloqs.mcmt.and_bloq import And
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.resource_counting.generalizers import ignore_split_join
from qualtran.simulation.classical_sim import add_ints

Expand Down Expand Up @@ -156,12 +155,6 @@ def build_composite_bloq(
ctrl = bb.join(np.array([ctrl_q]))
return {'ctrl': ctrl, 'a': a, 'b': b}

def _t_complexity_(self):
n = self.b_dtype.bitsize
num_and = self.a_dtype.bitsize + self.b_dtype.bitsize - 1
num_clifford = 33 * (n - 2) + 43
return TComplexity(t=4 * num_and, clifford=num_clifford)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {
(And(self.cv, 1), self.a_dtype.bitsize),
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/basic_gates/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _t_complexity_(self) -> 'TComplexity':
def my_static_costs(self, cost_key: 'CostKey'):
from qualtran.resource_counting import GateCounts, QECGatesCost

if cost_key == QECGatesCost():
if isinstance(cost_key, QECGatesCost):
# This is based on the decomposition provided by `cirq.decompose_multi_controlled_rotation`
# which uses three cirq.MatrixGate's to do a controlled version of any single-qubit gate.
# The first MatrixGate happens to be a clifford, Hadamard operation in this case.
Expand Down
23 changes: 19 additions & 4 deletions qualtran/bloqs/basic_gates/rotation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,28 @@
from qualtran._infra.gate_with_registers import get_named_qubits
from qualtran.bloqs.basic_gates import CZPowGate, Rx, Ry, Rz, XPowGate, YPowGate, ZPowGate
from qualtran.bloqs.basic_gates.rotation import _rx, _ry, _rz
from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost
from qualtran.resource_counting.classify_bloqs import bloq_is_rotation, bloq_is_t_like


def test_rotation_gates():
def test_t_like_rotation_gates():
angle = np.pi / 4.0
tcount = 52
assert Rx(angle).t_complexity().t_incl_rotations() == tcount
assert Ry(angle).t_complexity().t_incl_rotations() == tcount
# In prior versions of the library, only Rz(pi/4) would simplify to a T gate in gate counts.
# The others would report the synthesis cost for an arbitrary angle, which was reported as
# 52 T-gates.
assert not bloq_is_rotation(Rx(angle))
assert not bloq_is_rotation(Ry(angle))
assert not bloq_is_rotation(Rz(angle))
assert bloq_is_t_like(Rx(angle))
assert bloq_is_t_like(Ry(angle))
assert bloq_is_t_like(Rz(angle))

assert get_cost_value(Rx(angle), QECGatesCost()) == GateCounts(t=1)
assert get_cost_value(Ry(angle), QECGatesCost()) == GateCounts(t=1)
assert get_cost_value(Rz(angle), QECGatesCost()) == GateCounts(t=1)

assert Rx(angle).t_complexity().t_incl_rotations() == 1
assert Ry(angle).t_complexity().t_incl_rotations() == 1
assert Rz(angle).t_complexity().t_incl_rotations() == 1


Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/for_testing/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def my_tensors(
]

def my_static_costs(self, cost_key: 'CostKey'):
if cost_key == QECGatesCost():
if isinstance(cost_key, QECGatesCost):
return GateCounts(t=100)
return NotImplemented

Expand Down
20 changes: 14 additions & 6 deletions qualtran/cirq_interop/_cirq_to_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from qualtran.cirq_interop._interop_qubit_manager import InteropQubitManager
from qualtran.cirq_interop.t_complexity_protocol import _from_directly_countable_cirq, TComplexity
from qualtran.resource_counting import CostKey, GateCounts, QECGatesCost

if TYPE_CHECKING:
import quimb.tensor as qtn
Expand Down Expand Up @@ -108,12 +109,6 @@ def my_tensors(
self.cirq_gate, self.signature, incoming=incoming, outgoing=outgoing
)

def _t_complexity_(self) -> 'TComplexity':
t_count = _from_directly_countable_cirq(self.cirq_gate)
if t_count is None:
raise ValueError(f"Cirq gate must be directly countable, not {self.cirq_gate}")
return t_count

def as_cirq_op(
self, qubit_manager: 'cirq.QubitManager', **in_quregs: 'CirqQuregT'
) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]:
Expand Down Expand Up @@ -153,6 +148,19 @@ def pretty_name(self) -> str:
def cirq_gate(self) -> cirq.Gate:
return self.gate

def _t_complexity_(self) -> 'TComplexity':
t_count = _from_directly_countable_cirq(self.cirq_gate)
if t_count is None:
raise ValueError(f"Cirq gate must be directly countable, not {self.cirq_gate}")
return t_count

def my_static_costs(self, cost_key: 'CostKey'):
if isinstance(cost_key, QECGatesCost):
t_count = _from_directly_countable_cirq(self.cirq_gate)
if t_count is None:
raise ValueError(f"Cirq gate must be directly countable, not {self.cirq_gate}")
return GateCounts(t=t_count.t, rotation=t_count.rotations, clifford=t_count.clifford)


def _cirq_wire_symbol_to_qualtran_wire_symbol(symbol: str, side: Side) -> 'WireSymbol':
from qualtran.drawing import Circle, directional_text_box, ModPlus
Expand Down
11 changes: 11 additions & 0 deletions qualtran/cirq_interop/t_complexity_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def __mul__(self, other: int) -> 'TComplexity':
def __rmul__(self, other: int) -> 'TComplexity':
return self.__mul__(other)

def asdict(self):
return {'t': self.t, 'rotations': self.rotations, 'clifford': self.clifford}

def __str__(self) -> str:
return (
f'T-count: {self.t:g}\n'
Expand Down Expand Up @@ -236,6 +239,9 @@ def _t_complexity_for_bloq(bloq: Bloq) -> Optional[TComplexity]:
return _t_complexity_from_strategies(bloq, strategies)


USE_NEW_GATE_COUNTING_FLAG = True


def t_complexity(bloq: Bloq) -> TComplexity:
"""Returns the TComplexity of a bloq.
Expand All @@ -248,6 +254,11 @@ def t_complexity(bloq: Bloq) -> TComplexity:
Raises:
TypeError: if none of the strategies can derive the t complexity.
"""
if USE_NEW_GATE_COUNTING_FLAG:
from qualtran.resource_counting import get_cost_value, QECGatesCost

return get_cost_value(bloq, QECGatesCost(legacy_shims=True)).to_legacy_t_complexity()

ret = _t_complexity_for_bloq(bloq)
if ret is None:
raise TypeError(
Expand Down
43 changes: 2 additions & 41 deletions qualtran/cirq_interop/t_complexity_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
from attrs import frozen

from qualtran import Bloq, GateWithRegisters, Signature
from qualtran import Bloq, DecomposeNotImplementedError, GateWithRegisters, Signature
from qualtran._infra.gate_with_registers import get_named_qubits
from qualtran.bloqs.basic_gates import CHadamard, GlobalPhase
from qualtran.bloqs.mcmt.and_bloq import And
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_t_complexity_for_bloq_via_build_call_graph():


def test_t_complexity_for_bloq_does_not_support():
with pytest.raises(TypeError):
with pytest.raises(DecomposeNotImplementedError):
_ = t_complexity(DoesNotSupportTComplexityBloq())


Expand Down Expand Up @@ -213,45 +213,6 @@ def test_tagged_operations():
)


def test_cache_clear():
class Cachable1(Bloq):
def __init__(self) -> None:
self.num_calls = 0

@property
def signature(self) -> 'Signature':
return Signature([])

def _t_complexity_(self) -> TComplexity:
self.num_calls += 1
return TComplexity(clifford=1)

def __hash__(self):
# Manufacture a hash collision
return hash(2)

class Cachable2(Cachable1):
def _t_complexity_(self) -> TComplexity:
self.num_calls += 1
return TComplexity()

def __hash__(self):
# Manufacture a hash collision
return hash(2)

assert t_complexity(Cachable1()) == TComplexity(clifford=1)
# Using a global cache will result in a failure of this test since `cirq.X` has
# `T-complexity(clifford=1)` but we explicitly return `TComplexity()` for IsCachable
# operation; for which the hash would be equivalent to the hash of its subgate i.e. `cirq.X`.
# TODO: t_complexity protocol will be refactored (#735)
t_complexity.cache_clear() # type: ignore[attr-defined]
op = Cachable2()
assert t_complexity(op) == TComplexity()
assert t_complexity(op) == TComplexity()
assert op.num_calls == 1
t_complexity.cache_clear() # type: ignore[attr-defined]


@pytest.mark.notebook
def test_notebook():
execute_notebook('t_complexity')
3 changes: 3 additions & 0 deletions qualtran/drawing/bloq_counts_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def format_qec_gates_cost(cls, val: 'GateCounts', agg: Optional[str] = None) ->
'and_bloq': 'Ands',
'clifford': 'Cliffords',
'rotation': 'Rotations',
'rotations': 'Rotations',
'measurement': 'Measurements',
}
counts_dict: Mapping[str, SymbolicInt]
Expand All @@ -252,6 +253,8 @@ def format_qec_gates_cost(cls, val: 'GateCounts', agg: Optional[str] = None) ->
counts_dict = val.total_t_and_ccz_count()
elif agg == 'beverland':
counts_dict = val.total_beverland_count()
elif agg == 'legacy':
counts_dict = val.to_legacy_t_complexity().asdict()
else:
raise ValueError(f"Unknown aggregation mode {agg}.")

Expand Down
Loading

0 comments on commit a36c50f

Please sign in to comment.