Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate split dtype to join #817

Merged
merged 34 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
8d31061
Ensure dtypes get used when joining typed registers.
fdmalone Mar 21, 2024
e97cff3
Fix tests.
fdmalone Mar 22, 2024
bd8027e
Update for black?
fdmalone Mar 22, 2024
b15ce91
Ensure dtypes get used when joining typed registers.
fdmalone Mar 21, 2024
a937349
Fix tests.
fdmalone Mar 22, 2024
dae08f0
Update for black?
fdmalone Mar 22, 2024
f3eed5e
Don't rely on soq reg.
fdmalone Mar 22, 2024
776ec28
WIP testing.
fdmalone Mar 22, 2024
60183cd
Add bloq autotesting + update report card.
fdmalone Mar 22, 2024
b4e28ea
Merge branch 'propagate_split_dtype_to_join' of github.com:fdmalone/Q…
fdmalone Mar 22, 2024
019812b
Fix interop.
fdmalone Mar 24, 2024
8b2b1b5
Better typed show_bloqs.
fdmalone Mar 24, 2024
08b514b
Fixes.
fdmalone Mar 25, 2024
f01af84
Fix lint errors.
fdmalone Mar 25, 2024
7b42be7
Merge branch 'main' into propagate_split_dtype_to_join
fdmalone Mar 25, 2024
b49171e
Fix formatting.
fdmalone Mar 25, 2024
5d1adec
Fix test failures.
fdmalone Mar 25, 2024
da2f1b7
Fix formatting.
fdmalone Mar 25, 2024
819f9df
Merge branch 'main' into propagate_split_dtype_to_join
fdmalone Mar 29, 2024
6d9ab70
Merge branch 'main' into propagate_split_dtype_to_join
fdmalone Apr 5, 2024
1080ea1
Clean cirq_bloq_interop typing.
fdmalone Apr 5, 2024
7d10c75
Remove prints.
fdmalone Apr 5, 2024
281dcf8
Safer casting.
fdmalone Apr 5, 2024
b620f4f
Safer checking.
fdmalone Apr 6, 2024
c52390b
Remove print.
fdmalone Apr 6, 2024
2de1b34
Format / lint.
fdmalone Apr 6, 2024
2c79572
Only cast Fxp.
fdmalone Apr 6, 2024
f78ce7a
Custom hash for _QReg for single-qubit lookup.
fdmalone Apr 8, 2024
31519a8
Address review comments.
fdmalone Apr 10, 2024
bbf40c3
Add assertion.
fdmalone Apr 10, 2024
a8cc8be
Move location of assert.
fdmalone Apr 10, 2024
5259aa9
Merge branch 'main' into propagate_split_dtype_to_join
fdmalone Apr 10, 2024
86d4429
Move assert back.
fdmalone Apr 10, 2024
5ba5cb7
Merge branch 'propagate_split_dtype_to_join' of github.com:fdmalone/Q…
fdmalone Apr 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion dev_tools/qualtran_dev_tools/bloq_report_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
check_bloq_example_decompose,
check_bloq_example_make,
check_bloq_example_serialize,
check_connections_preserve_preserves_types,
check_equivalent_bloq_example_counts,
)

Expand Down Expand Up @@ -66,7 +67,7 @@ def bloq_classes_with_no_examples(


IDCOLS = ['package', 'bloq_cls', 'name']
CHECKCOLS = ['make', 'decomp', 'counts', 'serialize']
CHECKCOLS = ['make', 'decomp', 'counts', 'serialize', 'typing']


def record_for_class_with_no_examples(k: Type[Bloq]) -> Dict[str, Any]:
Expand All @@ -78,6 +79,7 @@ def record_for_class_with_no_examples(k: Type[Bloq]) -> Dict[str, Any]:
'decomp': BloqCheckResult.MISSING,
'counts': BloqCheckResult.MISSING,
'serialize': BloqCheckResult.MISSING,
'typing': BloqCheckResult.MISSING,
}


Expand All @@ -90,6 +92,7 @@ def record_for_bloq_example(be: BloqExample) -> Dict[str, Any]:
'decomp': check_bloq_example_decompose(be)[0],
'counts': check_equivalent_bloq_example_counts(be)[0],
'serialize': check_bloq_example_serialize(be)[0],
'typing': check_connections_preserve_preserves_types(be)[0],
}


