Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change build_call_graph in bloqs to return dict #1392

Merged
merged 5 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions qualtran/_infra/composite_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

from qualtran.bloqs.bookkeeping.auto_partition import Unused
from qualtran.cirq_interop._cirq_to_bloq import CirqQuregInT, CirqQuregT
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT

# NDArrays must be bound to np.generic
Expand Down Expand Up @@ -237,7 +237,7 @@ def decompose_bloq(self) -> 'CompositeBloq':
"Consider using the composite bloq directly or using `.flatten()`."
)

def build_call_graph(self, ssa: Optional['SympySymbolAllocator']) -> Set['BloqCountT']:
def build_call_graph(self, ssa: Optional['SympySymbolAllocator']) -> 'BloqCountDictT':
"""Return the bloq counts by counting up all the subbloqs."""
from qualtran.resource_counting import build_cbloq_call_graph

Expand Down
11 changes: 5 additions & 6 deletions qualtran/bloqs/arithmetic/_shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
will be fleshed out and moved to their final organizational location soon (written: 2024-05-06).
"""
from functools import cached_property
from typing import Set

from attrs import frozen

from qualtran import Bloq, QBit, QUInt, Register, Signature
from qualtran.bloqs.basic_gates import Toffoli
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator


@frozen
Expand All @@ -36,8 +35,8 @@ class MultiCToffoli(Bloq):
def signature(self) -> 'Signature':
return Signature([Register('ctrl', QBit(), shape=(self.n,)), Register('target', QBit())])

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(Toffoli(), self.n - 2)}
def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
return {Toffoli(): self.n - 2}


@frozen
Expand All @@ -51,9 +50,9 @@ def signature(self) -> 'Signature':
[Register('x', QUInt(self.n)), Register('y', QUInt(self.n)), Register('out', QBit())]
)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
# litinski
return {(Toffoli(), self.n)}
return {Toffoli(): self.n}


@frozen
Expand Down
25 changes: 15 additions & 10 deletions qualtran/bloqs/arithmetic/addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@

if TYPE_CHECKING:
from qualtran.drawing import WireSymbol
from qualtran.resource_counting import BloqCountDictT, BloqCountT, SympySymbolAllocator
from qualtran.resource_counting import (
BloqCountDictT,
BloqCountT,
MutableBloqCountDictT,
SympySymbolAllocator,
)
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics import SymbolicInt

Expand Down Expand Up @@ -209,10 +214,10 @@ def decompose_from_registers(
yield CNOT().on(input_bits[0], output_bits[0])
context.qubit_manager.qfree(ancillas)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
n = self.b_dtype.bitsize
n_cnot = (n - 2) * 6 + 3
return {(And(), n - 1), (And().adjoint(), n - 1), (CNOT(), n_cnot)}
return {And(): n - 1, And().adjoint(): n - 1, CNOT(): n_cnot}


@bloq_example(generalizer=ignore_split_join)
Expand Down Expand Up @@ -330,8 +335,8 @@ def decompose_from_registers(
]
return cirq.inverse(optree) if self.is_adjoint else optree

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(And(uncompute=self.is_adjoint), self.bitsize), (CNOT(), 5 * self.bitsize)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {And(uncompute=self.is_adjoint): self.bitsize, CNOT(): 5 * self.bitsize}

def __pow__(self, power: int):
if power == 1:
Expand Down Expand Up @@ -503,16 +508,16 @@ def build_composite_bloq(
def build_call_graph(
self, ssa: 'SympySymbolAllocator'
) -> Union['BloqCountDictT', Set['BloqCountT']]:
loading_cost: Tuple[Bloq, SymbolicInt]
loading_cost: MutableBloqCountDictT
if len(self.cvs) == 0:
loading_cost = (XGate(), self.bitsize) # upper bound; depends on the data.
loading_cost = {XGate(): self.bitsize} # upper bound; depends on the data.
elif len(self.cvs) == 1:
loading_cost = (CNOT(), self.bitsize) # upper bound; depends on the data.
loading_cost = {CNOT(): self.bitsize} # upper bound; depends on the data.
else:
# Otherwise, use the decomposition
return super().build_call_graph(ssa=ssa)

return {loading_cost, (Add(QUInt(self.bitsize)), 1)}
loading_cost[Add(QUInt(self.bitsize))] = 1
return loading_cost

def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> Tuple['Bloq', 'AddControlledT']:
if self.cvs:
Expand Down
10 changes: 5 additions & 5 deletions qualtran/bloqs/arithmetic/bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from qualtran.symbolics import is_symbolic, SymbolicInt

if TYPE_CHECKING:
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT


Expand Down Expand Up @@ -90,9 +90,9 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: 'Soquet') -> dict[str, 'Soq

return {'x': x}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> set['BloqCountT']:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
num_flips = self.bitsize if self.is_symbolic() else sum(self._bits_k)
return {(XGate(), num_flips)}
return {XGate(): num_flips}

def on_classical_vals(self, x: 'ClassicalValT') -> dict[str, 'ClassicalValT']:
if isinstance(self.k, sympy.Expr):
Expand Down Expand Up @@ -156,8 +156,8 @@ def build_composite_bloq(self, bb: BloqBuilder, x: Soquet, y: Soquet) -> dict[st

return {'x': bb.join(xs, dtype=self.dtype), 'y': bb.join(ys, dtype=self.dtype)}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> set['BloqCountT']:
return {(CNOT(), self.dtype.num_qubits)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {CNOT(): self.dtype.num_qubits}

def on_classical_vals(
self, x: 'ClassicalValT', y: 'ClassicalValT'
Expand Down
108 changes: 51 additions & 57 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,7 @@

from collections import defaultdict
from functools import cached_property
from typing import (
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from typing import Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import attrs
import cirq
Expand Down Expand Up @@ -65,7 +54,11 @@

if TYPE_CHECKING:
from qualtran import BloqBuilder
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.resource_counting import (
BloqCountDictT,
MutableBloqCountDictT,
SympySymbolAllocator,
)
from qualtran.simulation.classical_sim import ClassicalValT


Expand Down Expand Up @@ -183,22 +176,22 @@ def decompose_from_registers(
def _has_unitary_(self):
return True

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
if (
not is_symbolic(self.less_than_val, self.bitsize)
and self.less_than_val >= 2**self.bitsize
):
return {(XGate(), 1)}
return {XGate(): 1}
num_set_bits = (
int(self.less_than_val).bit_count()
if not is_symbolic(self.less_than_val)
else self.bitsize
)
return {
(And(), self.bitsize),
(And().adjoint(), self.bitsize),
(CNOT(), num_set_bits + 2 * self.bitsize),
(XGate(), 2 * (1 + num_set_bits)),
And(): self.bitsize,
And().adjoint(): self.bitsize,
CNOT(): num_set_bits + 2 * self.bitsize,
XGate(): 2 * (1 + num_set_bits),
}


Expand Down Expand Up @@ -307,8 +300,8 @@ def __pow__(self, power: int) -> 'BiQubitsMixer':
return self.adjoint()
return NotImplemented # pragma: no cover

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(XGate(), 1), (CNOT(), 9), (And(uncompute=self.is_adjoint), 2)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {XGate(): 1, CNOT(): 9, And(uncompute=self.is_adjoint): 2}

def _has_unitary_(self):
return not self.is_adjoint
Expand Down Expand Up @@ -380,8 +373,8 @@ def __pow__(self, power: int) -> Union['SingleQubitCompare', cirq.Gate]:
return self.adjoint()
return self

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(XGate(), 1), (CNOT(), 4), (And(uncompute=self.is_adjoint), 1)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {XGate(): 1, CNOT(): 4, And(uncompute=self.is_adjoint): 1}


@bloq_example
Expand Down Expand Up @@ -575,13 +568,13 @@ def decompose_from_registers(
all_ancilla = set([q for op in adjoint for q in op.qubits if q not in input_qubits])
context.qubit_manager.qfree(all_ancilla)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
if is_symbolic(self.x_bitsize, self.y_bitsize):
return {
(BiQubitsMixer(), self.x_bitsize),
(BiQubitsMixer().adjoint(), self.x_bitsize),
(SingleQubitCompare(), 1),
(SingleQubitCompare().adjoint(), 1),
BiQubitsMixer(): self.x_bitsize,
BiQubitsMixer().adjoint(): self.x_bitsize,
SingleQubitCompare(): 1,
SingleQubitCompare().adjoint(): 1,
}

n = min(self.x_bitsize, self.y_bitsize)
Expand Down Expand Up @@ -613,7 +606,7 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
ret[And(1, 0).adjoint()] += 1
ret[CNOT()] += 1

return set(ret.items())
return ret

def _has_unitary_(self):
return True
Expand Down Expand Up @@ -691,8 +684,8 @@ def build_composite_bloq(
target = bb.add(XGate(), q=target)
return {'a': a, 'b': b, 'target': target}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(LessThanEqual(self.a_bitsize, self.b_bitsize), 1), (XGate(), 1)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {LessThanEqual(self.a_bitsize, self.b_bitsize): 1, XGate(): 1}


@bloq_example
Expand Down Expand Up @@ -885,23 +878,23 @@ def wire_symbol(
return TextBox('t⨁(a>b)')
raise ValueError(f'Unknown register name {reg.name}')

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
if self.bitsize == 1:
return {(MultiControlX(cvs=(1, 0)), 1)}
return {MultiControlX(cvs=(1, 0)): 1}

if self.signed:
return {
(CNOT(), 6 * self.bitsize - 7),
(XGate(), 2 * self.bitsize + 2),
(And(), self.bitsize - 1),
(And(uncompute=True), self.bitsize - 1),
CNOT(): 6 * self.bitsize - 7,
XGate(): 2 * self.bitsize + 2,
And(): self.bitsize - 1,
And(uncompute=True): self.bitsize - 1,
}

return {
(CNOT(), 6 * self.bitsize - 1),
(XGate(), 2 * self.bitsize + 4),
(And(), self.bitsize),
(And(uncompute=True), self.bitsize),
CNOT(): 6 * self.bitsize - 1,
XGate(): 2 * self.bitsize + 4,
And(): self.bitsize,
And(uncompute=True): self.bitsize,
}


Expand Down Expand Up @@ -941,8 +934,8 @@ def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -
return TextBox(f"⨁(x > {self.val})")
raise ValueError(f'Unknown register symbol {reg.name}')

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(LessThanConstant(self.bitsize, less_than_val=self.val), 1)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {LessThanConstant(self.bitsize, less_than_val=self.val): 1}


@bloq_example
Expand Down Expand Up @@ -1007,8 +1000,8 @@ def build_composite_bloq(
x = bb.join(xs)
return {'x': x, 'target': target}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(MultiControlX(self.bits_k), 1)}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {MultiControlX(self.bits_k): 1}


def _make_equals_a_constant():
Expand Down Expand Up @@ -1134,21 +1127,22 @@ def on_classical_vals(
return {'ctrl': ctrl, 'a': a, 'b': b, 'target': target ^ (a > b)}
return {'ctrl': ctrl, 'a': a, 'b': b, 'target': target}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
signed_ops = []
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
signed_ops: 'MutableBloqCountDictT' = {}
if isinstance(self.dtype, QInt):
signed_ops = [
(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), 2),
(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(), 2),
]
signed_ops = {
SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)): 2,
SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)).adjoint(): 2,
}
dtype = attrs.evolve(self.dtype, bitsize=self.dtype.bitsize + 1)
return {
(BitwiseNot(dtype), 2),
(BitwiseNot(QUInt(dtype.bitsize + 1)), 2),
(OutOfPlaceAdder(self.dtype.bitsize + 1).adjoint(), 1),
(OutOfPlaceAdder(self.dtype.bitsize + 1), 1),
(MultiControlX((self.cv, 1)), 1),
}.union(signed_ops)
BitwiseNot(dtype): 2,
BitwiseNot(QUInt(dtype.bitsize + 1)): 2,
OutOfPlaceAdder(self.dtype.bitsize + 1).adjoint(): 1,
OutOfPlaceAdder(self.dtype.bitsize + 1): 1,
MultiControlX((self.cv, 1)): 1,
**signed_ops,
}


@bloq_example(generalizer=ignore_split_join)
Expand Down
12 changes: 6 additions & 6 deletions qualtran/bloqs/arithmetic/controlled_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Set, TYPE_CHECKING, Union
from typing import Dict, TYPE_CHECKING, Union

import numpy as np
import sympy
Expand Down Expand Up @@ -42,7 +42,7 @@
import quimb.tensor as qtn

from qualtran.drawing import WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT


Expand Down Expand Up @@ -155,11 +155,11 @@ def build_composite_bloq(
ctrl = bb.join(np.array([ctrl_q]))
return {'ctrl': ctrl, 'a': a, 'b': b}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {
(And(self.cv, 1), self.a_dtype.bitsize),
(Add(self.a_dtype, self.b_dtype), 1),
(And(self.cv, 1).adjoint(), self.a_dtype.bitsize),
And(self.cv, 1): self.a_dtype.bitsize,
Add(self.a_dtype, self.b_dtype): 1,
And(self.cv, 1).adjoint(): self.a_dtype.bitsize,
}


Expand Down
Loading
Loading