Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor Simulation protocol 2: keep indices factorized #1070

Merged
merged 15 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qualtran/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
BloqBuilder,
DidNotFlattenAnythingError,
SoquetT,
ConnectionT,
)

from ._infra.data_types import (
Expand Down
60 changes: 30 additions & 30 deletions qualtran/_infra/bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Contains the main interface for defining `Bloq`s."""

import abc
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union

if TYPE_CHECKING:
import cirq
Expand All @@ -30,11 +30,10 @@
Adjoint,
BloqBuilder,
CompositeBloq,
ConnectionT,
CtrlSpec,
GateWithRegisters,
Register,
Signature,
Soquet,
SoquetT,
)
from qualtran.cirq_interop import CirqQuregT
Expand Down Expand Up @@ -246,38 +245,39 @@ def tensor_contract(self) -> 'NDArray':

return bloq_to_dense(self)

def add_my_tensors(
self,
tn: 'qtn.TensorNetwork',
tag: Any,
*,
incoming: Dict[str, 'SoquetT'],
outgoing: Dict[str, 'SoquetT'],
):
def my_tensors(
self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT']
) -> List['qtn.Tensor']:
"""Override this method to support native quimb simulation of this Bloq.

This method is responsible for adding a tensor corresponding to the unitary, state, or
effect of the bloq to the provided tensor network `tn`. Often, this method will add
one tensor for a given Bloq, but some bloqs can be represented in a factorized form
requiring the addition of more than one tensor.

If this method is not overriden, the default implementation will try to use the bloq's
decomposition to find a dense representation for this bloq.
This method is responsible for returning tensors corresponding to the unitary, state, or
effect of the bloq. Often, this method will return one tensor for a given Bloq, but
some bloqs can be represented in a factorized form using more than one tensor.

By default, calls to `Bloq.tensor_contract()` will first decompose and flatten the bloq
before initiating the conversion to a tensor network. This has two consequences:
1) Overriding this method is only necessary if this bloq does not define a decomposition
or if the fully-decomposed form contains a bloq that does not define its tensors.
2) Even if you override this method to provide custom tensors, they may not be used
(by default) because we prefer the flat-decomposed version. This is usually desirable
for contraction performance; but for finer-grained control see
`qualtran.simulation.tensor.cbloq_to_quimb`.

Quimb defines a connection between two tensors by a shared index. The returned tensors
from this method must use the Qualtran-Quimb index convention:
- Each tensor index is a tuple `(cxn, j)`
- The `cxn: qualtran.Connection` entry identifies the connection between bloq instances.
- The second integer `j` is the bit index within high-bitsize registers,
which is necessary due to technical restrictions.

Args:
tn: The tensor network to which we add our tensor(s)
tag: An arbitrary tag that must be forwarded to `qtn.Tensor`'s `tag` attribute.
incoming: A mapping from register name to SoquetT to order left indices for
the tensor network.
outgoing: A mapping from register name to SoquetT to order right indices for
the tensor network.
incoming: A mapping from register name to Connection (or an array thereof) to use as
left indices for the tensor network. The shape of the array matches the register's
shape.
outgoing: A mapping from register name to Connection (or an array thereof) to use as
right indices for the tensor network.
"""
from qualtran.simulation.tensor import cbloq_as_contracted_tensor

cbloq = self.decompose_bloq()
tn.add(
cbloq_as_contracted_tensor(cbloq, incoming, outgoing, tags=[self.pretty_name(), tag])
)
raise NotImplementedError(f"{self} does not support tensor simulation.")

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
"""Override this method to build the bloq call graph.
Expand Down
44 changes: 42 additions & 2 deletions qualtran/_infra/composite_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@
canonicalize and return `SoquetT`.
"""

_ConnectionType = TypeVar('_ConnectionType', bound=np.generic)

ConnectionT = Union[Connection, NDArray[_ConnectionType]]
mpharrigan marked this conversation as resolved.
Show resolved Hide resolved
"""A `Connection` or array of connections."""


def _to_tuple(x: Iterable[Connection]) -> Sequence[Connection]:
"""mypy-compatible attrs converter for CompositeBloq.connections"""
Expand Down Expand Up @@ -341,6 +346,9 @@ def flatten_once(
the bloqs have decompositions.

"""
if len(self.bloq_instances) == 0:
raise DidNotFlattenAnythingError()

bb, _ = BloqBuilder.from_signature(self.signature)