Expand Down
42 changes: 32 additions & 10 deletions qualtran/_infra/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"""

import abc
from enum import Enum
from typing import Any, Iterable, List, Sequence, Union

import attrs
Expand Down Expand Up @@ -549,36 +550,57 @@ def assert_valid_classical_val_array(self, val_array: NDArray[int], debug_str: s
QAnyUInt = (QUInt, BoundedQUInt, QMontgomeryUInt)


class QDTypeCheckingSeverity(Enum):
"""The level of type checking to enforce"""

LOOSE = 0
"""Allow most type conversions between QAnyInt, QFxp and QAny."""

ANY = 1
"""Disallow numeric type conversions but allow QAny and single bit conversion."""

STRICT = 2
"""Strictly enforce type checking between registers. Only single bit conversions are allowed."""


def _check_uint_fxp_consistent(a: QUInt, b: QFxp) -> bool:
"""A uint is consistent with a whole or totally fractional unsigned QFxp."""
"""A uint / qfxp is consistent with a whole or totally fractional unsigned QFxp."""
if b.signed:
return False
return a.num_qubits == b.num_qubits and (b.num_frac == 0 or b.num_int == 0)


def check_dtypes_consistent(dtype_a: QDType, dtype_b: QDType, strict: bool = False) -> bool:
def check_dtypes_consistent(
dtype_a: QDType,
dtype_b: QDType,
type_checking_severity: QDTypeCheckingSeverity = QDTypeCheckingSeverity.LOOSE,
) -> bool:
"""Check if two types are consistent given our current definition on consistent types.

Args:
dtype_a: The dtype to check against the reference.
dtype_b: The reference dtype.
strict: Whether to compare types literally
type_checking_severity: Severity of type checking to perform.

Returns:
True if the types are consistent.
"""
if dtype_a == dtype_b:
same_dtypes = dtype_a == dtype_b
same_n_qubits = dtype_a.num_qubits == dtype_b.num_qubits
if same_dtypes:
# Same types are always ok.
return True
elif dtype_a.num_qubits == 1 and same_n_qubits:
# Single qubit types are ok.
return True
if strict:
if type_checking_severity == QDTypeCheckingSeverity.STRICT:
return False
same_n_qubits = dtype_a.num_qubits == dtype_b.num_qubits
if isinstance(dtype_a, QAny) or isinstance(dtype_b, QAny):
# QAny -> any dtype and any dtype -> QAny
return same_n_qubits
elif dtype_a.num_qubits == 1 and same_n_qubits:
# Single qubit types are ok.
return True
elif isinstance(dtype_a, QAnyInt) and isinstance(dtype_b, QAnyInt):
if type_checking_severity == QDTypeCheckingSeverity.ANY:
return False
if isinstance(dtype_a, QAnyInt) and isinstance(dtype_b, QAnyInt):
# A subset of the integers should be freely interchangeable.
return same_n_qubits
elif isinstance(dtype_a, QAnyUInt) and isinstance(dtype_b, QFxp):
Expand Down
14 changes: 8 additions & 6 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ class GreaterThan(Bloq):
b: n-bit-sized input registers.
target: A single bit output register to store the result of A > B.
"""

a_bitsize: int
b_bitsize: int

Expand Down Expand Up @@ -645,6 +646,7 @@ class LinearDepthGreaterThan(Bloq):

