Skip to content

Commit

Permalink
Add support for deep=True to cirq.eject_phased_paulis transformer (#…
Browse files Browse the repository at this point in the history
…5116)

- Adds support to recursively run `cirq.eject_phased_paulis` transformer on circuits wrapped inside a circuit operation by setting deep=True in transformer context.
- Part of #5039
  • Loading branch information
tanujkhattar authored Mar 21, 2022
1 parent d2f284d commit ca4bb72
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 19 deletions.
2 changes: 1 addition & 1 deletion cirq-core/cirq/transformers/eject_phased_paulis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import cirq


@transformer_api.transformer
@transformer_api.transformer(add_deep_support=True)
def eject_phased_paulis(
circuit: 'cirq.AbstractCircuit',
*,
Expand Down
50 changes: 32 additions & 18 deletions cirq-core/cirq/transformers/eject_phased_paulis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Iterable, cast

import dataclasses
import numpy as np
import pytest
import sympy
Expand Down Expand Up @@ -53,29 +54,42 @@ def assert_optimizes(
)

# And match the expected circuit.
assert circuit == expected, (
"Circuit wasn't optimized as expected.\n"
"INPUT:\n"
"{}\n"
"\n"
"EXPECTED OUTPUT:\n"
"{}\n"
"\n"
"ACTUAL OUTPUT:\n"
"{}\n"
"\n"
"EXPECTED OUTPUT (detailed):\n"
"{!r}\n"
"\n"
"ACTUAL OUTPUT (detailed):\n"
"{!r}"
).format(before, expected, circuit, expected, circuit)
cirq.testing.assert_same_circuits(circuit, expected)

# And it should be idempotent.
circuit = cirq.eject_phased_paulis(
circuit, eject_parameterized=eject_parameterized, context=context
)
assert circuit == expected
cirq.testing.assert_same_circuits(circuit, expected)

# Nested sub-circuits should also get optimized.
q = before.all_qubits()
c_nested = cirq.Circuit(
[cirq.PhasedXPowGate(phase_exponent=0.5).on_each(*q), (cirq.Z ** 0.5).on_each(*q)],
cirq.CircuitOperation(before.freeze()).repeat(2).with_tags("ignore"),
[cirq.Y.on_each(*q), cirq.X.on_each(*q)],
cirq.CircuitOperation(before.freeze()).repeat(3).with_tags("preserve_tag"),
)
c_expected = cirq.Circuit(
cirq.PhasedXPowGate(phase_exponent=0.75).on_each(*q),
cirq.Moment(cirq.CircuitOperation(before.freeze()).repeat(2).with_tags("ignore")),
cirq.Z.on_each(*q),
cirq.Moment(cirq.CircuitOperation(expected.freeze()).repeat(3).with_tags("preserve_tag")),
)
if context is None:
context = cirq.TransformerContext(tags_to_ignore=("ignore",), deep=True)
else:
context = dataclasses.replace(
context, tags_to_ignore=context.tags_to_ignore + ("ignore",), deep=True
)
c_nested = cirq.eject_phased_paulis(
c_nested, context=context, eject_parameterized=eject_parameterized
)
cirq.testing.assert_same_circuits(c_nested, c_expected)
c_nested = cirq.eject_phased_paulis(
c_nested, context=context, eject_parameterized=eject_parameterized
)
cirq.testing.assert_same_circuits(c_nested, c_expected)


def quick_circuit(*moments: Iterable[cirq.OP_TREE]) -> cirq.Circuit:
Expand Down

0 comments on commit ca4bb72

Please sign in to comment.