diff --git a/qualtran/bloqs/data_loading/qroam_clean.py b/qualtran/bloqs/data_loading/qroam_clean.py index bb70cb98f..91f076796 100644 --- a/qualtran/bloqs/data_loading/qroam_clean.py +++ b/qualtran/bloqs/data_loading/qroam_clean.py @@ -155,8 +155,8 @@ def build_from_bitsize( @log_block_sizes.default def _default_log_block_sizes(self) -> Tuple[SymbolicInt, ...]: - target_bitsize = sum(self.target_bitsizes) * sum( - prod(shape) for shape in self.target_shapes + target_bitsize = sum( + bs * prod(shape) for (bs, shape) in zip(self.target_bitsizes, self.target_shapes) ) return tuple( get_optimal_log_block_size_clean_ancilla(ilen, target_bitsize, adjoint=True) @@ -256,8 +256,8 @@ def _target_reg_side(self) -> Side: @log_block_sizes.default def _default_log_block_sizes(self) -> Tuple[SymbolicInt, ...]: - target_bitsize = sum(self.target_bitsizes) * sum( - prod(shape) for shape in self.target_shapes + target_bitsize = sum( + bs * prod(shape) for (bs, shape) in zip(self.target_bitsizes, self.target_shapes) ) return tuple( get_optimal_log_block_size_clean_ancilla(ilen, target_bitsize) diff --git a/qualtran/bloqs/data_loading/qroam_clean_test.py b/qualtran/bloqs/data_loading/qroam_clean_test.py index 932e9848c..7bc3b792f 100644 --- a/qualtran/bloqs/data_loading/qroam_clean_test.py +++ b/qualtran/bloqs/data_loading/qroam_clean_test.py @@ -18,6 +18,7 @@ from qualtran.bloqs.data_loading.qroam_clean import ( _qroam_clean_multi_data, _qroam_clean_multi_dim, + get_optimal_log_block_size_clean_ancilla, QROAMClean, QROAMCleanAdjoint, ) @@ -61,6 +62,19 @@ def test_t_complexity_2d_data_symbolic(): assert bloq_inv.t_complexity().t == 4 * expected_toffoli_inv +@pytest.mark.parametrize('n', range(3, 8)) +def test_qroam_default_log_block_sizes(n: int): + data = np.arange(2**n) + bloq = QROAMClean.build_from_data(data, data, target_bitsizes=(n.bit_length(), n.bit_length())) + bs = get_optimal_log_block_size_clean_ancilla(len(data), sum(bloq.target_bitsizes)) + assert bs == bloq.log_block_sizes[0] + bloq = bloq.adjoint() + bs = get_optimal_log_block_size_clean_ancilla( + len(data), sum(bloq.target_bitsizes), adjoint=True + ) + assert bs == bloq.log_block_sizes[0] + + def test_qroam_clean_classical_sim(): rng = np.random.default_rng(42) # 1D data, 1 dataset