diff --git a/qualtran/_infra/bloq.py b/qualtran/_infra/bloq.py index 5f96923bf..79c4e8788 100644 --- a/qualtran/_infra/bloq.py +++ b/qualtran/_infra/bloq.py @@ -40,7 +40,7 @@ from qualtran.cirq_interop import CirqQuregT from qualtran.cirq_interop.t_complexity_protocol import TComplexity from qualtran.drawing import WireSymbol - from qualtran.resource_counting import BloqCountT, GeneralizerT, SympySymbolAllocator + from qualtran.resource_counting import BloqCountT, CostKey, GeneralizerT, SympySymbolAllocator from qualtran.simulation.classical_sim import ClassicalValT @@ -296,6 +296,20 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: """ return self.decompose_bloq().build_call_graph(ssa) + def my_static_costs(self, cost_key: 'CostKey'): + """Override this method to provide static costs. + + The system will query a particular cost by asking for a `cost_key`. This method + can optionally provide a value, which will be preferred over a computed cost. + + Static costs can be provided if the particular cost cannot be easily computed or + as a performance optimization. + + This method must return `NotImplemented` if a value cannot be provided for the specified + CostKey. + """ + return NotImplemented + def call_graph( self, generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None, diff --git a/qualtran/bloqs/for_testing/costing.py b/qualtran/bloqs/for_testing/costing.py new file mode 100644 index 000000000..cc0e24114 --- /dev/null +++ b/qualtran/bloqs/for_testing/costing.py @@ -0,0 +1,61 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Sequence, Set, Tuple + +from attrs import field, frozen + +from qualtran import Bloq, Signature +from qualtran.resource_counting import BloqCountT, CostKey, SympySymbolAllocator + + +def _convert_callees(callees: Sequence[BloqCountT]) -> Tuple[BloqCountT, ...]: + # Convert to tuples in a type-checked way. + return tuple(callees) + + +@frozen +class CostingBloq(Bloq): + """A bloq that lets you set the costs via attributes.""" + + name: str + num_qubits: int + callees: Sequence[BloqCountT] = field(converter=_convert_callees, factory=tuple) + static_costs: Sequence[Tuple[CostKey, Any]] = field(converter=tuple, factory=tuple) + + @property + def signature(self) -> 'Signature': + return Signature.build(register=self.num_qubits) + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: + return set(self.callees) + + def my_static_costs(self, cost_key: 'CostKey'): + return dict(self.static_costs).get(cost_key, NotImplemented) + + def pretty_name(self): + return self.name + + def __str__(self): + return self.name + + +def make_example_costing_bloqs(): + from qualtran.bloqs.basic_gates import Hadamard, TGate, Toffoli + + func1 = CostingBloq( + 'Func1', num_qubits=10, callees=[(TGate(), 10), (TGate().adjoint(), 10), (Hadamard(), 10)] + ) + func2 = CostingBloq('Func2', num_qubits=3, callees=[(Toffoli(), 100)]) + algo = CostingBloq('Algo', num_qubits=100, callees=[(func1, 1), (func2, 1)]) + return algo diff --git a/qualtran/bloqs/for_testing/costing_test.py b/qualtran/bloqs/for_testing/costing_test.py new file mode 100644 index 000000000..fb8340b74 --- /dev/null +++ b/qualtran/bloqs/for_testing/costing_test.py @@ -0,0 +1,32 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from qualtran.bloqs.for_testing.costing import make_example_costing_bloqs +from qualtran.resource_counting import format_call_graph_debug_text + + +def test_costing_bloqs(): + algo = make_example_costing_bloqs() + g, _ = algo.call_graph() + assert ( + format_call_graph_debug_text(g) + == """\ +Algo -- 1 -> Func1 +Algo -- 1 -> Func2 +Func1 -- 10 -> Hadamard() +Func1 -- 10 -> TGate() +Func1 -- 10 -> TGate(is_adjoint=True) +Func2 -- 100 -> Toffoli() +Toffoli() -- 4 -> TGate()""" + ) diff --git a/qualtran/bloqs/phase_estimation/lp_resource_state.py b/qualtran/bloqs/phase_estimation/lp_resource_state.py index 7105a0a93..fab19e6df 100644 --- a/qualtran/bloqs/phase_estimation/lp_resource_state.py +++ b/qualtran/bloqs/phase_estimation/lp_resource_state.py @@ -20,7 +20,7 @@ import cirq import numpy as np import sympy -from numpy._typing import NDArray +from numpy.typing import NDArray from qualtran import ( Bloq, diff --git a/qualtran/resource_counting/__init__.py b/qualtran/resource_counting/__init__.py index dc77fd876..0f9d1885f 100644 --- a/qualtran/resource_counting/__init__.py +++ b/qualtran/resource_counting/__init__.py @@ -25,8 +25,10 @@ SympySymbolAllocator, get_bloq_callee_counts, get_bloq_call_graph, - print_counts_graph, build_cbloq_call_graph, + format_call_graph_debug_text, ) +from ._costing import GeneralizerT, get_cost_value, get_cost_cache, query_costs, CostKey, CostValT + from . import generalizers diff --git a/qualtran/resource_counting/_call_graph.py b/qualtran/resource_counting/_call_graph.py index 72bf75965..0e6017a97 100644 --- a/qualtran/resource_counting/_call_graph.py +++ b/qualtran/resource_counting/_call_graph.py @@ -14,7 +14,7 @@ """Functionality for the `Bloq.call_graph()` protocol.""" -import collections.abc as abc +import collections.abc from collections import defaultdict from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union @@ -231,7 +231,7 @@ def get_bloq_call_graph( keep = lambda b: False if generalizer is None: generalizer = lambda b: b - if isinstance(generalizer, abc.Sequence): + if isinstance(generalizer, collections.abc.Sequence): generalizer = _make_composite_generalizer(*generalizer) g = nx.DiGraph() @@ -243,8 +243,11 @@ def get_bloq_call_graph( return g, sigma -def print_counts_graph(g: nx.DiGraph): +def format_call_graph_debug_text(g: nx.DiGraph) -> str: """Print the graph returned from `get_bloq_counts_graph`.""" - for b in nx.topological_sort(g): - for succ in g.succ[b]: - print(b, '--', g.edges[b, succ]['n'], '->', succ) + lines = [] + for gen in nx.topological_generations(g): + for b in sorted(gen, key=str): + for succ in sorted(g.succ[b], key=str): + lines.append(f"{b} -- {g.edges[b, succ]['n']} -> {succ}") + return '\n'.join(lines) diff --git a/qualtran/resource_counting/_costing.py b/qualtran/resource_counting/_costing.py new file mode 100644 index 000000000..da866d024 --- /dev/null +++ b/qualtran/resource_counting/_costing.py @@ -0,0 +1,249 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import collections +import logging +import time +from collections import defaultdict +from typing import ( + Callable, + Dict, + Generic, + Iterable, + Optional, + Sequence, + TYPE_CHECKING, + TypeVar, + Union, +) + +from ._generalization import _make_composite_generalizer, GeneralizerT + +if TYPE_CHECKING: + from qualtran import Bloq + +logger = logging.getLogger(__name__) + +CostValT = TypeVar('CostValT') + + +class CostKey(Generic[CostValT], metaclass=abc.ABCMeta): + """Abstract base class for different types of costs. + + One important aspect of a bloq is the resources required to execute it on an error + corrected quantum computer. Since we're usually trying to minimize these resource requirements + we will generally use the catch-all term "costs". + + There are a variety of different types or flavors of costs. Each is represented by an + instance of a sublcass of `CostKey`. For example, gate counts (including T-gate counts), + qubit requirements, and circuit depth are all cost metrics that may be of interest. + + Each `CostKey` primarily encodes the behavior required to compute a cost value from a + bloq. Often, these costs are defined recursively: a bloq's costs is some combination + of the costs of the bloqs in its decomposition (i.e. the bloq 'callees'). Implementors + must override the `compute` method to define the cost computation. + + Each cost key has an associated CostValT. For example, the CostValT of a "t count" + CostKey could be an integer. For a more complicated gateset, the value could be a mapping + from gate to count. This abstract base class is generic w.r.t. `CostValT`. Subclasses + should have a concrete value type. The `validate_val` method can optionally be overridden + to raise an exception if a bad value type is encountered. The `zero` method must return + the zero (additive identity) cost value of the correct type. + """ + + @abc.abstractmethod + def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], CostValT]) -> CostValT: + """Compute this type of cost. + + When implementing a new CostKey, this method must be overridden. + Users should not call this method directly. Instead: use the `qualtran.resource_counting` + functions like `get_cost_value`, `get_cost_cache`, or `query_costs`. These provide + caching, logging, generalizers, and support for static costs. + + For recursive computations, use the provided callable to recurse. + + Args: + bloq: The bloq to compute the cost of. + get_callee_cost: A qualtran-provided function for computing costs for "callees" + of the bloq; i.e. bloqs in the decomposition. Use this function to accurately + cache intermediate cost values and respect bloqs' static costs. + + Returns: + A value of the generic type `CostValT`. Subclasses should define their value type. + """ + + @abc.abstractmethod + def zero(self) -> CostValT: + """The value corresponding to zero cost.""" + + def validate_val(self, val: CostValT): + """Assert that `val` is a valid `CostValT`. + + This method can be optionally overridden to raise an error if an invalid value + is encountered. By default, no validation is performed. + """ + + +def _get_cost_value( + bloq: 'Bloq', + cost_key: CostKey[CostValT], + *, + costs_cache: Dict['Bloq', CostValT], + generalizer: 'GeneralizerT', +) -> CostValT: + """Helper function for getting costs. + + This function tries the following strategies + 1. Use the value found in `costs_cache`, if it exists. + 2. Use the value returned by `Bloq.my_static_costs` if one is returned. + 3. Use `cost_key.compute()` and cache the result in `costs_cache`. + + Args: + bloq: The bloq. + cost_key: The cost key to get the value for. + costs_cache: A dictionary to use as a cache for computed bloq costs. This cache + will be mutated by this function. + generalizer: The generalizer to operate on each bloq before computing its cost. + """ + bloq = generalizer(bloq) + if bloq is None: + return cost_key.zero() + + # Strategy 1: Use cached value + if bloq in costs_cache: + logger.debug("Using cached %s for %s", cost_key, bloq) + return costs_cache[bloq] + + # Strategy 2: Static costs + static_cost = bloq.my_static_costs(cost_key) + if static_cost is not NotImplemented: + cost_key.validate_val(static_cost) + logger.info("Using static %s for %s", cost_key, bloq) + costs_cache[bloq] = static_cost + return static_cost + + # Strategy 3: Compute + # part a. set up caching of computed costs by currying the costs_cache. + def _get_cost_val_internal(callee: 'Bloq'): + return _get_cost_value(callee, cost_key, costs_cache=costs_cache, generalizer=generalizer) + + # part b. call the compute method and cache the result. + tstart = time.perf_counter() + computed_cost = cost_key.compute(bloq, _get_cost_val_internal) + tdur = time.perf_counter() - tstart + logger.info("Computed %s for %s in %g s", cost_key, bloq, tdur) + costs_cache[bloq] = computed_cost + return computed_cost + + +def get_cost_value( + bloq: 'Bloq', + cost_key: CostKey[CostValT], + costs_cache: Optional[Dict['Bloq', CostValT]] = None, + generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None, +) -> CostValT: + """Compute the specified cost of the provided bloq. + + Args: + bloq: The bloq to compute the cost of. + cost_key: A CostKey that specifies which cost to compute. + costs_cache: If provided, use this dictionary of cached cost values. Values in this + dictionary will be preferred over computed values (even if they disagree). This + dictionary will be mutated by the function. + generalizer: If provided, run this function on each bloq in the call graph to dynamically + modify attributes. If the function returns `None`, the bloq is ignored in the + cost computation. If a sequence of generalizers is provided, each generalizer + will be run in order. + + Returns: + The cost value. Its type depends on the provided `cost_key`. + """ + if costs_cache is None: + costs_cache = {} + if generalizer is None: + generalizer = lambda b: b + if isinstance(generalizer, collections.abc.Sequence): + generalizer = _make_composite_generalizer(*generalizer) + + cost_val = _get_cost_value(bloq, cost_key, costs_cache=costs_cache, generalizer=generalizer) + return cost_val + + +def get_cost_cache( + bloq: 'Bloq', + cost_key: CostKey[CostValT], + costs_cache: Optional[Dict['Bloq', CostValT]] = None, + generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None, +) -> Dict['Bloq', CostValT]: + """Build a cache of cost values for the bloq and its callees. + + This can be useful to inspect how callees' costs flow upwards in a given cost computation. + + Args: + bloq: The bloq to seed the cost computation. + cost_key: A CostKey that specifies which cost to compute. + costs_cache: If provided, use this dictionary for initial cached cost values. Values in this + dictionary will be preferred over computed values (even if they disagree). This + dictionary will be mutated by the function. This dictionary will be returned by the + function. + generalizer: If provided, run this function on each bloq in the call graph to dynamically + modify attributes. If the function returns `None`, the bloq is ignored in the + cost computation. If a sequence of generalizers is provided, each generalizer + will be run in order. + + Returns: + A dictionary mapping bloqs to cost values. The value type depends on the `cost_key`. + The bloqs in the mapping depend on the recursive nature of the cost key. + """ + if costs_cache is None: + costs_cache = {} + if generalizer is None: + generalizer = lambda b: b + if isinstance(generalizer, collections.abc.Sequence): + generalizer = _make_composite_generalizer(*generalizer) + + _get_cost_value(bloq, cost_key, costs_cache=costs_cache, generalizer=generalizer) + return costs_cache + + +def query_costs( + bloq: 'Bloq', + cost_keys: Iterable[CostKey], + generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None, +) -> Dict['Bloq', Dict[CostKey, CostValT]]: + """Compute a selection of costs for a bloq and its callees. + + This function can be used to annotate a call graph diagram with multiple costs + for each bloq. Specifically, the return value of this function can be used as the + `bloq_data` argument to `GraphvizCallGraph`. + + Args: + bloq: The bloq to seed the cost computation. + cost_keys: A sequence of CostKey that specifies which costs to compute. + generalizer: If provided, run this function on each bloq in the call graph to dynamically + modify attributes. If the function returns `None`, the bloq is ignored in the + cost computation. If a sequence of generalizers is provided, each generalizer + will be run in order. + + Returns: + A dictionary of dictionaries forming a table of multiple costs for multiple bloqs. + This is indexed by bloq, then cost key. + """ + costs: Dict['Bloq', Dict[CostKey, CostValT]] = defaultdict(dict) + for cost_key in cost_keys: + cost_for_bloqs = get_cost_cache(bloq, cost_key, generalizer=generalizer) + for bloq, val in cost_for_bloqs.items(): + costs[bloq][cost_key] = val + return dict(costs) diff --git a/qualtran/resource_counting/_costing_test.py b/qualtran/resource_counting/_costing_test.py new file mode 100644 index 000000000..30cfbd2ba --- /dev/null +++ b/qualtran/resource_counting/_costing_test.py @@ -0,0 +1,111 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List + +import attrs + +from qualtran import Bloq +from qualtran.bloqs.basic_gates import Hadamard, TGate +from qualtran.bloqs.for_testing.costing import CostingBloq, make_example_costing_bloqs +from qualtran.resource_counting import ( + CostKey, + get_bloq_callee_counts, + get_cost_cache, + get_cost_value, +) +from qualtran.resource_counting.generalizers import generalize_rotation_angle + + +class TestCostKey(CostKey[int]): + def __init__(self): + # For testing, keep a log of all the bloqs for which we called 'compute' on. + self._log: List[Bloq] = [] + + def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], int]) -> int: + self._log.append(bloq) + + total = 1 + for callee, n_times_called in get_bloq_callee_counts(bloq): + total += n_times_called * get_callee_cost(callee) + + return total + + def zero(self) -> int: + return 0 + + def __hash__(self): + return hash(self.__class__) + + def __eq__(self, other): + return isinstance(other, self.__class__) + + +def test_get_cost_value_caching(): + cost = TestCostKey() + algo = make_example_costing_bloqs() + assert isinstance(algo, CostingBloq) + _ = get_cost_value(algo, cost) + n_times_compute_called_on_t = sum(b == TGate() for b in cost._log) + assert n_times_compute_called_on_t == 1, 'should use cached value' + + +def test_get_cost_value_static(): + algo = make_example_costing_bloqs() + + # Modify "func1" to have static costs + func1 = algo.callees[0][0] + func1_mod = attrs.evolve(func1, static_costs=[(TestCostKey(), 123)]) + algo_mod = attrs.evolve(algo, callees=[(func1_mod, 1), algo.callees[1]]) + assert get_cost_value(func1_mod, TestCostKey()) == 123 + + # Should not call "compute" for Func1, since it has static costs + # Should not have to recurse into H, T^dag; which is only used by Func1 + cost = TestCostKey() + _ = get_cost_value(algo_mod, cost) + assert len(cost._log) == 4 + assert 'Func2' in [str(b) for b in cost._log] + assert 'Func1' not in [str(b) for b in cost._log] + assert TGate().adjoint() not in cost._log + assert Hadamard() not in cost._log + + +def test_get_cost_value_static_user_provided(): + cost = TestCostKey() + algo = make_example_costing_bloqs() + + # Provide cached costs up front for func1 + func1 = algo.callees[0][0] + + # Should not call "compute" for Func1, since we supplied an existing cache + # Should not have to recurse into H, T^dag; which is only used by Func1 + _ = get_cost_value(algo, cost, costs_cache={func1: 0}) + assert len(cost._log) == 4 + assert 'Func2' in [str(b) for b in cost._log] + assert 'Func1' not in [str(b) for b in cost._log] + assert TGate().adjoint() not in cost._log + assert Hadamard() not in cost._log + + +def test_costs_generalizer(): + assert generalize_rotation_angle(TGate().adjoint()) == TGate() + + algo = CostingBloq(name='algo', num_qubits=1, callees=[(TGate(), 1), (TGate().adjoint(), 1)]) + cost_cache = get_cost_cache(algo, TestCostKey()) + assert cost_cache[algo] == 3 + + cost_cache_gen = get_cost_cache(algo, TestCostKey(), generalizer=generalize_rotation_angle) + assert TGate() in cost_cache_gen + assert TGate().adjoint() not in cost_cache_gen + assert cost_cache_gen[algo] == 3