Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QECGatesCost] Port everything to QECGatesCost #1359

Merged
merged 6 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 2 additions & 38 deletions qualtran/_infra/adjoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import cast, Dict, TYPE_CHECKING
from typing import cast

import pytest
import sympy

import qualtran.testing as qlt_testing
from qualtran import Adjoint, Bloq, CompositeBloq, Side, Signature
from qualtran import Adjoint, CompositeBloq, Side
from qualtran._infra.adjoint import _adjoint_cbloq
from qualtran.bloqs.basic_gates import CNOT, CSwap, ZeroState
from qualtran.bloqs.for_testing.atom import TestAtom
from qualtran.bloqs.for_testing.with_call_graph import TestBloqWithCallGraph
from qualtran.bloqs.for_testing.with_decomposition import TestParallelCombo, TestSerialCombo
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.drawing import LarrowTextBox, RarrowTextBox, Text

if TYPE_CHECKING:
from qualtran import BloqBuilder, SoquetT


def test_serial_combo_adjoint():
# The normal decomposition is three `TestAtom` tagged atom{0,1,2}.
Expand Down Expand Up @@ -168,37 +163,6 @@ def test_wire_symbol():
assert isinstance(adj_ws, RarrowTextBox)


class TAcceptsAdjoint(TestAtom):
def _t_complexity_(self, adjoint: bool = False) -> TComplexity:
return TComplexity(t=2 if adjoint else 1)


class TDoesNotAcceptAdjoint(TestAtom):
def _t_complexity_(self) -> TComplexity:
return TComplexity(t=3)


class DecomposesIntoTAcceptsAdjoint(Bloq):
@cached_property
def signature(self) -> Signature:
return Signature.build(q=1)

def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str, 'SoquetT']:
soqs = bb.add_d(TAcceptsAdjoint(), **soqs)
return soqs


def test_t_complexity():
assert TAcceptsAdjoint().t_complexity().t == 1
assert Adjoint(TAcceptsAdjoint()).t_complexity().t == 2

assert DecomposesIntoTAcceptsAdjoint().t_complexity().t == 1
assert Adjoint(DecomposesIntoTAcceptsAdjoint()).t_complexity().t == 2

assert TDoesNotAcceptAdjoint().t_complexity().t == 3
assert Adjoint(TDoesNotAcceptAdjoint()).t_complexity().t == 3


@pytest.mark.notebook
def test_notebook():
qlt_testing.execute_notebook('../Adjoint')
7 changes: 0 additions & 7 deletions qualtran/bloqs/arithmetic/controlled_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from qualtran.bloqs.arithmetic.addition import Add
from qualtran.bloqs.bookkeeping import Cast
from qualtran.bloqs.mcmt.and_bloq import And
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.resource_counting.generalizers import ignore_split_join
from qualtran.simulation.classical_sim import add_ints

