Skip to content

Commit

Permalink
Update Density Matrix and State Vector Simulators to work when an ope…
Browse files Browse the repository at this point in the history
…ration allocates new qubits as part of its decomposition (quantumlib#6108)

* WIP add factoring and kron methods to sim state for adding and removing ancillas in state vector and density matrix simulators

* add test cases

* add delegating gate test case

* update test

* all tests pass

* add test case for unitary Y

* nit

* addresses PR comments by adding empty checks. Applys formatter. Subsequent push will add more test cases per Tanuj's comment

* nit formatting changes, add docustring with input/output for remove_qubits

* merge this branch and tanujkhattar/Cirq@ccde689

* merging branches, adding test coverage in next push

* format files

* add coverage tests

* change assert

* coverage and type check tests should pass

* incorporate tanujkhattar/Cirq@1db8ac5

* nit

* remove block comment

* add coverage

---------

Co-authored-by: Tanuj Khattar <[email protected]>
  • Loading branch information
senecameeks and tanujkhattar authored Jun 7, 2023
1 parent cb891f1 commit d38de16
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 14 deletions.
2 changes: 1 addition & 1 deletion cirq/protocols/act_on_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def act_on(

arg_fallback = getattr(sim_state, '_act_on_fallback_', None)
if arg_fallback is not None:
qubits = action.qubits if isinstance(action, ops.Operation) else qubits
qubits = action.qubits if is_op else qubits
result = arg_fallback(action, qubits=qubits, allow_decompose=allow_decompose)
if result is True:
return
Expand Down
16 changes: 16 additions & 0 deletions cirq/sim/density_matrix_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,22 @@ def __init__(
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)

def add_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().add_qubits(qubits)
return (
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
if ret is NotImplemented
else ret
)

def remove_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().remove_qubits(qubits)
if ret is not NotImplemented:
return ret
extracted, remainder = self.factor(qubits)
remainder._state._density_matrix *= extracted._state._density_matrix.reshape(-1)[0]
return remainder

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
Expand Down
12 changes: 12 additions & 0 deletions cirq/sim/density_matrix_simulation_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,15 @@ def test_initial_state_bad_shape():
cirq.DensityMatrixSimulationState(
qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64
)


def test_remove_qubits():
"""Test the remove_qubits method."""
q1 = cirq.LineQubit(0)
q2 = cirq.LineQubit(1)
state = cirq.DensityMatrixSimulationState(qubits=[q1, q2])

new_state = state.remove_qubits([q1])

assert len(new_state.qubits) == 1
assert q1 not in new_state.qubits
47 changes: 45 additions & 2 deletions cirq/sim/simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,38 @@ def create_merged_state(self) -> Self:
"""Creates a final merged state."""
return self

def add_qubits(self: Self, qubits: Sequence['cirq.Qid']):
"""Add qubits to a new state space and take the kron product.
Note that only Density Matrix and State Vector simulators
override this function.
Args:
qubits: Sequence of qubits to be added.
Returns:
NotImplemented: If the subclass does not implement this method.
Raises:
ValueError: If a qubit being added is already tracked.
"""
if any(q in self.qubits for q in qubits):
raise ValueError(f"Qubit to add {qubits} should not already be tracked.")
return NotImplemented

def remove_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
"""Remove qubits from the state space.
Args:
qubits: Sequence of qubits to be added.
Returns:
A new Simulation State with qubits removed. Or
`self` if there are no qubits to remove."""
if qubits is None or not qubits:
return self
return NotImplemented

def kronecker_product(self, other: Self, *, inplace=False) -> Self:
"""Joins two state spaces together."""
args = self if inplace else copy.copy(self)
Expand Down Expand Up @@ -294,13 +326,24 @@ def strat_act_on_from_apply_decompose(
val: Any, args: 'cirq.SimulationState', qubits: Sequence['cirq.Qid']
) -> bool:
operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val)
assert len(qubits1) == len(qubits)
qubit_map = {q: qubits[i] for i, q in enumerate(qubits1)}
if operations is None:
return NotImplemented
assert len(qubits1) == len(qubits)
all_qubits = frozenset([q for op in operations for q in op.qubits])
qubit_map = dict(zip(all_qubits, all_qubits))
qubit_map.update(dict(zip(qubits1, qubits)))
new_ancilla = tuple(q for q in sorted(all_qubits.difference(qubits)) if q not in args.qubits)
args = args.add_qubits(new_ancilla)
if args is NotImplemented:
return NotImplemented
for operation in operations:
operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits])
protocols.act_on(operation, args)
args = args.remove_qubits(new_ancilla)
if args is NotImplemented: # coverage: ignore
raise TypeError( # coverage: ignore
f"{type(args)} implements `add_qubits` but not `remove_qubits`." # coverage: ignore
) # coverage: ignore
return True


Expand Down
176 changes: 165 additions & 11 deletions cirq/sim/simulation_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import cirq
from cirq.sim import simulation_state
from cirq.testing import PhaseUsingCleanAncilla, PhaseUsingDirtyAncilla


class DummyQuantumState(cirq.QuantumStateRepresentation):
Expand All @@ -33,32 +34,90 @@ def reindex(self, axes):


