From 23577b27fb54f8bc5d88da05521fa092c6843a04 Mon Sep 17 00:00:00 2001 From: anurudhp Date: Fri, 21 Jun 2024 11:59:16 -0700 Subject: [PATCH] fix mypy --- qualtran/bloqs/arithmetic/sorting.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/qualtran/bloqs/arithmetic/sorting.py b/qualtran/bloqs/arithmetic/sorting.py index 08341e8d5..e201c4583 100644 --- a/qualtran/bloqs/arithmetic/sorting.py +++ b/qualtran/bloqs/arithmetic/sorting.py @@ -139,6 +139,11 @@ def build_composite_bloq(self, bb: 'BloqBuilder', xs: 'SoquetT') -> Dict[str, 'S if self.is_symbolic(): raise DecomposeTypeError(f"Cannot decompose symbolic {self=}") + # make mypy happy + assert not isinstance(self.k, sympy.Expr) + assert not isinstance(self.offset, sympy.Expr) + assert isinstance(xs, np.ndarray) + comp = Comparator(self.L) junk = [] @@ -172,6 +177,7 @@ class BitonicMerge(Bloq): def __attrs_post_init__(self): k = self.k if not is_symbolic(k): + assert not isinstance(k, sympy.Expr) assert k >= 1, "length of input lists must be positive" # TODO support non-power-of-two input lengths assert (k & (k - 1)) == 0, "length of input lists must be a power of 2" @@ -203,12 +209,16 @@ def num_comparisons(self) -> SymbolicInt: def is_symbolic(self): return is_symbolic(self.L, self.k) - def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> dict[str, 'SoquetT']: + def build_composite_bloq( + self, bb: 'BloqBuilder', xs: 'SoquetT', ys: 'SoquetT' + ) -> dict[str, 'SoquetT']: if self.is_symbolic(): raise DecomposeTypeError(f"Cannot decompose symbolic {self=}") - k = self.k - xs, ys = soqs['xs'], soqs['ys'] + assert isinstance(xs, np.ndarray) + assert isinstance(ys, np.ndarray) + + k = int(self.k) first_round_junk = [] for i in range(k): @@ -216,7 +226,7 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> dict[str first_round_junk.append(anc) result = np.concatenate([xs, ys]) - logk = bit_length(k - 1) + logk = int(bit_length(k - 1)) assert 2**logk == k all_junks = [first_round_junk] @@ -254,6 +264,7 @@ class BitonicSort(Bloq): def __attrs_post_init__(self): k = self.k if not is_symbolic(k): + assert not isinstance(k, sympy.Expr) assert k >= 1, "length of input list must be positive" # TODO support non-power-of-two input lengths assert (k & (k - 1)) == 0, "length of input list must be a power of 2" @@ -286,6 +297,8 @@ def build_composite_bloq(self, bb: 'BloqBuilder', xs: 'SoquetT') -> dict[str, 'S if self.is_symbolic(): raise DecomposeTypeError(f"Cannot decompose symbolic {self=}") + assert isinstance(xs, np.ndarray) + if self.k == 1: return {'xs': xs, 'junk': np.array([])}