Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed simulators fallback to decompose_once and removed ancilla support from DensityMatrixSimulator #6127

Merged
merged 5 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions cirq-core/cirq/sim/density_matrix_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,22 +285,6 @@ def __init__(
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)

def add_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().add_qubits(qubits)
return (
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
if ret is NotImplemented
else ret
)

def remove_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().remove_qubits(qubits)
if ret is not NotImplemented:
return ret
extracted, remainder = self.factor(qubits)
remainder._state._density_matrix *= extracted._state._density_matrix.reshape(-1)[0]
return remainder

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
Expand Down
12 changes: 0 additions & 12 deletions cirq-core/cirq/sim/density_matrix_simulation_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,3 @@ def test_initial_state_bad_shape():
cirq.DensityMatrixSimulationState(
qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64
)


def test_remove_qubits():
"""Test the remove_qubits method."""
q1 = cirq.LineQubit(0)
q2 = cirq.LineQubit(1)
state = cirq.DensityMatrixSimulationState(qubits=[q1, q2])

new_state = state.remove_qubits([q1])

assert len(new_state.qubits) == 1
assert q1 not in new_state.qubits
45 changes: 22 additions & 23 deletions cirq-core/cirq/sim/simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
List,
Optional,
Sequence,
Set,
TypeVar,
TYPE_CHECKING,
Tuple,
Expand All @@ -31,8 +32,8 @@

import numpy as np

from cirq import protocols, value
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
from cirq import ops, protocols, value

from cirq.sim.simulation_state_base import SimulationStateBase

TState = TypeVar('TState', bound='cirq.QuantumStateRepresentation')
Expand Down Expand Up @@ -166,7 +167,7 @@ def create_merged_state(self) -> Self:
"""Creates a final merged state."""
return self

def add_qubits(self: Self, qubits: Sequence['cirq.Qid']):
def add_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
"""Add qubits to a new state space and take the kron product.

Note that only Density Matrix and State Vector simulators
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -181,8 +182,8 @@ def add_qubits(self: Self, qubits: Sequence['cirq.Qid']):
Raises:
ValueError: If a qubit being added is already tracked.
"""
if any(q in self.qubits for q in qubits):
raise ValueError(f"Qubit to add {qubits} should not already be tracked.")
if not qubits:
return self
return NotImplemented