class DummySimulationState(cirq.SimulationState):
def __init__(self):
super().__init__(state=DummyQuantumState(), qubits=cirq.LineQubit.range(2))
def __init__(self, qubits=cirq.LineQubit.range(2)):
super().__init__(state=DummyQuantumState(), qubits=qubits)

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
return True


class AncillaZ(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.CX(qubits[0], ancilla)
yield cirq.Z(ancilla) ** self._exponent
yield cirq.CX(qubits[0], ancilla)


class AncillaH(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.H(ancilla) ** self._exponent
yield cirq.CX(ancilla, qubits[0])
yield cirq.H(ancilla) ** self._exponent


class AncillaY(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.Y(ancilla) ** self._exponent
yield cirq.CX(ancilla, qubits[0])
yield cirq.Y(ancilla) ** self._exponent


class DelegatingAncillaZ(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
a = cirq.NamedQubit('a')
yield cirq.CX(qubits[0], a)
yield AncillaZ(self._exponent).on(a)
yield cirq.CX(qubits[0], a)


class Composite(cirq.Gate):
def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
yield cirq.X(*qubits)


def test_measurements():
args = DummySimulationState()
args.measure([cirq.LineQubit(0)], "test", [False], {})
assert args.log_of_measurement_results["test"] == [5]


def test_decompose():
class Composite(cirq.Gate):
def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
yield cirq.X(*qubits)

args = DummySimulationState()
assert simulation_state.strat_act_on_from_apply_decompose(
Composite(), args, [cirq.LineQubit(0)]
assert (
simulation_state.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)])
is NotImplemented
)


Expand Down Expand Up @@ -101,3 +160,98 @@ def test_field_getters():
args = DummySimulationState()
assert args.prng is np.random
assert args.qubit_map == {q: i for i, q in enumerate(cirq.LineQubit.range(2))}


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_ancilla_z(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit(AncillaZ(exp).on(q))

control_circuit = cirq.Circuit(cirq.ZPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_ancilla_y(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit(AncillaY(exp).on(q))

control_circuit = cirq.Circuit(cirq.Y(q))
control_circuit.append(cirq.Y(q))
control_circuit.append(cirq.XPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_borrowable_qubit(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit()
test_circuit.append(cirq.H(q))
test_circuit.append(cirq.X(q))
test_circuit.append(AncillaH(exp).on(q))

control_circuit = cirq.Circuit(cirq.H(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_delegating_gate_qubit(exp):
q = cirq.LineQubit(0)

test_circuit = cirq.Circuit()
test_circuit.append(cirq.H(q))
test_circuit.append(DelegatingAncillaZ(exp).on(q))

control_circuit = cirq.Circuit(cirq.H(q))
control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
def test_phase_using_dirty_ancilla(num_ancilla: int):
q = cirq.LineQubit(0)
anc = cirq.NamedQubit.range(num_ancilla, prefix='anc')

u = cirq.MatrixGate(cirq.testing.random_unitary(2 ** (num_ancilla + 1)))
test_circuit = cirq.Circuit(
u.on(q, *anc), PhaseUsingDirtyAncilla(ancilla_bitsize=num_ancilla).on(q)
)
control_circuit = cirq.Circuit(u.on(q, *anc), cirq.Z(q))
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
@pytest.mark.parametrize('theta', np.linspace(0, 2 * np.pi, 10))
def test_phase_using_clean_ancilla(num_ancilla: int, theta: float):
q = cirq.LineQubit(0)
u = cirq.MatrixGate(cirq.testing.random_unitary(2))
test_circuit = cirq.Circuit(
u.on(q), PhaseUsingCleanAncilla(theta=theta, ancilla_bitsize=num_ancilla).on(q)
)
control_circuit = cirq.Circuit(u.on(q), cirq.ZPowGate(exponent=theta).on(q))
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


def test_add_qubits_raise_value_error(num_ancilla=1):
q = cirq.LineQubit(0)
args = cirq.StateVectorSimulationState(qubits=[q])

with pytest.raises(ValueError, match='should not already be tracked.'):
args.add_qubits([q])


def test_remove_qubits_not_implemented(num_ancilla=1):
args = DummySimulationState()

assert args.remove_qubits([cirq.LineQubit(0)]) is NotImplemented


def assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) -> None:
for test_simulator in ['cirq.final_state_vector', 'cirq.final_density_matrix']:
test_sim = eval(test_simulator)(test_circuit)
control_sim = eval(test_simulator)(control_circuit)
assert np.allclose(test_sim, control_sim)
16 changes: 16 additions & 0 deletions cirq/sim/state_vector_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,22 @@ def __init__(
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)

def add_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().add_qubits(qubits)
return (
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
if ret is NotImplemented
else ret
)

def remove_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().remove_qubits(qubits)
if ret is not NotImplemented:
return ret
extracted, remainder = self.factor(qubits, inplace=True)
remainder._state._state_vector *= extracted._state._state_vector.reshape((-1,))[0]
return remainder

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
Expand Down

0 comments on commit d38de16

Please sign in to comment.