[Improved quantum circuits for elliptic curve discrete logarithms](https://arxiv.org/abs/2306.08585).
"""

bitsize: int
signed: bool

Expand Down Expand Up @@ -690,11 +692,11 @@ def build_composite_bloq(
if not self.signed:
a_sign = bb.allocate(n=1)
a_split = bb.split(a)
a = bb.join(np.concatenate([[a_sign], a_split]))
a = bb.join(np.concatenate([[a_sign], a_split]), dtype=QUInt(self.bitsize + 1))

b_sign = bb.allocate(n=1)
b_split = bb.split(b)
b = bb.join(np.concatenate([[b_sign], b_split]))
b = bb.join(np.concatenate([[b_sign], b_split]), dtype=QUInt(self.bitsize + 1))

# Create variable true_bitsize to account for sign bit in bloq construction.
true_bitsize = self.bitsize if self.signed else (self.bitsize + 1)
Expand Down Expand Up @@ -764,19 +766,19 @@ def build_composite_bloq(
for i in range(true_bitsize):
b_split[i] = bb.add(XGate(), q=b_split[i])

a = bb.join(a_split)
b = bb.join(b_split)
a = bb.join(a_split, dtype=QUInt(true_bitsize))
b = bb.join(b_split, dtype=QUInt(true_bitsize))

# If the input registers were unsigned we free the ancilla sign bits.
if not self.signed:
a_split = bb.split(a)
a_sign = a_split[0]
a = bb.join(a_split[1:])
a = bb.join(a_split[1:], dtype=QUInt(self.bitsize))
bb.free(a_sign)

b_split = bb.split(b)
b_sign = b_split[0]
b = bb.join(b_split[1:])
b = bb.join(b_split[1:], dtype=QUInt(self.bitsize))
bb.free(b_sign)

# Return the output registers.
Expand Down
12 changes: 8 additions & 4 deletions qualtran/bloqs/factoring/mod_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,20 @@ def build_composite_bloq(
# constant subtraction circuit.
x_split = bb.split(x)
y_split = bb.split(y)
x = bb.join(np.concatenate([[junk_bit], x_split]))
y = bb.join(np.concatenate([[sign], y_split]))
x = bb.join(
np.concatenate([[junk_bit], x_split]), dtype=QMontgomeryUInt(bitsize=self.bitsize + 1)
)
y = bb.join(
np.concatenate([[sign], y_split]), dtype=QMontgomeryUInt(bitsize=self.bitsize + 1)
)

# Perform in-place addition on quantum register y.
x, y = bb.add(Add(QMontgomeryUInt(bitsize=self.bitsize + 1)), a=x, b=y)

# Temporary solution to equalize the bitlength of the x and y registers for Add().
x_split = bb.split(x)
junk_bit = x_split[0]
x = bb.join(x_split[1:])
x = bb.join(x_split[1:], dtype=QMontgomeryUInt(bitsize=self.bitsize))

# Add constant -p to the y register.
y = bb.add(
Expand All @@ -230,7 +234,7 @@ def build_composite_bloq(
# negative.
y_split = bb.split(y)
sign = y_split[0]
y = bb.join(y_split[1:])
y = bb.join(y_split[1:], dtype=QMontgomeryUInt(bitsize=self.bitsize))

sign_split = bb.split(sign)
sign_split, y = bb.add(
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/factoring/mod_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'SoquetT') -> Dict[s
exponent[j], x = bb.add(self._CtrlModMul(k=base), ctrl=exponent[j], x=x)
base = base * base % self.mod

return {'exponent': bb.join(exponent), 'x': x}
return {'exponent': bb.join(exponent, dtype=QUInt(self.exp_bitsize)), 'x': x}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
k = ssa.new_symbol('k')
Expand Down
10 changes: 7 additions & 3 deletions qualtran/bloqs/factoring/mod_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: SoquetT) -> Dict[str, 'Soqu
# Convert x to an n + 2-bit integer by attaching two |0⟩ qubits as the least and most
# significant bits.
x_split = bb.split(x)
x = bb.join(np.concatenate([[sign], x_split, [lower_bit]]))
x = bb.join(
np.concatenate([[sign], x_split, [lower_bit]]), dtype=QMontgomeryUInt(self.bitsize + 2)
)

# Add constant -p to the x register.
x = bb.add(
Expand All @@ -167,7 +169,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: SoquetT) -> Dict[str, 'Soqu
# addition circuit.
x_split = bb.split(x)
sign = x_split[0]
x = bb.join(x_split[1:])
x = bb.join(x_split[1:], dtype=QMontgomeryUInt(self.bitsize + 1))

# Add constant p to the x register if the result of the last modular reduction is negative.
sign_split = bb.split(sign)
Expand All @@ -187,7 +189,9 @@ def build_composite_bloq(self, bb: 'BloqBuilder', x: SoquetT) -> Dict[str, 'Soqu
lower_bit = bb.add(XGate(), q=lower_bit)

free_bit = x_split[0]
x = bb.join(np.concatenate([x_split[1:-1], [lower_bit]]))
x = bb.join(
np.concatenate([x_split[1:-1], [lower_bit]]), dtype=QMontgomeryUInt(self.bitsize)
)

# Free the ancilla bits.
bb.free(free_bit)
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/factoring/mod_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def build_composite_bloq(
x_split = bb.split(x)
for i in range(self.bitsize):
x_split[i] = bb.add(XGate(), q=x_split[i])
x = bb.join(x_split)
x = bb.join(x_split, dtype=QMontgomeryUInt(self.bitsize))

# Add constant p+1 to the x register.
x = bb.add(SimpleAddConstant(bitsize=self.bitsize, k=self.p + 1, signed=False, cvs=()), x=x)
Expand All @@ -88,7 +88,7 @@ def build_composite_bloq(
x_split = bb.split(x)
for i in range(self.bitsize):
x_split[i] = bb.add(XGate(), q=x_split[i])
x = bb.join(x_split)
x = bb.join(x_split, dtype=QMontgomeryUInt(self.bitsize))

# Return the output registers.
return {'x': x, 'y': y}
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/rotations/hamming_weight_phasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str
ZPowGate(exponent=(2**i) * self.exponent, eps=self.eps / len(out)),
q=out[-(i + 1)],
)
out = bb.join(out)
out = bb.join(out, dtype=QUInt(self.bitsize.bit_length()))
soqs['x'] = bb.add(
HammingWeightCompute(self.bitsize).adjoint(), x=soqs['x'], junk=junk, out=out
)
Expand Down
3 changes: 2 additions & 1 deletion qualtran/bloqs/rotations/quantum_variable_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class QvrZPow(QvrInterface):
floating point number.
eps: Precision for synthesizing the phases.
"""

cost_reg: Register
gamma: Union[float, sympy.Expr] = 1.0
eps: Union[float, sympy.Expr] = 1e-9
Expand Down Expand Up @@ -165,7 +166,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str
ZPowGate(exponent=(2**power_of_two) * self.gamma * 2, eps=self.eps / len(out)),
q=out[-(i + 1)],
)
return {self.cost_reg.name: bb.join(out)}
return {self.cost_reg.name: bb.join(out, self.cost_reg.dtype)}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
zpow = ZPowGate(exponent=self.gamma, eps=self.eps / self.cost_dtype.bitsize)
Expand Down
3 changes: 2 additions & 1 deletion qualtran/bloqs/swap_network/swap_with_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def build_composite_bloq(
# 0; we swap all values in the right subtree with all values in the left subtree. This
# takes (N / (2 ** (j + 1)) swaps at level `j`.
# Therefore, in total, we need $\sum_{j=0}^{logN-1} \frac{N}{2 ^ {j + 1}}$ controlled swaps.
selection_dtype = selection.reg.dtype
selection = bb.split(selection)
for j in range(self.selection_bitsize):
for i in range(0, self.n_target_registers - 2**j, 2 ** (j + 1)):
Expand All @@ -104,7 +105,7 @@ def build_composite_bloq(
cswap_n, ctrl=selection[sel_i], x=targets[i], y=targets[i + 2**j]
)

return {'selection': bb.join(selection), 'targets': targets}
return {'selection': bb.join(selection, dtype=selection_dtype), 'targets': targets}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
num_swaps = np.floor(
Expand Down
6 changes: 0 additions & 6 deletions qualtran/bloqs/swap_network/swap_with_zero_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@
random.seed(12345)


def _make_SwapWithZero():
from qualtran.bloqs.swap_network import SwapWithZero

return SwapWithZero(selection_bitsize=3, target_bitsize=64, n_target_registers=5)
fdmalone marked this conversation as resolved.
Show resolved Hide resolved


def test_swap_with_zero_decomp():
swz = SwapWithZero(selection_bitsize=3, target_bitsize=64, n_target_registers=5)
assert_valid_bloq_decomposition(swz)
Expand Down
5 changes: 4 additions & 1 deletion qualtran/bloqs/util_bloqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ class Cast(Bloq):
)

def __attrs_post_init__(self):
if isinstance(self.inp_dtype.bitsize, int):
if isinstance(self.inp_dtype.num_qubits, int):
if self.inp_dtype.num_qubits != self.out_dtype.num_qubits:
raise ValueError("Casting only permitted between same sized registers.")

Expand Down Expand Up @@ -470,6 +470,9 @@ def on_classical_vals(self, reg: int) -> Dict[str, 'ClassicalValT']:
# TODO: Actually cast the values https://github.com/quantumlib/Qualtran/issues/734
return {'reg': reg}

def as_cirq_op(self, qubit_manager, reg: 'CirqQuregT') -> Tuple[None, Dict[str, 'CirqQuregT']]:
return None, {'reg': reg}

def _t_complexity_(self) -> 'TComplexity':
return TComplexity()

Expand Down
7 changes: 5 additions & 2 deletions qualtran/cirq_interop/_bloq_to_cirq.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _bloq_to_cirq_op(
soq = cxn.left
assert soq.reg.name in out_quregs, f"{soq=} should exist in {out_quregs=}."
if soq.reg.side == Side.RIGHT:
qvar_to_qreg[soq] = _QReg(out_quregs[soq.reg.name][soq.idx])
qvar_to_qreg[soq] = _QReg(out_quregs[soq.reg.name][soq.idx], dtype=soq.reg.dtype)
return op


Expand All @@ -257,7 +257,10 @@ def _cbloq_to_cirq_circuit(
circuit: The cirq.FrozenCircuit version of this composite bloq.
cirq_quregs: The output mapping from right register names to Cirq qubit arrays.
"""
cirq_quregs = {k: np.apply_along_axis(_QReg, -1, v) for k, v in cirq_quregs.items()}
cirq_quregs = {
k: np.apply_along_axis(_QReg, -1, *(v, signature.get_left(k).dtype))
for k, v in cirq_quregs.items()
}
qvar_to_qreg: Dict[Soquet, _QReg] = {
Soquet(LeftDangle, idx=idx, reg=reg): cirq_quregs[reg.name][idx]
for reg in signature.lefts()
Expand Down
Loading
Loading