-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* bloq counts * [counts] test and docs * [counts] real imports in test files * support symbolics * merge fixes
- Loading branch information
1 parent
5aabc66
commit bae2356
Showing
8 changed files
with
318 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import logging | ||
from collections import defaultdict | ||
from typing import Callable, Dict, Sequence, Tuple, TYPE_CHECKING | ||
|
||
import attrs | ||
import networkx as nx | ||
from attrs import field, frozen | ||
|
||
from ._call_graph import get_bloq_callee_counts | ||
from ._costing import CostKey | ||
from .classify_bloqs import bloq_is_clifford | ||
|
||
if TYPE_CHECKING: | ||
from qualtran import Bloq | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
BloqCountDict = Dict['Bloq', int] | ||
|
||
|
||
def _gateset_bloqs_to_tuple(bloqs: Sequence['Bloq']) -> Tuple['Bloq', ...]: | ||
return tuple(bloqs) | ||
|
||
|
||
@frozen | ||
class BloqCount(CostKey[BloqCountDict]): | ||
"""A cost which is the count of a specific set of bloqs forming a gateset. | ||
Often, we wish to know the number of specific gates in our algorithm. This is a generic | ||
CostKey that can count any gate (bloq) of interest. | ||
The cost value type for this cost is a mapping from bloq to its count. | ||
Args: | ||
gateset_bloqs: A sequence of bloqs which we will count. Bloqs are counted according | ||
to their equality operator. | ||
gateset_name: A string name of the gateset. Used for display and debugging purposes. | ||
""" | ||
|
||
gateset_bloqs: Sequence['Bloq'] = field(converter=_gateset_bloqs_to_tuple) | ||
gateset_name: str | ||
|
||
@classmethod | ||
def for_gateset(cls, gateset_name: str): | ||
"""Helper constructor to configure this cost for some common gatesets. | ||
Args: | ||
gateset_name: One of 't', 't+tof', 't+tof+cswap'. This will construct a | ||
`BloqCount` cost with the indicated gates as the `gateset_bloqs`. In all | ||
cases, both TGate and its adjoint are included. | ||
""" | ||
from qualtran.bloqs.basic_gates import TGate, Toffoli, TwoBitCSwap | ||
|
||
bloqs: Tuple['Bloq', ...] | ||
if gateset_name == 't': | ||
bloqs = (TGate(), TGate(is_adjoint=True)) | ||
elif gateset_name == 't+tof': | ||
bloqs = (TGate(), TGate(is_adjoint=True), Toffoli()) | ||
elif gateset_name == 't+tof+cswap': | ||
bloqs = (TGate(), TGate(is_adjoint=True), Toffoli(), TwoBitCSwap()) | ||
else: | ||
raise ValueError(f"Unknown gateset name {gateset_name}") | ||
|
||
return cls(bloqs, gateset_name=gateset_name) | ||
|
||
@classmethod | ||
def for_call_graph_leaf_bloqs(cls, g: nx.DiGraph): | ||
"""Helper constructor to configure this cost for 'leaf' bloqs in a given call graph. | ||
Args: | ||
g: The call graph. Its leaves will be used for `gateset_bloqs`. This call graph | ||
can be generated from `Bloq.call_graph()` | ||
""" | ||
leaf_bloqs = {node for node in g.nodes if not g.succ[node]} | ||
return cls(tuple(leaf_bloqs), gateset_name='leaf') | ||
|
||
def compute( | ||
self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], BloqCountDict] | ||
) -> BloqCountDict: | ||
if bloq in self.gateset_bloqs: | ||
logger.info("Computing %s: %s is in the target gateset.", self, bloq) | ||
return {bloq: 1} | ||
|
||
totals: BloqCountDict = defaultdict(lambda: 0) | ||
callees = get_bloq_callee_counts(bloq) | ||
logger.info("Computing %s for %s from %d callee(s)", self, bloq, len(callees)) | ||
for callee, n_times_called in callees: | ||
callee_cost = get_callee_cost(callee) | ||
for gateset_bloq, count in callee_cost.items(): | ||
totals[gateset_bloq] += n_times_called * count | ||
|
||
return dict(totals) | ||
|
||
def zero(self) -> BloqCountDict: | ||
# The additive identity of the bloq counts dictionary is an empty dictionary. | ||
return {} | ||
|
||
def __str__(self): | ||
return f'{self.gateset_name} counts' | ||
|
||
|
||
@frozen(kw_only=True) | ||
class GateCounts: | ||
"""A data class of counts of the typical target gates in a compilation. | ||
Specifically, this class holds counts for the number of `TGate` (and adjoint), `Toffoli`, | ||
`TwoBitCSwap`, `And`, and clifford bloqs. | ||
""" | ||
|
||
t: int = 0 | ||
toffoli: int = 0 | ||
cswap: int = 0 | ||
and_bloq: int = 0 | ||
clifford: int = 0 | ||
|
||
def __add__(self, other): | ||
if not isinstance(other, GateCounts): | ||
raise TypeError(f"Can only add other `GateCounts` objects, not {self}") | ||
|
||
return GateCounts( | ||
t=self.t + other.t, | ||
toffoli=self.toffoli + other.toffoli, | ||
cswap=self.cswap + other.cswap, | ||
and_bloq=self.and_bloq + other.and_bloq, | ||
clifford=self.clifford + other.clifford, | ||
) | ||
|
||
def __mul__(self, other): | ||
return GateCounts( | ||
t=other * self.t, | ||
toffoli=other * self.toffoli, | ||
cswap=other * self.cswap, | ||
and_bloq=other * self.and_bloq, | ||
clifford=other * self.clifford, | ||
) | ||
|
||
def __rmul__(self, other): | ||
return self.__mul__(other) | ||
|
||
def __str__(self): | ||
strs = [] | ||
for f in attrs.fields(self.__class__): | ||
val = getattr(self, f.name) | ||
if val != 0: | ||
strs.append(f'{f.name}: {val}') | ||
|
||
if strs: | ||
return ', '.join(strs) | ||
return '-' | ||
|
||
def total_t_count( | ||
self, ts_per_toffoli: int = 4, ts_per_cswap: int = 7, ts_per_and_bloq: int = 4 | ||
) -> int: | ||
"""Get the total number of T Gates for the `GateCounts` object. | ||
This simply multiplies each gate type by its cost in terms of T gates, which is configurable | ||
via the arguments to this method. | ||
""" | ||
return ( | ||
self.t | ||
+ ts_per_toffoli * self.toffoli | ||
+ ts_per_cswap * self.cswap | ||
+ ts_per_and_bloq * self.and_bloq | ||
) | ||
|
||
|
||
@frozen | ||
class QECGatesCost(CostKey[GateCounts]): | ||
"""Counts specifically for 'expensive' gates in a surface code error correction scheme. | ||
The cost value type for this CostKey is `GateCounts`. | ||
""" | ||
|
||
def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts]) -> GateCounts: | ||
from qualtran.bloqs.basic_gates import TGate, Toffoli, TwoBitCSwap | ||
from qualtran.bloqs.mcmt.and_bloq import And | ||
|
||
# T gates | ||
if isinstance(bloq, TGate): | ||
return GateCounts(t=1) | ||
|
||
# Toffolis | ||
if isinstance(bloq, Toffoli): | ||
return GateCounts(toffoli=1) | ||
|
||
# 'And' bloqs | ||
if isinstance(bloq, And) and not bloq.uncompute: | ||
return GateCounts(and_bloq=1) | ||
|
||
# CSwaps aka Fredkin | ||
if isinstance(bloq, TwoBitCSwap): | ||
return GateCounts(cswap=1) | ||
|
||
# Cliffords | ||
if bloq_is_clifford(bloq): | ||
return GateCounts(clifford=1) | ||
|
||
# Recursive case | ||
totals = GateCounts() | ||
callees = get_bloq_callee_counts(bloq) | ||
logger.info("Computing %s for %s from %d callee(s)", self, bloq, len(callees)) | ||
for callee, n_times_called in callees: | ||
callee_cost = get_callee_cost(callee) | ||
totals += n_times_called * callee_cost | ||
return totals | ||
|
||
def zero(self) -> GateCounts: | ||
return GateCounts() | ||
|
||
def validate_val(self, val: GateCounts): | ||
if not isinstance(val, GateCounts): | ||
raise TypeError(f"{self} values should be `GateCounts`, got {val}") | ||
|
||
def __str__(self): | ||
return 'gate counts' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from qualtran.bloqs.basic_gates import Hadamard, TGate, Toffoli | ||
from qualtran.bloqs.for_testing.costing import make_example_costing_bloqs | ||
from qualtran.resource_counting import BloqCount, GateCounts, get_cost_value, QECGatesCost | ||
|
||
|
||
def test_bloq_count(): | ||
algo = make_example_costing_bloqs() | ||
|
||
cost = BloqCount([Toffoli()], 'toffoli') | ||
tof_count = get_cost_value(algo, cost) | ||
|
||
# `make_example_costing_bloqs` has `func` and `func2`. `func2` has 100 Tof | ||
assert tof_count == {Toffoli(): 100} | ||
|
||
t_and_tof_count = get_cost_value(algo, BloqCount.for_gateset('t+tof')) | ||
assert t_and_tof_count == {Toffoli(): 100, TGate(): 2 * 10, TGate().adjoint(): 2 * 10} | ||
|
||
g, _ = algo.call_graph() | ||
leaf = BloqCount.for_call_graph_leaf_bloqs(g) | ||
# Note: Toffoli has a decomposition in terms of T gates. | ||
assert set(leaf.gateset_bloqs) == {Hadamard(), TGate(), TGate().adjoint()} | ||
|
||
t_count = get_cost_value(algo, leaf) | ||
assert t_count == {TGate(): 2 * 10 + 100 * 4, TGate().adjoint(): 2 * 10, Hadamard(): 2 * 10} | ||
|
||
# count things other than leaf bloqs | ||
top_level = get_cost_value(algo, BloqCount([bloq for bloq, n in algo.callees], 'top')) | ||
assert sorted(f'{k}: {v}' for k, v in top_level.items()) == ['Func1: 2', 'Func2: 1'] | ||
|
||
|
||
def test_gate_counts(): | ||
gc = GateCounts(t=100, toffoli=13) | ||
assert str(gc) == 't: 100, toffoli: 13' | ||
|
||
assert GateCounts(t=10) * 2 == GateCounts(t=20) | ||
assert 2 * GateCounts(t=10) == GateCounts(t=20) | ||
|
||
assert GateCounts(toffoli=1, cswap=1, and_bloq=1).total_t_count() == 4 + 7 + 4 | ||
|
||
|
||
def test_qec_gates_cost(): | ||
algo = make_example_costing_bloqs() | ||
gc = get_cost_value(algo, QECGatesCost()) | ||
assert gc == GateCounts(toffoli=100, t=2 * 2 * 10, clifford=2 * 10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.