Skip to content

Commit

Permalink
Port usage to new tensor network protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
mpharrigan committed Jun 19, 2024
1 parent 07fe225 commit df07276
Show file tree
Hide file tree
Showing 46 changed files with 663 additions and 645 deletions.
49 changes: 19 additions & 30 deletions qualtran/_infra/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
if TYPE_CHECKING:
import quimb.tensor as qtn

from qualtran import BloqBuilder, CompositeBloq, Soquet, SoquetT
from qualtran import BloqBuilder, CompositeBloq, ConnectionT, SoquetT
from qualtran.cirq_interop import CirqQuregT
from qualtran.drawing import WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
Expand Down Expand Up @@ -400,17 +400,7 @@ def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, 'ClassicalValT

return vals

def add_my_tensors(
self,
tn: 'qtn.TensorNetwork',
tag: Any,
*,
incoming: Dict[str, 'SoquetT'],
outgoing: Dict[str, 'SoquetT'],
):
import quimb.tensor as qtn

from qualtran._infra.composite_bloq import _flatten_soquet_collection
def _tensor_data(self):
from qualtran.simulation.tensor._tensor_data_manipulation import (
active_space_for_ctrl_spec,
eye_tensor_for_signature,
Expand All @@ -419,36 +409,35 @@ def add_my_tensors(

# Create an identity tensor corresponding to the signature of current Bloq
data = eye_tensor_for_signature(self.signature)
# Verify it has the right shape
in_ind = _flatten_soquet_collection(incoming[reg.name] for reg in self.signature.lefts())
out_ind = _flatten_soquet_collection(outgoing[reg.name] for reg in self.signature.rights())
assert data.shape == tuple(2**soq.reg.bitsize for ind in [out_ind, in_ind] for soq in ind)
# Figure out the ctrl indexes for which the ctrl is "active"
active_idx = active_space_for_ctrl_spec(self.signature, self.ctrl_spec)
# Put the subbloq tensor at indices where ctrl is active.
subbloq_shape = tensor_shape_from_signature(self.subbloq.signature)
data[active_idx] = self.subbloq.tensor_contract().reshape(subbloq_shape)
# Add the data to the tensor network.
tn.add(qtn.Tensor(data=data, inds=out_ind + in_ind, tags=[self.pretty_name(), tag]))
return data

def _unitary_(self):
if isinstance(self.subbloq, GateWithRegisters):
# subbloq is a cirq gate, use the cirq-style API to derive a unitary.
return cirq.unitary(
cirq.ControlledGate(self.subbloq, control_values=self.ctrl_spec.to_cirq_cv())
)
if all(reg.side == Side.THRU for reg in self.subbloq.signature):
# subbloq has only THRU registers, so the tensor contraction corresponds
# to a unitary matrix.
return self.tensor_contract()
# Unable to determine the unitary effect.
return NotImplemented
n = self.signature.n_qubits()
return self._tensor_data().reshape(2**n, 2**n)

def my_tensors(
self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT']
) -> List['qtn.Tensor']:
import quimb.tensor as qtn

from qualtran.simulation.tensor._dense import _order_incoming_outgoing_indices

inds = _order_incoming_outgoing_indices(
self.signature, incoming=incoming, outgoing=outgoing
)
data = self._tensor_data().reshape((2,) * len(inds))
return [qtn.Tensor(data=data, inds=inds, tags=[str(self)])]

def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
from qualtran.drawing import Text

if reg is None:
return Text(f'C[{self.subbloq.wire_symbol(reg=None)}]')
return Text(f'C[{self.subbloq}]')
if reg.name not in self.ctrl_reg_names:
# Delegate to subbloq
return self.subbloq.wire_symbol(reg, idx)
Expand Down
25 changes: 24 additions & 1 deletion qualtran/_infra/controlled_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
from qualtran._infra.gate_with_registers import get_named_qubits, merge_qubits
from qualtran.bloqs.basic_gates import (
CSwap,
GlobalPhase,
IntEffect,
IntState,
OneState,
Swap,
TwoBitCSwap,
XGate,
XPowGate,
YGate,
Expand All @@ -50,6 +52,7 @@
from qualtran.cirq_interop.testing import GateHelper
from qualtran.drawing import get_musical_score_data
from qualtran.drawing.musical_score import Circle, SoqData, TextBox
from qualtran.simulation.tensor import cbloq_to_quimb, get_right_and_left_inds

if TYPE_CHECKING:
from qualtran import SoquetT
Expand Down Expand Up @@ -340,7 +343,7 @@ def test_notebook():
def _verify_ctrl_tensor_for_unitary(ctrl_spec: CtrlSpec, bloq: Bloq, gate: cirq.Gate):
cbloq = Controlled(bloq, ctrl_spec)
cgate = cirq.ControlledGate(gate, control_values=ctrl_spec.to_cirq_cv())
np.testing.assert_array_equal(cbloq.tensor_contract(), cirq.unitary(cgate))
np.testing.assert_allclose(cbloq.tensor_contract(), cirq.unitary(cgate), atol=1e-8)


interesting_ctrl_specs = [
Expand All @@ -362,6 +365,26 @@ def test_controlled_tensor_for_unitary(ctrl_spec: CtrlSpec):
_verify_ctrl_tensor_for_unitary(ctrl_spec, CSwap(3), CSwap(3))


def test_controlled_tensor_without_decompose():
ctrl_spec = CtrlSpec()
bloq = TwoBitCSwap()
cbloq = Controlled(bloq, ctrl_spec)
cgate = cirq.ControlledGate(cirq.CSWAP, control_values=ctrl_spec.to_cirq_cv())

tn = cbloq_to_quimb(cbloq.as_composite_bloq())
# pylint: disable=unbalanced-tuple-unpacking
right, left = get_right_and_left_inds(tn, cbloq.signature)
# pylint: enable=unbalanced-tuple-unpacking
np.testing.assert_allclose(tn.to_dense(right, left), cirq.unitary(cgate), atol=1e-8)
np.testing.assert_allclose(cbloq.tensor_contract(), cirq.unitary(cgate), atol=1e-8)


def test_controlled_global_phase_tensor():
bloq = GlobalPhase(1.0j).controlled()
should_be = np.diag([1, 1.0j])
np.testing.assert_allclose(bloq.tensor_contract(), should_be)


@attrs.frozen
class TestCtrlStatePrepAnd(Bloq):
"""Decomposes into a Controlled-AND gate + int effects & targets where ctrl is active.
Expand Down
22 changes: 7 additions & 15 deletions qualtran/_infra/gate_with_registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import abc
from typing import (
Any,
cast,
Collection,
Dict,
Expand All @@ -40,7 +39,7 @@
if TYPE_CHECKING:
import quimb.tensor as qtn

from qualtran import AddControlledT, BloqBuilder, CtrlSpec, SoquetT
from qualtran import AddControlledT, BloqBuilder, ConnectionT, CtrlSpec, SoquetT
from qualtran.cirq_interop import CirqQuregT
from qualtran.drawing import WireSymbol

Expand Down Expand Up @@ -507,22 +506,15 @@ def controlled(
def _unitary_(self):
return NotImplemented

def add_my_tensors(
self,
tn: 'qtn.TensorNetwork',
tag: 'Any',
*,
incoming: Dict[str, 'SoquetT'],
outgoing: Dict[str, 'SoquetT'],
):
def my_tensors(
self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT']
) -> List['qtn.Tensor']:
if not self._unitary_.__qualname__.startswith('GateWithRegisters.'):
from qualtran.cirq_interop._cirq_to_bloq import _add_my_tensors_from_gate
from qualtran.cirq_interop._cirq_to_bloq import _my_tensors_from_gate

_add_my_tensors_from_gate(
self, self.signature, str(self), tn, tag, incoming=incoming, outgoing=outgoing
)
return _my_tensors_from_gate(self, self.signature, incoming=incoming, outgoing=outgoing)
else:
return super().add_my_tensors(tn, tag, incoming=incoming, outgoing=outgoing)
return super().my_tensors(incoming=incoming, outgoing=outgoing)

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
"""Default diagram info that uses register names to name the boxes in multi-qubit gates.
Expand Down
26 changes: 0 additions & 26 deletions qualtran/bloqs/arithmetic/addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import math
from functools import cached_property
from typing import (
Any,
Dict,
Iterable,
Iterator,
Expand Down Expand Up @@ -63,8 +61,6 @@
from qualtran.drawing import directional_text_box, Text

if TYPE_CHECKING:
import quimb.tensor as qtn

from qualtran.drawing import WireSymbol
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT
Expand Down Expand Up @@ -125,28 +121,6 @@ def dtype(self):
def signature(self):
return Signature([Register("a", self.a_dtype), Register("b", self.b_dtype)])

def add_my_tensors(
self,
tn: 'qtn.TensorNetwork',
tag: Any,
*,
incoming: Dict[str, 'SoquetT'],
outgoing: Dict[str, 'SoquetT'],
):
import quimb.tensor as qtn

if isinstance(self.a_dtype, QInt) or isinstance(self.b_dtype, QInt):
raise TypeError("Tensor contraction for addition is only supported for unsigned ints.")
N_a = 2**self.a_dtype.bitsize
N_b = 2**self.b_dtype.bitsize
inds = (incoming['a'], incoming['b'], outgoing['a'], outgoing['b'])
unitary = np.zeros((N_a, N_b, N_a, N_b), dtype=np.complex128)
# TODO: Add a value-to-index method on dtype to make this easier.
for a, b in itertools.product(range(N_a), range(N_b)):
unitary[a, b, a, int(math.fmod(a + b, N_b))] = 1

tn.add(qtn.Tensor(data=unitary, inds=inds, tags=[self.pretty_name(), tag]))

def decompose_bloq(self) -> 'CompositeBloq':
return decompose_from_cirq_style_method(self)

Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/arithmetic/addition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_addition_gate_counts(n: int):


@pytest.mark.parametrize('a,b', itertools.product(range(2**3), repeat=2))
def test_add_no_decompose(a, b):
def test_add_tensor_contract(a, b):
num_bits = 5
bloq = Add(QUInt(num_bits))

Expand All @@ -184,7 +184,7 @@ def test_add_no_decompose(a, b):
assert true_out_int == int(out_bin, 2)

unitary = bloq.tensor_contract()
assert unitary[output_int, input_int] == 1
np.testing.assert_allclose(unitary[output_int, input_int], 1)


@pytest.mark.parametrize('a,b,num_bits', itertools.product(range(4), range(4), range(3, 5)))
Expand Down
43 changes: 18 additions & 25 deletions qualtran/bloqs/arithmetic/multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Iterable, Sequence, Set, TYPE_CHECKING, Union
from typing import Dict, Iterable, List, Sequence, Set, TYPE_CHECKING, Union

import cirq
import numpy as np
Expand All @@ -22,6 +22,7 @@
Bloq,
bloq_example,
BloqDocSpec,
ConnectionT,
GateWithRegisters,
QFxp,
QUInt,
Expand All @@ -35,7 +36,6 @@
if TYPE_CHECKING:
import quimb.tensor as qtn

from qualtran import SoquetT
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT

Expand Down Expand Up @@ -90,19 +90,12 @@ def __pow__(self, power):
)
raise NotImplementedError("PlusEqualProduct.__pow__ defined only for powers +1/-1.")

def add_my_tensors(
self,
tn: 'qtn.TensorNetwork',
tag: Any,
*,
incoming: Dict[str, 'SoquetT'],
outgoing: Dict[str, 'SoquetT'],
):
from qualtran.cirq_interop._cirq_to_bloq import _add_my_tensors_from_gate

_add_my_tensors_from_gate(
self, self.signature, self.pretty_name(), tn, tag, incoming=incoming, outgoing=outgoing
)
def my_tensors(
self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT']
) -> List['qtn.Tensor']:
from qualtran.cirq_interop._cirq_to_bloq import _my_tensors_from_gate

return _my_tensors_from_gate(self, self.signature, incoming=incoming, outgoing=outgoing)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
# TODO: The T-complexity here is approximate.
Expand Down Expand Up @@ -169,25 +162,25 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
num_toff = self.bitsize * (self.bitsize - 1)
return {(Toffoli(), num_toff)}

def add_my_tensors(
self,
tn: 'qtn.TensorNetwork',
tag: Any,
*,
incoming: Dict[str, 'SoquetT'],
outgoing: Dict[str, 'SoquetT'],
):
def my_tensors(
self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT']
) -> List['qtn.Tensor']:
import quimb.tensor as qtn

n = self.bitsize
N = 2**self.bitsize
data = np.zeros((N, N, N**2), dtype=np.complex128)
for x in range(N):
data[x, x, x**2] = 1

trg = incoming['result'] if self.uncompute else outgoing['result']
tn.add(
qtn.Tensor(data=data, inds=(incoming['a'], outgoing['a'], trg), tags=['Square', tag])
inds = (
[(incoming['a'], j) for j in range(n)]
+ [(outgoing['a'], j) for j in range(n)]
+ [(trg, j) for j in range(2 * n)]
)
data = data.reshape((2,) * (4 * n))
return [qtn.Tensor(data=data, inds=inds, tags=[str(self)])]

def adjoint(self) -> 'Bloq':
return Square(self.bitsize, not self.uncompute)
Expand Down
2 changes: 1 addition & 1 deletion qualtran/bloqs/arithmetic/subtraction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_subtract_bloq_decomposition():
c = (a - b) % 32
want[(a << 5) | c][a_b] = 1
got = gate.tensor_contract()
np.testing.assert_equal(got, want)
np.testing.assert_allclose(got, want)


def test_subtract_bloq_validation():
Expand Down
Loading

0 comments on commit df07276

Please sign in to comment.