Skip to content

Commit

Permalink
Add drop_empty_moments and drop_negligible_operations transformer…
Browse files Browse the repository at this point in the history
…s. (quantumlib#4915)

* Add drop_empty_moments and drop_negligible_operations transformers

* Fix notebook tests

* Fix ejectz tests

* Replace DropEmptyMoments with drop_empty_moments.

* Fix isolated_notebook_test

* Address comments.

Co-authored-by: Cirq Bot <[email protected]>
  • Loading branch information
tanujkhattar and CirqBot authored Jan 31, 2022
1 parent bc15660 commit faea068
Show file tree
Hide file tree
Showing 22 changed files with 254 additions and 50 deletions.
2 changes: 2 additions & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@
decompose_multi_controlled_x,
decompose_multi_controlled_rotation,
decompose_two_qubit_interaction_into_four_fsim_gates,
drop_empty_moments,
drop_negligible_operations,
is_negligible_turn,
map_moments,
map_operations,
Expand Down
5 changes: 2 additions & 3 deletions cirq/contrib/paulistring/convert_gate_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from cirq import circuits, optimizers
from cirq import circuits, optimizers, transformers

from cirq.contrib.paulistring.convert_to_pauli_string_phasors import ConvertToPauliStringPhasors

Expand All @@ -34,5 +34,4 @@ def converted_gate_set(
keep_clifford=not no_clifford_gates,
atol=atol,
).optimize_circuit(conv_circuit)
optimizers.DropEmptyMoments().optimize_circuit(conv_circuit)
return conv_circuit
return transformers.drop_empty_moments(conv_circuit)
10 changes: 6 additions & 4 deletions cirq/contrib/quimb/Cirq-to-Tensor-Networks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
" import cirq\n",
"except ImportError:\n",
" print(\"installing cirq...\")\n",
" !pip install --quiet cirq\n",
" !pip install --quiet cirq --pre\n",
" print(\"installed cirq.\")\n",
"\n",
"try:\n",
Expand Down Expand Up @@ -78,7 +78,7 @@
"source": [
"qubits = cirq.LineQubit.range(3)\n",
"circuit = cirq.testing.random_circuit(qubits, n_moments=10, op_density=0.8, random_state=52)\n",
"cirq.DropEmptyMoments().optimize_circuit(circuit)\n",
"circuit = cirq.drop_empty_moments(circuit)\n",
"SVGCircuit(circuit)"
]
},
Expand All @@ -87,7 +87,9 @@
"metadata": {},
"source": [
"### Circuit to Tensors\n",
"The circuit defines a tensor network representation. By default, the initial state is the `|0...0>` state (represented by the \"zero qubit\" operations labeled \"Q0\" in the legend. \"Q1\" are single qubit operations and \"Q2\" are two qubit operations. The open legs are the indices into the state vector and are of the form \"i{m}_q{n}\" where `m` is the time index (given by the returned `qubit_frontier` dictionary) and \"n\" is the qubit string."
"The circuit defines a tensor network representation. By default, the initial state is the `|0...0>` state (represented by the \"zero qubit\" operations labeled \"Q0\" in the legend. \"Q1\" are single qubit operations and \"Q2\" are two qubit operations. The open legs are the indices into the state vector and are of the form \"i{m}_q{n}\" where `m` is the time index (given by the returned `qubit_frontier` dictionary) and \"n\" is the qubit string.\n",
"\n",
"Note: this notebook relies on unreleased Cirq features. If you want to try these features, make sure you install cirq via `pip install cirq --pre`."
]
},
{
Expand Down Expand Up @@ -285,7 +287,7 @@
" circuit = cirq.testing.random_circuit(qubits, n_moments=n_moments, op_density=0.8)\n",
" noise_model = cirq.ConstantQubitNoiseModel(cirq.DepolarizingChannel(p=1e-3))\n",
" circuit = cirq.Circuit(noise_model.noisy_moments(circuit.moments, qubits))\n",
" cirq.DropEmptyMoments().optimize_circuit(circuit)\n",
" circuit = cirq.drop_empty_moments(circuit)\n",
" n_moments = len(circuit)\n",
" variables = {'circuit': circuit, 'qubits': qubits}\n",
"\n",
Expand Down
10 changes: 6 additions & 4 deletions cirq/contrib/quimb/Contract-a-Grid-Circuit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
" import cirq\n",
"except ImportError:\n",
" print(\"installing cirq...\")\n",
" !pip install --quiet cirq\n",
" !pip install --quiet cirq --pre\n",
" print(\"installed cirq.\")\n",
"\n",
"try:\n",
Expand All @@ -28,7 +28,9 @@
"metadata": {},
"source": [
"# Contract a Grid Circuit\n",
"Shallow circuits on a planar grid with low-weight observables permit easy contraction."
"Shallow circuits on a planar grid with low-weight observables permit easy contraction.\n",
"\n",
"Note: this notebook relies on unreleased Cirq features. If you want to try these features, make sure you install cirq via `pip install cirq --pre`."
]
},
{
Expand Down Expand Up @@ -207,8 +209,8 @@
"ccq.MergeNQubitGates(n_qubits=2).optimize_circuit(compressed_c)\n",
"ccq.MergeNQubitGates(n_qubits=1).optimize_circuit(compressed_c)\n",
"\n",
"cirq.DropNegligible(tolerance=1e-6).optimize_circuit(compressed_c)\n",
"cirq.DropEmptyMoments().optimize_circuit(compressed_c)\n",
"compressed_c = cirq.drop_negligible_operations(compressed_c, atol=1e-6)\n",
"compressed_c = cirq.drop_empty_moments(compressed_c)\n",
"print(len(list(compressed_c.all_operations())), len(compressed_c.all_qubits()))\n",
"SVGCircuit(compressed_c)"
]
Expand Down
4 changes: 2 additions & 2 deletions cirq/contrib/quimb/density_matrix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_tensor_density_matrix_3():
def test_tensor_density_matrix_4():
qubits = cirq.LineQubit.range(4)
circuit = cirq.testing.random_circuit(qubits=qubits, n_moments=100, op_density=0.8)
cirq.DropEmptyMoments().optimize_circuit(circuit)
circuit = cirq.drop_empty_moments(circuit)
noise_model = cirq.ConstantQubitNoiseModel(cirq.DepolarizingChannel(p=1e-3))
circuit = cirq.Circuit(noise_model.noisy_moments(circuit.moments, qubits))
rho1 = cirq.final_density_matrix(circuit, dtype=np.complex128)
Expand All @@ -69,7 +69,7 @@ def test_tensor_density_matrix_4():
def test_tensor_density_matrix_gridqubit():
qubits = cirq.GridQubit.rect(2, 2)
circuit = cirq.testing.random_circuit(qubits=qubits, n_moments=10, op_density=0.8)
cirq.DropEmptyMoments().optimize_circuit(circuit)
circuit = cirq.drop_empty_moments(circuit)
noise_model = cirq.ConstantQubitNoiseModel(cirq.DepolarizingChannel(p=1e-3))
circuit = cirq.Circuit(noise_model.noisy_moments(circuit.moments, qubits))
rho1 = cirq.final_density_matrix(circuit, dtype=np.complex128)
Expand Down
5 changes: 2 additions & 3 deletions cirq/contrib/quimb/grid_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,9 @@ def simplify_expectation_value_circuit(circuit_sand: cirq.Circuit):
n_op = sum(1 for _ in circuit_sand.all_operations())
while True:
MergeNQubitGates(n_qubits=1).optimize_circuit(circuit_sand)
cirq.DropNegligible(tolerance=1e-6).optimize_circuit(circuit_sand)
circuit_sand = cirq.drop_negligible_operations(circuit_sand, atol=1e-6)
MergeNQubitGates(n_qubits=2).optimize_circuit(circuit_sand)
cirq.DropNegligible(tolerance=1e-6)
cirq.DropEmptyMoments().optimize_circuit(circuit_sand)
circuit_sand = cirq.drop_empty_moments(circuit_sand)
new_n_op = sum(1 for _ in circuit_sand.all_operations())

