Skip to content

Commit

Permalink
Add support for deep=True flag to remaining transformer primitives (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tanujkhattar authored Mar 21, 2022
1 parent 2693951 commit 6af5387
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 16 deletions.
63 changes: 61 additions & 2 deletions cirq-core/cirq/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,17 @@ def map_moments(
circuit: CIRCUIT_TYPE,
map_func: Callable[[circuits.Moment, int], Union[circuits.Moment, Sequence[circuits.Moment]]],
*,
tags_to_ignore: Sequence[Hashable] = (),
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Applies local transformation on moments, by calling `map_func(moment)` for each moment.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
map_func: Mapping function from (cirq.Moment, moment_index) to a sequence of moments.
tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be
ignored when recursively applying the transformer primitive to sub-circuits, given
deep=True.
deep: If true, `map_func` will be recursively applied to circuits wrapped inside
any circuit operations contained within `circuit`.
Expand All @@ -79,6 +83,8 @@ def map_moments(
for i, op in circuit.findall_operations(
lambda o: isinstance(o.untagged, circuits.CircuitOperation)
):
if set(op.tags).intersection(tags_to_ignore):
continue
op_untagged = cast(circuits.CircuitOperation, op.untagged)
mapped_op = op_untagged.replace(
circuit=map_moments(op_untagged.circuit, map_func, deep=deep)
Expand Down Expand Up @@ -190,6 +196,7 @@ def merge_operations(
merge_func: Callable[[ops.Operation, ops.Operation], Optional[ops.Operation]],
*,
tags_to_ignore: Sequence[Hashable] = (),
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Merges operations in a circuit by calling `merge_func` iteratively on operations.
Expand Down Expand Up @@ -226,6 +233,8 @@ def merge_operations(
tags_to_ignore: Sequence of tags which should be ignored while applying `merge_func` on
tagged operations -- i.e. `merge_func(op1, op2)` will be called only if both `op1` and
`op2` satisfy `set(op.tags).isdisjoint(tags_to_ignore)`.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
Returns:
Expand All @@ -235,9 +244,11 @@ def merge_operations(
ValueError if the merged operation acts on new qubits outside the set of qubits
corresponding to the original operations to be merged.
"""
_circuit_op_tag = "_internal_tag_to_mark_circuit_ops_in_circuit"
tags_to_ignore_set = set(tags_to_ignore) | {_circuit_op_tag}

def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Operation]:
if not all(set(op.tags).isdisjoint(tags_to_ignore) for op in [op1, op2]):
if not all(tags_to_ignore_set.isdisjoint(op.tags) for op in [op1, op2]):
return None
new_op = merge_func(op1, op2)
qubit_set = frozenset(op1.qubits + op2.qubits)
Expand All @@ -252,6 +263,23 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope
for current_moment in circuit:
new_moment = circuits.Moment()
for op in sorted(current_moment.operations, key=lambda op: op.qubits):
if (
deep
and isinstance(op.untagged, circuits.CircuitOperation)
and tags_to_ignore_set.isdisjoint(op.tags)
):
op_untagged = op.untagged
new_moment = new_moment.with_operation(
op_untagged.replace(
circuit=merge_operations(
op_untagged.circuit,
merge_func,
tags_to_ignore=tags_to_ignore,
deep=True,
)
).with_tags(*op.tags, _circuit_op_tag)
)
continue
op_qs = set(op.qubits)
idx = ret_circuit.prev_moment_operating_on(tuple(op_qs))
if idx is not None and op_qs.issubset(ret_circuit[idx][op_qs].operations[0].qubits):
Expand Down Expand Up @@ -279,6 +307,12 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope
idx = ret_circuit.prev_moment_operating_on(tuple(op_qs))
new_moment = new_moment.with_operation(op)
ret_circuit += new_moment
if deep:
ret_circuit = map_operations(
ret_circuit,
lambda o, _: o.untagged.with_tags(*(set(o.tags) - {_circuit_op_tag})),
deep=True,
)
return _to_target_circuit_type(ret_circuit, circuit)


Expand All @@ -288,6 +322,7 @@ def merge_operations_to_circuit_op(
*,
tags_to_ignore: Sequence[Hashable] = (),
merged_circuit_op_tag: str = "Merged connected component",
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Merges connected components of operations and wraps each component into a circuit operation.
Expand All @@ -307,6 +342,8 @@ def merge_operations_to_circuit_op(
potential candidates for any connected component.
merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
components.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
Returns:
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
Expand All @@ -329,7 +366,7 @@ def get_ops(op: 'cirq.Operation'):
merged_circuit_op_tag
)

return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore)
return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore, deep=deep)


def merge_k_qubit_unitaries_to_circuit_op(
Expand All @@ -338,6 +375,7 @@ def merge_k_qubit_unitaries_to_circuit_op(
*,
tags_to_ignore: Sequence[Hashable] = (),
merged_circuit_op_tag: Optional[str] = None,
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Merges connected components of operations, acting on <= k qubits, into circuit operations.
Expand All @@ -353,6 +391,8 @@ def merge_k_qubit_unitaries_to_circuit_op(
potential candidates for any connected component.
merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
components. A default tag is applied if left None.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
Returns:
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
Expand All @@ -370,12 +410,16 @@ def can_merge(ops1: Sequence['cirq.Operation'], ops2: Sequence['cirq.Operation']
can_merge,
tags_to_ignore=tags_to_ignore,
merged_circuit_op_tag=merged_circuit_op_tag or f"Merged {k}q unitary connected component.",
deep=deep,
)


def merge_moments(
circuit: CIRCUIT_TYPE,
merge_func: Callable[[circuits.Moment, circuits.Moment], Optional[circuits.Moment]],
*,
tags_to_ignore: Sequence[Hashable] = (),
deep: bool = False,
) -> CIRCUIT_TYPE:
"""Merges adjacent moments, one by one from left to right, by calling `merge_func(m1, m2)`.
Expand All @@ -384,12 +428,27 @@ def merge_moments(
merge_func: Callable to determine whether two adjacent moments in the circuit should be
merged. If the moments can be merged, the callable should return the merged moment,
else None.
tags_to_ignore: Tagged circuit operations marked with any of `tags_to_ignore` will be
ignored when recursively applying the transformer primitive to sub-circuits, given
deep=True.
deep: If true, the transformer primitive will be recursively applied to all circuits
wrapped inside circuit operations.
Returns:
Copy of input circuit with merged moments.
"""
if not circuit:
return circuit
if deep:
circuit = map_operations(
circuit,
lambda op, _: op.untagged.replace(
circuit=merge_moments(op.untagged.circuit, merge_func, deep=deep)
).with_tags(*op.tags)
if isinstance(op.untagged, circuits.CircuitOperation)
else op,
tags_to_ignore=tags_to_ignore,
)
merged_moments: List[circuits.Moment] = [circuit[0]]
for current_moment in circuit[1:]:
merged_moment = merge_func(merged_moments[-1], current_moment)
Expand Down
111 changes: 97 additions & 14 deletions cirq-core/cirq/transformers/transformer_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,35 @@ def test_map_moments_drop_empty_moments():
cirq.testing.assert_same_circuits(c_mapped, cirq.Circuit(c[0], c[0]))


def test_map_moments_drop_empty_moments_deep():
op = cirq.X(cirq.NamedQubit("q"))
c_nested = cirq.FrozenCircuit(cirq.Moment(op), cirq.Moment(), cirq.Moment(op))
c_orig = cirq.Circuit(
c_nested,
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
c_nested,
cirq.CircuitOperation(c_nested).repeat(5).with_tags("preserve_tag"),
)
c_expected = cirq.Circuit(
[op, op],
cirq.CircuitOperation(c_nested).repeat(6).with_tags("ignore"),
[op, op],
cirq.CircuitOperation(cirq.FrozenCircuit([op, op])).repeat(5).with_tags("preserve_tag"),
)
c_mapped = cirq.map_moments(
c_orig, lambda m, i: [] if len(m) == 0 else [m], deep=True, tags_to_ignore=("ignore",)
)
cirq.testing.assert_same_circuits(c_mapped, c_expected)


def _merge_z_moments_func(m1: cirq.Moment, m2: cirq.Moment) -> Optional[cirq.Moment]:
if any(op.gate != cirq.Z for m in [m1, m2] for op in m):
return None
return cirq.Moment(
cirq.Z(q) for q in (m1.qubits | m2.qubits) if m1.operates_on([q]) ^ m2.operates_on([q])
)


def test_merge_moments():
q = cirq.LineQubit.range(3)
c_orig = cirq.Circuit(
Expand All @@ -419,21 +448,8 @@ def test_merge_moments():
''',
)

def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> Optional[cirq.Moment]:
def is_z_moment(m):
return all(op.gate == cirq.Z for op in m)

if not (is_z_moment(m1) and is_z_moment(m2)):
return None
qubits = m1.qubits | m2.qubits

def mul(op1, op2):
return (op1 or op2) if not (op1 and op2) else cirq.decompose_once(op1 * op2)

return cirq.Moment(mul(m1.operation_at(q), m2.operation_at(q)) for q in qubits)

cirq.testing.assert_has_diagram(
cirq.merge_moments(c_orig, merge_func),
cirq.merge_moments(c_orig, _merge_z_moments_func),
'''
0: ───────@───────
Expand All @@ -444,6 +460,35 @@ def mul(op1, op2):
)


def test_merge_moments_deep():
q = cirq.LineQubit.range(3)
c_z_moments = cirq.Circuit(
[cirq.Z.on_each(q[0], q[1]), cirq.Z.on_each(q[1], q[2]), cirq.Z.on_each(q[1], q[0])],
strategy=cirq.InsertStrategy.NEW_THEN_INLINE,
)
merged_z_moment = cirq.Moment(cirq.Z.on_each(*q[1:]))
c_nested_circuit = cirq.FrozenCircuit(c_z_moments, cirq.CCX(*q), c_z_moments)
c_merged_circuit = cirq.FrozenCircuit(merged_z_moment, cirq.CCX(*q), merged_z_moment)
c_orig = cirq.Circuit(
cirq.CircuitOperation(c_nested_circuit).repeat(5).with_tags("ignore"),
c_nested_circuit,
cirq.CircuitOperation(c_nested_circuit).repeat(6).with_tags("preserve_tag"),
c_nested_circuit,
cirq.CircuitOperation(c_nested_circuit).repeat(7),
)
c_expected = cirq.Circuit(
cirq.CircuitOperation(c_nested_circuit).repeat(5).with_tags("ignore"),
c_merged_circuit,
cirq.CircuitOperation(c_merged_circuit).repeat(6).with_tags("preserve_tag"),
c_merged_circuit,
cirq.CircuitOperation(c_merged_circuit).repeat(7),
)
cirq.testing.assert_same_circuits(
cirq.merge_moments(c_orig, _merge_z_moments_func, tags_to_ignore=("ignore",), deep=True),
c_expected,
)


def test_merge_moments_empty_moment_as_intermediate_step():
q = cirq.NamedQubit("q")
c_orig = cirq.Circuit([cirq.X(q), cirq.Y(q), cirq.Z(q)] * 2, cirq.X(q) ** 0.5)
Expand Down Expand Up @@ -543,7 +588,45 @@ def merge_func(op1, op2):
)


def test_merge_operations_deep():
q = cirq.LineQubit.range(2)
h_cz_y = [cirq.H(q[0]), cirq.CZ(*q), cirq.Y(q[1])]
m_cz_m = [cirq.Moment(), cirq.Moment(cirq.CZ(*q)), cirq.Moment()]
c_orig = cirq.Circuit(
h_cz_y,
cirq.Moment(cirq.X(q[0]).with_tags("ignore"), cirq.Y(q[1])),
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
[cirq.CNOT(*q), cirq.CNOT(*q)],
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(4),
[cirq.CNOT(*q), cirq.CZ(*q), cirq.CNOT(*q)],
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(5).with_tags("preserve_tag"),
)
c_expected = cirq.Circuit(
m_cz_m,
cirq.Moment(cirq.X(q[0]).with_tags("ignore")),
cirq.CircuitOperation(cirq.FrozenCircuit(h_cz_y)).repeat(6).with_tags("ignore"),
[cirq.CNOT(*q), cirq.CNOT(*q)],
cirq.CircuitOperation(cirq.FrozenCircuit(m_cz_m)).repeat(4),
[cirq.CZ(*q), cirq.Moment(), cirq.Moment()],
cirq.CircuitOperation(cirq.FrozenCircuit(m_cz_m)).repeat(5).with_tags("preserve_tag"),
strategy=cirq.InsertStrategy.NEW,
)

def merge_func(op1, op2):
"""Artificial example where a CZ will absorb any merge-able operation."""
for op in [op1, op2]:
if op.gate == cirq.CZ:
return op
return None

cirq.testing.assert_same_circuits(
cirq.merge_operations(c_orig, merge_func, tags_to_ignore=["ignore"], deep=True), c_expected
)


# pylint: disable=line-too-long


def test_merge_operations_to_circuit_op_merges_connected_component():
c_orig = _create_circuit_to_merge()
cirq.testing.assert_has_diagram(
Expand Down

0 comments on commit 6af5387

Please sign in to comment.