Skip to content

Commit

Permalink
Implement bit conversion for QMontgomeryUInt, fix dtype, symbolic dec…
Browse files Browse the repository at this point in the history
…omposition and classical action of modular gates, and add classical action tests (#1264)

* Implement bit conversion for QMontgomeryUInt, fix dtype, symbolic decomposition and classical action of modular gates, and add classical action tests

* address comments
  • Loading branch information
NoureldinYosri authored Aug 9, 2024
1 parent 2917eeb commit a4365b2
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 23 deletions.
5 changes: 3 additions & 2 deletions qualtran/_infra/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,10 +787,11 @@ def get_classical_domain(self) -> Iterable[Any]:
return range(2**self.bitsize)

def to_bits(self, x: int) -> List[int]:
raise NotImplementedError(f"to_bits not implemented for {self}")
self.assert_valid_classical_val(x)
return [int(x) for x in f'{int(x):0{self.bitsize}b}']

def from_bits(self, bits: Sequence[int]) -> int:
raise NotImplementedError(f"from_bits not implemented for {self}")
return int("".join(str(x) for x in bits), 2)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
Expand Down
7 changes: 7 additions & 0 deletions qualtran/_infra/data_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,10 @@ def test_fixed_point(val, width, signed):
_ = QFxp(width, width).to_fixed_width_int(-val)
bits_from_int = QUInt(width).to_bits(QFxp(width, width).to_fixed_width_int(val))
assert bits == bits_from_int


@pytest.mark.parametrize('bitsize', range(1, 6))
def test_montgomery_bit_conversion(bitsize):
dtype = QMontgomeryUInt(bitsize)
for v in range(1 << bitsize):
assert v == dtype.from_bits(dtype.to_bits(v))
24 changes: 23 additions & 1 deletion qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ class LinearDepthGreaterThan(Bloq):
[Improved quantum circuits for elliptic curve discrete logarithms](https://arxiv.org/abs/2306.08585).
"""

bitsize: int
bitsize: 'SymbolicInt'
signed: bool = False

@property
Expand All @@ -749,6 +749,9 @@ def on_classical_vals(
def build_composite_bloq(
self, bb: 'BloqBuilder', a: Soquet, b: Soquet, target: SoquetT
) -> Dict[str, 'SoquetT']:
if not isinstance(self.bitsize, int):
raise NotImplementedError(f'symbolic decomposition is not supported for {self}')

# Base Case: Comparing two qubits.
# Signed doesn't matter because we can't represent signed integers with 1 qubit.
if self.bitsize == 1:
Expand Down Expand Up @@ -875,6 +878,25 @@ def wire_symbol(
return TextBox('t⨁(a>b)')
raise ValueError(f'Unknown register name {reg.name}')

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
if self.bitsize == 1:
return {(MultiControlX(cvs=(1, 0)), 1)}

if self.signed:
return {
(CNOT(), 6 * self.bitsize - 7),
(XGate(), 2 * self.bitsize + 2),
(And(), self.bitsize - 1),
(And(uncompute=True), self.bitsize - 1),
}

return {
(CNOT(), 6 * self.bitsize - 1),
(XGate(), 2 * self.bitsize + 4),
(And(), self.bitsize),
(And(uncompute=True), self.bitsize),
}


@frozen
class GreaterThanConstant(Bloq):
Expand Down
2 changes: 2 additions & 0 deletions qualtran/bloqs/arithmetic/comparison_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity
from qualtran.cirq_interop.testing import assert_circuit_inp_out_cirqsim
from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join


def test_greater_than(bloq_autotester):
Expand Down Expand Up @@ -240,6 +241,7 @@ def test_greater_than_manual():
def test_linear_depth_greater_than_decomp(bitsize, signed):
bloq = LinearDepthGreaterThan(bitsize=bitsize, signed=signed)
qlt_testing.assert_valid_bloq_decomposition(bloq)
qlt_testing.assert_equivalent_bloq_counts(bloq, [ignore_alloc_free, ignore_split_join])


# TODO: write tests for signed integer comparison
Expand Down
40 changes: 34 additions & 6 deletions qualtran/bloqs/factoring/mod_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import cached_property
from typing import Dict, TYPE_CHECKING
from typing import Dict, Set, TYPE_CHECKING

from attrs import frozen

Expand All @@ -25,7 +25,9 @@

if TYPE_CHECKING:
from qualtran import BloqBuilder
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics import SymbolicInt


@frozen
Expand All @@ -49,8 +51,8 @@ class MontgomeryModSub(Bloq):
Fig 6c and 8
"""

bitsize: int
p: int
bitsize: 'SymbolicInt'
p: 'SymbolicInt'

@cached_property
def signature(self) -> 'Signature':
Expand All @@ -64,9 +66,13 @@ def signature(self) -> 'Signature':
def on_classical_vals(
self, x: 'ClassicalValT', y: 'ClassicalValT'
) -> Dict[str, 'ClassicalValT']:
return {'x': x, 'y': (y - x) % self.p}
if x < self.p and y < self.p:
return {'x': x, 'y': (y - x) % self.p}
return {'x': x, 'y': y}

def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']:
if not isinstance(self.bitsize, int):
raise NotImplementedError(f'symbolic decomposition is not supported for {self}')
# Bit flip all qubits in register x.
x_split = bb.split(x)
for i in range(self.bitsize):
Expand Down Expand Up @@ -94,6 +100,14 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[
def pretty_name(self) -> str:
return f'y = y - x mod {self.p}'

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {
(XGate(), 2 * self.bitsize),
(AddK(self.bitsize, self.p + 1, signed=False), 1),
(ModAdd(self.bitsize, self.p), 1),
(AddK(self.bitsize, self.p + 1, signed=False).adjoint(), 1),
}


@frozen
class MontgomeryModNeg(Bloq):
Expand All @@ -114,8 +128,8 @@ class MontgomeryModNeg(Bloq):
Fig 6b and 8
"""

bitsize: int
p: int
bitsize: 'SymbolicInt'
p: 'SymbolicInt'

@cached_property
def signature(self) -> 'Signature':
Expand All @@ -125,6 +139,8 @@ def on_classical_vals(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT']:
return {'x': (-1 * x) % self.p}

def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet) -> Dict[str, 'SoquetT']:
if not isinstance(self.bitsize, int):
raise NotImplementedError(f'symbolic decomposition is not supported for {self}')
# Initialize an ancilla qubit to |1>.
ctrl = bb.allocate(n=1)
ctrl = bb.add(XGate(), q=ctrl)
Expand Down Expand Up @@ -165,3 +181,15 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet) -> Dict[str, 'Soque

def pretty_name(self) -> str:
return f'x = -x mod {self.p}'

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
if not isinstance(self.bitsize, int):
raise NotImplementedError(f'symbolic call graph is not supported for {self}')

# TODO: support symbolic cost
return {
(XGate(), 2),
(MultiControlX(cvs=[0] * self.bitsize), 2),
(CNOT(), self.bitsize),
(AddK(bitsize=self.bitsize, k=self.p + 1, cvs=(1,), signed=False), 1),
}
45 changes: 44 additions & 1 deletion qualtran/bloqs/factoring/mod_sub_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,61 @@
# limitations under the License.

import pytest
import sympy

from qualtran.bloqs.factoring.mod_sub import MontgomeryModNeg, MontgomeryModSub
from qualtran.testing import assert_valid_bloq_decomposition
from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join
from qualtran.testing import assert_equivalent_bloq_counts, assert_valid_bloq_decomposition


@pytest.mark.parametrize('bitsize,p', [(1, 1), (2, 3), (5, 8)])
def test_montgomery_mod_neg_decomp(bitsize, p):
bloq = MontgomeryModNeg(bitsize=bitsize, p=p)
assert_valid_bloq_decomposition(bloq)
assert_equivalent_bloq_counts(bloq, [ignore_alloc_free, ignore_split_join])


@pytest.mark.parametrize('bitsize,p', [(1, 1), (2, 3), (5, 8)])
def test_montgomery_mod_sub_decomp(bitsize, p):
bloq = MontgomeryModSub(bitsize=bitsize, p=p)
assert_valid_bloq_decomposition(bloq)
assert_equivalent_bloq_counts(bloq, [ignore_alloc_free, ignore_split_join])


@pytest.mark.parametrize('bitsize', [*range(1, 5), sympy.Symbol('n')])
def test_montgomery_sub_complexity(bitsize):
tcomplexity = MontgomeryModSub(bitsize, sympy.Symbol('p')).t_complexity()
assert tcomplexity.t == 24 * bitsize - 12 # 6n toffoli
assert tcomplexity.rotations == 0


@pytest.mark.parametrize('bitsize', range(1, 5))
def test_montgomery_neg_complexity(bitsize):
tcomplexity = MontgomeryModNeg(bitsize, sympy.Symbol('p')).t_complexity()
assert tcomplexity.t == 12 * bitsize - 12 # 3n toffoli
assert tcomplexity.rotations == 0


@pytest.mark.parametrize(
['prime', 'bitsize'],
[(p, bitsize) for p in [11, 13, 31] for bitsize in range(1 + p.bit_length(), 8)],
)
def test_classical_action_montgomery_sub(bitsize, prime):
b = MontgomeryModSub(bitsize, prime)
cb = b.decompose_bloq()
valid_range = range(prime)
for x in valid_range:
for y in valid_range:
assert b.call_classically(x=x, y=y) == cb.call_classically(x=x, y=y)


@pytest.mark.parametrize(
['prime', 'bitsize'],
[(p, bitsize) for p in [11, 13, 31] for bitsize in range(1 + p.bit_length(), 8)],
)
def test_classical_action_mod_neg(bitsize, prime):
b = MontgomeryModNeg(bitsize, prime)
cb = b.decompose_bloq()
valid_range = range(prime)
for x in valid_range:
assert b.call_classically(x=x) == cb.call_classically(x=x) == ((-x) % prime,)
10 changes: 0 additions & 10 deletions qualtran/bloqs/mod_arithmetic/_shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,6 @@
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator


@frozen
class ModAdd(Bloq):
n: int
mod: int

@cached_property
def signature(self) -> 'Signature':
return Signature([Register('x', QUInt(self.n)), Register('y', QUInt(self.n))])


@frozen
class ModSub(Bloq):
n: int
Expand Down
15 changes: 12 additions & 3 deletions qualtran/bloqs/mod_arithmetic/mod_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@
from qualtran.drawing import Circle, Text, TextBox, WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics import is_symbolic

if TYPE_CHECKING:
from qualtran import BloqBuilder
from qualtran.symbolics import SymbolicInt


@frozen
Expand All @@ -65,19 +67,26 @@ class ModAdd(Bloq):
Construction from Figure 6a and cost summary in Figure 8.
"""

bitsize: int
mod: int
bitsize: 'SymbolicInt'
mod: 'SymbolicInt'

@cached_property
def signature(self) -> 'Signature':
return Signature([Register('x', QUInt(self.bitsize)), Register('y', QUInt(self.bitsize))])
return Signature(
[
Register('x', QMontgomeryUInt(self.bitsize)),
Register('y', QMontgomeryUInt(self.bitsize)),
]
)

def on_classical_vals(
self, x: 'ClassicalValT', y: 'ClassicalValT'
) -> Dict[str, 'ClassicalValT']:
return {'x': x, 'y': (x + y) % self.mod}

def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']:
if is_symbolic(self.bitsize):
raise NotImplementedError(f'symbolic decomposition is not supported for {self}')
# Allocate ancilla bits for use in addition.
junk_bit = bb.allocate(n=1)
sign = bb.allocate(n=1)
Expand Down
21 changes: 21 additions & 0 deletions qualtran/bloqs/mod_arithmetic/mod_addition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import pytest
import sympy

from qualtran import QUInt
from qualtran.bloqs.arithmetic import Add
Expand Down Expand Up @@ -69,3 +70,23 @@ def test_ctrl_mod_add_k():
def test_mod_add_valid_decomp(bitsize, p):
bloq = ModAdd(bitsize=bitsize, mod=p)
assert_valid_bloq_decomposition(bloq)


@pytest.mark.parametrize('bitsize', list(range(1, 6)) + [sympy.Symbol('n')])
def test_mod_add_symbolic_cost(bitsize):
tcomplexity = ModAdd(bitsize, sympy.Symbol('p')).t_complexity()
assert tcomplexity.t == 16 * bitsize - 4 # 4n toffoli
assert tcomplexity.rotations == 0


@pytest.mark.parametrize(
['prime', 'bitsize'],
[(p, bitsize) for p in [11, 13, 31] for bitsize in range(1 + p.bit_length(), 8)],
)
def test_classical_action_mod_add(prime, bitsize):
b = ModAdd(bitsize=bitsize, mod=prime)
cb = b.decompose_bloq()
valid_range = range(prime)
for x in valid_range:
for y in valid_range:
assert b.call_classically(x=x, y=y) == cb.call_classically(x=x, y=y)

0 comments on commit a4365b2

Please sign in to comment.