if new_n_op < n_op:
Expand Down
4 changes: 2 additions & 2 deletions cirq/ops/pauli_string_phasor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ def test_drop_negligible():
cirq.PauliStringPhasor(cirq.PauliString({q0: cirq.Z})) ** 0.25,
cirq.PauliStringPhasor(cirq.PauliString({q0: cirq.Z})) ** sym,
)
cirq.DropNegligible().optimize_circuit(circuit)
cirq.DropEmptyMoments().optimize_circuit(circuit)
circuit = cirq.drop_negligible_operations(circuit)
circuit = cirq.drop_empty_moments(circuit)
assert circuit == expected


Expand Down
2 changes: 2 additions & 0 deletions cirq/optimizers/drop_empty_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

from cirq.circuits.circuit import Circuit
from cirq.circuits import circuit as _circuit
from cirq._compat import deprecated_class


@deprecated_class(deadline='v1.0', fix='Use cirq.drop_empty_moments instead.')
class DropEmptyMoments:
"""Removes empty moments from a circuit."""

Expand Down
7 changes: 4 additions & 3 deletions cirq/optimizers/drop_empty_moments_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@


def assert_optimizes(before, after):
opt = cirq.DropEmptyMoments()
opt.optimize_circuit(before)
assert before == after
with cirq.testing.assert_deprecated("Use cirq.drop_empty_moments", deadline='v1.0'):
opt = cirq.DropEmptyMoments()
opt.optimize_circuit(before)
assert before == after