# We take particular care during flattening to preserve the `binst.i` of bloq instances
Expand Down Expand Up @@ -513,7 +521,7 @@ def _cxn_to_soq_dict(
get_me: A function that says which soquet is used to derive keys for the returned
dictionary. Generally: if `cxns` is predecessor connections, this will return the
`right` element of the connection and opposite of successor connections.
get_assign: A function that says which soquet is used to dervice the values for the
get_assign: A function that says which soquet is used to derive the values for the
returned dictionary. Generally, this is the opposite side vs. `get_me`, but we
do something fancier in `cbloq_to_quimb`.
"""
Expand All @@ -538,7 +546,39 @@ def _cxn_to_soq_dict(
return soqdict


def _get_dangling_soquets(signature: Signature, right=True) -> Dict[str, SoquetT]:
def _cxn_to_cxn_dict(
mpharrigan marked this conversation as resolved.
Show resolved Hide resolved
regs: Iterable[Register], cxns: Iterable[Connection], get_me: Callable[[Connection], Soquet]
) -> Dict[str, ConnectionT]:
"""Helper function to get a dictionary of connections keyed by register name.

Args:
regs: Left or right registers (used as a reference to initialize multidimensional
registers correctly).
cxns: Predecessor or successor connections from which we get the soquets of interest.
get_me: A function that says which soquet is used to derive keys for the returned
dictionary. Generally: if `cxns` is predecessor connections, this will return the
`right` element of the connection (opposite for successor connections).
"""
cxndict: Dict[str, ConnectionT] = {}

# Initialize multi-dimensional dictionary values.
for reg in regs:
if reg.shape:
cxndict[reg.name] = np.empty(reg.shape, dtype=object)

# In the abstract: set `soqdict[me] = assign`. Specifically: use the register name as
# keys and handle multi-dimensional registers.
for cxn in cxns:
me = get_me(cxn)
if me.reg.shape:
cxndict[me.reg.name][me.idx] = cxn # type: ignore[index]
else:
cxndict[me.reg.name] = cxn

return cxndict


def _get_dangling_soquets(signature: Signature, right: bool = True) -> Dict[str, SoquetT]:
"""Get instantiated dangling soquets from a `Signature`.

Args:
Expand Down
49 changes: 19 additions & 30 deletions qualtran/_infra/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
if TYPE_CHECKING:
import quimb.tensor as qtn

from qualtran import BloqBuilder, CompositeBloq, Soquet, SoquetT
from qualtran import BloqBuilder, CompositeBloq, ConnectionT, SoquetT
from qualtran.cirq_interop import CirqQuregT
from qualtran.drawing import WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
Expand Down Expand Up @@ -400,17 +400,7 @@ def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, 'ClassicalValT

return vals

def add_my_tensors(
self,
tn: 'qtn.TensorNetwork',
tag: Any,
*,
incoming: Dict[str, 'SoquetT'],
outgoing: Dict[str, 'SoquetT'],
):
import quimb.tensor as qtn

from qualtran._infra.composite_bloq import _flatten_soquet_collection
def _tensor_data(self):
from qualtran.simulation.tensor._tensor_data_manipulation import (
active_space_for_ctrl_spec,
eye_tensor_for_signature,
Expand All @@ -419,36 +409,35 @@ def add_my_tensors(

# Create an identity tensor corresponding to the signature of current Bloq
data = eye_tensor_for_signature(self.signature)
# Verify it has the right shape
in_ind = _flatten_soquet_collection(incoming[reg.name] for reg in self.signature.lefts())
out_ind = _flatten_soquet_collection(outgoing[reg.name] for reg in self.signature.rights())
assert data.shape == tuple(2**soq.reg.bitsize for ind in [out_ind, in_ind] for soq in ind)
# Figure out the ctrl indexes for which the ctrl is "active"
active_idx = active_space_for_ctrl_spec(self.signature, self.ctrl_spec)
# Put the subbloq tensor at indices where ctrl is active.
subbloq_shape = tensor_shape_from_signature(self.subbloq.signature)
data[active_idx] = self.subbloq.tensor_contract().reshape(subbloq_shape)
# Add the data to the tensor network.
tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.pretty_name(), tag]))
return data

def _unitary_(self):
if isinstance(self.subbloq, GateWithRegisters):
# subbloq is a cirq gate, use the cirq-style API to derive a unitary.
return cirq.unitary(
cirq.ControlledGate(self.subbloq, control_values=self.ctrl_spec.to_cirq_cv())
)
if all(reg.side == Side.THRU for reg in self.subbloq.signature):
# subbloq has only THRU registers, so the tensor contraction corresponds
# to a unitary matrix.
return self.tensor_contract()
# Unable to determine the unitary effect.
return NotImplemented
n = self.signature.n_qubits()
return self._tensor_data().reshape(2**n, 2**n)
mpharrigan marked this conversation as resolved.
Show resolved Hide resolved

def my_tensors(
self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT']
) -> List['qtn.Tensor']:
import quimb.tensor as qtn

from qualtran.simulation.tensor._dense import _order_incoming_outgoing_indices

inds = _order_incoming_outgoing_indices(
self.signature, incoming=incoming, outgoing=outgoing
)
data = self._tensor_data().reshape((2,) * len(inds))
return [qtn.Tensor(data=data, inds=inds, tags=[str(self)])]

