-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathclassify_bloqs.py
163 lines (141 loc) · 5.58 KB
/
classify_bloqs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc as abc
from collections import defaultdict
from typing import cast, Dict, List, Optional, Sequence, TYPE_CHECKING, Union
import sympy
from qualtran import Adjoint, Bloq, Controlled
from qualtran.resource_counting.generalizers import (
ignore_alloc_free,
ignore_cliffords,
ignore_split_join,
)
from qualtran.resource_counting.t_counts_from_sigma import t_counts_from_sigma
if TYPE_CHECKING:
from qualtran.resource_counting import GeneralizerT
def _get_basic_bloq_classification() -> Dict[str, str]:
"""High level classification of bloqs by the module name."""
bloq_classifier = {
'qualtran.bloqs.arithmetic': 'arithmetic',
'qualtran.bloqs.rotations': 'rotations',
'qualtran.bloqs.basic_gates.rotation': 'rotations',
'qualtran.bloqs.state_preparation': 'state_preparation',
'qualtran.bloqs.data_loading': 'data_loading',
'qualtran.bloqs.mcmt': 'multi_control_pauli',
'qualtran.bloqs.multiplexers': 'multiplexers',
'qualtran.bloqs.swap_network': 'swaps',
'qualtran.bloqs.basic_gates.swap': 'swaps',
'qualtran.bloqs.reflection': 'reflection',
'qualtran.bloqs.basic_gates.toffoli': 'toffoli',
'qualtran.bloqs.basic_gates.t_gate': 'tgate',
}
return bloq_classifier
def classify_bloq(bloq: Bloq, bloq_classification: Dict[str, str]) -> str:
"""Classify a bloq given a bloq_classification.
Args:
bloq: The bloq to classify
bloq_classification: A dictionary mapping a classification to a tuple of
bloqs in that classification.
Returns:
classification: The matching key in bloq_classification. Returns other if not classified.
"""
if 'adjoint' in bloq.__module__:
mod_name = cast(Adjoint, bloq).subbloq.__module__
else:
mod_name = bloq.__module__
for k, v in bloq_classification.items():
if k in mod_name:
return v
return 'other'
def classify_t_count_by_bloq_type(
bloq: Bloq,
bloq_classification: Optional[Dict[str, str]] = None,
generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None,
) -> Dict[str, Union[int, sympy.Expr]]:
"""Classify (bin) the T count of a bloq's call graph by type of operation.
Args:
bloq: the bloq to classify.
bloq_classification: An optional dictionary mapping bloq_classifications to bloq types.
generalizer: If provided, run this function on each (sub)bloq to replace attributes
that do not affect resource estimates with generic sympy symbols. If the function
returns `None`, the bloq is omitted from the counts graph. If a sequence of
generalizers is provided, each generalizer will be run in order.
Returns
classified_bloqs: dictionary containing the T count for different types of bloqs.
"""
if bloq_classification is None:
bloq_classification = _get_basic_bloq_classification()
keeper = lambda bloq: classify_bloq(bloq, bloq_classification) != 'other'
basic_generalizer: List['GeneralizerT'] = [
ignore_split_join,
ignore_alloc_free,
ignore_cliffords,
]
if generalizer is not None:
if isinstance(generalizer, abc.Sequence):
basic_generalizer.extend(generalizer)
else:
basic_generalizer.append(generalizer)
_, sigma = bloq.call_graph(generalizer=basic_generalizer, keep=keeper)
classified_bloqs: Dict[str, Union[int, sympy.Expr]] = defaultdict(int)
for k, v in sigma.items():
classification = classify_bloq(k, bloq_classification)
t_counts = t_counts_from_sigma(k.call_graph()[1])
if t_counts > 0:
classified_bloqs[classification] += v * t_counts
return classified_bloqs
def bloq_is_clifford(b: Bloq):
from qualtran.bloqs.basic_gates import (
CNOT,
CYGate,
CZ,
Hadamard,
SGate,
TwoBitSwap,
XGate,
YGate,
ZGate,
)
from qualtran.bloqs.bookkeeping import ArbitraryClifford
from qualtran.bloqs.mcmt.multi_target_cnot import MultiTargetCNOT
if isinstance(b, Adjoint):
b = b.subbloq
if isinstance(
b,
(
TwoBitSwap,
Hadamard,
XGate,
ZGate,
YGate,
ArbitraryClifford,
CNOT,
MultiTargetCNOT,
CYGate,
CZ,
SGate,
),
):
return True
return False
def bloq_is_rotation(b: Bloq):
from qualtran.bloqs.basic_gates import GlobalPhase, SGate, TGate
from qualtran.bloqs.basic_gates.rotation import Rx, Ry, Rz, XPowGate, YPowGate, ZPowGate
if isinstance(b, Controlled):
# TODO https://github.com/quantumlib/Qualtran/issues/878
# explicit representation of all two-qubit rotations.
if isinstance(b.subbloq, (SGate, TGate, GlobalPhase)):
return True
return bloq_is_rotation(b.subbloq)
return isinstance(b, (Rz, Rx, Ry, ZPowGate, XPowGate, YPowGate))