def test_drop():
Expand Down
2 changes: 2 additions & 0 deletions cirq/optimizers/drop_negligible.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

from cirq import protocols
from cirq.circuits import circuit as _circuit
from cirq._compat import deprecated_class

if TYPE_CHECKING:
from cirq import ops


@deprecated_class(deadline='v1.0', fix='Use cirq.drop_negligible_operations instead.')
class DropNegligible:
"""An optimization pass that removes operations with tiny effects."""

Expand Down
22 changes: 10 additions & 12 deletions cirq/optimizers/drop_negligible_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,26 @@
import cirq


def assert_optimizes(optimizer, initial_circuit: cirq.Circuit, expected_circuit: cirq.Circuit):
circuit = cirq.Circuit(initial_circuit)
optimizer.optimize_circuit(circuit)
assert circuit == expected_circuit
def assert_optimizes(atol: float, initial_circuit: cirq.Circuit, expected_circuit: cirq.Circuit):
with cirq.testing.assert_deprecated("Use cirq.drop_negligible_operations", deadline='v1.0'):
optimizer = cirq.DropNegligible(atol)
circuit = cirq.Circuit(initial_circuit)
optimizer.optimize_circuit(circuit)
assert circuit == expected_circuit


def test_leaves_big():
drop = cirq.DropNegligible(0.001)
a = cirq.NamedQubit('a')
circuit = cirq.Circuit([cirq.Moment([cirq.Z(a) ** 0.1])])

assert_optimizes(optimizer=drop, initial_circuit=circuit, expected_circuit=circuit)
assert_optimizes(0.001, initial_circuit=circuit, expected_circuit=circuit)


def test_clears_small():
drop = cirq.DropNegligible(0.001)
a = cirq.NamedQubit('a')
circuit = cirq.Circuit([cirq.Moment([cirq.Z(a) ** 0.000001])])

assert_optimizes(
optimizer=drop, initial_circuit=circuit, expected_circuit=cirq.Circuit([cirq.Moment()])
)
assert_optimizes(0.001, initial_circuit=circuit, expected_circuit=cirq.Circuit([cirq.Moment()]))


def test_clears_known_empties_even_at_zero_tolerance():
Expand All @@ -45,12 +43,12 @@ def test_clears_known_empties_even_at_zero_tolerance():
cirq.Z(a) ** 0, cirq.Y(a) ** 0.0000001, cirq.X(a) ** -0.0000001, cirq.CZ(a, b) ** 0
)
assert_optimizes(
optimizer=cirq.DropNegligible(tolerance=0.001),
0.001,
initial_circuit=circuit,
expected_circuit=cirq.Circuit([cirq.Moment()] * 4),
)
assert_optimizes(
optimizer=cirq.DropNegligible(tolerance=0),
0,
initial_circuit=circuit,
expected_circuit=cirq.Circuit(
[
Expand Down
6 changes: 3 additions & 3 deletions cirq/optimizers/eject_z_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def test_swap():
optimized = original.copy()

cirq.EjectZ().optimize_circuit(optimized)
cirq.DropEmptyMoments().optimize_circuit(optimized)
optimized = cirq.drop_empty_moments(optimized)

assert optimized[0].operations == (cirq.SWAP(a, b),)
# Note: EjectZ drops `global_phase` from Rz turning it into a Z
Expand All @@ -397,7 +397,7 @@ def test_swap_fsim(theta):
optimized = original.copy()

cirq.EjectZ().optimize_circuit(optimized)
cirq.DropEmptyMoments().optimize_circuit(optimized)
optimized = cirq.drop_empty_moments(optimized)

assert optimized[0].operations == (cirq.FSimGate(theta=theta, phi=0.123).on(a, b),)
# Note: EjectZ drops `global_phase` from Rz turning it into a Z
Expand All @@ -420,7 +420,7 @@ def test_swap_iswap(exponent):
optimized = original.copy()

cirq.EjectZ().optimize_circuit(optimized)
cirq.DropEmptyMoments().optimize_circuit(optimized)
optimized = cirq.drop_empty_moments(optimized)

assert optimized[0].operations == (cirq.ISWAP(a, b) ** exponent,)
# Note: EjectZ drops `global_phase` from Rz turning it into a Z
Expand Down
4 changes: 1 addition & 3 deletions cirq/optimizers/expand_composite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@


def assert_equal_mod_empty(expected, actual):
drop_empty = cirq.DropEmptyMoments()
drop_empty.optimize_circuit(actual)

actual = cirq.drop_empty_moments(actual)
assert expected == actual, f'EXPECTED {expected} : ACTUAL {actual}'


Expand Down
12 changes: 9 additions & 3 deletions cirq/optimizers/merge_interactions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit):
cirq.merge_single_qubit_gates_into_phased_x_z,
cirq.EjectPhasedPaulis().optimize_circuit,
cirq.EjectZ().optimize_circuit,
cirq.DropNegligible().optimize_circuit,
cirq.DropEmptyMoments().optimize_circuit,
]
for post in followup_optimizations:
post(actual)
post(expected)

followup_transformers: List[cirq.TRANSFORMER] = [
cirq.drop_negligible_operations,
cirq.drop_empty_moments,
]
for transform in followup_transformers:
actual = transform(actual).unfreeze(copy=False)
expected = transform(expected).unfreeze(copy=False)

assert actual == expected, f'ACTUAL {actual} : EXPECTED {expected}'


Expand Down Expand Up @@ -249,7 +255,7 @@ def clean_up(operations):

optimizer = cirq.MergeInteractions(allow_partial_czs=False, post_clean_up=clean_up)
optimizer.optimize_circuit(circuit)
cirq.DropEmptyMoments().optimize_circuit(circuit)
circuit = cirq.drop_empty_moments(circuit)

assert isinstance(circuit[0].operations[0].gate, Marker)
assert isinstance(circuit[-1].operations[0].gate, Marker)
Expand Down
10 changes: 8 additions & 2 deletions cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,19 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs):
cirq.merge_single_qubit_gates_into_phased_x_z,
cirq.EjectPhasedPaulis().optimize_circuit,
cirq.EjectZ().optimize_circuit,
cirq.DropNegligible().optimize_circuit,
cirq.DropEmptyMoments().optimize_circuit,
]
for post in followup_optimizations:
post(actual)
post(expected)

