Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make static type checking work better with symbolics #1156

Merged
merged 2 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,11 +962,9 @@ def is_symbolic(self):

@property
def bits_k(self) -> Union[tuple[int, ...], HasLength]:
if self.is_symbolic():
if is_symbolic(self.bitsize) or is_symbolic(self.val):
return HasLength(self.bitsize)

assert not isinstance(self.bitsize, sympy.Expr)
assert not isinstance(self.val, sympy.Expr)
return tuple(QUInt(self.bitsize).to_bits(self.val))

def build_composite_bloq(
Expand Down
6 changes: 2 additions & 4 deletions qualtran/bloqs/arithmetic/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,8 @@ def is_symbolic(self):
return is_symbolic(self.N, self.cycle)

def build_composite_bloq(self, bb: 'BloqBuilder', x: 'SoquetT') -> dict[str, 'SoquetT']:
if self.is_symbolic():
if is_symbolic(self.cycle):
raise DecomposeTypeError(f"cannot decompose symbolic {self}")
assert not isinstance(self.cycle, Shaped)

a: 'SoquetT' = bb.allocate(dtype=QBit())

Expand Down Expand Up @@ -253,10 +252,9 @@ def from_cycle_lengths(cls, N: SymbolicInt, cycle_lengths: tuple[SymbolicInt, ..
return cls(N, cycles)

def build_composite_bloq(self, bb: 'BloqBuilder', x: 'Soquet') -> dict[str, 'SoquetT']:
if self.is_symbolic():
if is_symbolic(self.cycles):
raise DecomposeTypeError(f"cannot decompose symbolic {self}")

assert not isinstance(self.cycles, Shaped)
for cycle in self.cycles:
x = bb.add(PermutationCycle(self.N, cycle), x=x)

Expand Down
13 changes: 5 additions & 8 deletions qualtran/bloqs/arithmetic/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,19 @@ def is_symbolic(self):
@cached_property
def num_comparisons(self) -> SymbolicInt:
"""Number of `Comparator` gates used in the decomposition"""
if is_symbolic(self.k, self.offset):
if is_symbolic(self.k) or is_symbolic(self.offset):
return self.k // 2 # upper bound

full = (self.k // (2 * self.offset)) * self.offset
rest = self.k % (2 * self.offset)
return full + max(rest - self.offset, 0)

def build_composite_bloq(self, bb: 'BloqBuilder', xs: 'SoquetT') -> Dict[str, 'SoquetT']:
if is_symbolic(self.k, self.offset):
if is_symbolic(self.k) or is_symbolic(self.offset):
raise DecomposeTypeError(f"Cannot decompose symbolic {self=}")

# make mypy happy
k = int(self.k)
offset = int(self.offset)
k = self.k
offset = self.offset
assert isinstance(xs, np.ndarray)

comp = Comparator(self.bitsize)
Expand Down Expand Up @@ -225,7 +224,6 @@ class BitonicMerge(Bloq):
def __attrs_post_init__(self):
k = self.half_length
if not is_symbolic(k):
assert not isinstance(k, sympy.Expr)
assert k >= 1, "length of input lists must be positive"
# TODO(#1090) support non-power-of-two input lengths
assert (k & (k - 1)) == 0, "length of input lists must be a power of 2"
Expand Down Expand Up @@ -260,7 +258,7 @@ def build_composite_bloq(
assert isinstance(xs, np.ndarray)
assert isinstance(ys, np.ndarray)

k = int(self.half_length)
k = self.half_length

first_round_junk = []
for i in range(k):
Expand Down Expand Up @@ -340,7 +338,6 @@ class BitonicSort(Bloq):
def __attrs_post_init__(self):
k = self.k
if not is_symbolic(k):
assert not isinstance(k, sympy.Expr)
assert k >= 1, f"length of input list must be positive, got {k=}"
# TODO(#1090) support non-power-of-two input lengths
assert (k & (k - 1)) == 0, f"length of input list must be a power of 2, got {k=}"
Expand Down
23 changes: 10 additions & 13 deletions qualtran/bloqs/block_encoding/linear_combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ def prepare(self) -> BlackBoxPrepare:
raise DecomposeTypeError(f"Cannot decompose symbolic {self=}")

alt, keep, mu = preprocess_probabilities_for_reversible_sampling(
unnormalized_probabilities=tuple(self.rescaled_lambd),
sub_bit_precision=cast(int, self.lambd_bits),
unnormalized_probabilities=tuple(self.rescaled_lambd), sub_bit_precision=self.lambd_bits
)
N = len(self.rescaled_lambd)

Expand Down Expand Up @@ -224,14 +223,14 @@ def select(self) -> BlackBoxSelect:
or is_symbolic(self.resource_bitsize)
):
raise DecomposeTypeError(f"Cannot decompose symbolic {self=}")
assert isinstance(self.be_ancilla_bitsize, int)
assert isinstance(self.be_resource_bitsize, int)
assert not is_symbolic(self.be_ancilla_bitsize)
assert not is_symbolic(self.be_resource_bitsize)

# make all bloqs have same ancilla and resource registers
bloqs = []
for be in self.signed_block_encodings:
assert isinstance(be.ancilla_bitsize, int)
assert isinstance(be.resource_bitsize, int)
assert not is_symbolic(be.ancilla_bitsize)
assert not is_symbolic(be.resource_bitsize)

partitions: List[Tuple[Register, List[Union[str, Unused]]]] = [
(Register("system", QAny(self.system_bitsize)), ["system"])
Expand Down Expand Up @@ -267,17 +266,15 @@ def build_composite_bloq(
or is_symbolic(self.resource_bitsize)
):
raise DecomposeTypeError(f"Cannot decompose symbolic {self=}")
assert isinstance(self.be_ancilla_bitsize, int)
assert isinstance(self.ancilla_bitsize, int)
assert isinstance(self.be_resource_bitsize, int)
assert isinstance(self.resource_bitsize, int)
assert not is_symbolic(self.be_ancilla_bitsize)
assert not is_symbolic(self.be_resource_bitsize)

# partition ancilla register
be_system_soqs: Dict[str, SoquetT] = {"system": system}
anc_regs = [Register("selection", QAny(self.prepare.selection_bitsize))]
if self.be_ancilla_bitsize > 0:
anc_regs.append(Register("ancilla", QAny(self.be_ancilla_bitsize)))
anc_part = Partition(cast(int, self.ancilla_bitsize), tuple(anc_regs))
anc_part = Partition(self.ancilla_bitsize, tuple(anc_regs))
anc_soqs = bb.add_d(anc_part, x=ancilla)
if self.be_ancilla_bitsize > 0:
be_system_soqs["ancilla"] = anc_soqs.pop("ancilla")
Expand All @@ -290,7 +287,7 @@ def build_composite_bloq(
res_regs.append(Register("resource", QAny(self.be_resource_bitsize)))
if self.prepare.junk_bitsize > 0:
res_regs.append(Register("prepare_junk", QAny(self.prepare.junk_bitsize)))
res_part = Partition(cast(int, self.resource_bitsize), tuple(res_regs))
res_part = Partition(self.resource_bitsize, tuple(res_regs))
res_soqs = bb.add_d(res_part, x=soqs.pop("resource"))
if self.be_resource_bitsize > 0:
be_system_soqs["resource"] = res_soqs.pop("resource")
Expand All @@ -303,7 +300,7 @@ def build_composite_bloq(
be_regs.append(Register("ancilla", QAny(self.be_ancilla_bitsize)))
if self.be_resource_bitsize > 0:
be_regs.append(Register("resource", QAny(self.be_resource_bitsize)))
be_part = Partition(cast(int, self.select.system_bitsize), tuple(be_regs))
be_part = Partition(self.select.system_bitsize, tuple(be_regs))

prepare_soqs = bb.add_d(self.prepare, **prepare_in_soqs)
select_out_soqs = bb.add_d(
Expand Down
9 changes: 2 additions & 7 deletions qualtran/bloqs/block_encoding/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,6 @@ def build_composite_bloq(
or is_symbolic(self.resource_bitsize)
):
raise DecomposeTypeError(f"Cannot decompose symbolic {self=}")
assert (
isinstance(self.system_bitsize, int)
and isinstance(self.ancilla_bitsize, int)
and isinstance(self.resource_bitsize, int)
)
n = len(self.block_encodings)

if self.ancilla_bitsize > 0:
Expand All @@ -176,8 +171,8 @@ def build_composite_bloq(

# connect constituent bloqs
for i, u in enumerate(reversed(self.block_encodings)):
assert isinstance(u.ancilla_bitsize, int)
assert isinstance(u.resource_bitsize, int)
assert not is_symbolic(u.ancilla_bitsize)
assert not is_symbolic(u.resource_bitsize)
u_soqs = {"system": system}
partition: List[Tuple[Register, List[Union[str, Unused]]]] = [
(Register("system", dtype=QAny(u.system_bitsize)), ["system"])
Expand Down
8 changes: 4 additions & 4 deletions qualtran/bloqs/block_encoding/tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from collections import Counter
from functools import cached_property
from typing import cast, Dict, Set, Tuple
from typing import Dict, Set, Tuple

from attrs import evolve, field, frozen, validators

Expand Down Expand Up @@ -142,13 +142,13 @@ def build_composite_bloq(
if "resource" in u.signature._lefts
)

sys_part = Partition(cast(int, self.system_bitsize), regs=sys_regs)
sys_part = Partition(self.system_bitsize, regs=sys_regs)
sys_out_regs = list(bb.add_t(sys_part, x=system))
if len(anc_regs) > 0:
anc_part = Partition(cast(int, self.ancilla_bitsize), regs=anc_regs)
anc_part = Partition(self.ancilla_bitsize, regs=anc_regs)
anc_out_regs = list(bb.add_t(anc_part, x=soqs["ancilla"]))
if len(res_regs) > 0:
res_part = Partition(cast(int, self.resource_bitsize), regs=res_regs)
res_part = Partition(self.resource_bitsize, regs=res_regs)
res_out_regs = list(bb.add_t(res_part, x=soqs["resource"]))
sys_i = 0
anc_i = 0
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/data_loading/select_swap_qrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from qualtran.bloqs.data_loading.qrom_base import QROMBase
from qualtran.bloqs.swap_network import SwapWithZero
from qualtran.drawing import Circle, Text, TextBox, WireSymbol
from qualtran.symbolics import ceil, is_symbolic, log2, prod, SymbolicInt
from qualtran.symbolics import ceil, is_symbolic, log2, prod, SymbolicFloat, SymbolicInt

if TYPE_CHECKING:
from qualtran import Bloq
Expand All @@ -46,7 +46,7 @@ def find_optimal_log_block_size(
* iteration_length/2^k + target_bitsize*(2^k - 1) is minimized.
The corresponding block size for SelectSwapQROM would be 2^k.
"""
k = 0.5 * log2(iteration_length / target_bitsize)
k: SymbolicFloat = 0.5 * log2(iteration_length / target_bitsize)
if is_symbolic(k):
return ceil(k)

Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/qft/qft_text_book.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import cast, Iterator, Set
from typing import Iterator, Set

import attrs
import cirq
Expand Down Expand Up @@ -94,7 +94,7 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> Set['BloqCountT']:
)
}
else:
for i in range(1, cast(int, self.bitsize)):
for i in range(1, self.bitsize):
ret |= {(PhaseGradientUnitary(i, exponent=0.5, is_controlled=True), 1)}
if self.with_reverse:
ret |= {(TwoBitSwap(), self.bitsize // 2)}
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/rotations/phase_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import cast, Dict, Iterator, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union
from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union

import attrs
import cirq
Expand Down Expand Up @@ -138,7 +138,7 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> Set['BloqCountT']:
}

ret: Set['BloqCountT'] = set()
for i in range(cast(int, self.bitsize)):
for i in range(self.bitsize):
ret.add((gate(exponent=self.exponent / 2**i, eps=self.eps / self.bitsize), 1))
return ret

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def k_l_logL(self) -> Tuple[SymbolicInt, SymbolicInt, SymbolicInt]:
k, n, logL = 0, self.n, bit_length(self.n - 1)
if is_symbolic(n):
return 0, self.n, bit_length(self.n - 1)
n = int(n)
while n > 1 and n % 2 == 0:
k += 1
logL -= 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ class StatePreparationViaRotations(GateWithRegisters):
def __attrs_post_init__(self):
if is_symbolic(self.state_coefficients):
return
assert isinstance(self.state_coefficients, tuple)
# a valid quantum state has a number of coefficients that is a power of two
assert slen(self.state_coefficients) == 2**self.state_bitsize
# negative number of control bits is not allowed
Expand Down
18 changes: 17 additions & 1 deletion qualtran/symbolics/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import overload, TypeVar, Union

import sympy
from attrs import field, frozen, validators
from cirq._doc import document
from typing_extensions import TypeIs

SymbolicFloat = Union[float, sympy.Expr]
document(SymbolicFloat, """A floating point value or a sympy expression.""")
Expand Down Expand Up @@ -63,7 +64,22 @@ def is_symbolic(self):
return True


T = TypeVar('T')


@overload
def is_symbolic(
arg: Union[T, sympy.Expr, Shaped, HasLength], /
) -> TypeIs[Union[sympy.Expr, Shaped, HasLength]]:
...


@overload
def is_symbolic(*args) -> bool:
...


def is_symbolic(*args) -> Union[TypeIs[Union[sympy.Expr, Shaped, HasLength]], bool]:
"""Returns whether the inputs contain any symbolic object.

Returns:
Expand Down
Loading