Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Product block encoding decomposition #1128

Merged
merged 6 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 63 additions & 73 deletions qualtran/bloqs/block_encoding/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import cached_property
from typing import cast, Dict, Tuple
from typing import cast, Dict, List, Tuple, Union

import cirq
from attrs import evolve, field, frozen, validators
Expand All @@ -34,6 +34,7 @@
from qualtran.bloqs.basic_gates.x_basis import XGate
from qualtran.bloqs.block_encoding import BlockEncoding
from qualtran.bloqs.block_encoding.lcu_select_and_prepare import PrepareOracle
from qualtran.bloqs.bookkeeping.auto_partition import AutoPartition, Unused
from qualtran.bloqs.bookkeeping.partition import Partition
from qualtran.bloqs.mcmt.multi_control_multi_target_pauli import MultiControlPauli
from qualtran.symbolics import is_symbolic, prod, smax, ssum, SymbolicFloat, SymbolicInt
Expand Down Expand Up @@ -111,7 +112,7 @@ def ancilla_bitsize(self) -> SymbolicInt:

@cached_property
def resource_bitsize(self) -> SymbolicInt:
return ssum(u.resource_bitsize for u in self.block_encodings)
return smax(u.resource_bitsize for u in self.block_encodings)

@cached_property
def epsilon(self) -> SymbolicFloat:
Expand Down Expand Up @@ -150,84 +151,73 @@ def build_composite_bloq(
and isinstance(self.ancilla_bitsize, int)
and isinstance(self.resource_bitsize, int)
)

n = len(self.block_encodings)
res_bits_used = 0
for i, u in enumerate(reversed(self.block_encodings)):
u_soqs = {"system": system}

# split ancilla register if necessary
if self.ancilla_bitsize > 0:
anc_bits = cast(int, u.ancilla_bitsize)
flag_bits = Register("flag_bits", dtype=QBit(), shape=(n - 1,)) # type: ignore
anc_used = Register("anc_used", dtype=QAny(anc_bits))
anc_unused_bits = self.ancilla_bitsize - (n - 1) - anc_bits
anc_unused = Register("anc_unused", dtype=QAny(anc_unused_bits))
anc_regs = [flag_bits]
if anc_bits > 0:
anc_regs.append(anc_used)
if anc_unused_bits > 0:
anc_regs.append(anc_unused)
anc_part = Partition(self.ancilla_bitsize, tuple(anc_regs))
anc_part_soqs = bb.add_d(anc_part, x=soqs["ancilla"])
if anc_bits > 0:
u_soqs["ancilla"] = anc_part_soqs["anc_used"]

# split resource register if necessary
res_bits = cast(int, u.resource_bitsize)
if res_bits > 0:
res_before = Register("res_before", dtype=QAny(res_bits_used))
res = Register("res", dtype=QAny(res_bits))
res_bits_left = self.resource_bitsize - res_bits_used - res_bits
res_after = Register("res_after", dtype=QAny(res_bits_left))
res_regs = []
if res_bits_used > 0:
res_regs.append(res_before)
res_regs.append(res)
res_bits_used += res_bits
if res_bits_left > 0:
res_regs.append(res_after)
res_part = Partition(self.resource_bitsize, tuple(res_regs))
res_part_soqs = bb.add_d(res_part, x=soqs["resource"])
u_soqs["resource"] = res_part_soqs["res"]

# connect the constituent bloq
u_out_soqs = bb.add_d(u, **u_soqs)
system = u_out_soqs["system"]

# un-partition the resource register
if res_bits > 0:
res_part_soqs["res"] = u_out_soqs["resource"]
soqs["resource"] = cast(
Soquet, bb.add(evolve(res_part, partition=False), **res_part_soqs)
)
if self.ancilla_bitsize > 0:
# partition ancilla into flag and inner ancilla
anc_regs = []
if n - 1 > 0:
anc_regs.append(Register("flag_bits", dtype=QBit(), shape=(n - 1,)))
anc_bits = self.ancilla_bitsize - (n - 1)
if anc_bits > 0:
anc_regs.append(Register("ancilla", dtype=QAny(anc_bits)))
anc_part = Partition(self.ancilla_bitsize, tuple(anc_regs))
anc_part_soqs = bb.add_d(anc_part, x=soqs.pop("ancilla"))
if n - 1 > 0:
flag_bits_soq = cast(NDArray, anc_part_soqs.pop("flag_bits"))
if anc_bits > 0:
anc_soq = anc_part_soqs.pop("ancilla")
if self.resource_bitsize > 0:
res_soq = soqs.pop("resource")

