From e6877b7ae49d29924f40c90100f5c3391d735cec Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 21 Mar 2022 12:01:04 -0700 Subject: [PATCH] Add support for deep=True to cirq.drop_negligible_operations transformer --- .../drop_negligible_operations.py | 4 ++- .../drop_negligible_operations_test.py | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/transformers/drop_negligible_operations.py b/cirq-core/cirq/transformers/drop_negligible_operations.py index a2a2ed1c069..51c11366e83 100644 --- a/cirq-core/cirq/transformers/drop_negligible_operations.py +++ b/cirq-core/cirq/transformers/drop_negligible_operations.py @@ -43,6 +43,8 @@ def drop_negligible_operations( Returns: Copy of the transformed input circuit. """ + if context is None: + context = transformer_api.TransformerContext() def map_func(op: 'cirq.Operation', _: int) -> 'cirq.OP_TREE': return ( @@ -50,5 +52,5 @@ def map_func(op: 'cirq.Operation', _: int) -> 'cirq.OP_TREE': ) return transformer_primitives.map_operations( - circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else () + circuit, map_func, tags_to_ignore=context.tags_to_ignore, deep=context.deep ).unfreeze(copy=False) diff --git a/cirq-core/cirq/transformers/drop_negligible_operations_test.py b/cirq-core/cirq/transformers/drop_negligible_operations_test.py index 531d135a7cd..3f454decf95 100644 --- a/cirq-core/cirq/transformers/drop_negligible_operations_test.py +++ b/cirq-core/cirq/transformers/drop_negligible_operations_test.py @@ -59,3 +59,39 @@ def test_clears_known_empties_even_at_zero_tolerance(): cirq.Moment(), ), ) + + +def test_recursively_runs_inside_circuit_ops_deep(): + a = cirq.NamedQubit('a') + small_op = cirq.Z(a) ** 0.000001 + nested_circuit = cirq.FrozenCircuit( + cirq.X(a), small_op, small_op.with_tags(NO_COMPILE_TAG), small_op, cirq.Y(a) + ) + nested_circuit_dropped = cirq.FrozenCircuit( + cirq.Moment(cirq.X(a)), + cirq.Moment(), + cirq.Moment(small_op.with_tags(NO_COMPILE_TAG)), + cirq.Moment(), + cirq.Moment(cirq.Y(a)), + ) + c_orig = cirq.Circuit( + small_op, + cirq.CircuitOperation(nested_circuit).repeat(6).with_tags(NO_COMPILE_TAG), + small_op, + cirq.CircuitOperation(nested_circuit).repeat(5).with_tags("preserve_tag"), + small_op, + ) + c_expected = cirq.Circuit( + cirq.Moment(), + cirq.Moment(cirq.CircuitOperation(nested_circuit).repeat(6).with_tags(NO_COMPILE_TAG)), + cirq.Moment(), + cirq.Moment( + cirq.CircuitOperation(nested_circuit_dropped).repeat(5).with_tags("preserve_tag") + ), + cirq.Moment(), + ) + context = cirq.TransformerContext(tags_to_ignore=[NO_COMPILE_TAG], deep=True) + cirq.testing.assert_same_circuits( + cirq.drop_negligible_operations(c_orig, context=context, atol=0.001), + c_expected, + )