From 66d40ec411dc9664a62172b8207eac51a34199b4 Mon Sep 17 00:00:00 2001 From: Charles Yuan Date: Tue, 23 Jul 2024 12:02:11 -0700 Subject: [PATCH 1/4] Improve static type hints for symbolic math (#1176) --- qualtran/symbolics/math_funcs.py | 126 +++++++++++++++++++++++++++++-- 1 file changed, 118 insertions(+), 8 deletions(-) diff --git a/qualtran/symbolics/math_funcs.py b/qualtran/symbolics/math_funcs.py index 743dd4e14..7fb740e4e 100644 --- a/qualtran/symbolics/math_funcs.py +++ b/qualtran/symbolics/math_funcs.py @@ -30,49 +30,129 @@ def pi(*args) -> SymbolicFloat: return sympy.pi if is_symbolic(*args) else np.pi +@overload +def log2(x: float) -> float: + ... + + +@overload +def log2(x: sympy.Expr) -> sympy.Expr: + ... + + def log2(x: SymbolicFloat) -> SymbolicFloat: from sympy.codegen.cfunctions import log2 - if not isinstance(x, sympy.Basic): + if not is_symbolic(x): return np.log2(x) return log2(x) +@overload +def sexp(x: complex) -> complex: + ... + + +@overload +def sexp(x: sympy.Expr) -> sympy.Expr: + ... + + def sexp(x: SymbolicComplex) -> SymbolicComplex: - if isinstance(x, sympy.Basic): + if is_symbolic(x): return sympy.exp(x) return np.exp(x) +@overload +def sarg(x: complex) -> float: + ... + + +@overload +def sarg(x: sympy.Expr) -> sympy.Expr: + ... + + def sarg(x: SymbolicComplex) -> SymbolicFloat: r"""Argument $t$ of a complex number $r e^{i t}$""" - if isinstance(x, sympy.Basic): + if is_symbolic(x): return sympy.arg(x) return float(np.angle(x)) +@overload +def sabs(x: float) -> float: + ... + + +@overload +def sabs(x: sympy.Expr) -> sympy.Expr: + ... + + def sabs(x: SymbolicFloat) -> SymbolicFloat: return cast(SymbolicFloat, abs(x)) +@overload +def ssqrt(x: float) -> float: + ... + + +@overload +def ssqrt(x: sympy.Expr) -> sympy.Expr: + ... + + def ssqrt(x: SymbolicFloat) -> SymbolicFloat: - if isinstance(x, sympy.Basic): + if is_symbolic(x): return sympy.sqrt(x) return np.sqrt(x) +@overload +def ceil(x: float) -> int: + ... + + +@overload +def ceil(x: sympy.Expr) -> sympy.Expr: + ... + + def ceil(x: SymbolicFloat) -> SymbolicInt: - if not isinstance(x, sympy.Basic): + if not is_symbolic(x): return int(np.ceil(x)) return sympy.ceiling(x) +@overload +def floor(x: float) -> int: + ... + + +@overload +def floor(x: sympy.Expr) -> sympy.Expr: + ... + + def floor(x: SymbolicFloat) -> SymbolicInt: - if not isinstance(x, sympy.Basic): + if not is_symbolic(x): return int(np.floor(x)) return sympy.floor(x) +@overload +def bit_length(x: float) -> int: + ... + + +@overload +def bit_length(x: sympy.Expr) -> sympy.Expr: + ... + + def bit_length(x: SymbolicFloat) -> SymbolicInt: """Returns the number of bits required to represent the integer part of positive float `x`.""" if not is_symbolic(x) and 0 <= x < 1: @@ -157,15 +237,45 @@ def ssum(args: Iterable[SymbolicT]) -> SymbolicT: return ret +@overload +def acos(x: float) -> float: + ... + + +@overload +def acos(x: sympy.Expr) -> sympy.Expr: + ... + + def acos(x: SymbolicFloat) -> SymbolicFloat: - if not isinstance(x, sympy.Basic): + if not is_symbolic(x): return np.arccos(x) return sympy.acos(x) +@overload +def sconj(x: complex) -> complex: + ... + + +@overload +def sconj(x: sympy.Expr) -> sympy.Expr: + ... + + def sconj(x: SymbolicComplex) -> SymbolicComplex: """Compute the complex conjugate.""" - return sympy.conjugate(x) if isinstance(x, sympy.Expr) else np.conjugate(x) + return sympy.conjugate(x) if is_symbolic(x) else np.conjugate(x) + + +@overload +def slen(x: Sized) -> int: + ... + + +@overload +def slen(x: Union[Shaped, HasLength]) -> sympy.Expr: + ... def slen(x: Union[Sized, Shaped, HasLength]) -> SymbolicInt: From a24bfab20d57414ad35addcdd2a5060a5e2ec945 Mon Sep 17 00:00:00 2001 From: Fionn Malone Date: Tue, 23 Jul 2024 12:40:00 -0700 Subject: [PATCH 2/4] Move SelectOracle and PrepareOracle (#1178) * Move PrepareOracle to state_preparation. * Move SelectOracle to multiplexers --- dev_tools/autogenerate-bloqs-notebooks-v2.py | 5 +- .../block_encoding/block_encoding_base.py | 2 +- .../block_encoding/lcu_block_encoding.py | 3 +- .../block_encoding/lcu_block_encoding_test.py | 3 +- .../block_encoding/linear_combination.py | 2 +- qualtran/bloqs/block_encoding/phase.py | 2 +- qualtran/bloqs/block_encoding/product.py | 2 +- .../bloqs/block_encoding/sparse_matrix.py | 2 +- .../bloqs/block_encoding/tensor_product.py | 2 +- qualtran/bloqs/block_encoding/unitary.py | 2 +- .../chemistry/df/double_factorization.py | 2 +- .../qubitization/prepare_hubbard.py | 2 +- .../qubitization/select_hubbard.py | 2 +- .../projectile/select_and_prepare.py | 3 +- .../first_quantization/select_and_prepare.py | 3 +- .../chemistry/sf/single_factorization.py | 2 +- qualtran/bloqs/chemistry/sparse/prepare.py | 2 +- .../bloqs/chemistry/sparse/select_bloq.py | 2 +- qualtran/bloqs/chemistry/thc/prepare.py | 2 +- qualtran/bloqs/chemistry/thc/select_bloq.py | 2 +- .../bloqs/chemistry/writing_algorithms.ipynb | 2 +- .../for_testing/qubitization_walk_test.py | 2 +- .../for_testing/random_select_and_prepare.py | 3 +- .../mean_estimation/complex_phase_oracle.py | 2 +- .../complex_phase_oracle_test.py | 2 +- .../mean_estimation_operator.py | 3 +- .../mean_estimation_operator_test.py | 3 +- qualtran/bloqs/multiplexers/apply_lth_bloq.py | 2 +- qualtran/bloqs/multiplexers/select_base.py | 66 +++++++++++++++++++ .../bloqs/multiplexers/select_pauli_lcu.py | 2 +- .../qubitization_walk_operator.ipynb | 4 +- .../qubitization_walk_operator.py | 3 +- .../bloqs/reflections/prepare_identity.py | 2 +- .../reflections/reflection_using_prepare.py | 2 +- .../prepare_base.py} | 49 +------------- .../state_preparation_alias_sampling.py | 2 +- qualtran/serialization/resolver_dict.py | 6 +- 37 files changed, 114 insertions(+), 88 deletions(-) create mode 100644 qualtran/bloqs/multiplexers/select_base.py rename qualtran/bloqs/{block_encoding/lcu_select_and_prepare.py => state_preparation/prepare_base.py} (62%) diff --git a/dev_tools/autogenerate-bloqs-notebooks-v2.py b/dev_tools/autogenerate-bloqs-notebooks-v2.py index 58bdbe807..ec374e39b 100644 --- a/dev_tools/autogenerate-bloqs-notebooks-v2.py +++ b/dev_tools/autogenerate-bloqs-notebooks-v2.py @@ -62,7 +62,6 @@ import qualtran.bloqs.block_encoding.block_encoding_base import qualtran.bloqs.block_encoding.chebyshev_polynomial import qualtran.bloqs.block_encoding.lcu_block_encoding -import qualtran.bloqs.block_encoding.lcu_select_and_prepare import qualtran.bloqs.block_encoding.linear_combination import qualtran.bloqs.block_encoding.phase import qualtran.bloqs.bookkeeping @@ -581,8 +580,8 @@ module=qualtran.bloqs.qubitization.qubitization_walk_operator, bloq_specs=[ qualtran.bloqs.qubitization.qubitization_walk_operator._QUBITIZATION_WALK_DOC, - qualtran.bloqs.block_encoding.lcu_select_and_prepare._SELECT_ORACLE_DOC, - qualtran.bloqs.block_encoding.lcu_select_and_prepare._PREPARE_ORACLE_DOC, + qualtran.bloqs.multiplexers.select_base._SELECT_ORACLE_DOC, + qualtran.bloqs.state_preparation.prepare_base._PREPARE_ORACLE_DOC, ], ), NotebookSpecV2( diff --git a/qualtran/bloqs/block_encoding/block_encoding_base.py b/qualtran/bloqs/block_encoding/block_encoding_base.py index a35306b48..c1f99aa6b 100644 --- a/qualtran/bloqs/block_encoding/block_encoding_base.py +++ b/qualtran/bloqs/block_encoding/block_encoding_base.py @@ -15,7 +15,7 @@ from typing import Tuple from qualtran import Bloq, BloqDocSpec, Register -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.symbolics import SymbolicFloat, SymbolicInt diff --git a/qualtran/bloqs/block_encoding/lcu_block_encoding.py b/qualtran/bloqs/block_encoding/lcu_block_encoding.py index 691a5a397..07961302c 100644 --- a/qualtran/bloqs/block_encoding/lcu_block_encoding.py +++ b/qualtran/bloqs/block_encoding/lcu_block_encoding.py @@ -30,8 +30,9 @@ SoquetT, ) from qualtran.bloqs.block_encoding.block_encoding_base import BlockEncoding -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle, SelectOracle from qualtran.bloqs.bookkeeping import Partition +from qualtran.bloqs.multiplexers.select_base import SelectOracle +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.symbolics import SymbolicFloat diff --git a/qualtran/bloqs/block_encoding/lcu_block_encoding_test.py b/qualtran/bloqs/block_encoding/lcu_block_encoding_test.py index 98663bcb5..24c5d83e1 100644 --- a/qualtran/bloqs/block_encoding/lcu_block_encoding_test.py +++ b/qualtran/bloqs/block_encoding/lcu_block_encoding_test.py @@ -27,7 +27,8 @@ BlackBoxPrepare, BlackBoxSelect, ) -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle, SelectOracle +from qualtran.bloqs.multiplexers.select_base import SelectOracle +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle def test_lcu_block_encoding(bloq_autotester): diff --git a/qualtran/bloqs/block_encoding/linear_combination.py b/qualtran/bloqs/block_encoding/linear_combination.py index b8a8061d2..8fcb52981 100644 --- a/qualtran/bloqs/block_encoding/linear_combination.py +++ b/qualtran/bloqs/block_encoding/linear_combination.py @@ -32,10 +32,10 @@ from qualtran._infra.bloq import DecomposeTypeError from qualtran.bloqs.block_encoding import BlockEncoding from qualtran.bloqs.block_encoding.lcu_block_encoding import BlackBoxPrepare, BlackBoxSelect -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle from qualtran.bloqs.block_encoding.phase import Phase from qualtran.bloqs.bookkeeping.auto_partition import AutoPartition, Unused from qualtran.bloqs.bookkeeping.partition import Partition +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.linalg.lcu_util import preprocess_probabilities_for_reversible_sampling from qualtran.symbolics import smax, ssum, SymbolicFloat, SymbolicInt from qualtran.symbolics.types import is_symbolic diff --git a/qualtran/bloqs/block_encoding/phase.py b/qualtran/bloqs/block_encoding/phase.py index 9b2efea52..f572fc571 100644 --- a/qualtran/bloqs/block_encoding/phase.py +++ b/qualtran/bloqs/block_encoding/phase.py @@ -20,7 +20,7 @@ from qualtran import bloq_example, BloqBuilder, BloqDocSpec, QAny, Register, Signature, SoquetT from qualtran.bloqs.basic_gates import GlobalPhase from qualtran.bloqs.block_encoding import BlockEncoding -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.resource_counting import BloqCountT, SympySymbolAllocator from qualtran.symbolics import SymbolicFloat, SymbolicInt diff --git a/qualtran/bloqs/block_encoding/product.py b/qualtran/bloqs/block_encoding/product.py index 111cc9857..bbd7904e5 100644 --- a/qualtran/bloqs/block_encoding/product.py +++ b/qualtran/bloqs/block_encoding/product.py @@ -33,10 +33,10 @@ ) from qualtran.bloqs.basic_gates.x_basis import XGate from qualtran.bloqs.block_encoding import BlockEncoding -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle from qualtran.bloqs.bookkeeping.auto_partition import AutoPartition, Unused from qualtran.bloqs.bookkeeping.partition import Partition from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiControlPauli +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.symbolics import is_symbolic, prod, smax, ssum, SymbolicFloat, SymbolicInt diff --git a/qualtran/bloqs/block_encoding/sparse_matrix.py b/qualtran/bloqs/block_encoding/sparse_matrix.py index 8586e8c6e..17e4c0a32 100644 --- a/qualtran/bloqs/block_encoding/sparse_matrix.py +++ b/qualtran/bloqs/block_encoding/sparse_matrix.py @@ -37,9 +37,9 @@ ) from qualtran.bloqs.basic_gates import Ry, Swap from qualtran.bloqs.block_encoding import BlockEncoding -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle from qualtran.bloqs.bookkeeping.auto_partition import AutoPartition, Unused from qualtran.bloqs.data_loading import QROM +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.bloqs.state_preparation.prepare_uniform_superposition import ( PrepareUniformSuperposition, ) diff --git a/qualtran/bloqs/block_encoding/tensor_product.py b/qualtran/bloqs/block_encoding/tensor_product.py index c5cf7631e..8952340e7 100644 --- a/qualtran/bloqs/block_encoding/tensor_product.py +++ b/qualtran/bloqs/block_encoding/tensor_product.py @@ -29,8 +29,8 @@ SoquetT, ) from qualtran.bloqs.block_encoding import BlockEncoding -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle from qualtran.bloqs.bookkeeping import Partition +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.resource_counting import BloqCountT, SympySymbolAllocator from qualtran.symbolics import is_symbolic, prod, ssum, SymbolicFloat, SymbolicInt diff --git a/qualtran/bloqs/block_encoding/unitary.py b/qualtran/bloqs/block_encoding/unitary.py index 54012d740..86bdc7df6 100644 --- a/qualtran/bloqs/block_encoding/unitary.py +++ b/qualtran/bloqs/block_encoding/unitary.py @@ -29,7 +29,7 @@ SoquetT, ) from qualtran.bloqs.block_encoding import BlockEncoding -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.resource_counting import BloqCountT, SympySymbolAllocator from qualtran.symbolics import SymbolicFloat, SymbolicInt diff --git a/qualtran/bloqs/chemistry/df/double_factorization.py b/qualtran/bloqs/chemistry/df/double_factorization.py index ccbfd00d5..a5781de43 100644 --- a/qualtran/bloqs/chemistry/df/double_factorization.py +++ b/qualtran/bloqs/chemistry/df/double_factorization.py @@ -50,7 +50,6 @@ ) from qualtran.bloqs.basic_gates import CSwap, Hadamard, Toffoli from qualtran.bloqs.block_encoding import BlockEncoding -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle from qualtran.bloqs.bookkeeping import ArbitraryClifford from qualtran.bloqs.chemistry.black_boxes import ApplyControlledZs from qualtran.bloqs.chemistry.df.prepare import ( @@ -61,6 +60,7 @@ from qualtran.bloqs.chemistry.df.select_bloq import ProgRotGateArray from qualtran.bloqs.reflections.prepare_identity import PrepareIdentity from qualtran.bloqs.reflections.reflection_using_prepare import ReflectionUsingPrepare +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle if TYPE_CHECKING: from qualtran.resource_counting import BloqCountT, SympySymbolAllocator diff --git a/qualtran/bloqs/chemistry/hubbard_model/qubitization/prepare_hubbard.py b/qualtran/bloqs/chemistry/hubbard_model/qubitization/prepare_hubbard.py index 5ce2146fe..83b8933ae 100644 --- a/qualtran/bloqs/chemistry/hubbard_model/qubitization/prepare_hubbard.py +++ b/qualtran/bloqs/chemistry/hubbard_model/qubitization/prepare_hubbard.py @@ -22,9 +22,9 @@ from qualtran import bloq_example, BloqDocSpec, BoundedQUInt, QAny, Register, Signature from qualtran.bloqs.basic_gates import CSwap -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle from qualtran.bloqs.mcmt.and_bloq import MultiAnd from qualtran.bloqs.mod_arithmetic import ModAddK +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.bloqs.state_preparation.prepare_uniform_superposition import ( PrepareUniformSuperposition, ) diff --git a/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard.py b/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard.py index 53ec8d6dd..124b30033 100644 --- a/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard.py +++ b/qualtran/bloqs/chemistry/hubbard_model/qubitization/select_hubbard.py @@ -23,8 +23,8 @@ from qualtran import bloq_example, BloqDocSpec, BoundedQUInt, QAny, QBit, Register, Signature from qualtran._infra.gate_with_registers import SpecializedSingleQubitControlledGate, total_bits from qualtran.bloqs.basic_gates import CSwap -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import SelectOracle from qualtran.bloqs.multiplexers.apply_gate_to_lth_target import ApplyGateToLthQubit +from qualtran.bloqs.multiplexers.select_base import SelectOracle from qualtran.bloqs.multiplexers.selected_majorana_fermion import SelectedMajoranaFermion diff --git a/qualtran/bloqs/chemistry/pbc/first_quantization/projectile/select_and_prepare.py b/qualtran/bloqs/chemistry/pbc/first_quantization/projectile/select_and_prepare.py index 5d5d12a86..3af1f1641 100644 --- a/qualtran/bloqs/chemistry/pbc/first_quantization/projectile/select_and_prepare.py +++ b/qualtran/bloqs/chemistry/pbc/first_quantization/projectile/select_and_prepare.py @@ -33,7 +33,6 @@ SoquetT, ) from qualtran.bloqs.basic_gates import CSwap, Toffoli -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle, SelectOracle from qualtran.bloqs.chemistry.pbc.first_quantization.projectile.prepare_t import ( PrepareTFirstQuantizationWithProj, ) @@ -50,6 +49,8 @@ MultiplexedCSwap3D, UniformSuperpostionIJFirstQuantization, ) +from qualtran.bloqs.multiplexers.select_base import SelectOracle +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.bloqs.swap_network import MultiplexedCSwap from qualtran.drawing import Circle, Text, TextBox, WireSymbol diff --git a/qualtran/bloqs/chemistry/pbc/first_quantization/select_and_prepare.py b/qualtran/bloqs/chemistry/pbc/first_quantization/select_and_prepare.py index 7f5e59765..7dd7d0c06 100644 --- a/qualtran/bloqs/chemistry/pbc/first_quantization/select_and_prepare.py +++ b/qualtran/bloqs/chemistry/pbc/first_quantization/select_and_prepare.py @@ -33,11 +33,12 @@ SoquetT, ) from qualtran.bloqs.basic_gates import Toffoli -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle, SelectOracle from qualtran.bloqs.chemistry.pbc.first_quantization.prepare_t import PrepareTFirstQuantization from qualtran.bloqs.chemistry.pbc.first_quantization.prepare_uv import PrepareUVFirstQuantization from qualtran.bloqs.chemistry.pbc.first_quantization.select_t import SelectTFirstQuantization from qualtran.bloqs.chemistry.pbc.first_quantization.select_uv import SelectUVFirstQuantization +from qualtran.bloqs.multiplexers.select_base import SelectOracle +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.bloqs.swap_network import MultiplexedCSwap from qualtran.drawing import Text, TextBox, WireSymbol diff --git a/qualtran/bloqs/chemistry/sf/single_factorization.py b/qualtran/bloqs/chemistry/sf/single_factorization.py index 4f358bfa9..9df810670 100644 --- a/qualtran/bloqs/chemistry/sf/single_factorization.py +++ b/qualtran/bloqs/chemistry/sf/single_factorization.py @@ -45,7 +45,6 @@ from qualtran.bloqs.basic_gates import Hadamard from qualtran.bloqs.basic_gates.swap import CSwap from qualtran.bloqs.block_encoding import BlockEncoding -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle from qualtran.bloqs.chemistry.sf.prepare import ( InnerPrepareSingleFactorization, OuterPrepareSingleFactorization, @@ -53,6 +52,7 @@ from qualtran.bloqs.chemistry.sf.select_bloq import SelectSingleFactorization from qualtran.bloqs.reflections.prepare_identity import PrepareIdentity from qualtran.bloqs.reflections.reflection_using_prepare import ReflectionUsingPrepare +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle if TYPE_CHECKING: from qualtran.resource_counting import BloqCountT, SympySymbolAllocator diff --git a/qualtran/bloqs/chemistry/sparse/prepare.py b/qualtran/bloqs/chemistry/sparse/prepare.py index 5ed2237c9..13b3ddcc7 100644 --- a/qualtran/bloqs/chemistry/sparse/prepare.py +++ b/qualtran/bloqs/chemistry/sparse/prepare.py @@ -37,8 +37,8 @@ from qualtran.bloqs.basic_gates import CSwap, Hadamard from qualtran.bloqs.basic_gates.on_each import OnEach from qualtran.bloqs.basic_gates.z_basis import CZ, ZGate -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle from qualtran.bloqs.data_loading.select_swap_qrom import find_optimal_log_block_size, SelectSwapQROM +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.bloqs.state_preparation.prepare_uniform_superposition import ( PrepareUniformSuperposition, ) diff --git a/qualtran/bloqs/chemistry/sparse/select_bloq.py b/qualtran/bloqs/chemistry/sparse/select_bloq.py index 4e729229c..52f1443f3 100644 --- a/qualtran/bloqs/chemistry/sparse/select_bloq.py +++ b/qualtran/bloqs/chemistry/sparse/select_bloq.py @@ -31,7 +31,7 @@ ) from qualtran._infra.gate_with_registers import SpecializedSingleQubitControlledGate from qualtran.bloqs.basic_gates import SGate -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import SelectOracle +from qualtran.bloqs.multiplexers.select_base import SelectOracle from qualtran.bloqs.multiplexers.selected_majorana_fermion import SelectedMajoranaFermion if TYPE_CHECKING: diff --git a/qualtran/bloqs/chemistry/thc/prepare.py b/qualtran/bloqs/chemistry/thc/prepare.py index 49461354c..98ae77fe9 100644 --- a/qualtran/bloqs/chemistry/thc/prepare.py +++ b/qualtran/bloqs/chemistry/thc/prepare.py @@ -42,10 +42,10 @@ ) from qualtran.bloqs.basic_gates import CSwap, Hadamard, Ry, Toffoli, XGate from qualtran.bloqs.basic_gates.on_each import OnEach -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle from qualtran.bloqs.data_loading.select_swap_qrom import SelectSwapQROM from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiControlPauli from qualtran.bloqs.reflections.reflection_using_prepare import ReflectionUsingPrepare +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.cirq_interop import CirqGateAsBloq from qualtran.drawing import Text, WireSymbol from qualtran.linalg.lcu_util import preprocess_probabilities_for_reversible_sampling diff --git a/qualtran/bloqs/chemistry/thc/select_bloq.py b/qualtran/bloqs/chemistry/thc/select_bloq.py index 525a603e3..1ff92ca58 100644 --- a/qualtran/bloqs/chemistry/thc/select_bloq.py +++ b/qualtran/bloqs/chemistry/thc/select_bloq.py @@ -33,8 +33,8 @@ ) from qualtran._infra.gate_with_registers import SpecializedSingleQubitControlledGate from qualtran.bloqs.basic_gates import CSwap, Toffoli, XGate -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import SelectOracle from qualtran.bloqs.chemistry.black_boxes import ApplyControlledZs +from qualtran.bloqs.multiplexers.select_base import SelectOracle if TYPE_CHECKING: from qualtran.resource_counting import BloqCountT, SympySymbolAllocator diff --git a/qualtran/bloqs/chemistry/writing_algorithms.ipynb b/qualtran/bloqs/chemistry/writing_algorithms.ipynb index 14cbe90e4..ffdf21eab 100644 --- a/qualtran/bloqs/chemistry/writing_algorithms.ipynb +++ b/qualtran/bloqs/chemistry/writing_algorithms.ipynb @@ -72,7 +72,7 @@ "from qualtran import Register, BoundedQUInt, QBit, QAny\n", "\n", "from qualtran.drawing import show_bloq\n", - "from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle\n", + "from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle\n", "\n", "@frozen\n", "class PrepareSecondQuantization(PrepareOracle):\n", diff --git a/qualtran/bloqs/for_testing/qubitization_walk_test.py b/qualtran/bloqs/for_testing/qubitization_walk_test.py index c4bc8c574..a823075f3 100644 --- a/qualtran/bloqs/for_testing/qubitization_walk_test.py +++ b/qualtran/bloqs/for_testing/qubitization_walk_test.py @@ -20,10 +20,10 @@ from numpy.typing import NDArray from qualtran import Signature -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle from qualtran.bloqs.multiplexers.select_pauli_lcu import SelectPauliLCU from qualtran.bloqs.qubitization.qubitization_walk_operator import QubitizationWalkOperator from qualtran.bloqs.state_preparation import PrepareUniformSuperposition +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.symbolics import SymbolicFloat diff --git a/qualtran/bloqs/for_testing/random_select_and_prepare.py b/qualtran/bloqs/for_testing/random_select_and_prepare.py index deb7ffd55..1f2d331d8 100644 --- a/qualtran/bloqs/for_testing/random_select_and_prepare.py +++ b/qualtran/bloqs/for_testing/random_select_and_prepare.py @@ -21,9 +21,10 @@ from qualtran import BloqBuilder, BoundedQUInt, QBit, Register, SoquetT from qualtran._infra.gate_with_registers import SpecializedSingleQubitControlledGate -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle, SelectOracle from qualtran.bloqs.for_testing.matrix_gate import MatrixGate +from qualtran.bloqs.multiplexers.select_base import SelectOracle from qualtran.bloqs.qubitization.qubitization_walk_operator import QubitizationWalkOperator +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle @frozen diff --git a/qualtran/bloqs/mean_estimation/complex_phase_oracle.py b/qualtran/bloqs/mean_estimation/complex_phase_oracle.py index 19bc6c124..1b83a3238 100644 --- a/qualtran/bloqs/mean_estimation/complex_phase_oracle.py +++ b/qualtran/bloqs/mean_estimation/complex_phase_oracle.py @@ -21,8 +21,8 @@ from qualtran import GateWithRegisters, Register, Signature from qualtran._infra.gate_with_registers import merge_qubits, total_bits -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import SelectOracle from qualtran.bloqs.mean_estimation.arctan import ArcTan +from qualtran.bloqs.multiplexers.select_base import SelectOracle @attrs.frozen diff --git a/qualtran/bloqs/mean_estimation/complex_phase_oracle_test.py b/qualtran/bloqs/mean_estimation/complex_phase_oracle_test.py index 343e57c1c..a8d26f510 100644 --- a/qualtran/bloqs/mean_estimation/complex_phase_oracle_test.py +++ b/qualtran/bloqs/mean_estimation/complex_phase_oracle_test.py @@ -22,8 +22,8 @@ from attrs import frozen from qualtran import QAny, QBit, QFxp, Register -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import SelectOracle from qualtran.bloqs.mean_estimation.complex_phase_oracle import ComplexPhaseOracle +from qualtran.bloqs.multiplexers.select_base import SelectOracle from qualtran.cirq_interop import testing as cq_testing from qualtran.testing import assert_valid_bloq_decomposition diff --git a/qualtran/bloqs/mean_estimation/mean_estimation_operator.py b/qualtran/bloqs/mean_estimation/mean_estimation_operator.py index 2e9efa1ba..9e1a92bf4 100644 --- a/qualtran/bloqs/mean_estimation/mean_estimation_operator.py +++ b/qualtran/bloqs/mean_estimation/mean_estimation_operator.py @@ -21,9 +21,10 @@ from qualtran import CtrlSpec, Register, Signature from qualtran._infra.gate_with_registers import SpecializedSingleQubitControlledGate, total_bits -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle, SelectOracle from qualtran.bloqs.mean_estimation.complex_phase_oracle import ComplexPhaseOracle +from qualtran.bloqs.multiplexers.select_base import SelectOracle from qualtran.bloqs.reflections.reflection_using_prepare import ReflectionUsingPrepare +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle @attrs.frozen diff --git a/qualtran/bloqs/mean_estimation/mean_estimation_operator_test.py b/qualtran/bloqs/mean_estimation/mean_estimation_operator_test.py index 5a238c682..e76734256 100644 --- a/qualtran/bloqs/mean_estimation/mean_estimation_operator_test.py +++ b/qualtran/bloqs/mean_estimation/mean_estimation_operator_test.py @@ -26,11 +26,12 @@ SpecializedSingleQubitControlledGate, total_bits, ) -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle, SelectOracle from qualtran.bloqs.mean_estimation.mean_estimation_operator import ( CodeForRandomVariable, MeanEstimationOperator, ) +from qualtran.bloqs.multiplexers.select_base import SelectOracle +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.testing import assert_valid_bloq_decomposition diff --git a/qualtran/bloqs/multiplexers/apply_lth_bloq.py b/qualtran/bloqs/multiplexers/apply_lth_bloq.py index 608e858e8..e66927eb6 100644 --- a/qualtran/bloqs/multiplexers/apply_lth_bloq.py +++ b/qualtran/bloqs/multiplexers/apply_lth_bloq.py @@ -22,7 +22,7 @@ from qualtran import Bloq, bloq_example, BloqDocSpec, BoundedQUInt, QBit, Register, Side from qualtran._infra.gate_with_registers import merge_qubits, SpecializedSingleQubitControlledGate -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import SelectOracle +from qualtran.bloqs.multiplexers.select_base import SelectOracle from qualtran.bloqs.multiplexers.unary_iteration_bloq import UnaryIterationGate from qualtran.resource_counting import BloqCountT from qualtran.symbolics import ceil, log2 diff --git a/qualtran/bloqs/multiplexers/select_base.py b/qualtran/bloqs/multiplexers/select_base.py new file mode 100644 index 000000000..26ad6d3d6 --- /dev/null +++ b/qualtran/bloqs/multiplexers/select_base.py @@ -0,0 +1,66 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from functools import cached_property +from typing import Tuple + +from qualtran import BloqDocSpec, GateWithRegisters, Register, Signature + + +class SelectOracle(GateWithRegisters): + r"""Abstract base class that defines the interface for a SELECT Oracle. + + The action of a SELECT oracle on a selection register $|l\rangle$ and target register + $|\Psi\rangle$ can be defined as: + + $$ + \mathrm{SELECT} = \sum_{l}|l \rangle \langle l| \otimes U_l + $$ + + In other words, the `SELECT` oracle applies $l$'th unitary $U_l$ on the target register + $|\Psi\rangle$ when the selection register stores integer $l$. + + $$ + \mathrm{SELECT}|l\rangle |\Psi\rangle = |l\rangle U_{l}|\Psi\rangle + $$ + """ + + @property + @abc.abstractmethod + def control_registers(self) -> Tuple[Register, ...]: + ... + + @property + @abc.abstractmethod + def selection_registers(self) -> Tuple[Register, ...]: + ... + + @property + @abc.abstractmethod + def target_registers(self) -> Tuple[Register, ...]: + ... + + @cached_property + def signature(self) -> Signature: + return Signature( + [*self.control_registers, *self.selection_registers, *self.target_registers] + ) + + +_SELECT_ORACLE_DOC = BloqDocSpec( + bloq_cls=SelectOracle, + import_line='from qualtran.bloqs.multiplexers.select_base import SelectOracle', + examples=[], +) diff --git a/qualtran/bloqs/multiplexers/select_pauli_lcu.py b/qualtran/bloqs/multiplexers/select_pauli_lcu.py index 1b6735745..c6bf53c31 100644 --- a/qualtran/bloqs/multiplexers/select_pauli_lcu.py +++ b/qualtran/bloqs/multiplexers/select_pauli_lcu.py @@ -24,7 +24,7 @@ from qualtran import bloq_example, BloqDocSpec, BoundedQUInt, QAny, QBit, Register from qualtran._infra.gate_with_registers import SpecializedSingleQubitControlledGate -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import SelectOracle +from qualtran.bloqs.multiplexers.select_base import SelectOracle from qualtran.bloqs.multiplexers.unary_iteration_bloq import UnaryIterationGate from qualtran.resource_counting.generalizers import ( cirq_to_bloqs, diff --git a/qualtran/bloqs/qubitization/qubitization_walk_operator.ipynb b/qualtran/bloqs/qubitization/qubitization_walk_operator.ipynb index 08d44e04f..3d1e9f0a5 100644 --- a/qualtran/bloqs/qubitization/qubitization_walk_operator.ipynb +++ b/qualtran/bloqs/qubitization/qubitization_walk_operator.ipynb @@ -76,7 +76,7 @@ }, "outputs": [], "source": [ - "from qualtran.bloqs.block_encoding.lcu_select_and_prepare import SelectOracle" + "from qualtran.bloqs.multiplexers.select_base import SelectOracle" ] }, { @@ -111,7 +111,7 @@ }, "outputs": [], "source": [ - "from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle" + "from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle" ] }, { diff --git a/qualtran/bloqs/qubitization/qubitization_walk_operator.py b/qualtran/bloqs/qubitization/qubitization_walk_operator.py index 5a5ada83f..7c0993e59 100644 --- a/qualtran/bloqs/qubitization/qubitization_walk_operator.py +++ b/qualtran/bloqs/qubitization/qubitization_walk_operator.py @@ -37,8 +37,9 @@ from qualtran import bloq_example, BloqDocSpec, CtrlSpec, Register, Signature from qualtran._infra.gate_with_registers import SpecializedSingleQubitControlledGate, total_bits -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle, SelectOracle +from qualtran.bloqs.multiplexers.select_base import SelectOracle from qualtran.bloqs.reflections.reflection_using_prepare import ReflectionUsingPrepare +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.resource_counting.generalizers import ( cirq_to_bloqs, ignore_cliffords, diff --git a/qualtran/bloqs/reflections/prepare_identity.py b/qualtran/bloqs/reflections/prepare_identity.py index b7c4be99c..d17121315 100644 --- a/qualtran/bloqs/reflections/prepare_identity.py +++ b/qualtran/bloqs/reflections/prepare_identity.py @@ -20,7 +20,7 @@ from qualtran import bloq_example, BloqDocSpec, QAny, Register, Soquet from qualtran.bloqs.basic_gates import Identity from qualtran.bloqs.basic_gates.on_each import OnEach -from qualtran.bloqs.block_encoding.lcu_block_encoding import PrepareOracle +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.resource_counting.generalizers import ignore_split_join from qualtran.symbolics.types import SymbolicInt diff --git a/qualtran/bloqs/reflections/reflection_using_prepare.py b/qualtran/bloqs/reflections/reflection_using_prepare.py index fe6b784c4..9488452a9 100644 --- a/qualtran/bloqs/reflections/reflection_using_prepare.py +++ b/qualtran/bloqs/reflections/reflection_using_prepare.py @@ -33,7 +33,7 @@ from qualtran.symbolics.types import SymbolicInt if TYPE_CHECKING: - from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle + from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.resource_counting import BloqCountT, SympySymbolAllocator diff --git a/qualtran/bloqs/block_encoding/lcu_select_and_prepare.py b/qualtran/bloqs/state_preparation/prepare_base.py similarity index 62% rename from qualtran/bloqs/block_encoding/lcu_select_and_prepare.py rename to qualtran/bloqs/state_preparation/prepare_base.py index e64ba6d15..44b6e2258 100644 --- a/qualtran/bloqs/block_encoding/lcu_select_and_prepare.py +++ b/qualtran/bloqs/state_preparation/prepare_base.py @@ -22,53 +22,6 @@ from qualtran.symbolics import SymbolicFloat -class SelectOracle(GateWithRegisters): - r"""Abstract base class that defines the interface for a SELECT Oracle. - - The action of a SELECT oracle on a selection register $|l\rangle$ and target register - $|\Psi\rangle$ can be defined as: - - $$ - \mathrm{SELECT} = \sum_{l}|l \rangle \langle l| \otimes U_l - $$ - - In other words, the `SELECT` oracle applies $l$'th unitary $U_l$ on the target register - $|\Psi\rangle$ when the selection register stores integer $l$. - - $$ - \mathrm{SELECT}|l\rangle |\Psi\rangle = |l\rangle U_{l}|\Psi\rangle - $$ - """ - - @property - @abc.abstractmethod - def control_registers(self) -> Tuple[Register, ...]: - ... - - @property - @abc.abstractmethod - def selection_registers(self) -> Tuple[Register, ...]: - ... - - @property - @abc.abstractmethod - def target_registers(self) -> Tuple[Register, ...]: - ... - - @cached_property - def signature(self) -> Signature: - return Signature( - [*self.control_registers, *self.selection_registers, *self.target_registers] - ) - - -_SELECT_ORACLE_DOC = BloqDocSpec( - bloq_cls=SelectOracle, - import_line='from qualtran.bloqs.block_encoding.lcu_select_and_prepare import SelectOracle', - examples=[], -) - - class PrepareOracle(GateWithRegisters): r"""Abstract base class that defines the API for a PREPARE Oracle. @@ -109,6 +62,6 @@ def l1_norm_of_coeffs(self) -> Optional['SymbolicFloat']: _PREPARE_ORACLE_DOC = BloqDocSpec( bloq_cls=PrepareOracle, - import_line='from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle', + import_line='from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle', examples=[], ) diff --git a/qualtran/bloqs/state_preparation/state_preparation_alias_sampling.py b/qualtran/bloqs/state_preparation/state_preparation_alias_sampling.py index 200155568..838edafec 100644 --- a/qualtran/bloqs/state_preparation/state_preparation_alias_sampling.py +++ b/qualtran/bloqs/state_preparation/state_preparation_alias_sampling.py @@ -31,8 +31,8 @@ from qualtran._infra.gate_with_registers import total_bits from qualtran.bloqs.arithmetic import LessThanEqual from qualtran.bloqs.basic_gates import CSwap, Hadamard, OnEach -from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle from qualtran.bloqs.data_loading.qrom import QROM +from qualtran.bloqs.state_preparation.prepare_base import PrepareOracle from qualtran.bloqs.state_preparation.prepare_uniform_superposition import ( PrepareUniformSuperposition, ) diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index d9f3a9522..ea2b837d3 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -40,7 +40,6 @@ import qualtran.bloqs.block_encoding import qualtran.bloqs.block_encoding.chebyshev_polynomial import qualtran.bloqs.block_encoding.lcu_block_encoding -import qualtran.bloqs.block_encoding.lcu_select_and_prepare import qualtran.bloqs.block_encoding.linear_combination import qualtran.bloqs.block_encoding.phase import qualtran.bloqs.block_encoding.product @@ -113,6 +112,7 @@ import qualtran.bloqs.mod_arithmetic import qualtran.bloqs.multiplexers.apply_gate_to_lth_target import qualtran.bloqs.multiplexers.apply_lth_bloq +import qualtran.bloqs.multiplexers.select_base import qualtran.bloqs.multiplexers.select_pauli_lcu import qualtran.bloqs.multiplexers.selected_majorana_fermion import qualtran.bloqs.multiplexers.unary_iteration_bloq @@ -343,6 +343,7 @@ "qualtran.bloqs.mean_estimation.mean_estimation_operator.MeanEstimationOperator": qualtran.bloqs.mean_estimation.mean_estimation_operator.MeanEstimationOperator, "qualtran.bloqs.multiplexers.apply_gate_to_lth_target.ApplyGateToLthQubit": qualtran.bloqs.multiplexers.apply_gate_to_lth_target.ApplyGateToLthQubit, "qualtran.bloqs.multiplexers.apply_lth_bloq.ApplyLthBloq": qualtran.bloqs.multiplexers.apply_lth_bloq.ApplyLthBloq, + "qualtran.bloqs.multiplexers.select_base.SelectOracle": qualtran.bloqs.multiplexers.select_base.SelectOracle, "qualtran.bloqs.multiplexers.select_pauli_lcu.SelectPauliLCU": qualtran.bloqs.multiplexers.select_pauli_lcu.SelectPauliLCU, "qualtran.bloqs.multiplexers.selected_majorana_fermion.SelectedMajoranaFermion": qualtran.bloqs.multiplexers.selected_majorana_fermion.SelectedMajoranaFermion, "qualtran.bloqs.multiplexers.unary_iteration_bloq.UnaryIterationGate": qualtran.bloqs.multiplexers.unary_iteration_bloq.UnaryIterationGate, @@ -370,9 +371,8 @@ "qualtran.bloqs.rotations.quantum_variable_rotation.QvrInterface": qualtran.bloqs.rotations.quantum_variable_rotation.QvrInterface, "qualtran.bloqs.rotations.quantum_variable_rotation.QvrPhaseGradient": qualtran.bloqs.rotations.quantum_variable_rotation.QvrPhaseGradient, "qualtran.bloqs.rotations.quantum_variable_rotation.QvrZPow": qualtran.bloqs.rotations.quantum_variable_rotation.QvrZPow, - "qualtran.bloqs.block_encoding.lcu_select_and_prepare.PrepareOracle": qualtran.bloqs.block_encoding.lcu_select_and_prepare.PrepareOracle, - "qualtran.bloqs.block_encoding.lcu_select_and_prepare.SelectOracle": qualtran.bloqs.block_encoding.lcu_select_and_prepare.SelectOracle, "qualtran.bloqs.state_preparation.prepare_uniform_superposition.PrepareUniformSuperposition": qualtran.bloqs.state_preparation.prepare_uniform_superposition.PrepareUniformSuperposition, + "qualtran.bloqs.state_preparation.prepare_base.PrepareOracle": qualtran.bloqs.state_preparation.prepare_base.PrepareOracle, "qualtran.bloqs.state_preparation.state_preparation_alias_sampling.StatePreparationAliasSampling": qualtran.bloqs.state_preparation.state_preparation_alias_sampling.StatePreparationAliasSampling, "qualtran.bloqs.state_preparation.state_preparation_alias_sampling.SparseStatePreparationAliasSampling": qualtran.bloqs.state_preparation.state_preparation_alias_sampling.SparseStatePreparationAliasSampling, "qualtran.bloqs.state_preparation.state_preparation_via_rotation.PRGAViaPhaseGradient": qualtran.bloqs.state_preparation.state_preparation_via_rotation.PRGAViaPhaseGradient, From 169d91d6eede068ba3012dfed2675395e2b31eed Mon Sep 17 00:00:00 2001 From: Matthew Harrigan Date: Tue, 23 Jul 2024 20:43:51 +0000 Subject: [PATCH 3/4] [surface code] MagicStateFactory and GateCounts refactor (#1154) * Factories and gate counts * [factories] Update THC notebook --- qualtran/resource_counting/_bloq_counts.py | 13 +- .../resource_counting/_bloq_counts_test.py | 6 +- qualtran/surface_code/__init__.py | 1 + qualtran/surface_code/algorithm_summary.py | 2 +- .../surface_code/algorithm_summary_test.py | 6 +- qualtran/surface_code/ccz2t_cost_model.py | 135 ++++++++++-------- .../surface_code/ccz2t_cost_model_test.py | 8 +- qualtran/surface_code/data_block.py | 8 +- qualtran/surface_code/fifteen_to_one.py | 67 ++++----- qualtran/surface_code/fifteen_to_one_test.py | 32 +++-- qualtran/surface_code/magic_state_factory.py | 43 ++++-- qualtran/surface_code/multi_factory.py | 28 ++-- qualtran/surface_code/thc_compilation.ipynb | 20 +-- qualtran/surface_code/ui.py | 72 ++++++---- 14 files changed, 256 insertions(+), 185 deletions(-) diff --git a/qualtran/resource_counting/_bloq_counts.py b/qualtran/resource_counting/_bloq_counts.py index 73aaf7b87..bcba4e918 100644 --- a/qualtran/resource_counting/_bloq_counts.py +++ b/qualtran/resource_counting/_bloq_counts.py @@ -118,8 +118,6 @@ class GateCounts: Specifically, this class holds counts for the number of `TGate` (and adjoint), `Toffoli`, `TwoBitCSwap`, `And`, clifford bloqs, single qubit rotations, and measurements. - In addition to this, the class holds a heuristic approximation for the depth of the - circuit `depth` which we compute as the depth of the call graph. """ t: int = 0 @@ -129,7 +127,6 @@ class GateCounts: clifford: int = 0 rotation: int = 0 measurement: int = 0 - depth: int = 0 def __add__(self, other): if not isinstance(other, GateCounts): @@ -143,7 +140,6 @@ def __add__(self, other): clifford=self.clifford + other.clifford, rotation=self.rotation + other.rotation, measurement=self.measurement + other.measurement, - depth=self.depth + other.depth, ) def __mul__(self, other): @@ -155,7 +151,6 @@ def __mul__(self, other): clifford=other * self.clifford, rotation=other * self.rotation, measurement=other * self.measurement, - depth=other * self.depth, ) def __rmul__(self, other): @@ -195,6 +190,11 @@ def total_t_count( + ts_per_rotation * self.rotation ) + def total_t_and_ccz_count(self, ts_per_rotation: int = 11) -> Dict[str, int]: + n_ccz = self.toffoli + self.cswap + self.and_bloq + n_t = self.t + ts_per_rotation * self.rotation + return {'n_t': n_t, 'n_ccz': n_ccz} + @frozen class QECGatesCost(CostKey[GateCounts]): @@ -236,12 +236,9 @@ def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts]) totals = GateCounts() callees = get_bloq_callee_counts(bloq) logger.info("Computing %s for %s from %d callee(s)", self, bloq, len(callees)) - depth = 0 for callee, n_times_called in callees: callee_cost = get_callee_cost(callee) totals += n_times_called * callee_cost - depth = max(depth, callee_cost.depth + 1) - totals = attrs.evolve(totals, depth=depth) return totals def zero(self) -> GateCounts: diff --git a/qualtran/resource_counting/_bloq_counts_test.py b/qualtran/resource_counting/_bloq_counts_test.py index 0cb0587b1..8e9cda963 100644 --- a/qualtran/resource_counting/_bloq_counts_test.py +++ b/qualtran/resource_counting/_bloq_counts_test.py @@ -58,7 +58,7 @@ def test_gate_counts(): def test_qec_gates_cost(): algo = make_example_costing_bloqs() gc = get_cost_value(algo, QECGatesCost()) - assert gc == GateCounts(toffoli=100, t=2 * 2 * 10, clifford=2 * 10, depth=2) + assert gc == GateCounts(toffoli=100, t=2 * 2 * 10, clifford=2 * 10) @pytest.mark.parametrize( @@ -78,12 +78,12 @@ def test_qec_gates_cost(): rotations.phase_gradient.PhaseGradientUnitary( bitsize=10, exponent=1, is_controlled=False, eps=1e-10 ), - GateCounts(rotation=10, depth=1), + GateCounts(rotation=10), ], # Recursive [ mcmt.MultiControlPauli(cvs=(1, 1, 1), target_gate=cirq.X), - GateCounts(and_bloq=2, depth=2, measurement=2, clifford=3), + GateCounts(and_bloq=2, measurement=2, clifford=3), ], ], ) diff --git a/qualtran/surface_code/__init__.py b/qualtran/surface_code/__init__.py index 776e357dd..4ae5de414 100644 --- a/qualtran/surface_code/__init__.py +++ b/qualtran/surface_code/__init__.py @@ -21,6 +21,7 @@ ) from qualtran.surface_code.data_block import ( CompactDataBlock, + DataBlock, FastDataBlock, IntermediateDataBlock, SimpleDataBlock, diff --git a/qualtran/surface_code/algorithm_summary.py b/qualtran/surface_code/algorithm_summary.py index 8d1628eb9..54f05f4d5 100644 --- a/qualtran/surface_code/algorithm_summary.py +++ b/qualtran/surface_code/algorithm_summary.py @@ -126,6 +126,6 @@ def from_bloq(bloq: 'Bloq') -> 'AlgorithmSummary': toffoli_gates=gate_count.toffoli + gate_count.and_bloq + gate_count.cswap, rotation_gates=gate_count.rotation, measurements=gate_count.measurement, - rotation_circuit_depth=gate_count.depth, + rotation_circuit_depth=gate_count.rotation, algorithm_qubits=float(get_cost_value(bloq, _QUBIT_COUNT)), ) diff --git a/qualtran/surface_code/algorithm_summary_test.py b/qualtran/surface_code/algorithm_summary_test.py index 807730fa2..f1ff8ab9e 100644 --- a/qualtran/surface_code/algorithm_summary_test.py +++ b/qualtran/surface_code/algorithm_summary_test.py @@ -103,18 +103,18 @@ def test_subtraction(): [mcmt.And(), AlgorithmSummary(algorithm_qubits=3, toffoli_gates=1)], [ basic_gates.ZPowGate(exponent=0.1, global_shift=0.0, eps=1e-11), - AlgorithmSummary(algorithm_qubits=1, rotation_gates=1), + AlgorithmSummary(algorithm_qubits=1, rotation_gates=1, rotation_circuit_depth=1), ], [ rotations.phase_gradient.PhaseGradientUnitary( bitsize=10, exponent=1, is_controlled=False, eps=1e-10 ), - AlgorithmSummary(algorithm_qubits=10, rotation_gates=10, rotation_circuit_depth=1), + AlgorithmSummary(algorithm_qubits=10, rotation_gates=10, rotation_circuit_depth=10), ], [ mcmt.MultiControlPauli(cvs=(1, 1, 1), target_gate=cirq.X), AlgorithmSummary( - algorithm_qubits=6, toffoli_gates=2, rotation_circuit_depth=2, measurements=2 + algorithm_qubits=6, toffoli_gates=2, rotation_circuit_depth=0, measurements=2 ), ], ], diff --git a/qualtran/surface_code/ccz2t_cost_model.py b/qualtran/surface_code/ccz2t_cost_model.py index b7530d068..233136401 100644 --- a/qualtran/surface_code/ccz2t_cost_model.py +++ b/qualtran/surface_code/ccz2t_cost_model.py @@ -13,16 +13,18 @@ # limitations under the License. import math -from typing import Callable, cast, Iterable, Iterator, Optional, Tuple +from typing import Callable, cast, Iterable, Iterator, Optional, Tuple, TYPE_CHECKING from attrs import frozen -import qualtran.surface_code.quantum_error_correction_scheme_summary as qec -from qualtran.surface_code.data_block import DataBlock, SimpleDataBlock -from qualtran.surface_code.magic_count import MagicCount -from qualtran.surface_code.magic_state_factory import MagicStateFactory -from qualtran.surface_code.multi_factory import MultiFactory -from qualtran.surface_code.physical_cost import PhysicalCost +from .data_block import DataBlock, SimpleDataBlock +from .magic_state_factory import MagicStateFactory +from .multi_factory import MultiFactory +from .physical_cost import PhysicalCost + +if TYPE_CHECKING: + from qualtran.resource_counting import GateCounts + from qualtran.surface_code import DataBlock, LogicalErrorModel @frozen @@ -40,21 +42,20 @@ class CCZ2TFactory(MagicStateFactory): distillation_l1_d: int = 15 distillation_l2_d: int = 31 - qec_scheme: qec.QuantumErrorCorrectionSchemeSummary = qec.FowlerSuperconductingQubits # ------------------------------------------------------------------------------- # ---- Level 0 --------- # ------------------------------------------------------------------------------- - def l0_state_injection_error(self, phys_err: float) -> float: + def l0_state_injection_error(self, error_model: 'LogicalErrorModel') -> float: """Error rate associated with the level-0 creation of a |T> state. By using the techniques of Ying Li (https://arxiv.org/abs/1410.7808), this can be done with approximately the same error rate as the underlying physical error rate. """ - return phys_err + return error_model.physical_error - def l0_topo_error_t_gate(self, phys_err: float) -> float: + def l0_topo_error_t_gate(self, error_model: 'LogicalErrorModel') -> float: """Topological error associated with level-0 distillation. For a level-1 code distance of `d1`, this construction uses a `d1/2` distance code @@ -63,74 +64,66 @@ def l0_topo_error_t_gate(self, phys_err: float) -> float: # The chance of a logical error occurring within a lattice surgery unit cell at # code distance d1*0.5. - topo_error_per_unit_cell = self.qec_scheme.logical_error_rate( - physical_error_rate=phys_err, code_distance=self.distillation_l1_d // 2 - ) + topo_error_per_unit_cell = error_model(code_distance=self.distillation_l1_d // 2) # It takes approximately 100 L0 unit cells to get the injected state where # it needs to be and perform the T gate. return 100 * topo_error_per_unit_cell - def l0_error(self, phys_err: float) -> float: + def l0_error(self, error_model: 'LogicalErrorModel') -> float: """Chance of failure of a T gate performed with an injected (level-0) T state. As a simplifying approximation here (and elsewhere) we assume different sources of error are independent, and we merely add the probabilities. """ - return self.l0_state_injection_error(phys_err) + self.l0_topo_error_t_gate(phys_err) + return self.l0_state_injection_error(error_model) + self.l0_topo_error_t_gate(error_model) # ------------------------------------------------------------------------------- # ---- Level 1 --------- # ------------------------------------------------------------------------------- - def l1_topo_error_factory(self, phys_err: float) -> float: + def l1_topo_error_factory(self, error_model: 'LogicalErrorModel') -> float: """Topological error associated with a L1 T factory.""" # The L1 T factory uses approximately 1000 L1 unit cells. - return 1000 * self.qec_scheme.logical_error_rate( - physical_error_rate=phys_err, code_distance=self.distillation_l1_d - ) + return 1000 * error_model(code_distance=self.distillation_l1_d) - def l1_topo_error_t_gate(self, phys_err: float) -> float: + def l1_topo_error_t_gate(self, error_model: 'LogicalErrorModel') -> float: # It takes approximately 100 L1 unit cells to get the L1 state produced by the # factory to where it needs to be and perform the T gate. - return 100 * self.qec_scheme.logical_error_rate( - physical_error_rate=phys_err, code_distance=self.distillation_l1_d - ) + return 100 * error_model(code_distance=self.distillation_l1_d) - def l1_distillation_error(self, phys_err: float) -> float: + def l1_distillation_error(self, error_model: 'LogicalErrorModel') -> float: """The error due to level-0 faulty T states making it through distillation undetected. The level 1 distillation procedure detects any two errors. There are 35 weight-three errors that can make it through undetected. """ - return 35 * self.l0_error(phys_err) ** 3 + return 35 * self.l0_error(error_model) ** 3 - def l1_error(self, phys_err: float) -> float: + def l1_error(self, error_model: 'LogicalErrorModel') -> float: """Chance of failure of a T gate performed with a T state produced from the L1 factory.""" return ( - self.l1_topo_error_factory(phys_err) - + self.l1_topo_error_t_gate(phys_err) - + self.l1_distillation_error(phys_err) + self.l1_topo_error_factory(error_model) + + self.l1_topo_error_t_gate(error_model) + + self.l1_distillation_error(error_model) ) # ------------------------------------------------------------------------------- # ---- Level 2 --------- # ------------------------------------------------------------------------------- - def l2_error(self, phys_err: float) -> float: + def l2_error(self, error_model: 'LogicalErrorModel') -> float: """Chance of failure of the level two factory. This is the chance of failure of a CCZ gate or a pair of T gates performed with a CCZ state. """ # The L2 CCZ factory and catalyzed T factory both use approximately 1000 L2 unit cells. - l2_topo_error_factory = 1000 * self.qec_scheme.logical_error_rate( - physical_error_rate=phys_err, code_distance=self.distillation_l2_d - ) + l2_topo_error_factory = 1000 * error_model(self.distillation_l2_d) # Distillation error for this level. - l2_distillation_error = 28 * self.l1_error(phys_err) ** 2 + l2_distillation_error = 28 * self.l1_error(error_model) ** 2 return l2_topo_error_factory + l2_distillation_error @@ -138,21 +131,27 @@ def l2_error(self, phys_err: float) -> float: # ---- Totals --------- # ------------------------------------------------------------------------------- - def footprint(self) -> int: + def n_physical_qubits(self) -> int: l1 = 4 * 8 * 2 * self.distillation_l1_d**2 l2 = 4 * 8 * 2 * self.distillation_l2_d**2 return 6 * l1 + l2 - def distillation_error(self, n_magic: MagicCount, phys_err: float) -> float: + def factory_error( + self, n_logical_gates: 'GateCounts', logical_error_model: 'LogicalErrorModel' + ) -> float: """Error resulting from the magic state distillation part of the computation.""" - n_ccz_states = n_magic.n_ccz + math.ceil(n_magic.n_t / 2) - return self.l2_error(phys_err) * n_ccz_states + counts = n_logical_gates.total_t_and_ccz_count() + total_ccz_states = counts['n_ccz'] + math.ceil(counts['n_t'] / 2) + return self.l2_error(logical_error_model) * total_ccz_states - def n_cycles(self, n_magic: MagicCount, phys_err: float = 1e-3) -> int: + def n_cycles( + self, n_logical_gates: 'GateCounts', logical_error_model: 'LogicalErrorModel' + ) -> int: """The number of error-correction cycles to distill enough magic states.""" distillation_d = max(2 * self.distillation_l1_d + 1, self.distillation_l2_d) - n_ccz_states = n_magic.n_ccz + math.ceil(n_magic.n_t / 2) - catalyzations = math.ceil(n_magic.n_t / 2) + counts = n_logical_gates.total_t_and_ccz_count() + n_ccz_states = counts['n_ccz'] + math.ceil(counts['n_t'] / 2) + catalyzations = math.ceil(counts['n_t'] / 2) # Naive depth of 8.5, but can be overlapped to effective depth of 5.5 # See section 2, paragraph 2 of the reference. @@ -163,7 +162,7 @@ def n_cycles(self, n_magic: MagicCount, phys_err: float = 1e-3) -> int: def get_ccz2t_costs( *, - n_magic: MagicCount, + n_logical_gates: 'GateCounts', n_algo_qubits: int, phys_err: float, cycle_time_us: float, @@ -175,27 +174,33 @@ def get_ccz2t_costs( Note that this function can return failure probabilities larger than 1. Args: - n_magic: The number of magic states (T, Toffoli) required to execute the algorithm + n_logical_gates: The number of algorithm logical gates. n_algo_qubits: Number of algorithm logical qubits. phys_err: The physical error rate of the device. cycle_time_us: The number of microseconds it takes to execute a surface code cycle. factory: magic state factory configuration. Used to evaluate distillation error and cost. data_block: data block configuration. Used to evaluate data error and footprint. """ - err_model = qec.LogicalErrorModel( - qec_scheme=qec.FowlerSuperconductingQubits, physical_error=phys_err + from qualtran.surface_code import FowlerSuperconductingQubits, LogicalErrorModel + + err_model = LogicalErrorModel(qec_scheme=FowlerSuperconductingQubits, physical_error=phys_err) + distillation_error = factory.factory_error( + n_logical_gates=n_logical_gates, logical_error_model=err_model + ) + n_generation_cycles = factory.n_cycles( + n_logical_gates=n_logical_gates, logical_error_model=err_model ) - distillation_error = factory.distillation_error(n_magic=n_magic, phys_err=phys_err) - n_generation_cycles = factory.n_cycles(n_magic=n_magic, phys_err=phys_err) n_consumption_cycles = data_block.n_cycles( - n_logical_gates=n_magic, logical_error_model=err_model + n_logical_gates=n_logical_gates, logical_error_model=err_model ) n_cycles = max(n_generation_cycles, n_consumption_cycles) data_error = data_block.data_error( n_algo_qubits=n_algo_qubits, n_cycles=int(n_cycles), logical_error_model=err_model ) failure_prob = distillation_error + data_error - footprint = factory.footprint() + data_block.n_physical_qubits(n_algo_qubits=n_algo_qubits) + footprint = factory.n_physical_qubits() + data_block.n_physical_qubits( + n_algo_qubits=n_algo_qubits + ) duration_hr = (cycle_time_us * n_cycles) / (1_000_000 * 60 * 60) return PhysicalCost(failure_prob=failure_prob, footprint=footprint, duration_hr=duration_hr) @@ -203,7 +208,7 @@ def get_ccz2t_costs( def get_ccz2t_costs_from_error_budget( *, - n_magic: MagicCount, + n_logical_gates: 'GateCounts', n_algo_qubits: int, phys_err: float = 1e-3, error_budget: float = 1e-2, @@ -215,7 +220,7 @@ def get_ccz2t_costs_from_error_budget( """Physical costs using the model from catalyzed CCZ to 2T paper. Args: - n_magic: The number of magic states (T, Toffoli) required to execute the algorithm + n_logical_gates: Number of algorithm logical gates. n_algo_qubits: Number of algorithm logical qubits. phys_err: The physical error rate of the device. This sets the suppression factor for increasing code distance. @@ -242,8 +247,13 @@ def get_ccz2t_costs_from_error_budget( if factory is None: factory = CCZ2TFactory() - distillation_error = factory.distillation_error(n_magic=n_magic, phys_err=phys_err) - n_cycles = factory.n_cycles(n_magic=n_magic, phys_err=phys_err) + from qualtran.surface_code import FowlerSuperconductingQubits, LogicalErrorModel + + err_model = LogicalErrorModel(qec_scheme=FowlerSuperconductingQubits, physical_error=phys_err) + distillation_error = factory.factory_error( + n_logical_gates=n_logical_gates, logical_error_model=err_model + ) + n_cycles = factory.n_cycles(n_logical_gates=n_logical_gates, logical_error_model=err_model) if data_block is None: # Use "left over" budget for data qubits. @@ -255,13 +265,13 @@ def get_ccz2t_costs_from_error_budget( n_logical_qubits = math.ceil((1 + routing_overhead) * n_algo_qubits) data_unit_cells = n_logical_qubits * n_cycles target_err_per_round = err_budget / data_unit_cells - data_d = qec.FowlerSuperconductingQubits.code_distance_from_budget( + data_d = FowlerSuperconductingQubits.code_distance_from_budget( physical_error_rate=phys_err, budget=target_err_per_round ) data_block = SimpleDataBlock(data_d=data_d, routing_overhead=routing_overhead) return get_ccz2t_costs( - n_magic=n_magic, + n_logical_gates=n_logical_gates, n_algo_qubits=n_algo_qubits, phys_err=phys_err, cycle_time_us=cycle_time_us, @@ -282,13 +292,14 @@ def iter_ccz2t_factories( automatically chosen as 2 + l1_distance, ensuring l2_distance > l1_distance. n_factories (int, optional): Number of factories to be used in parallel. """ + factory: Callable[[int, int], MagicStateFactory] if n_factories == 1: factory = CCZ2TFactory elif n_factories > 1: - def factory(distillation_l1_d, distillation_l2_d): # type: ignore[misc] + def factory(distillation_l1_d: int, distillation_l2_d: int) -> MagicStateFactory: base_factory = CCZ2TFactory( - distillation_l1_d=l1_distance, distillation_l2_d=l2_distance + distillation_l1_d=distillation_l1_d, distillation_l2_d=distillation_l2_d ) return MultiFactory(base_factory=base_factory, n_factories=n_factories) @@ -307,7 +318,7 @@ def iter_simple_data_blocks(d_start: int = 7, d_stop: int = 35): def get_ccz2t_costs_from_grid_search( *, - n_magic: MagicCount, + n_logical_gates: 'GateCounts', n_algo_qubits: int, phys_err: float = 1e-3, error_budget: float = 1e-2, @@ -316,10 +327,10 @@ def get_ccz2t_costs_from_grid_search( data_block_iter: Iterable[DataBlock] = tuple(iter_simple_data_blocks()), cost_function: Callable[[PhysicalCost], float] = (lambda pc: pc.qubit_hours), ) -> Tuple[PhysicalCost, MagicStateFactory, SimpleDataBlock]: - """Grid search over parameters to minimize space time volume. + """Grid search over parameters to minimize the space-time volume. Args: - n_magic: The number of magic states (T, Toffoli) required to execute the algorithm + n_logical_gates: Number of algorithm logical gates. n_algo_qubits: Number of algorithm logical qubits. phys_err: The physical error rate of the device. This sets the suppression factor for increasing code distance. @@ -343,7 +354,7 @@ def get_ccz2t_costs_from_grid_search( for factory in factory_iter: for data_block in data_block_iter: cost = get_ccz2t_costs( - n_magic=n_magic, + n_logical_gates=n_logical_gates, n_algo_qubits=n_algo_qubits, phys_err=phys_err, cycle_time_us=cycle_time_us, diff --git a/qualtran/surface_code/ccz2t_cost_model_test.py b/qualtran/surface_code/ccz2t_cost_model_test.py index b1782a6f3..c7365a286 100644 --- a/qualtran/surface_code/ccz2t_cost_model_test.py +++ b/qualtran/surface_code/ccz2t_cost_model_test.py @@ -14,19 +14,19 @@ import numpy as np +from qualtran.resource_counting import GateCounts from qualtran.surface_code.ccz2t_cost_model import ( CCZ2TFactory, get_ccz2t_costs_from_error_budget, get_ccz2t_costs_from_grid_search, iter_ccz2t_factories, ) -from qualtran.surface_code.magic_count import MagicCount from qualtran.surface_code.multi_factory import MultiFactory def test_vs_spreadsheet(): re = get_ccz2t_costs_from_error_budget( - n_magic=MagicCount(n_t=10**8, n_ccz=10**8), + n_logical_gates=GateCounts(t=10**8, toffoli=10**8), n_algo_qubits=100, error_budget=0.01, phys_err=1e-3, @@ -40,7 +40,7 @@ def test_vs_spreadsheet(): def test_grid_search_runs(): cost, factory, db = get_ccz2t_costs_from_grid_search( - n_magic=MagicCount(n_t=10**8, n_ccz=10**8), + n_logical_gates=GateCounts(t=10**8, toffoli=10**8), n_algo_qubits=100, phys_err=1e-3, error_budget=0.1, @@ -55,7 +55,7 @@ def test_grid_search_runs(): def test_grid_search_against_thc(): """test based on the parameters reported in section IV.C of Lee et al., PRXQuantum 2, 2021""" best_cost, best_factory, best_data_block = get_ccz2t_costs_from_grid_search( - n_magic=MagicCount(n_ccz=6665400000), + n_logical_gates=GateCounts(toffoli=6665400000), n_algo_qubits=696, error_budget=1e-2, phys_err=1e-3, diff --git a/qualtran/surface_code/data_block.py b/qualtran/surface_code/data_block.py index 1416e83fd..82a0bb662 100644 --- a/qualtran/surface_code/data_block.py +++ b/qualtran/surface_code/data_block.py @@ -20,11 +20,7 @@ if TYPE_CHECKING: from qualtran.resource_counting import GateCounts - from qualtran.surface_code import ( - LogicalErrorModel, - MagicCount, - QuantumErrorCorrectionSchemeSummary, - ) + from qualtran.surface_code import LogicalErrorModel, QuantumErrorCorrectionSchemeSummary class DataBlock(metaclass=abc.ABCMeta): @@ -87,7 +83,7 @@ def n_steps_to_consume_a_magic_state(self): """ def n_cycles( - self, n_logical_gates: 'MagicCount', logical_error_model: 'LogicalErrorModel' + self, n_logical_gates: 'GateCounts', logical_error_model: 'LogicalErrorModel' ) -> int: """The number of surface code cycles to apply the number of gates to the data block. diff --git a/qualtran/surface_code/fifteen_to_one.py b/qualtran/surface_code/fifteen_to_one.py index a0a2ad521..7dc4860d1 100644 --- a/qualtran/surface_code/fifteen_to_one.py +++ b/qualtran/surface_code/fifteen_to_one.py @@ -12,79 +12,80 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import lru_cache -from typing import Optional +from typing import TYPE_CHECKING import cirq import numpy as np from attrs import frozen -from qualtran.surface_code.magic_count import MagicCount from qualtran.surface_code.magic_state_factory import MagicStateFactory -from qualtran.surface_code.quantum_error_correction_scheme_summary import ( - FowlerSuperconductingQubits, - QuantumErrorCorrectionSchemeSummary, -) -from qualtran.surface_code.reference import Reference from qualtran.surface_code.t_factory_utils import NoisyPauliRotation, storage_error +if TYPE_CHECKING: + from qualtran.resource_counting import GateCounts + from qualtran.surface_code import LogicalErrorModel + @frozen class FifteenToOne(MagicStateFactory): """15-to-1 Magic T state factory. reference: - [Magic State Distillation: Not as Costly as You Think] https://arxiv.org/abs/1905.06903 + [Magic State Distillation: Not as Costly as You Think](https://arxiv.org/abs/1905.06903). Attributes: d_X: Side length of the surface code along which X measurements happen. d_Z: Side length of the surface code along which Z measurements happen. d_m: Number of code cycles used in lattice surgery. qec: Quantum error correction scheme being used. - reference: A description of the source of the factory. """ d_X: int d_Z: int d_m: int - qec: QuantumErrorCorrectionSchemeSummary = FowlerSuperconductingQubits - reference: Optional[Reference] = None def __attrs_post_init__(self): assert 0 < self.d_X <= 3 * self.d_m assert self.d_m > 0 assert self.d_Z > 0 - def footprint(self) -> int: + def n_physical_qubits(self) -> int: # source: page 11 of https://arxiv.org/abs/1905.06903 return 2 * (self.d_X + 4 * self.d_Z) * 3 * self.d_X + 4 * self.d_m @lru_cache(8) - def _final_state(self, phys_err: float): - factory = _build_factory(phys_err, self.d_X, self.d_Z, self.d_m, self.qec) + def _final_state(self, logi_err_model: 'LogicalErrorModel'): + factory = _build_factory( + d_X=self.d_X, d_Z=self.d_Z, d_m=self.d_m, logical_error_model=logi_err_model + ) return ( cirq.DensityMatrixSimulator(dtype=np.complex128).simulate(factory).final_density_matrix ) @lru_cache(8) - def p_fail(self, phys_err: float) -> float: + def p_fail(self, logical_error_model: 'LogicalErrorModel') -> float: projector = np.kron(np.eye(2), np.ones((16, 16)) / 16) - return np.real_if_close(1 - np.trace(projector @ self._final_state(phys_err))).item() + return np.real_if_close( + 1 - np.trace(projector @ self._final_state(logical_error_model)) + ).item() @lru_cache(8) - def p_out(self, phys_err: float) -> float: + def p_out(self, logical_error_model: 'LogicalErrorModel') -> float: # I \otimes ones \otimes ones \otimes ones \otimes ones / 16 projector = np.kron(np.eye(2), np.ones((16, 16)) / 16) project_state = ( 1 - / (1 - self.p_fail(phys_err)) - * (projector @ self._final_state(phys_err) @ projector.T.conj()) + / (1 - self.p_fail(logical_error_model)) + * (projector @ self._final_state(logical_error_model) @ projector.T.conj()) ) # |T> int: + def n_cycles( + self, n_logical_gates: 'GateCounts', logical_error_model: 'LogicalErrorModel' + ) -> int: """The number of cycles (time) required to produce the requested number of magic states. Unlike the same method for other factories. This method reports the *expected* number of cycles @@ -92,17 +93,19 @@ def n_cycles(self, n_magic: MagicCount, phys_err: float) -> int: reference: page 11 of https://arxiv.org/abs/1905.06903 """ - num_t = n_magic.n_t + 4 * n_magic.n_ccz - return np.ceil(num_t * 6 * self.d_m / (1 - self.p_fail(phys_err))) + num_t = n_logical_gates.total_t_count() + return np.ceil(num_t * 6 * self.d_m / (1 - self.p_fail(logical_error_model))) - def distillation_error(self, n_magic: MagicCount, phys_err: float) -> float: + def factory_error( + self, n_logical_gates: 'GateCounts', logical_error_model: 'LogicalErrorModel' + ) -> float: """The total error expected from distilling magic states with a given physical error rate.""" - num_t = n_magic.n_t + 4 * n_magic.n_ccz - return self.p_out(phys_err) * num_t + num_t = n_logical_gates.total_t_count() + return self.p_out(logical_error_model) * num_t def _build_factory( - phys_err: float, d_X: int, d_Z: int, d_m: int, qec: QuantumErrorCorrectionSchemeSummary + *, d_X: int, d_Z: int, d_m: int, logical_error_model: 'LogicalErrorModel' ) -> cirq.Circuit: """Builds the 15-to-1 factory with its associated cost model. @@ -123,9 +126,10 @@ def _build_factory( The factory as a cirq circuit. """ qs = cirq.LineQubit.range(5) - px = qec.logical_error_rate(d_X, phys_err) - pz = qec.logical_error_rate(d_Z, phys_err) - pm = qec.logical_error_rate(d_m, phys_err) + px = logical_error_model(d_X) + pz = logical_error_model(d_Z) + pm = logical_error_model(d_m) + phys_err = logical_error_model.physical_error factory = cirq.Circuit.from_moments( cirq.H.on_each(qs), @@ -380,6 +384,5 @@ def _build_factory( return factory -FifteenToOne733 = FifteenToOne(7, 3, 3, reference=Reference(url='https://arxiv.org/abs/1905.06903')) - -FifteenToOne933 = FifteenToOne(9, 3, 3, reference=Reference(url='https://arxiv.org/abs/1905.06903')) +FifteenToOne733 = FifteenToOne(7, 3, 3) +FifteenToOne933 = FifteenToOne(9, 3, 3) diff --git a/qualtran/surface_code/fifteen_to_one_test.py b/qualtran/surface_code/fifteen_to_one_test.py index 91b38436c..1775e4bff 100644 --- a/qualtran/surface_code/fifteen_to_one_test.py +++ b/qualtran/surface_code/fifteen_to_one_test.py @@ -17,12 +17,13 @@ import pytest from attrs import frozen +from qualtran.resource_counting import GateCounts +from qualtran.surface_code import FowlerSuperconductingQubits, LogicalErrorModel from qualtran.surface_code.fifteen_to_one import FifteenToOne -from qualtran.surface_code.magic_count import MagicCount @frozen -class TestCase: +class FifteenToOneTestCase: d_X: int d_Z: int d_m: int @@ -34,19 +35,28 @@ class TestCase: PAPER_RESULTS = [ - TestCase(d_X=7, d_Z=3, d_m=3, phys_err=1e-4, p_out=4.4e-8, footprint=810, cycles=18.1), - TestCase(d_X=9, d_Z=3, d_m=3, phys_err=1e-4, p_out=9.3e-10, footprint=1150, cycles=18.1), - TestCase(d_X=11, d_Z=5, d_m=5, phys_err=1e-4, p_out=1.9e-11, footprint=2070, cycles=30), - TestCase(d_X=17, d_Z=7, d_m=7, phys_err=1e-3, p_out=4.5e-8, footprint=4620, cycles=42.6), + FifteenToOneTestCase( + d_X=7, d_Z=3, d_m=3, phys_err=1e-4, p_out=4.4e-8, footprint=810, cycles=18.1 + ), + FifteenToOneTestCase( + d_X=9, d_Z=3, d_m=3, phys_err=1e-4, p_out=9.3e-10, footprint=1150, cycles=18.1 + ), + FifteenToOneTestCase( + d_X=11, d_Z=5, d_m=5, phys_err=1e-4, p_out=1.9e-11, footprint=2070, cycles=30 + ), + FifteenToOneTestCase( + d_X=17, d_Z=7, d_m=7, phys_err=1e-3, p_out=4.5e-8, footprint=4620, cycles=42.6 + ), ] @pytest.mark.parametrize("test", PAPER_RESULTS) -def test_compare_with_paper(test: TestCase): - factory = FifteenToOne(test.d_X, test.d_Z, test.d_m) - assert f'{factory.distillation_error(MagicCount(n_t=1), test.phys_err):.1e}' == str(test.p_out) - assert round(factory.footprint(), -1) == test.footprint # rounding to the 10s digit. - assert factory.n_cycles(MagicCount(n_t=1), test.phys_err) == math.ceil(test.cycles + 1e-9) +def test_compare_with_paper(test: FifteenToOneTestCase): + factory = FifteenToOne(d_X=test.d_X, d_Z=test.d_Z, d_m=test.d_m) + lem = LogicalErrorModel(qec_scheme=FowlerSuperconductingQubits, physical_error=test.phys_err) + assert f'{factory.factory_error(GateCounts(t=1), lem):.1e}' == str(test.p_out) + assert round(factory.n_physical_qubits(), -1) == test.footprint # rounding to the 10s digit. + assert factory.n_cycles(GateCounts(t=1), lem) == math.ceil(test.cycles + 1e-9) def test_validation(): diff --git a/qualtran/surface_code/magic_state_factory.py b/qualtran/surface_code/magic_state_factory.py index 03e227a7e..8f6020554 100644 --- a/qualtran/surface_code/magic_state_factory.py +++ b/qualtran/surface_code/magic_state_factory.py @@ -13,26 +13,49 @@ # limitations under the License. import abc +from typing import TYPE_CHECKING -from qualtran.surface_code.magic_count import MagicCount +if TYPE_CHECKING: + from qualtran.resource_counting import GateCounts + from qualtran.surface_code import LogicalErrorModel class MagicStateFactory(metaclass=abc.ABCMeta): - """A cost model for the magic state distillation factory of a surface code compilation. - - A surface code layout is segregated into qubits dedicated to magic state distillation - and storing the data being processed. The former area is called the magic state distillation - factory, and we provide its costs here. + """Methods for modeling the costs of the magic state factories of a surface code compilation. + + An important consideration for a surface code compilation is how to execute arbitrary gates + to run the desired algorithm. The surface code can execute Clifford gates in a fault-tolerant + manner. Non-Clifford gates like the T gate, Toffoli or CCZ gate, or non-Clifford rotation + gates require more expensive gadgets to implement. Executing a T or CCZ gate requires first + using the technique of state distillation in an area of the computation called a "magic state + factory" to distill a noisy T or CCZ state into a "magic state" of sufficiently low error. + Such quantum states can be used to enact the non-Clifford quantum gate through gate + teleportation. + + Magic state production is thought to be an important runtime and qubit-count bottleneck in + foreseeable fault-tolerant quantum computers. + + This abstract interface specifies that each magic state factory must report its required + number of physical qubits, the number of error correction cycles to produce enough magic + states to enact a given number of logical gates and an error model, and the expected error + associated with generating those magic states. """ @abc.abstractmethod - def footprint(self) -> int: + def n_physical_qubits(self) -> int: """The number of physical qubits used by the magic state factory.""" @abc.abstractmethod - def n_cycles(self, n_magic: MagicCount, phys_err: float) -> int: + def n_cycles( + self, n_logical_gates: 'GateCounts', logical_error_model: 'LogicalErrorModel' + ) -> int: """The number of cycles (time) required to produce the requested number of magic states.""" @abc.abstractmethod - def distillation_error(self, n_magic: MagicCount, phys_err: float) -> float: - """The total error expected from distilling magic states with a given physical error rate.""" + def factory_error( + self, n_logical_gates: 'GateCounts', logical_error_model: 'LogicalErrorModel' + ) -> float: + """The total error expected from distilling magic states with a given physical error rate. + + This includes the cumulative effects of data-processing errors and distillation failures. + """ diff --git a/qualtran/surface_code/multi_factory.py b/qualtran/surface_code/multi_factory.py index 2cec77c7e..3d57ab743 100644 --- a/qualtran/surface_code/multi_factory.py +++ b/qualtran/surface_code/multi_factory.py @@ -11,20 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math +from typing import TYPE_CHECKING -import numpy as np from attrs import frozen -from qualtran.surface_code.magic_count import MagicCount from qualtran.surface_code.magic_state_factory import MagicStateFactory +if TYPE_CHECKING: + from qualtran.resource_counting import GateCounts + from qualtran.surface_code import LogicalErrorModel + @frozen class MultiFactory(MagicStateFactory): """Overlay of MagicStateFactory representing multiple factories of the same kind. All quantities are derived by those of `base_factory`. `footprint` is multiplied by - `n_factories`, `n_cycles` is divided by `n_factoties`, and `distillation_error` is independent + `n_factories`, `n_cycles` is divided by `n_factories`, and `distillation_error` is independent on the number of factories. Args: @@ -35,11 +39,17 @@ class MultiFactory(MagicStateFactory): base_factory: MagicStateFactory n_factories: int - def footprint(self) -> int: - return self.base_factory.footprint() * self.n_factories + def n_physical_qubits(self) -> int: + return self.base_factory.n_physical_qubits() * self.n_factories - def n_cycles(self, n_magic: MagicCount, phys_err: float = 1e-3) -> int: - return np.ceil(self.base_factory.n_cycles(n_magic, phys_err) / self.n_factories) + def n_cycles( + self, n_logical_gates: 'GateCounts', logical_error_model: 'LogicalErrorModel' + ) -> int: + return math.ceil( + self.base_factory.n_cycles(n_logical_gates, logical_error_model) / self.n_factories + ) - def distillation_error(self, n_magic: MagicCount, phys_err: float) -> float: - return self.base_factory.distillation_error(n_magic, phys_err) + def factory_error( + self, n_logical_gates: 'GateCounts', logical_error_model: 'LogicalErrorModel' + ) -> float: + return self.base_factory.factory_error(n_logical_gates, logical_error_model) diff --git a/qualtran/surface_code/thc_compilation.ipynb b/qualtran/surface_code/thc_compilation.ipynb index 11d9075e6..b1e1a767b 100644 --- a/qualtran/surface_code/thc_compilation.ipynb +++ b/qualtran/surface_code/thc_compilation.ipynb @@ -23,17 +23,17 @@ "from qualtran.surface_code.ccz2t_cost_model import get_ccz2t_costs, CCZ2TFactory\n", "from qualtran.surface_code.multi_factory import MultiFactory\n", "from qualtran.surface_code.data_block import SimpleDataBlock\n", - "from qualtran.surface_code.magic_count import MagicCount\n", + "from qualtran.resource_counting import GateCounts\n", "\n", - "n_magic = MagicCount(n_ccz=6665400000) # pag. 26\n", - "n_data_qubits = 696 # Fig. 10 \n", + "n_logical_gates = GateCounts(toffoli=6665400000) # pag. 26\n", + "n_algo_qubits = 696 # Fig. 10 \n", "factory = MultiFactory(base_factory=CCZ2TFactory(distillation_l1_d=19, distillation_l2_d=31),\n", " n_factories=4)\n", "data_block = SimpleDataBlock(data_d=31, routing_overhead=0.5) \n", "\n", "cost = get_ccz2t_costs(\n", - " n_magic=n_magic,\n", - " n_algo_qubits=n_data_qubits,\n", + " n_logical_gates=n_logical_gates,\n", + " n_algo_qubits=n_algo_qubits,\n", " phys_err=1e-3,\n", " cycle_time_us=1,\n", " factory=factory,\n", @@ -74,8 +74,8 @@ " iter_ccz2t_factories\n", "\n", "best_cost, best_factory, best_data_block = get_ccz2t_costs_from_grid_search(\n", - " n_magic=n_magic,\n", - " n_algo_qubits=n_data_qubits,\n", + " n_logical_gates=n_logical_gates,\n", + " n_algo_qubits=n_algo_qubits,\n", " error_budget=1e-2,\n", " phys_err=err_model.physical_error,\n", " factory_iter=iter_ccz2t_factories(n_factories=4), # use 4 CCZ factories in parallel\n", @@ -89,9 +89,11 @@ "metadata": {}, "outputs": [], "source": [ - "distillation_error = best_factory.distillation_error(n_magic, phys_err=err_model.physical_error)\n", + "distillation_error = best_factory.factory_error(n_logical_gates, logical_error_model=err_model)\n", "data_error = best_data_block.data_error(\n", - " n_algo_qubits=n_data_qubits, n_cycles=best_factory.n_cycles(n_magic), logical_error_model=err_model\n", + " n_algo_qubits=n_algo_qubits, \n", + " n_cycles=best_factory.n_cycles(n_logical_gates, logical_error_model=err_model),\n", + " logical_error_model=err_model,\n", ")\n", "\n", "print(f\"distillation error: {distillation_error:.3%}\") # ref: 0.1% per 1e10 Toffolis\n", diff --git a/qualtran/surface_code/ui.py b/qualtran/surface_code/ui.py index 1fb258343..2372badc0 100644 --- a/qualtran/surface_code/ui.py +++ b/qualtran/surface_code/ui.py @@ -21,17 +21,22 @@ from dash import ALL, Dash, dcc, html, Input, Output from dash.exceptions import PreventUpdate -from qualtran.surface_code import ccz2t_cost_model, fifteen_to_one, magic_state_factory +from qualtran.resource_counting import GateCounts +from qualtran.surface_code import ( + AlgorithmSummary, + ccz2t_cost_model, + fifteen_to_one, + LogicalErrorModel, + magic_state_factory, +) from qualtran.surface_code import quantum_error_correction_scheme_summary as qecs from qualtran.surface_code import rotation_cost_model -from qualtran.surface_code.algorithm_summary import AlgorithmSummary from qualtran.surface_code.azure_cost_model import code_distance, minimum_time_steps from qualtran.surface_code.ccz2t_cost_model import ( get_ccz2t_costs_from_grid_search, iter_ccz2t_factories, ) from qualtran.surface_code.data_block import FastDataBlock -from qualtran.surface_code.magic_count import MagicCount from qualtran.surface_code.multi_factory import MultiFactory @@ -354,15 +359,15 @@ def create_qubit_pie_chart( physical_error_rate: float, error_budget: float, estimation_model: str, - algorithm: AlgorithmSummary, + algorithm: 'AlgorithmSummary', magic_factory: magic_state_factory.MagicStateFactory, magic_count: int, - needed_magic: MagicCount, + n_logical_gates: 'GateCounts', ) -> go.Figure: """Create a pie chart of the physical qubit utilization.""" if estimation_model == _GIDNEY_FOWLER_MODEL: res, factory, _ = get_ccz2t_costs_from_grid_search( - n_magic=needed_magic, + n_logical_gates=n_logical_gates, n_algo_qubits=int(algorithm.algorithm_qubits), phys_err=physical_error_rate, error_budget=error_budget, @@ -373,7 +378,10 @@ def create_qubit_pie_chart( 'logical qubits + routing overhead', 'Magic State Distillation', ] - memory_footprint['qubits'] = [res.footprint - factory.footprint(), factory.footprint()] + memory_footprint['qubits'] = [ + res.footprint - factory.n_physical_qubits(), + factory.n_physical_qubits(), + ] fig = px.pie( memory_footprint, values='qubits', names='source', title='Physical Qubit Utilization' ) @@ -388,7 +396,7 @@ def create_qubit_pie_chart( ] memory_footprint['qubits'] = [ FastDataBlock.get_n_tiles(int(algorithm.algorithm_qubits)), - multi_factory.footprint(), + multi_factory.n_physical_qubits(), ] fig = px.pie( memory_footprint, values='qubits', names='source', title='Physical Qubit Utilization' @@ -425,12 +433,12 @@ def create_runtime_plot( physical_error_rate: float, error_budget: float, estimation_model: str, - algorithm: AlgorithmSummary, + algorithm: 'AlgorithmSummary', qec: qecs.QuantumErrorCorrectionSchemeSummary, magic_factory: magic_state_factory.MagicStateFactory, magic_count: int, rotation_model: rotation_cost_model.RotationCostModel, - needed_magic: MagicCount, + n_logical_gates: 'GateCounts', ) -> Tuple[Dict[str, Any], go.Figure]: """Creates the runtime figure and decides whether to display it or not. @@ -442,7 +450,10 @@ def create_runtime_plot( c_min = minimum_time_steps( error_budget=error_budget, alg=algorithm, rotation_model=rotation_model ) - factory_cycles = factory.n_cycles(needed_magic, physical_error_rate) + err_model = LogicalErrorModel(qec_scheme=qec, physical_error=physical_error_rate) + factory_cycles = factory.n_cycles( + n_logical_gates=n_logical_gates, logical_error_model=err_model + ) min_num_factories = int(np.ceil(factory_cycles / c_min)) magic_counts = list( 1 + np.random.choice(min_num_factories, replace=False, size=min(min_num_factories, 5)) @@ -467,7 +478,7 @@ def create_runtime_plot( duration_name = f'Duration ({unit})' num_qubits = ( FastDataBlock.get_n_tiles(int(algorithm.algorithm_qubits)) - + factory.footprint() * magic_counts + + factory.n_physical_qubits() * magic_counts ) df = pd.DataFrame( { @@ -513,7 +524,9 @@ def update( magic_factory = _MAGIC_FACTORIES[magic_name] rotation_model = _ROTATION_MODELS[rotaion_model_name] needed_magic = algorithm.to_magic_count(rotation_model, error_budget / 3) + n_logical_gates = GateCounts(t=int(needed_magic.n_t), toffoli=int(needed_magic.n_ccz)) magic_count = int(magic_count) + logical_err_model = LogicalErrorModel(qec_scheme=qec, physical_error=physical_error_rate) return ( create_qubit_pie_chart( physical_error_rate, @@ -522,7 +535,7 @@ def update( algorithm, magic_factory, magic_count, - needed_magic, + n_logical_gates, ), *create_runtime_plot( physical_error_rate, @@ -533,17 +546,17 @@ def update( magic_factory, magic_count, rotation_model, - needed_magic, + n_logical_gates, ), - *total_magic(estimation_model, needed_magic), + *total_magic(estimation_model, n_logical_gates), *min_num_factories( - physical_error_rate, + logical_err_model, error_budget, estimation_model, algorithm, rotation_model, magic_factory, - needed_magic, + n_logical_gates, ), *compute_duration( physical_error_rate, @@ -552,14 +565,14 @@ def update( algorithm, rotation_model, magic_count, - needed_magic, + n_logical_gates, ), ) -def total_magic(estimation_model: str, needed_magic: MagicCount) -> Tuple[List[str], str]: +def total_magic(estimation_model: str, n_logical_gates: 'GateCounts') -> Tuple[List[str], str]: """Compute the number of magic states needed for the algorithm and their type.""" - total_t = needed_magic.n_t + 4 * needed_magic.n_ccz + total_t = n_logical_gates.total_t_count() total_ccz = total_t / 4 if estimation_model == _GIDNEY_FOWLER_MODEL: return ['Total Number of Toffoli gates'], f'{total_ccz:g}' @@ -568,13 +581,13 @@ def total_magic(estimation_model: str, needed_magic: MagicCount) -> Tuple[List[s def min_num_factories( - physical_error_rate, + logical_error_model: 'LogicalErrorModel', error_budget: float, estimation_model: str, - algorithm: AlgorithmSummary, + algorithm: 'AlgorithmSummary', rotation_model: rotation_cost_model.RotationCostModel, magic_factory: magic_state_factory.MagicStateFactory, - needed_magic: MagicCount, + n_logical_gates: 'GateCounts', ) -> Tuple[Dict[str, Any], int]: if estimation_model == _GIDNEY_FOWLER_MODEL: return {'display': 'none'}, 1 @@ -582,7 +595,12 @@ def min_num_factories( error_budget=error_budget, alg=algorithm, rotation_model=rotation_model ) return {'display': 'block'}, int( - np.ceil(magic_factory.n_cycles(needed_magic, physical_error_rate) / c_min) + np.ceil( + magic_factory.n_cycles( + n_logical_gates=n_logical_gates, logical_error_model=logical_error_model + ) + / c_min + ) ) @@ -590,10 +608,10 @@ def compute_duration( physical_error_rate: float, error_budget: float, estimation_model: str, - algorithm: AlgorithmSummary, + algorithm: 'AlgorithmSummary', rotation_model: rotation_cost_model.RotationCostModel, magic_count: int, - needed_magic: MagicCount, + n_logical_gates: 'GateCounts', ) -> Tuple[Dict[str, Any], str]: """Compute the duration of running the algorithm and whether to display the result or not. @@ -601,7 +619,7 @@ def compute_duration( """ if estimation_model == _GIDNEY_FOWLER_MODEL: res, _, _ = get_ccz2t_costs_from_grid_search( - n_magic=needed_magic, + n_logical_gates=n_logical_gates, n_algo_qubits=int(algorithm.algorithm_qubits), phys_err=physical_error_rate, error_budget=error_budget, From dd19667cc4df4f0fbec78b212e7ca1dd4df7468b Mon Sep 17 00:00:00 2001 From: Noureldin Date: Tue, 23 Jul 2024 15:07:58 -0700 Subject: [PATCH 4/4] Update classical action of addition gates and fix classical action bug in Join (#1174) * Update classical action of addition gates and fix classical action bug in Join * address comments * fix typo --------- Co-authored-by: Matthew Harrigan --- qualtran/bloqs/arithmetic/addition.py | 35 +++++++++++----------- qualtran/bloqs/arithmetic/addition_test.py | 8 +++-- qualtran/bloqs/arithmetic/negate.py | 2 +- qualtran/bloqs/bookkeeping/join.py | 6 +++- qualtran/simulation/classical_sim.py | 32 +++++++++++++++++++- qualtran/simulation/classical_sim_test.py | 31 +++++++++++++++++++ 6 files changed, 91 insertions(+), 23 deletions(-) diff --git a/qualtran/bloqs/arithmetic/addition.py b/qualtran/bloqs/arithmetic/addition.py index 300a7ae3b..f68b55be9 100644 --- a/qualtran/bloqs/arithmetic/addition.py +++ b/qualtran/bloqs/arithmetic/addition.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from functools import cached_property from typing import ( Dict, @@ -59,6 +58,7 @@ from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiControlX from qualtran.cirq_interop import decompose_from_cirq_style_method from qualtran.drawing import directional_text_box, Text +from qualtran.simulation.classical_sim import add_ints if TYPE_CHECKING: from qualtran.drawing import WireSymbol @@ -129,19 +129,10 @@ def on_classical_vals( ) -> Dict[str, 'ClassicalValT']: unsigned = isinstance(self.a_dtype, (QUInt, QMontgomeryUInt)) b_bitsize = self.b_dtype.bitsize - N = 2**b_bitsize - if unsigned: - return {'a': a, 'b': int((a + b) % N)} - - # Addition of signed integers can result in overflow. In most classical programming languages (e.g. C++) - # what happens when an overflow happens is left as an implementation detail for compiler designers. - # However for quantum subtraction the operation should be unitary and that means that the unitary of - # the bloq should be a permutation matrix. - # If we hold `a` constant then the valid range of values of `b` [-N/2, N/2) gets shifted forward or backwards - # by `a`. to keep the operation unitary overflowing values wrap around. this is the same as moving the range [0, N) - # by the same amount modulu $N$. that is add N/2 before addition and then remove it. - half_n = N >> 1 - return {'a': a, 'b': int(a + b + half_n) % N - half_n} + return { + 'a': a, + 'b': add_ints(int(a), int(b), num_bits=int(b_bitsize), is_signed=not unsigned), + } def _circuit_diagram_info_(self, _) -> cirq.CircuitDiagramInfo: wire_symbols = ["In(x)"] * int(self.a_dtype.bitsize) @@ -302,7 +293,13 @@ def adjoint(self) -> 'OutOfPlaceAdder': def on_classical_vals( self, *, a: 'ClassicalValT', b: 'ClassicalValT' ) -> Dict[str, 'ClassicalValT']: - return {'a': a, 'b': b, 'c': a + b} + if isinstance(self.bitsize, sympy.Expr): + raise ValueError(f'Classical simulation is not support for symbolic bloq {self}') + return { + 'a': a, + 'b': b, + 'c': add_ints(int(a), int(b), num_bits=self.bitsize + 1, is_signed=False), + } def with_registers(self, *new_registers: Union[int, Sequence[int]]): raise NotImplementedError("no need to implement with_registers.") @@ -421,14 +418,18 @@ def signature(self) -> 'Signature': def on_classical_vals( self, x: 'ClassicalValT', **vals: 'ClassicalValT' ) -> Dict[str, 'ClassicalValT']: + if isinstance(self.k, sympy.Expr) or isinstance(self.bitsize, sympy.Expr): + raise ValueError(f"Classical simulation isn't supported for symbolic block {self}") N = 2**self.bitsize if len(self.cvs) > 0: ctrls = vals['ctrls'] else: - return {'x': int(math.fmod(x + self.k, N))} + return { + 'x': add_ints(int(x), int(self.k), num_bits=self.bitsize, is_signed=self.signed) + } if np.all(self.cvs == ctrls): - x = int(math.fmod(x + self.k, N)) + x = add_ints(int(x), int(self.k), num_bits=self.bitsize, is_signed=self.signed) return {'ctrls': ctrls, 'x': x} diff --git a/qualtran/bloqs/arithmetic/addition_test.py b/qualtran/bloqs/arithmetic/addition_test.py index 1337dccc2..85ad18e02 100644 --- a/qualtran/bloqs/arithmetic/addition_test.py +++ b/qualtran/bloqs/arithmetic/addition_test.py @@ -247,10 +247,12 @@ def test_add_classical(): def test_out_of_place_adder(): basis_map = {} gate = OutOfPlaceAdder(bitsize=3) + cbloq = gate.decompose_bloq() for x in range(2**3): for y in range(2**3): basis_map[int(f'0b_{x:03b}_{y:03b}_0000', 2)] = int(f'0b_{x:03b}_{y:03b}_{x+y:04b}', 2) assert gate.call_classically(a=x, b=y, c=0) == (x, y, x + y) + assert cbloq.call_classically(a=x, b=y, c=0) == (x, y, x + y) op = GateHelper(gate).operation op_inv = cirq.inverse(op) cirq.testing.assert_equivalent_computational_basis_map(basis_map, cirq.Circuit(op)) @@ -323,9 +325,9 @@ def test_classical_add_signed_overflow(bitsize): assert bloq.call_classically(a=mx, b=mx) == (mx, -2) -# TODO: write tests for signed integer addition (subtraction) -# https://github.com/quantumlib/Qualtran/issues/606 -@pytest.mark.parametrize('bitsize,k,x,cvs,ctrls,result', [(5, 2, 0, (1, 0), (1, 0), 2)]) +@pytest.mark.parametrize( + 'bitsize,k,x,cvs,ctrls,result', [(5, 2, 0, (1, 0), (1, 0), 2), (6, -3, 2, (), (), -1)] +) def test_classical_add_k_signed(bitsize, k, x, cvs, ctrls, result): bloq = AddK(bitsize=bitsize, k=k, cvs=cvs, signed=True) cbloq = bloq.decompose_bloq() diff --git a/qualtran/bloqs/arithmetic/negate.py b/qualtran/bloqs/arithmetic/negate.py index bfd36a63e..c3910fc1e 100644 --- a/qualtran/bloqs/arithmetic/negate.py +++ b/qualtran/bloqs/arithmetic/negate.py @@ -61,7 +61,7 @@ def signature(self) -> 'Signature': def build_composite_bloq(self, bb: 'BloqBuilder', x: 'SoquetT') -> dict[str, 'SoquetT']: x = bb.add(BitwiseNot(self.dtype), x=x) # ~x - x = bb.add(AddK(self.dtype.num_qubits, k=1), x=x) # -x + x = bb.add(AddK(self.dtype.num_qubits, k=1, signed=isinstance(self.dtype, QInt)), x=x) # -x return {'x': x} diff --git a/qualtran/bloqs/bookkeeping/join.py b/qualtran/bloqs/bookkeeping/join.py index 49d2674cb..e2680307c 100644 --- a/qualtran/bloqs/bookkeeping/join.py +++ b/qualtran/bloqs/bookkeeping/join.py @@ -26,6 +26,7 @@ DecomposeTypeError, QBit, QDType, + QFxp, QUInt, Register, Side, @@ -95,7 +96,10 @@ def my_tensors( ] def on_classical_vals(self, reg: 'NDArray[np.uint]') -> Dict[str, int]: - return {'reg': bits_to_ints(reg)[0]} + if isinstance(self.dtype, QFxp): + # TODO(#1095): support QFxp in classical simulation + return {'reg': bits_to_ints(reg)[0]} + return {'reg': self.dtype.from_bits(reg.tolist())} def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol': if reg is None: diff --git a/qualtran/simulation/classical_sim.py b/qualtran/simulation/classical_sim.py index d62e016e3..63af89cca 100644 --- a/qualtran/simulation/classical_sim.py +++ b/qualtran/simulation/classical_sim.py @@ -14,7 +14,7 @@ """Functionality for the `Bloq.call_classically(...)` protocol.""" import itertools -from typing import Any, Dict, Iterable, List, Mapping, Sequence, Tuple, Type, Union +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union import networkx as nx import numpy as np @@ -265,3 +265,33 @@ def format_classical_truth_table( for invals, outvals in truth_table ] return '\n'.join([heading] + entries) + + +def add_ints(a: int, b: int, *, num_bits: Optional[int] = None, is_signed: bool = False) -> int: + r"""Performs addition modulo $2^\mathrm{num\_bits}$ of (un)signed in a reversible way. + + Addition of signed integers can result in an overflow. In most classical programming languages (e.g. C++) + what happens when an overflow happens is left as an implementation detail for compiler designers. However, + for quantum subtraction, the operation should be unitary and that means that the unitary of the bloq should + be a permutation matrix. + + If we hold `a` constant then the valid range of values of $b \in [-2^{\mathrm{num\_bits}-1}, 2^{\mathrm{num\_bits}-1})$ + gets shifted forward or backward by `a`. To keep the operation unitary overflowing values wrap around. This is the same + as moving the range $2^\mathrm{num\_bits}$ by the same amount modulo $2^\mathrm{num\_bits}$. That is add + $2^{\mathrm{num\_bits}-1})$ before addition modulo and then remove it. + + Args: + a: left operand of addition. + b: right operand of addition. + num_bits: optional num_bits. When specified addition is done in the interval [0, 2**num_bits) or + [-2**(num_bits-1), 2**(num_bits-1)) based on the value of `is_signed`. + is_signed: boolean whether the numbers are unsigned or signed ints. This value is only used when + `num_bits` is provided. + """ + c = a + b + if num_bits is not None: + N = 2**num_bits + if is_signed: + return (c + N // 2) % N - N // 2 + return c % N + return c diff --git a/qualtran/simulation/classical_sim_test.py b/qualtran/simulation/classical_sim_test.py index bd56280b7..398d640e3 100644 --- a/qualtran/simulation/classical_sim_test.py +++ b/qualtran/simulation/classical_sim_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools from typing import Dict import cirq @@ -24,6 +25,7 @@ from qualtran.bloqs.basic_gates import CNOT from qualtran.simulation.classical_sim import ( _update_assign_from_vals, + add_ints, bits_to_ints, call_cbloq_classically, ClassicalValT, @@ -168,6 +170,35 @@ def test_apply_classical_cbloq(): np.testing.assert_array_equal(z, xarr) +@pytest.mark.parametrize( + ['x', 'y', 'n_bits'], + [ + (x, y, n_bits) + for n_bits in range(1, 5) + for x, y in itertools.product(range(1 << n_bits), repeat=2) + ], +) +def test_add_ints_unsigned(x, y, n_bits): + assert add_ints(x, y, num_bits=n_bits, is_signed=False) == (x + y) % (1 << n_bits) + + +@pytest.mark.parametrize( + ['x', 'y', 'n_bits'], + [ + (x, y, n_bits) + for n_bits in range(2, 5) + for x, y in itertools.product(range(-(2 ** (n_bits - 1)), 2 ** (n_bits - 1)), repeat=2) + ], +) +def test_add_ints_signed(x, y, n_bits): + half_n = 1 << (n_bits - 1) + # Addition of signed ints `x` and `y` is a cyclic rotation of the interval [-2^(n-1), 2^(n-1)) by `y`. + interval = [*range(-(2 ** (n_bits - 1)), 2 ** (n_bits - 1))] + i = x + half_n # position of `x` in the interval + z = interval[(i + y) % len(interval)] # rotate by `y` + assert add_ints(x, y, num_bits=n_bits, is_signed=True) == z + + @pytest.mark.notebook def test_notebook(): execute_notebook('classical_sim')