From a4365b26a01956bb0341f12e673176e3a66ab33b Mon Sep 17 00:00:00 2001 From: Noureldin Date: Fri, 9 Aug 2024 12:42:22 -0700 Subject: [PATCH] Implement bit conversion for QMontgomeryUInt, fix dtype, symbolic decomposition and classical action of modular gates, and add classical action tests (#1264) * Implement bit conversion for QMontgomeryUInt, fix dtype, symbolic decomposition and classical action of modular gates, and add classical action tests * address comments --- qualtran/_infra/data_types.py | 5 ++- qualtran/_infra/data_types_test.py | 7 +++ qualtran/bloqs/arithmetic/comparison.py | 24 +++++++++- qualtran/bloqs/arithmetic/comparison_test.py | 2 + qualtran/bloqs/factoring/mod_sub.py | 40 ++++++++++++++--- qualtran/bloqs/factoring/mod_sub_test.py | 45 ++++++++++++++++++- qualtran/bloqs/mod_arithmetic/_shims.py | 10 ----- qualtran/bloqs/mod_arithmetic/mod_addition.py | 15 +++++-- .../bloqs/mod_arithmetic/mod_addition_test.py | 21 +++++++++ 9 files changed, 146 insertions(+), 23 deletions(-) diff --git a/qualtran/_infra/data_types.py b/qualtran/_infra/data_types.py index ad2649ead..824ebc4de 100644 --- a/qualtran/_infra/data_types.py +++ b/qualtran/_infra/data_types.py @@ -787,10 +787,11 @@ def get_classical_domain(self) -> Iterable[Any]: return range(2**self.bitsize) def to_bits(self, x: int) -> List[int]: - raise NotImplementedError(f"to_bits not implemented for {self}") + self.assert_valid_classical_val(x) + return [int(x) for x in f'{int(x):0{self.bitsize}b}'] def from_bits(self, bits: Sequence[int]) -> int: - raise NotImplementedError(f"from_bits not implemented for {self}") + return int("".join(str(x) for x in bits), 2) def assert_valid_classical_val(self, val: int, debug_str: str = 'val'): if not isinstance(val, (int, np.integer)): diff --git a/qualtran/_infra/data_types_test.py b/qualtran/_infra/data_types_test.py index 2ed327ac9..a59ba43ce 100644 --- a/qualtran/_infra/data_types_test.py +++ b/qualtran/_infra/data_types_test.py @@ -469,3 +469,10 @@ def test_fixed_point(val, width, signed): _ = QFxp(width, width).to_fixed_width_int(-val) bits_from_int = QUInt(width).to_bits(QFxp(width, width).to_fixed_width_int(val)) assert bits == bits_from_int + + +@pytest.mark.parametrize('bitsize', range(1, 6)) +def test_montgomery_bit_conversion(bitsize): + dtype = QMontgomeryUInt(bitsize) + for v in range(1 << bitsize): + assert v == dtype.from_bits(dtype.to_bits(v)) diff --git a/qualtran/bloqs/arithmetic/comparison.py b/qualtran/bloqs/arithmetic/comparison.py index 179ebc1a1..c4839839c 100644 --- a/qualtran/bloqs/arithmetic/comparison.py +++ b/qualtran/bloqs/arithmetic/comparison.py @@ -726,7 +726,7 @@ class LinearDepthGreaterThan(Bloq): [Improved quantum circuits for elliptic curve discrete logarithms](https://arxiv.org/abs/2306.08585). """ - bitsize: int + bitsize: 'SymbolicInt' signed: bool = False @property @@ -749,6 +749,9 @@ def on_classical_vals( def build_composite_bloq( self, bb: 'BloqBuilder', a: Soquet, b: Soquet, target: SoquetT ) -> Dict[str, 'SoquetT']: + if not isinstance(self.bitsize, int): + raise NotImplementedError(f'symbolic decomposition is not supported for {self}') + # Base Case: Comparing two qubits. # Signed doesn't matter because we can't represent signed integers with 1 qubit. if self.bitsize == 1: @@ -875,6 +878,25 @@ def wire_symbol( return TextBox('t⨁(a>b)') raise ValueError(f'Unknown register name {reg.name}') + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: + if self.bitsize == 1: + return {(MultiControlX(cvs=(1, 0)), 1)} + + if self.signed: + return { + (CNOT(), 6 * self.bitsize - 7), + (XGate(), 2 * self.bitsize + 2), + (And(), self.bitsize - 1), + (And(uncompute=True), self.bitsize - 1), + } + + return { + (CNOT(), 6 * self.bitsize - 1), + (XGate(), 2 * self.bitsize + 4), + (And(), self.bitsize), + (And(uncompute=True), self.bitsize), + } + @frozen class GreaterThanConstant(Bloq): diff --git a/qualtran/bloqs/arithmetic/comparison_test.py b/qualtran/bloqs/arithmetic/comparison_test.py index fa0076ccd..cd54f4efd 100644 --- a/qualtran/bloqs/arithmetic/comparison_test.py +++ b/qualtran/bloqs/arithmetic/comparison_test.py @@ -37,6 +37,7 @@ ) from qualtran.cirq_interop.t_complexity_protocol import t_complexity, TComplexity from qualtran.cirq_interop.testing import assert_circuit_inp_out_cirqsim +from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join def test_greater_than(bloq_autotester): @@ -240,6 +241,7 @@ def test_greater_than_manual(): def test_linear_depth_greater_than_decomp(bitsize, signed): bloq = LinearDepthGreaterThan(bitsize=bitsize, signed=signed) qlt_testing.assert_valid_bloq_decomposition(bloq) + qlt_testing.assert_equivalent_bloq_counts(bloq, [ignore_alloc_free, ignore_split_join]) # TODO: write tests for signed integer comparison diff --git a/qualtran/bloqs/factoring/mod_sub.py b/qualtran/bloqs/factoring/mod_sub.py index 486d98267..2959545e9 100644 --- a/qualtran/bloqs/factoring/mod_sub.py +++ b/qualtran/bloqs/factoring/mod_sub.py @@ -13,7 +13,7 @@ # limitations under the License. from functools import cached_property -from typing import Dict, TYPE_CHECKING +from typing import Dict, Set, TYPE_CHECKING from attrs import frozen @@ -25,7 +25,9 @@ if TYPE_CHECKING: from qualtran import BloqBuilder + from qualtran.resource_counting import BloqCountT, SympySymbolAllocator from qualtran.simulation.classical_sim import ClassicalValT + from qualtran.symbolics import SymbolicInt @frozen @@ -49,8 +51,8 @@ class MontgomeryModSub(Bloq): Fig 6c and 8 """ - bitsize: int - p: int + bitsize: 'SymbolicInt' + p: 'SymbolicInt' @cached_property def signature(self) -> 'Signature': @@ -64,9 +66,13 @@ def signature(self) -> 'Signature': def on_classical_vals( self, x: 'ClassicalValT', y: 'ClassicalValT' ) -> Dict[str, 'ClassicalValT']: - return {'x': x, 'y': (y - x) % self.p} + if x < self.p and y < self.p: + return {'x': x, 'y': (y - x) % self.p} + return {'x': x, 'y': y} def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']: + if not isinstance(self.bitsize, int): + raise NotImplementedError(f'symbolic decomposition is not supported for {self}') # Bit flip all qubits in register x. x_split = bb.split(x) for i in range(self.bitsize): @@ -94,6 +100,14 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[ def pretty_name(self) -> str: return f'y = y - x mod {self.p}' + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: + return { + (XGate(), 2 * self.bitsize), + (AddK(self.bitsize, self.p + 1, signed=False), 1), + (ModAdd(self.bitsize, self.p), 1), + (AddK(self.bitsize, self.p + 1, signed=False).adjoint(), 1), + } + @frozen class MontgomeryModNeg(Bloq): @@ -114,8 +128,8 @@ class MontgomeryModNeg(Bloq): Fig 6b and 8 """ - bitsize: int - p: int + bitsize: 'SymbolicInt' + p: 'SymbolicInt' @cached_property def signature(self) -> 'Signature': @@ -125,6 +139,8 @@ def on_classical_vals(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT']: return {'x': (-1 * x) % self.p} def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet) -> Dict[str, 'SoquetT']: + if not isinstance(self.bitsize, int): + raise NotImplementedError(f'symbolic decomposition is not supported for {self}') # Initialize an ancilla qubit to |1>. ctrl = bb.allocate(n=1) ctrl = bb.add(XGate(), q=ctrl) @@ -165,3 +181,15 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet) -> Dict[str, 'Soque def pretty_name(self) -> str: return f'x = -x mod {self.p}' + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']: + if not isinstance(self.bitsize, int): + raise NotImplementedError(f'symbolic call graph is not supported for {self}') + + # TODO: support symbolic cost + return { + (XGate(), 2), + (MultiControlX(cvs=[0] * self.bitsize), 2), + (CNOT(), self.bitsize), + (AddK(bitsize=self.bitsize, k=self.p + 1, cvs=(1,), signed=False), 1), + } diff --git a/qualtran/bloqs/factoring/mod_sub_test.py b/qualtran/bloqs/factoring/mod_sub_test.py index 20bb2c5bb..50285006b 100644 --- a/qualtran/bloqs/factoring/mod_sub_test.py +++ b/qualtran/bloqs/factoring/mod_sub_test.py @@ -13,18 +13,61 @@ # limitations under the License. import pytest +import sympy from qualtran.bloqs.factoring.mod_sub import MontgomeryModNeg, MontgomeryModSub -from qualtran.testing import assert_valid_bloq_decomposition +from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join +from qualtran.testing import assert_equivalent_bloq_counts, assert_valid_bloq_decomposition @pytest.mark.parametrize('bitsize,p', [(1, 1), (2, 3), (5, 8)]) def test_montgomery_mod_neg_decomp(bitsize, p): bloq = MontgomeryModNeg(bitsize=bitsize, p=p) assert_valid_bloq_decomposition(bloq) + assert_equivalent_bloq_counts(bloq, [ignore_alloc_free, ignore_split_join]) @pytest.mark.parametrize('bitsize,p', [(1, 1), (2, 3), (5, 8)]) def test_montgomery_mod_sub_decomp(bitsize, p): bloq = MontgomeryModSub(bitsize=bitsize, p=p) assert_valid_bloq_decomposition(bloq) + assert_equivalent_bloq_counts(bloq, [ignore_alloc_free, ignore_split_join]) + + +@pytest.mark.parametrize('bitsize', [*range(1, 5), sympy.Symbol('n')]) +def test_montgomery_sub_complexity(bitsize): + tcomplexity = MontgomeryModSub(bitsize, sympy.Symbol('p')).t_complexity() + assert tcomplexity.t == 24 * bitsize - 12 # 6n toffoli + assert tcomplexity.rotations == 0 + + +@pytest.mark.parametrize('bitsize', range(1, 5)) +def test_montgomery_neg_complexity(bitsize): + tcomplexity = MontgomeryModNeg(bitsize, sympy.Symbol('p')).t_complexity() + assert tcomplexity.t == 12 * bitsize - 12 # 3n toffoli + assert tcomplexity.rotations == 0 + + +@pytest.mark.parametrize( + ['prime', 'bitsize'], + [(p, bitsize) for p in [11, 13, 31] for bitsize in range(1 + p.bit_length(), 8)], +) +def test_classical_action_montgomery_sub(bitsize, prime): + b = MontgomeryModSub(bitsize, prime) + cb = b.decompose_bloq() + valid_range = range(prime) + for x in valid_range: + for y in valid_range: + assert b.call_classically(x=x, y=y) == cb.call_classically(x=x, y=y) + + +@pytest.mark.parametrize( + ['prime', 'bitsize'], + [(p, bitsize) for p in [11, 13, 31] for bitsize in range(1 + p.bit_length(), 8)], +) +def test_classical_action_mod_neg(bitsize, prime): + b = MontgomeryModNeg(bitsize, prime) + cb = b.decompose_bloq() + valid_range = range(prime) + for x in valid_range: + assert b.call_classically(x=x) == cb.call_classically(x=x) == ((-x) % prime,) diff --git a/qualtran/bloqs/mod_arithmetic/_shims.py b/qualtran/bloqs/mod_arithmetic/_shims.py index 6b831963e..b317d1230 100644 --- a/qualtran/bloqs/mod_arithmetic/_shims.py +++ b/qualtran/bloqs/mod_arithmetic/_shims.py @@ -37,16 +37,6 @@ from qualtran.resource_counting import BloqCountT, SympySymbolAllocator -@frozen -class ModAdd(Bloq): - n: int - mod: int - - @cached_property - def signature(self) -> 'Signature': - return Signature([Register('x', QUInt(self.n)), Register('y', QUInt(self.n))]) - - @frozen class ModSub(Bloq): n: int diff --git a/qualtran/bloqs/mod_arithmetic/mod_addition.py b/qualtran/bloqs/mod_arithmetic/mod_addition.py index 6a53ad06c..0866048c0 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_addition.py +++ b/qualtran/bloqs/mod_arithmetic/mod_addition.py @@ -38,9 +38,11 @@ from qualtran.drawing import Circle, Text, TextBox, WireSymbol from qualtran.resource_counting import BloqCountT, SympySymbolAllocator from qualtran.simulation.classical_sim import ClassicalValT +from qualtran.symbolics import is_symbolic if TYPE_CHECKING: from qualtran import BloqBuilder + from qualtran.symbolics import SymbolicInt @frozen @@ -65,12 +67,17 @@ class ModAdd(Bloq): Construction from Figure 6a and cost summary in Figure 8. """ - bitsize: int - mod: int + bitsize: 'SymbolicInt' + mod: 'SymbolicInt' @cached_property def signature(self) -> 'Signature': - return Signature([Register('x', QUInt(self.bitsize)), Register('y', QUInt(self.bitsize))]) + return Signature( + [ + Register('x', QMontgomeryUInt(self.bitsize)), + Register('y', QMontgomeryUInt(self.bitsize)), + ] + ) def on_classical_vals( self, x: 'ClassicalValT', y: 'ClassicalValT' @@ -78,6 +85,8 @@ def on_classical_vals( return {'x': x, 'y': (x + y) % self.mod} def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']: + if is_symbolic(self.bitsize): + raise NotImplementedError(f'symbolic decomposition is not supported for {self}') # Allocate ancilla bits for use in addition. junk_bit = bb.allocate(n=1) sign = bb.allocate(n=1) diff --git a/qualtran/bloqs/mod_arithmetic/mod_addition_test.py b/qualtran/bloqs/mod_arithmetic/mod_addition_test.py index 7a227c901..2c0a4e596 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_addition_test.py +++ b/qualtran/bloqs/mod_arithmetic/mod_addition_test.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +import sympy from qualtran import QUInt from qualtran.bloqs.arithmetic import Add @@ -69,3 +70,23 @@ def test_ctrl_mod_add_k(): def test_mod_add_valid_decomp(bitsize, p): bloq = ModAdd(bitsize=bitsize, mod=p) assert_valid_bloq_decomposition(bloq) + + +@pytest.mark.parametrize('bitsize', list(range(1, 6)) + [sympy.Symbol('n')]) +def test_mod_add_symbolic_cost(bitsize): + tcomplexity = ModAdd(bitsize, sympy.Symbol('p')).t_complexity() + assert tcomplexity.t == 16 * bitsize - 4 # 4n toffoli + assert tcomplexity.rotations == 0 + + +@pytest.mark.parametrize( + ['prime', 'bitsize'], + [(p, bitsize) for p in [11, 13, 31] for bitsize in range(1 + p.bit_length(), 8)], +) +def test_classical_action_mod_add(prime, bitsize): + b = ModAdd(bitsize=bitsize, mod=prime) + cb = b.decompose_bloq() + valid_range = range(prime) + for x in valid_range: + for y in valid_range: + assert b.call_classically(x=x, y=y) == cb.call_classically(x=x, y=y)