Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize gate count of Subtract to n-1 Toffolis #1057

Merged
merged 25 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 36 additions & 26 deletions qualtran/bloqs/arithmetic/subtraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
bloq_example,
BloqBuilder,
BloqDocSpec,
QAny,
QInt,
QMontgomeryUInt,
QUInt,
Expand All @@ -31,8 +32,9 @@
Soquet,
SoquetT,
)
from qualtran.bloqs.arithmetic.addition import Add, AddK
from qualtran.bloqs.basic_gates import XGate
from qualtran.bloqs.arithmetic.addition import Add
from qualtran.bloqs.basic_gates import OnEach, XGate
from qualtran.bloqs.bookkeeping import Allocate, Free
from qualtran.drawing import Text

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,7 +63,9 @@ class Subtract(Bloq):
b: A b_dtype.bitsize-sized input/output register (register b above).
"""

a_dtype: Union[QInt, QUInt, QMontgomeryUInt] = field()
a_dtype: Union[QInt, QUInt, QMontgomeryUInt] = field(
converter=lambda k: QUInt(k) if isinstance(k, int) else k
)
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved
b_dtype: Union[QInt, QUInt, QMontgomeryUInt] = field()

@b_dtype.default
Expand Down Expand Up @@ -118,32 +122,38 @@ def wire_symbol(
raise ValueError()

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
a_dtype = (
self.a_dtype if not isinstance(self.a_dtype, QInt) else QUInt(self.a_dtype.bitsize)
)
b_dtype = (
self.b_dtype if not isinstance(self.b_dtype, QInt) else QUInt(self.b_dtype.bitsize)
)
delta = self.b_dtype.bitsize - self.a_dtype.bitsize
return {
(XGate(), self.b_dtype.bitsize),
(AddK(self.b_dtype.bitsize, k=1), 1),
(Add(a_dtype, b_dtype), 1),
}

def build_composite_bloq(self, bb: 'BloqBuilder', a: Soquet, b: Soquet) -> Dict[str, 'SoquetT']:
b = np.array([bb.add(XGate(), q=q) for q in bb.split(b)]) # 1s complement of b.
b = bb.add(
AddK(self.b_dtype.bitsize, k=1), x=bb.join(b, self.b_dtype)
) # 2s complement of b.

a_dtype = (
self.a_dtype if not isinstance(self.a_dtype, QInt) else QUInt(self.a_dtype.bitsize)
)
b_dtype = (
self.b_dtype if not isinstance(self.b_dtype, QInt) else QUInt(self.b_dtype.bitsize)
(OnEach(self.b_dtype.bitsize, XGate()), 3),
(Add(QUInt(self.b_dtype.bitsize), QUInt(self.b_dtype.bitsize)), 1),
}.union(
[
(Allocate(QAny(self.b_dtype.bitsize - self.a_dtype.bitsize)), 1),
(Free(QAny(self.b_dtype.bitsize - self.a_dtype.bitsize)), 1),
]
if delta
else []
)

a, b = bb.add(Add(a_dtype, b_dtype), a=a, b=b) # a - b
def build_composite_bloq(self, bb: 'BloqBuilder', a: Soquet, b: Soquet) -> Dict[str, 'SoquetT']:
delta = self.b_dtype.bitsize - self.a_dtype.bitsize
n_bits = self.b_dtype.bitsize
a = bb.split(a)
b = bb.split(b)
if delta:
# Add a zero prefix to `a`
a = np.concatenate([bb.split(bb.allocate(delta)), a])
a = bb.join(a, QUInt(n_bits))
b = bb.join(b, QUInt(n_bits))
a = bb.add(OnEach(n_bits, XGate()), q=a)
a, b = bb.add(Add(QUInt(n_bits), QUInt(n_bits)), a=a, b=b) # a - b
b = bb.add(OnEach(n_bits, XGate()), q=b)
a = bb.add(OnEach(n_bits, XGate()), q=a)
b = bb.join(bb.split(b), self.b_dtype)
a = bb.split(a)
if delta:
bb.free(bb.join(a[:delta]))
a = bb.join(a[delta:], self.a_dtype)
return {'a': a, 'b': b}


Expand Down
26 changes: 19 additions & 7 deletions qualtran/bloqs/arithmetic/subtraction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,20 @@
from qualtran.resource_counting.generalizers import ignore_split_join


def test_subtract_bloq_decomposition():
gate = Subtract(QInt(3), QInt(5))
@pytest.mark.parametrize(
['a_bits', 'b_bits'], [(a, b) for a in range(1, 6) for b in range(a, 6) if a + b <= 10]
)
def test_subtract_bloq_decomposition(a_bits, b_bits):
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved
gate = Subtract(QInt(a_bits), QInt(b_bits))
qlt_testing.assert_valid_bloq_decomposition(gate)

want = np.zeros((256, 256))
for a_b in range(256):
a, b = a_b >> 5, a_b & 31
c = (a - b) % 32
want[(a << 5) | c][a_b] = 1
tot = 1 << (a_bits + b_bits)
want = np.zeros((tot, tot))
max_b = 1 << b_bits
for a_b in range(tot):
a, b = a_b >> b_bits, a_b & (max_b - 1)
c = (a - b) % max_b
want[(a << b_bits) | c][a_b] = 1
got = gate.tensor_contract()
np.testing.assert_equal(got, want)

Expand All @@ -45,3 +50,10 @@ def test_subtract_bloq_consitant_counts():
qlt_testing.assert_equivalent_bloq_counts(
Subtract(QInt(3), QInt(4)), generalizer=ignore_split_join
)


@pytest.mark.parametrize('n_bits', range(1, 10))
def test_t_complexity(n_bits):
complexity = Subtract(n_bits).t_complexity()
assert complexity.t == 4 * (n_bits - 1)
assert complexity.rotations == 0
Loading