followup_transformers: List[cirq.TRANSFORMER] = [
cirq.drop_negligible_operations,
cirq.drop_empty_moments,
]
for transform in followup_transformers:
actual = transform(actual).unfreeze(copy=False)
expected = transform(expected).unfreeze(copy=False)

assert actual == expected, f'ACTUAL {actual} : EXPECTED {expected}'


Expand Down
10 changes: 5 additions & 5 deletions cirq/optimizers/merge_single_qubit_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def assert_optimizes(
optimizer(before)

# Ignore differences that would be caught by follow-up optimizations.
followup_optimizations = [cirq.DropNegligible(), cirq.DropEmptyMoments()]
for post in followup_optimizations:
post(before) # type: ignore # error: "object" not callable
post(expected) # type: ignore # error: "object" not callable
followup_transformers = [cirq.drop_negligible_operations, cirq.drop_empty_moments]
for transform in followup_transformers:
before = transform(before) # type: ignore # error: "object" not callable
expected = transform(expected) # type: ignore # error: "object" not callable

assert before == expected, f'BEFORE:\n{before}\nEXPECTED:\n{expected}'

Expand Down Expand Up @@ -150,7 +150,7 @@ def test_rewrite():
cirq.MergeSingleQubitGates(rewriter=lambda ops: cirq.H(ops[0].qubits[0])).optimize_circuit(
circuit
)
cirq.DropEmptyMoments().optimize_circuit(circuit)
circuit = cirq.drop_empty_moments(circuit)

cirq.testing.assert_same_circuits(
circuit,
Expand Down
4 changes: 4 additions & 0 deletions cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@

from cirq.transformers.align import align_left, align_right

from cirq.transformers.drop_empty_moments import drop_empty_moments

from cirq.transformers.drop_negligible_operations import drop_negligible_operations

from cirq.transformers.synchronize_terminal_measurements import synchronize_terminal_measurements

from cirq.transformers.transformer_api import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _optimize_multiplexed_angles_circuit(operations: Sequence[ops.Operation]):
the optimized operations
"""
circuit = cirq.Circuit(operations)
cirq.optimizers.DropNegligible().optimize_circuit(circuit)
circuit = cirq.transformers.drop_negligible_operations(circuit)
if np.allclose(circuit.unitary(), np.eye(8), atol=1e-14):
return cirq.Circuit([])

Expand Down
37 changes: 37 additions & 0 deletions cirq/transformers/drop_empty_moments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2022 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer pass that removes empty moments from a circuit."""

from typing import Optional, TYPE_CHECKING
from cirq.transformers import transformer_api, transformer_primitives

if TYPE_CHECKING:
import cirq


@transformer_api.transformer
def drop_empty_moments(
circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None
) -> 'cirq.Circuit':
"""Removes empty moments from a circuit.
Args:
circuit: Input circuit to transform.
context: `cirq.TransformerContext` storing common configurable options for transformers.
Returns:
Copy of the transformed input circuit.
"""
return transformer_primitives.map_moments(circuit.unfreeze(False), lambda m, _: m if m else [])
Loading

0 comments on commit faea068

Please sign in to comment.