# un-partition the ancilla register
if self.ancilla_bitsize > 0:
flag_bits_soq = cast(NDArray, anc_part_soqs["flag_bits"])
if anc_bits > 0:
anc_used_soq = cast(Soquet, u_out_soqs["ancilla"])
if i == n - 1:
anc_part_soqs["anc_used"] = anc_used_soq
else:
# set corresponding flag if ancillas are all zero
ctrl, flag_bits_soq[i] = bb.add_t(
MultiControlPauli(tuple([0] * anc_bits), cirq.X),
controls=bb.split(anc_used_soq),
target=flag_bits_soq[i],
)
flag_bits_soq[i] = bb.add(XGate(), q=flag_bits_soq[i])
anc_part_soqs["anc_used"] = bb.join(cast(NDArray, ctrl))
anc_part_soqs["flag_bits"] = flag_bits_soq
soqs["ancilla"] = cast(
Soquet, bb.add(evolve(anc_part, partition=False), **anc_part_soqs)
# connect constituent bloqs
for i, u in enumerate(reversed(self.block_encodings)):
assert isinstance(u.ancilla_bitsize, int)
assert isinstance(u.resource_bitsize, int)
u_soqs = {"system": system}
partition: List[Tuple[Register, List[Union[str, Unused]]]] = [
(Register("system", dtype=QAny(u.system_bitsize)), ["system"])
]
if u.ancilla_bitsize > 0:
u_soqs["ancilla"] = anc_soq
regs: List[Union[str, Unused]] = ["ancilla"]
if anc_bits > u.ancilla_bitsize:
regs.append(Unused(anc_bits - u.ancilla_bitsize))
partition.append((Register("ancilla", dtype=QAny(anc_bits)), regs))
if u.resource_bitsize > 0:
u_soqs["resource"] = res_soq
regs = ["resource"]
if self.resource_bitsize > u.resource_bitsize:
regs.append(Unused(self.resource_bitsize - u.resource_bitsize))
partition.append((Register("resource", dtype=QAny(u.resource_bitsize)), regs))
u_out_soqs = bb.add_d(AutoPartition(u, partition, left_only=False), **u_soqs)
system = u_out_soqs.pop("system")
if u.ancilla_bitsize > 0:
anc_soq = u_out_soqs.pop("ancilla")
if u.resource_bitsize > 0:
res_soq = u_out_soqs.pop("resource")

# set corresponding flag if ancillas are all zero
if u.ancilla_bitsize > 0 and n - 1 > 0 and i != n - 1:
controls = bb.split(cast(Soquet, anc_soq))
controls[: u.ancilla_bitsize], flag_bits_soq[i] = bb.add_t(
MultiControlPauli(tuple([0] * u.ancilla_bitsize), cirq.X),
controls=controls[: u.ancilla_bitsize],
target=flag_bits_soq[i],
)
flag_bits_soq[i] = bb.add(XGate(), q=flag_bits_soq[i])
anc_soq = bb.join(controls)

out = {"system": system}
if self.ancilla_bitsize > 0:
out["ancilla"] = soqs["ancilla"]
if self.resource_bitsize > 0:
out["resource"] = soqs["resource"]
out["resource"] = res_soq
if self.ancilla_bitsize > 0:
anc_soqs: Dict[str, SoquetT] = dict()
if n - 1 > 0:
anc_soqs["flag_bits"] = flag_bits_soq
if anc_bits > 0:
anc_soqs["ancilla"] = anc_soq
out["ancilla"] = cast(Soquet, bb.add(evolve(anc_part, partition=False), **anc_soqs))
return out


Expand Down
25 changes: 16 additions & 9 deletions qualtran/bloqs/block_encoding/product_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,16 @@
import sympy

from qualtran import BloqBuilder, QAny, Register, Signature, Soquet
from qualtran.bloqs.basic_gates import CNOT, Hadamard, TGate, XGate, ZeroEffect, ZeroState
from qualtran.bloqs.basic_gates import (
CNOT,
Hadamard,
IntEffect,
IntState,
TGate,
XGate,
ZeroEffect,
ZeroState,
)
from qualtran.bloqs.block_encoding.product import (
_product_block_encoding,
_product_block_encoding_properties,
Expand All @@ -42,7 +51,7 @@ def test_product_signature():
[Register("system", QAny(1)), Register("ancilla", QAny(1))]
)
assert _product_block_encoding_properties().signature == Signature(
[Register("system", QAny(1)), Register("ancilla", QAny(3)), Register("resource", QAny(2))]
[Register("system", QAny(1)), Register("ancilla", QAny(3)), Register("resource", QAny(1))]
)
assert _product_block_encoding_symb().signature == Signature(
[
Expand Down Expand Up @@ -72,7 +81,7 @@ def test_product_params():
assert bloq.alpha == 0.5 * 0.5
assert bloq.epsilon == 0.5 * 0.01 + 0.5 * 0.1
assert bloq.ancilla_bitsize == max(2, 1) + 1
assert bloq.resource_bitsize == 1 + 1
assert bloq.resource_bitsize == max(1, 1)

bloq = _product_block_encoding_symb()
assert bloq.system_bitsize == 1
Expand Down Expand Up @@ -111,15 +120,13 @@ def test_product_single_tensors():
def test_product_properties_tensors():
bb = BloqBuilder()
system = bb.add_register("system", 1)
ancilla = bb.join(np.array([bb.add(ZeroState()), bb.add(ZeroState()), bb.add(ZeroState())]))
resource = bb.join(np.array([bb.add(ZeroState()), bb.add(ZeroState())]))
ancilla = cast(Soquet, bb.add(IntState(0, 3)))
resource = cast(Soquet, bb.add(ZeroState()))
system, ancilla, resource = bb.add_t(
_product_block_encoding_properties(), system=system, ancilla=ancilla, resource=resource
)
for q in bb.split(cast(Soquet, ancilla)):
bb.add(ZeroEffect(), q=q)
for q in bb.split(cast(Soquet, resource)):
bb.add(ZeroEffect(), q=q)
bb.add(ZeroEffect(), q=resource)
bb.add(IntEffect(0, 3), val=ancilla)
bloq = bb.finalize(system=system)

from_gate = np.matmul(TGate().tensor_contract(), Hadamard().tensor_contract())
Expand Down
Loading