Skip to content

Commit

Permalink
Add SubtractFrom to subtract from a register in place (#1158)
Browse files Browse the repository at this point in the history
* Add `SubtractFrom` to subtract from a register in place

* Update docstring

* Change output to |a>|b-a>

* Update subtraction.py

Co-authored-by: Anurudh Peduri <[email protected]>

* Update docstring

---------

Co-authored-by: Anurudh Peduri <[email protected]>
  • Loading branch information
charlesyuan314 and anurudhp authored Jul 19, 2024
1 parent 85c1dab commit 22b3d5e
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 7 deletions.
5 changes: 4 additions & 1 deletion dev_tools/autogenerate-bloqs-notebooks-v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,10 @@
NotebookSpecV2(
title='Subtraction',
module=qualtran.bloqs.arithmetic.subtraction,
bloq_specs=[qualtran.bloqs.arithmetic.subtraction._SUB_DOC],
bloq_specs=[
qualtran.bloqs.arithmetic.subtraction._SUB_DOC,
qualtran.bloqs.arithmetic.subtraction._SUB_FROM_DOC,
],
),
NotebookSpecV2(
title='Multiplication',
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/arithmetic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@
)
from qualtran.bloqs.arithmetic.negate import Negate
from qualtran.bloqs.arithmetic.sorting import BitonicSort, Comparator
from qualtran.bloqs.arithmetic.subtraction import Subtract
from qualtran.bloqs.arithmetic.subtraction import Subtract, SubtractFrom

from ._shims import CHalf, Lt, MultiCToffoli
134 changes: 134 additions & 0 deletions qualtran/bloqs/arithmetic/subtraction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,140 @@
"show_call_graph(sub_symb_g)\n",
"show_counts_sigma(sub_symb_sigma)"
]
},
{
"cell_type": "markdown",
"id": "12650fb1",
"metadata": {
"cq.autogen": "SubtractFrom.bloq_doc.md"
},
"source": [
"## `SubtractFrom`\n",
"A version of `Subtract` that subtracts the first register from the second in place.\n",
"\n",
"Implements $U|a\n",
"angle|b\n",
"angle\n",
"ightarrow |a\n",
"angle|b - a\n",
"angle$, essentially equivalent to\n",
"the statement `b -= a`.\n",
"\n",
"#### Parameters\n",
" - `dtype`: Quantum datatype used to represent the integers a, b, and b - a. \n",
"\n",
"#### Registers\n",
" - `a`: A dtype.bitsize-sized input register (register a above).\n",
" - `b`: A dtype.bitsize-sized input/output register (register b above).\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "746c25eb",
"metadata": {
"cq.autogen": "SubtractFrom.bloq_doc.py"
},
"outputs": [],
"source": [
"from qualtran.bloqs.arithmetic import SubtractFrom"
]
},
{
"cell_type": "markdown",
"id": "75ea545f",
"metadata": {
"cq.autogen": "SubtractFrom.example_instances.md"
},
"source": [
"### Example Instances"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7ee57a8",
"metadata": {
"cq.autogen": "SubtractFrom.sub_from_symb"
},
"outputs": [],
"source": [
"n = sympy.Symbol('n')\n",
"sub_from_symb = SubtractFrom(QInt(bitsize=n))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8123a738",
"metadata": {
"cq.autogen": "SubtractFrom.sub_from_small"
},
"outputs": [],
"source": [
"sub_from_small = SubtractFrom(QInt(bitsize=4))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e469b17a",
"metadata": {
"cq.autogen": "SubtractFrom.sub_from_large"
},
"outputs": [],
"source": [
"sub_from_large = SubtractFrom(QInt(bitsize=64))"
]
},
{
"cell_type": "markdown",
"id": "a32af1e3",
"metadata": {
"cq.autogen": "SubtractFrom.graphical_signature.md"
},
"source": [
"#### Graphical Signature"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0fd2dbd7",
"metadata": {
"cq.autogen": "SubtractFrom.graphical_signature.py"
},
"outputs": [],
"source": [
"from qualtran.drawing import show_bloqs\n",
"show_bloqs([sub_from_symb, sub_from_small, sub_from_large],\n",
" ['`sub_from_symb`', '`sub_from_small`', '`sub_from_large`'])"
]
},
{
"cell_type": "markdown",
"id": "47e50576",
"metadata": {
"cq.autogen": "SubtractFrom.call_graph.md"
},
"source": [
"### Call Graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4564d561",
"metadata": {
"cq.autogen": "SubtractFrom.call_graph.py"
},
"outputs": [],
"source": [
"from qualtran.resource_counting.generalizers import ignore_split_join\n",
"sub_from_symb_g, sub_from_symb_sigma = sub_from_symb.call_graph(max_depth=1, generalizer=ignore_split_join)\n",
"show_call_graph(sub_from_symb_g)\n",
"show_counts_sigma(sub_from_symb_sigma)"
]
}
],
"metadata": {
Expand Down
87 changes: 84 additions & 3 deletions qualtran/bloqs/arithmetic/subtraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def b_dtype_default(self):
@a_dtype.validator
def _a_dtype_validate(self, field, val):
if not isinstance(val, (QInt, QUInt, QMontgomeryUInt)):
raise ValueError("Only QInt, QUInt and QMontgomerUInt types are supported.")
raise ValueError("Only QInt, QUInt and QMontgomeryUInt types are supported.")
if isinstance(val.num_qubits, sympy.Expr):
return
if val.bitsize > self.b_dtype.bitsize:
Expand All @@ -78,13 +78,13 @@ def _a_dtype_validate(self, field, val):
@b_dtype.validator
def _b_dtype_validate(self, field, val):
if not isinstance(val, (QInt, QUInt, QMontgomeryUInt)):
raise ValueError("Only QInt, QUInt and QMontgomerUInt types are supported.")
raise ValueError("Only QInt, QUInt and QMontgomeryUInt types are supported.")

@property
def dtype(self):
if self.a_dtype != self.b_dtype:
raise ValueError(
"Add.dtype is only supported when both operands have the same dtype: "
"Subtract.dtype is only supported when both operands have the same dtype: "
f"{self.a_dtype=}, {self.b_dtype=}"
)
return self.a_dtype
Expand Down Expand Up @@ -168,3 +168,84 @@ def _sub_diff_size_regs() -> Subtract:
_SUB_DOC = BloqDocSpec(
bloq_cls=Subtract, examples=[_sub_symb, _sub_small, _sub_large, _sub_diff_size_regs]
)


@frozen
class SubtractFrom(Bloq):
"""A version of `Subtract` that subtracts the first register from the second in place.
Implements $U|a\rangle|b\rangle \rightarrow |a\rangle|b - a\rangle$, essentially equivalent to
the statement `b -= a`.
Args:
dtype: Quantum datatype used to represent the integers a, b, and b - a.
Registers:
a: A dtype.bitsize-sized input register (register a above).
b: A dtype.bitsize-sized input/output register (register b above).
"""

dtype: Union[QInt, QUInt, QMontgomeryUInt] = field()

@dtype.validator
def _dtype_validate(self, field, val):
if not isinstance(val, (QInt, QUInt, QMontgomeryUInt)):
raise ValueError("Only QInt, QUInt and QMontgomeryUInt types are supported.")

@property
def signature(self):
return Signature([Register("a", self.dtype), Register("b", self.dtype)])

def on_classical_vals(
self, a: 'ClassicalValT', b: 'ClassicalValT'
) -> Dict[str, 'ClassicalValT']:
unsigned = isinstance(self.dtype, (QUInt, QMontgomeryUInt))
bitsize = self.dtype.bitsize
N = 2**bitsize if unsigned else 2 ** (bitsize - 1)
return {'a': a, 'b': int(math.fmod(b - a, N))}

def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
from qualtran.drawing import directional_text_box

if reg is None:
return Text('')
if reg.name == 'a':
return directional_text_box('a', side=reg.side)
elif reg.name == 'b':
return directional_text_box('b - a', side=reg.side)
else:
raise ValueError()

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {(Negate(self.dtype), 1), (Subtract(self.dtype, self.dtype), 1)}

def build_composite_bloq(self, bb: 'BloqBuilder', a: Soquet, b: Soquet) -> Dict[str, 'SoquetT']:
a, b = bb.add_t(Subtract(self.dtype, self.dtype), a=a, b=b) # a, a - b
b = bb.add(Negate(self.dtype), x=b) # a, b - a
return {'a': a, 'b': b}


@bloq_example
def _sub_from_symb() -> SubtractFrom:
n = sympy.Symbol('n')
sub_from_symb = SubtractFrom(QInt(bitsize=n))
return sub_from_symb


@bloq_example
def _sub_from_small() -> SubtractFrom:
sub_from_small = SubtractFrom(QInt(bitsize=4))
return sub_from_small


@bloq_example
def _sub_from_large() -> SubtractFrom:
sub_from_large = SubtractFrom(QInt(bitsize=64))
return sub_from_large


_SUB_FROM_DOC = BloqDocSpec(
bloq_cls=SubtractFrom, examples=[_sub_from_symb, _sub_from_small, _sub_from_large]
)
59 changes: 57 additions & 2 deletions qualtran/bloqs/arithmetic/subtraction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,36 @@

import qualtran.testing as qlt_testing
from qualtran import QInt, QUInt
from qualtran.bloqs.arithmetic import Subtract
from qualtran.bloqs.arithmetic.subtraction import (
_sub_diff_size_regs,
_sub_from_large,
_sub_from_small,
_sub_from_symb,
_sub_large,
_sub_small,
_sub_symb,
Subtract,
SubtractFrom,
)
from qualtran.resource_counting.generalizers import ignore_split_join


def test_sub_symb(bloq_autotester):
bloq_autotester(_sub_symb)


def test_sub_small(bloq_autotester):
bloq_autotester(_sub_small)


def test_sub_large(bloq_autotester):
bloq_autotester(_sub_large)


def test_sub_diff_size_regs(bloq_autotester):
bloq_autotester(_sub_diff_size_regs)


def test_subtract_bloq_decomposition():
gate = Subtract(QInt(3), QInt(5))
qlt_testing.assert_valid_bloq_decomposition(gate)
Expand All @@ -41,7 +67,36 @@ def test_subtract_bloq_validation():
assert Subtract(QUInt(3)).dtype == QUInt(3)


def test_subtract_bloq_consitant_counts():
def test_subtract_bloq_consistent_counts():
qlt_testing.assert_equivalent_bloq_counts(
Subtract(QInt(3), QInt(4)), generalizer=ignore_split_join
)


def test_sub_from_symb(bloq_autotester):
bloq_autotester(_sub_from_symb)


def test_sub_from_small(bloq_autotester):
bloq_autotester(_sub_from_small)


def test_sub_from_large(bloq_autotester):
bloq_autotester(_sub_from_large)


def test_subtract_from_bloq_decomposition():
gate = SubtractFrom(QInt(4))
qlt_testing.assert_valid_bloq_decomposition(gate)

want = np.zeros((256, 256))
for a_b in range(256):
a, b = a_b >> 4, a_b & 15
c = (b - a) % 16
want[(a << 4) | c][a_b] = 1
got = gate.tensor_contract()
np.testing.assert_allclose(got, want)


def test_subtract_from_bloq_consistent_counts():
qlt_testing.assert_equivalent_bloq_counts(SubtractFrom(QInt(3)), generalizer=ignore_split_join)
2 changes: 2 additions & 0 deletions qualtran/serialization/resolver_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import qualtran.bloqs.arithmetic.negate
import qualtran.bloqs.arithmetic.permutation
import qualtran.bloqs.arithmetic.sorting
import qualtran.bloqs.arithmetic.subtraction
import qualtran.bloqs.basic_gates.cnot
import qualtran.bloqs.basic_gates.hadamard
import qualtran.bloqs.basic_gates.identity
Expand Down Expand Up @@ -177,6 +178,7 @@
"qualtran.bloqs.arithmetic.sorting.Comparator": qualtran.bloqs.arithmetic.sorting.Comparator,
"qualtran.bloqs.arithmetic.sorting.ParallelComparators": qualtran.bloqs.arithmetic.sorting.ParallelComparators,
"qualtran.bloqs.arithmetic.subtraction.Subtract": qualtran.bloqs.arithmetic.subtraction.Subtract,
"qualtran.bloqs.arithmetic.subtraction.SubtractFrom": qualtran.bloqs.arithmetic.subtraction.SubtractFrom,
"qualtran.bloqs.basic_gates.cnot.CNOT": qualtran.bloqs.basic_gates.cnot.CNOT,
"qualtran.bloqs.basic_gates.identity.Identity": qualtran.bloqs.basic_gates.identity.Identity,
"qualtran.bloqs.basic_gates.global_phase.GlobalPhase": qualtran.bloqs.basic_gates.global_phase.GlobalPhase,
Expand Down

0 comments on commit 22b3d5e

Please sign in to comment.