def remove_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
Expand All @@ -194,7 +195,7 @@ def remove_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
Returns:
A new Simulation State with qubits removed. Or
`self` if there are no qubits to remove."""
if qubits is None or not qubits:
if not qubits:
return self
return NotImplemented

Expand Down Expand Up @@ -325,25 +326,23 @@ def can_represent_mixed_states(self) -> bool:
def strat_act_on_from_apply_decompose(
val: Any, args: 'cirq.SimulationState', qubits: Sequence['cirq.Qid']
) -> bool:
operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val)
if operations is None:
if isinstance(val, ops.Gate):
decomposed = protocols.decompose_once_with_qubits(val, qubits, flatten=False, default=None)
else:
decomposed = protocols.decompose_once(val, flatten=False, default=None)
if decomposed is None:
return NotImplemented
assert len(qubits1) == len(qubits)
all_qubits = frozenset([q for op in operations for q in op.qubits])
qubit_map = dict(zip(all_qubits, all_qubits))
qubit_map.update(dict(zip(qubits1, qubits)))
new_ancilla = tuple(q for q in sorted(all_qubits.difference(qubits)) if q not in args.qubits)
args = args.add_qubits(new_ancilla)
if args is NotImplemented:
return NotImplemented
for operation in operations:
operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits])
all_ancilla: Set['cirq.Qid'] = set()
for operation in ops.flatten_to_ops(decomposed):
curr_ancilla = tuple(q for q in operation.qubits if q not in args.qubits)
args = args.add_qubits(curr_ancilla)
if args is NotImplemented:
return NotImplemented
all_ancilla.update(curr_ancilla)
protocols.act_on(operation, args)
args = args.remove_qubits(new_ancilla)
if args is NotImplemented: # coverage: ignore
raise TypeError( # coverage: ignore
f"{type(args)} implements `add_qubits` but not `remove_qubits`." # coverage: ignore
) # coverage: ignore
args = args.remove_qubits(all_ancilla)
if args is NotImplemented:
raise TypeError(f"{type(args)} implements add_qubits but not remove_qubits.")
return True


Expand Down
146 changes: 52 additions & 94 deletions cirq-core/cirq/sim/simulation_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,61 +42,26 @@ def _act_on_fallback_(
) -> bool:
return True


class AncillaZ(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.CX(qubits[0], ancilla)
yield cirq.Z(ancilla) ** self._exponent
yield cirq.CX(qubits[0], ancilla)


class AncillaH(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.H(ancilla) ** self._exponent
yield cirq.CX(ancilla, qubits[0])
yield cirq.H(ancilla) ** self._exponent


class AncillaY(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.Y(ancilla) ** self._exponent
yield cirq.CX(ancilla, qubits[0])
yield cirq.Y(ancilla) ** self._exponent
def add_qubits(self, qubits):
ret = super().add_qubits(qubits)
return self if NotImplemented else ret


class DelegatingAncillaZ(cirq.Gate):
def __init__(self, exponent=1):
def __init__(self, exponent=1, measure_ancilla: bool = False):
self._exponent = exponent
self._measure_ancilla = measure_ancilla

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
a = cirq.NamedQubit('a')
yield cirq.CX(qubits[0], a)
yield AncillaZ(self._exponent).on(a)
yield PhaseUsingCleanAncilla(self._exponent).on(a)
yield cirq.CX(qubits[0], a)
if self._measure_ancilla:
yield cirq.measure(a)


class Composite(cirq.Gate):
Expand All @@ -115,12 +80,23 @@ def test_measurements():

def test_decompose():
args = DummySimulationState()
assert (
simulation_state.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)])
is NotImplemented
assert simulation_state.strat_act_on_from_apply_decompose(
Composite(), args, [cirq.LineQubit(0)]
)


def test_decompose_for_gate_allocating_qubits_raises():
class Composite(cirq.testing.SingleQubitGate):
def _decompose_(self, qubits):
anc = cirq.NamedQubit("anc")
yield cirq.CNOT(*qubits, anc)

args = DummySimulationState()

with pytest.raises(TypeError, match="add_qubits but not remove_qubits"):
simulation_state.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)])


def test_mapping():
args = DummySimulationState()
assert list(iter(args)) == cirq.LineQubit.range(2)
Expand Down Expand Up @@ -162,53 +138,35 @@ def test_field_getters():
assert args.qubit_map == {q: i for i, q in enumerate(cirq.LineQubit.range(2))}


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_ancilla_z(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit(AncillaZ(exp).on(q))

control_circuit = cirq.Circuit(cirq.ZPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_ancilla_y(exp):
@pytest.mark.parametrize('exp', np.linspace(0, 2 * np.pi, 10))
def test_delegating_gate_unitary(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit(AncillaY(exp).on(q))

control_circuit = cirq.Circuit(cirq.Y(q))
control_circuit.append(cirq.Y(q))
control_circuit.append(cirq.XPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_borrowable_qubit(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit()
test_circuit.append(cirq.H(q))
test_circuit.append(cirq.X(q))
test_circuit.append(AncillaH(exp).on(q))
test_circuit.append(DelegatingAncillaZ(exp).on(q))

control_circuit = cirq.Circuit(cirq.H(q))
control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
assert_test_circuit_for_sv_simulator(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_delegating_gate_qubit(exp):
@pytest.mark.parametrize('exp', np.linspace(0, 2 * np.pi, 10))
def test_delegating_gate_channel(exp):
q = cirq.LineQubit(0)

test_circuit = cirq.Circuit()
test_circuit.append(cirq.H(q))
test_circuit.append(DelegatingAncillaZ(exp).on(q))
test_circuit.append(DelegatingAncillaZ(exp, True).on(q))

control_circuit = cirq.Circuit(cirq.H(q))
control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
with pytest.raises(TypeError, match="DensityMatrixSimulator doesn't support"):
# TODO: This test should pass once we extend support to DensityMatrixSimulator.
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)


@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
Expand All @@ -221,7 +179,8 @@ def test_phase_using_dirty_ancilla(num_ancilla: int):
u.on(q, *anc), PhaseUsingDirtyAncilla(ancilla_bitsize=num_ancilla).on(q)
)
control_circuit = cirq.Circuit(u.on(q, *anc), cirq.Z(q))
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
assert_test_circuit_for_sv_simulator(test_circuit, control_circuit)


@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
Expand All @@ -233,25 +192,24 @@ def test_phase_using_clean_ancilla(num_ancilla: int, theta: float):
u.on(q), PhaseUsingCleanAncilla(theta=theta, ancilla_bitsize=num_ancilla).on(q)
)
control_circuit = cirq.Circuit(u.on(q), cirq.ZPowGate(exponent=theta).on(q))
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


def test_add_qubits_raise_value_error(num_ancilla=1):
q = cirq.LineQubit(0)
args = cirq.StateVectorSimulationState(qubits=[q])

with pytest.raises(ValueError, match='should not already be tracked.'):
args.add_qubits([q])
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
assert_test_circuit_for_sv_simulator(test_circuit, control_circuit)


def test_remove_qubits_not_implemented(num_ancilla=1):
args = DummySimulationState()

assert args.remove_qubits([cirq.LineQubit(0)]) is NotImplemented
def assert_test_circuit_for_dm_simulator(test_circuit, control_circuit) -> None:
# Density Matrix Simulator: For unitary gates, this fallbacks to `cirq.apply_channel`
# which recursively calls to `cirq.apply_unitary(decompose=True)`.
for split_untangled_states in [True, False]:
sim = cirq.DensityMatrixSimulator(split_untangled_states=split_untangled_states)
control_sim = sim.simulate(control_circuit).final_density_matrix
test_sim = sim.simulate(test_circuit).final_density_matrix
assert np.allclose(test_sim, control_sim)


def assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) -> None:
for test_simulator in ['cirq.final_state_vector', 'cirq.final_density_matrix']:
test_sim = eval(test_simulator)(test_circuit)
control_sim = eval(test_simulator)(control_circuit)
def assert_test_circuit_for_sv_simulator(test_circuit, control_circuit) -> None:
# State Vector Simulator.
for split_untangled_states in [True, False]:
sim = cirq.Simulator(split_untangled_states=split_untangled_states)
control_sim = sim.simulate(control_circuit).final_state_vector
test_sim = sim.simulate(test_circuit).final_state_vector
assert np.allclose(test_sim, control_sim)