Expand Down Expand Up @@ -156,12 +155,6 @@ def build_composite_bloq(
ctrl = bb.join(np.array([ctrl_q]))
return {'ctrl': ctrl, 'a': a, 'b': b}

def _t_complexity_(self):
n = self.b_dtype.bitsize
num_and = self.a_dtype.bitsize + self.b_dtype.bitsize - 1
num_clifford = 33 * (n - 2) + 43
return TComplexity(t=4 * num_and, clifford=num_clifford)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {
(And(self.cv, 1), self.a_dtype.bitsize),
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/basic_gates/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _t_complexity_(self) -> 'TComplexity':
def my_static_costs(self, cost_key: 'CostKey'):
from qualtran.resource_counting import GateCounts, QECGatesCost

if cost_key == QECGatesCost():
if isinstance(cost_key, QECGatesCost):
# This is based on the decomposition provided by `cirq.decompose_multi_controlled_rotation`
# which uses three cirq.MatrixGate's to do a controlled version of any single-qubit gate.
# The first MatrixGate happens to be a clifford, Hadamard operation in this case.
Expand Down
21 changes: 18 additions & 3 deletions qualtran/bloqs/basic_gates/rotation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,28 @@
from qualtran._infra.gate_with_registers import get_named_qubits
from qualtran.bloqs.basic_gates import CZPowGate, Rx, Ry, Rz, XPowGate, YPowGate, ZPowGate
from qualtran.bloqs.basic_gates.rotation import _rx, _ry, _rz
from qualtran.resource_counting import GateCounts, get_cost_value, QECGatesCost
from qualtran.resource_counting.classify_bloqs import bloq_is_rotation, bloq_is_t_like


def test_rotation_gates():
mpharrigan marked this conversation as resolved.
Show resolved Hide resolved
angle = np.pi / 4.0
tcount = 52
assert Rx(angle).t_complexity().t_incl_rotations() == tcount
assert Ry(angle).t_complexity().t_incl_rotations() == tcount
# In prior versions of the library, only Rz(pi/4) would simplify to a T gate in gate counts.
# The others would report the synthesis cost for an arbitrary angle, which was reported as
# 52 T-gates.
assert not bloq_is_rotation(Rx(angle))
assert not bloq_is_rotation(Ry(angle))
assert not bloq_is_rotation(Rz(angle))
assert bloq_is_t_like(Rx(angle))
assert bloq_is_t_like(Ry(angle))
assert bloq_is_t_like(Rz(angle))

assert get_cost_value(Rx(angle), QECGatesCost()) == GateCounts(t=1)
assert get_cost_value(Ry(angle), QECGatesCost()) == GateCounts(t=1)
assert get_cost_value(Rz(angle), QECGatesCost()) == GateCounts(t=1)

assert Rx(angle).t_complexity().t_incl_rotations() == 1
assert Ry(angle).t_complexity().t_incl_rotations() == 1
assert Rz(angle).t_complexity().t_incl_rotations() == 1


Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/for_testing/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def my_tensors(
]

def my_static_costs(self, cost_key: 'CostKey'):
if cost_key == QECGatesCost():
if isinstance(cost_key, QECGatesCost):
return GateCounts(t=100)
return NotImplemented

Expand Down
20 changes: 14 additions & 6 deletions qualtran/cirq_interop/_cirq_to_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from qualtran.cirq_interop._interop_qubit_manager import InteropQubitManager
from qualtran.cirq_interop.t_complexity_protocol import _from_directly_countable_cirq, TComplexity
from qualtran.resource_counting import CostKey, GateCounts, QECGatesCost

if TYPE_CHECKING:
import quimb.tensor as qtn
Expand Down Expand Up @@ -108,12 +109,6 @@ def my_tensors(
self.cirq_gate, self.signature, incoming=incoming, outgoing=outgoing
)

def _t_complexity_(self) -> 'TComplexity':
t_count = _from_directly_countable_cirq(self.cirq_gate)
if t_count is None:
raise ValueError(f"Cirq gate must be directly countable, not {self.cirq_gate}")
return t_count

def as_cirq_op(
self, qubit_manager: 'cirq.QubitManager', **in_quregs: 'CirqQuregT'
) -> Tuple[Union['cirq.Operation', None], Dict[str, 'CirqQuregT']]:
Expand Down Expand Up @@ -153,6 +148,19 @@ def pretty_name(self) -> str:
def cirq_gate(self) -> cirq.Gate:
return self.gate

def _t_complexity_(self) -> 'TComplexity':
t_count = _from_directly_countable_cirq(self.cirq_gate)
if t_count is None:
raise ValueError(f"Cirq gate must be directly countable, not {self.cirq_gate}")
return t_count

def my_static_costs(self, cost_key: 'CostKey'):
if isinstance(cost_key, QECGatesCost):
t_count = _from_directly_countable_cirq(self.cirq_gate)
if t_count is None:
raise ValueError(f"Cirq gate must be directly countable, not {self.cirq_gate}")
return GateCounts(t=t_count.t, rotation=t_count.rotations, clifford=t_count.clifford)


def _cirq_wire_symbol_to_qualtran_wire_symbol(symbol: str, side: Side) -> 'WireSymbol':
from qualtran.drawing import Circle, directional_text_box, ModPlus
Expand Down
11 changes: 11 additions & 0 deletions qualtran/cirq_interop/t_complexity_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def __mul__(self, other: int) -> 'TComplexity':
def __rmul__(self, other: int) -> 'TComplexity':
return self.__mul__(other)

def asdict(self):
return {'t': self.t, 'rotations': self.rotations, 'clifford': self.clifford}

def __str__(self) -> str:
return (
f'T-count: {self.t:g}\n'
Expand Down Expand Up @@ -236,6 +239,9 @@ def _t_complexity_for_bloq(bloq: Bloq) -> Optional[TComplexity]:
return _t_complexity_from_strategies(bloq, strategies)


USE_NEW_GATE_COUNTING_FLAG = True
fdmalone marked this conversation as resolved.
Show resolved Hide resolved


def t_complexity(bloq: Bloq) -> TComplexity:
"""Returns the TComplexity of a bloq.

Expand All @@ -248,6 +254,11 @@ def t_complexity(bloq: Bloq) -> TComplexity:
Raises:
TypeError: if none of the strategies can derive the t complexity.
"""
if USE_NEW_GATE_COUNTING_FLAG:
from qualtran.resource_counting import get_cost_value, QECGatesCost

return get_cost_value(bloq, QECGatesCost(legacy_shims=True)).to_legacy_t_complexity()

ret = _t_complexity_for_bloq(bloq)
if ret is None:
raise TypeError(
Expand Down
43 changes: 2 additions & 41 deletions qualtran/cirq_interop/t_complexity_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
from attrs import frozen

from qualtran import Bloq, GateWithRegisters, Signature
from qualtran import Bloq, DecomposeNotImplementedError, GateWithRegisters, Signature
from qualtran._infra.gate_with_registers import get_named_qubits
from qualtran.bloqs.basic_gates import CHadamard, GlobalPhase
from qualtran.bloqs.mcmt.and_bloq import And
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_t_complexity_for_bloq_via_build_call_graph():


def test_t_complexity_for_bloq_does_not_support():
with pytest.raises(TypeError):
with pytest.raises(DecomposeNotImplementedError):
_ = t_complexity(DoesNotSupportTComplexityBloq())


Expand Down Expand Up @@ -213,45 +213,6 @@ def test_tagged_operations():
)


def test_cache_clear():
class Cachable1(Bloq):
def __init__(self) -> None:
self.num_calls = 0

@property
def signature(self) -> 'Signature':
return Signature([])

def _t_complexity_(self) -> TComplexity:
self.num_calls += 1
return TComplexity(clifford=1)

def __hash__(self):
# Manufacture a hash collision
return hash(2)

class Cachable2(Cachable1):
def _t_complexity_(self) -> TComplexity:
self.num_calls += 1
return TComplexity()

def __hash__(self):
# Manufacture a hash collision
return hash(2)

assert t_complexity(Cachable1()) == TComplexity(clifford=1)
# Using a global cache will result in a failure of this test since `cirq.X` has
# `T-complexity(clifford=1)` but we explicitly return `TComplexity()` for IsCachable
# operation; for which the hash would be equivalent to the hash of its subgate i.e. `cirq.X`.
# TODO: t_complexity protocol will be refactored (#735)
t_complexity.cache_clear() # type: ignore[attr-defined]
op = Cachable2()
assert t_complexity(op) == TComplexity()
assert t_complexity(op) == TComplexity()
assert op.num_calls == 1
t_complexity.cache_clear() # type: ignore[attr-defined]


@pytest.mark.notebook
def test_notebook():
execute_notebook('t_complexity')
3 changes: 3 additions & 0 deletions qualtran/drawing/bloq_counts_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def format_qec_gates_cost(cls, val: 'GateCounts', agg: Optional[str] = None) ->
'and_bloq': 'Ands',
'clifford': 'Cliffords',
'rotation': 'Rotations',
'rotations': 'Rotations',
fdmalone marked this conversation as resolved.
Show resolved Hide resolved
'measurement': 'Measurements',
}
counts_dict: Mapping[str, SymbolicInt]
Expand All @@ -252,6 +253,8 @@ def format_qec_gates_cost(cls, val: 'GateCounts', agg: Optional[str] = None) ->
counts_dict = val.total_t_and_ccz_count()
elif agg == 'beverland':
counts_dict = val.total_beverland_count()
elif agg == 'legacy':
counts_dict = val.to_legacy_t_complexity().asdict()
else:
raise ValueError(f"Unknown aggregation mode {agg}.")

Expand Down
Loading
Loading