def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
from qualtran.drawing import Text

if reg is None:
return Text(f'C[{self.subbloq.wire_symbol(reg=None)}]')
return Text(f'C[{self.subbloq}]')
if reg.name not in self.ctrl_reg_names:
# Delegate to subbloq
return self.subbloq.wire_symbol(reg, idx)
Expand Down
25 changes: 24 additions & 1 deletion qualtran/_infra/controlled_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
from qualtran._infra.gate_with_registers import get_named_qubits, merge_qubits
from qualtran.bloqs.basic_gates import (
CSwap,
GlobalPhase,
IntEffect,
IntState,
OneState,
Swap,
TwoBitCSwap,
XGate,
XPowGate,
YGate,
Expand All @@ -50,6 +52,7 @@
from qualtran.cirq_interop.testing import GateHelper
from qualtran.drawing import get_musical_score_data
from qualtran.drawing.musical_score import Circle, SoqData, TextBox
from qualtran.simulation.tensor import cbloq_to_quimb, get_right_and_left_inds

if TYPE_CHECKING:
from qualtran import SoquetT
Expand Down Expand Up @@ -340,7 +343,7 @@ def test_notebook():
def _verify_ctrl_tensor_for_unitary(ctrl_spec: CtrlSpec, bloq: Bloq, gate: cirq.Gate):
cbloq = Controlled(bloq, ctrl_spec)
cgate = cirq.ControlledGate(gate, control_values=ctrl_spec.to_cirq_cv())
np.testing.assert_array_equal(cbloq.tensor_contract(), cirq.unitary(cgate))
np.testing.assert_allclose(cbloq.tensor_contract(), cirq.unitary(cgate), atol=1e-8)


interesting_ctrl_specs = [
Expand All @@ -362,6 +365,26 @@ def test_controlled_tensor_for_unitary(ctrl_spec: CtrlSpec):
_verify_ctrl_tensor_for_unitary(ctrl_spec, CSwap(3), CSwap(3))


def test_controlled_tensor_without_decompose():
ctrl_spec = CtrlSpec()
bloq = TwoBitCSwap()
cbloq = Controlled(bloq, ctrl_spec)
cgate = cirq.ControlledGate(cirq.CSWAP, control_values=ctrl_spec.to_cirq_cv())

tn = cbloq_to_quimb(cbloq.as_composite_bloq())
# pylint: disable=unbalanced-tuple-unpacking
right, left = get_right_and_left_inds(tn, cbloq.signature)
# pylint: enable=unbalanced-tuple-unpacking
np.testing.assert_allclose(tn.to_dense(right, left), cirq.unitary(cgate), atol=1e-8)
np.testing.assert_allclose(cbloq.tensor_contract(), cirq.unitary(cgate), atol=1e-8)


def test_controlled_global_phase_tensor():
bloq = GlobalPhase(1.0j).controlled()
should_be = np.diag([1, 1.0j])
np.testing.assert_allclose(bloq.tensor_contract(), should_be)


@attrs.frozen
class TestCtrlStatePrepAnd(Bloq):
"""Decomposes into a Controlled-AND gate + int effects & targets where ctrl is active.
Expand Down
22 changes: 7 additions & 15 deletions qualtran/_infra/gate_with_registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import abc
from typing import (
Any,
cast,
Collection,
Dict,
Expand All @@ -40,7 +39,7 @@
if TYPE_CHECKING:
import quimb.tensor as qtn

from qualtran import AddControlledT, BloqBuilder, CtrlSpec, SoquetT
from qualtran import AddControlledT, BloqBuilder, ConnectionT, CtrlSpec, SoquetT
from qualtran.cirq_interop import CirqQuregT
from qualtran.drawing import WireSymbol

Expand Down Expand Up @@ -507,22 +506,15 @@ def controlled(
def _unitary_(self):
return NotImplemented

def add_my_tensors(
self,
tn: 'qtn.TensorNetwork',
tag: 'Any',
*,
incoming: Dict[str, 'SoquetT'],
outgoing: Dict[str, 'SoquetT'],
):
def my_tensors(
self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT']
) -> List['qtn.Tensor']:
if not self._unitary_.__qualname__.startswith('GateWithRegisters.'):
from qualtran.cirq_interop._cirq_to_bloq import _add_my_tensors_from_gate
from qualtran.cirq_interop._cirq_to_bloq import _my_tensors_from_gate

_add_my_tensors_from_gate(
self, self.signature, str(self), tn, tag, incoming=incoming, outgoing=outgoing
)
return _my_tensors_from_gate(self, self.signature, incoming=incoming, outgoing=outgoing)
else:
return super().add_my_tensors(tn, tag, incoming=incoming, outgoing=outgoing)
return super().my_tensors(incoming=incoming, outgoing=outgoing)

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
"""Default diagram info that uses register names to name the boxes in multi-qubit gates.
Expand Down
Loading
Loading