-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Costs] Bloq & Gate counts #958
Merged
mpharrigan
merged 8 commits into
quantumlib:main
from
mpharrigan:2024-04/generic-costs-counts
May 30, 2024
Merged
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
c7f3e12
bloq counts
mpharrigan a97c6da
[counts] test and docs
mpharrigan fe09151
[counts] real imports in test files
mpharrigan 69f5c4a
Merge remote-tracking branch 'origin/main' into 2024-04/generic-costs…
mpharrigan 5a3edff
support symbolics
mpharrigan 5d3f9b8
Merge branch 'main' into 2024-04/generic-costs-counts
mpharrigan 6124daa
Merge remote-tracking branch 'origin/main' into 2024-04/generic-costs…
mpharrigan ef6ed66
merge fixes
mpharrigan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import logging | ||
from collections import defaultdict | ||
from typing import Callable, Dict, Sequence, Tuple, TYPE_CHECKING | ||
|
||
import attrs | ||
import networkx as nx | ||
from attrs import field, frozen | ||
|
||
from ._call_graph import get_bloq_callee_counts | ||
from ._costing import CostKey | ||
from .classify_bloqs import bloq_is_clifford | ||
|
||
if TYPE_CHECKING: | ||
from qualtran import Bloq | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
BloqCountDict = Dict['Bloq', int] | ||
|
||
|
||
def _gateset_bloqs_to_tuple(bloqs: Sequence['Bloq']) -> Tuple['Bloq', ...]: | ||
return tuple(bloqs) | ||
|
||
|
||
@frozen | ||
class BloqCount(CostKey[BloqCountDict]): | ||
"""A cost which is the count of a specific set of bloqs forming a gateset. | ||
|
||
Often, we wish to know the number of specific gates in our algorithm. This is a generic | ||
CostKey that can count any gate (bloq) of interest. | ||
|
||
The cost value type for this cost is a mapping from bloq to its count. | ||
|
||
Args: | ||
gateset_bloqs: A sequence of bloqs which we will count. Bloqs are counted according | ||
to their equality operator. | ||
gateset_name: A string name of the gateset. Used for display and debugging purposes. | ||
""" | ||
|
||
gateset_bloqs: Sequence['Bloq'] = field(converter=_gateset_bloqs_to_tuple) | ||
gateset_name: str | ||
|
||
@classmethod | ||
def for_gateset(cls, gateset_name: str): | ||
"""Helper constructor to configure this cost for some common gatesets. | ||
|
||
Args: | ||
gateset_name: One of 't', 't+tof', 't+tof+cswap'. This will construct a | ||
`BloqCount` cost with the indicated gates as the `gateset_bloqs`. In all | ||
cases, both TGate and its adjoint are included. | ||
""" | ||
from qualtran.bloqs.basic_gates import TGate, Toffoli, TwoBitCSwap | ||
|
||
bloqs: Tuple['Bloq', ...] | ||
if gateset_name == 't': | ||
bloqs = (TGate(), TGate(is_adjoint=True)) | ||
elif gateset_name == 't+tof': | ||
bloqs = (TGate(), TGate(is_adjoint=True), Toffoli()) | ||
elif gateset_name == 't+tof+cswap': | ||
bloqs = (TGate(), TGate(is_adjoint=True), Toffoli(), TwoBitCSwap()) | ||
else: | ||
raise ValueError(f"Unknown gateset name {gateset_name}") | ||
|
||
return cls(bloqs, gateset_name=gateset_name) | ||
|
||
@classmethod | ||
def for_call_graph_leaf_bloqs(cls, g: nx.DiGraph): | ||
"""Helper constructor to configure this cost for 'leaf' bloqs in a given call graph. | ||
|
||
Args: | ||
g: The call graph. Its leaves will be used for `gateset_bloqs`. This call graph | ||
can be generated from `Bloq.call_graph()` | ||
""" | ||
leaf_bloqs = {node for node in g.nodes if not g.succ[node]} | ||
return cls(tuple(leaf_bloqs), gateset_name='leaf') | ||
|
||
def compute( | ||
self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], BloqCountDict] | ||
) -> BloqCountDict: | ||
if bloq in self.gateset_bloqs: | ||
logger.info("Computing %s: %s is in the target gateset.", self, bloq) | ||
return {bloq: 1} | ||
|
||
totals: BloqCountDict = defaultdict(lambda: 0) | ||
callees = get_bloq_callee_counts(bloq) | ||
logger.info("Computing %s for %s from %d callee(s)", self, bloq, len(callees)) | ||
for callee, n_times_called in callees: | ||
callee_cost = get_callee_cost(callee) | ||
for gateset_bloq, count in callee_cost.items(): | ||
totals[gateset_bloq] += n_times_called * count | ||
|
||
return dict(totals) | ||
|
||
def zero(self) -> BloqCountDict: | ||
# The additive identity of the bloq counts dictionary is an empty dictionary. | ||
return {} | ||
|
||
def __str__(self): | ||
return f'{self.gateset_name} counts' | ||
|
||
|
||
@frozen(kw_only=True) | ||
class GateCounts: | ||
"""A data class of counts of the typical target gates in a compilation. | ||
|
||
Specifically, this class holds counts for the number of `TGate` (and adjoint), `Toffoli`, | ||
`TwoBitCSwap`, `And`, and clifford bloqs. | ||
""" | ||
|
||
t: int = 0 | ||
toffoli: int = 0 | ||
cswap: int = 0 | ||
and_bloq: int = 0 | ||
clifford: int = 0 | ||
|
||
def __add__(self, other): | ||
if not isinstance(other, GateCounts): | ||
raise TypeError(f"Can only add other `GateCounts` objects, not {self}") | ||
|
||
return GateCounts( | ||
t=self.t + other.t, | ||
toffoli=self.toffoli + other.toffoli, | ||
cswap=self.cswap + other.cswap, | ||
and_bloq=self.and_bloq + other.and_bloq, | ||
clifford=self.clifford + other.clifford, | ||
) | ||
|
||
def __mul__(self, other): | ||
return GateCounts( | ||
t=other * self.t, | ||
toffoli=other * self.toffoli, | ||
cswap=other * self.cswap, | ||
and_bloq=other * self.and_bloq, | ||
clifford=other * self.clifford, | ||
) | ||
|
||
def __rmul__(self, other): | ||
return self.__mul__(other) | ||
|
||
def __str__(self): | ||
strs = [] | ||
for f in attrs.fields(self.__class__): | ||
val = getattr(self, f.name) | ||
if val != 0: | ||
strs.append(f'{f.name}: {val}') | ||
|
||
if strs: | ||
return ', '.join(strs) | ||
return '-' | ||
|
||
def total_t_count( | ||
self, ts_per_toffoli: int = 4, ts_per_cswap: int = 7, ts_per_and_bloq: int = 4 | ||
) -> int: | ||
"""Get the total number of T Gates for the `GateCounts` object. | ||
|
||
This simply multiplies each gate type by its cost in terms of T gates, which is configurable | ||
via the arguments to this method. | ||
""" | ||
return ( | ||
self.t | ||
+ ts_per_toffoli * self.toffoli | ||
+ ts_per_cswap * self.cswap | ||
+ ts_per_and_bloq * self.and_bloq | ||
) | ||
|
||
|
||
@frozen | ||
class QECGatesCost(CostKey[GateCounts]): | ||
"""Counts specifically for 'expensive' gates in a surface code error correction scheme. | ||
|
||
The cost value type for this CostKey is `GateCounts`. | ||
""" | ||
|
||
def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts]) -> GateCounts: | ||
from qualtran.bloqs.basic_gates import TGate, Toffoli, TwoBitCSwap | ||
from qualtran.bloqs.mcmt.and_bloq import And | ||
|
||
# T gates | ||
if isinstance(bloq, TGate): | ||
return GateCounts(t=1) | ||
|
||
# Toffolis | ||
if isinstance(bloq, Toffoli): | ||
return GateCounts(toffoli=1) | ||
|
||
# 'And' bloqs | ||
if isinstance(bloq, And) and not bloq.uncompute: | ||
return GateCounts(and_bloq=1) | ||
|
||
# CSwaps aka Fredkin | ||
if isinstance(bloq, TwoBitCSwap): | ||
return GateCounts(cswap=1) | ||
|
||
# Cliffords | ||
if bloq_is_clifford(bloq): | ||
return GateCounts(clifford=1) | ||
|
||
# Recursive case | ||
totals = GateCounts() | ||
callees = get_bloq_callee_counts(bloq) | ||
logger.info("Computing %s for %s from %d callee(s)", self, bloq, len(callees)) | ||
for callee, n_times_called in callees: | ||
callee_cost = get_callee_cost(callee) | ||
totals += n_times_called * callee_cost | ||
return totals | ||
|
||
def zero(self) -> GateCounts: | ||
return GateCounts() | ||
|
||
def validate_val(self, val: GateCounts): | ||
if not isinstance(val, GateCounts): | ||
raise TypeError(f"{self} values should be `GateCounts`, got {val}") | ||
|
||
def __str__(self): | ||
return 'gate counts' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from qualtran.bloqs.basic_gates import Hadamard, TGate, Toffoli | ||
from qualtran.bloqs.for_testing.costing import make_example_costing_bloqs | ||
from qualtran.resource_counting import BloqCount, GateCounts, get_cost_value, QECGatesCost | ||
|
||
|
||
def test_bloq_count(): | ||
algo = make_example_costing_bloqs() | ||
|
||
cost = BloqCount([Toffoli()], 'toffoli') | ||
tof_count = get_cost_value(algo, cost) | ||
|
||
# `make_example_costing_bloqs` has `func` and `func2`. `func2` has 100 Tof | ||
assert tof_count == {Toffoli(): 100} | ||
|
||
t_and_tof_count = get_cost_value(algo, BloqCount.for_gateset('t+tof')) | ||
assert t_and_tof_count == {Toffoli(): 100, TGate(): 2 * 10, TGate().adjoint(): 2 * 10} | ||
|
||
g, _ = algo.call_graph() | ||
leaf = BloqCount.for_call_graph_leaf_bloqs(g) | ||
# Note: Toffoli has a decomposition in terms of T gates. | ||
assert set(leaf.gateset_bloqs) == {Hadamard(), TGate(), TGate().adjoint()} | ||
|
||
t_count = get_cost_value(algo, leaf) | ||
assert t_count == {TGate(): 2 * 10 + 100 * 4, TGate().adjoint(): 2 * 10, Hadamard(): 2 * 10} | ||
|
||
# count things other than leaf bloqs | ||
top_level = get_cost_value(algo, BloqCount([bloq for bloq, n in algo.callees], 'top')) | ||
assert sorted(f'{k}: {v}' for k, v in top_level.items()) == ['Func1: 2', 'Func2: 1'] | ||
|
||
|
||
def test_gate_counts(): | ||
gc = GateCounts(t=100, toffoli=13) | ||
assert str(gc) == 't: 100, toffoli: 13' | ||
|
||
assert GateCounts(t=10) * 2 == GateCounts(t=20) | ||
assert 2 * GateCounts(t=10) == GateCounts(t=20) | ||
|
||
assert GateCounts(toffoli=1, cswap=1, and_bloq=1).total_t_count() == 4 + 7 + 4 | ||
|
||
|
||
def test_qec_gates_cost(): | ||
algo = make_example_costing_bloqs() | ||
gc = get_cost_value(algo, QECGatesCost()) | ||
assert gc == GateCounts(toffoli=100, t=2 * 2 * 10, clifford=2 * 10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason you just picked this test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
somewhat arbitrary -- seemed like a reasonably complicated example