From 39795e1b9437a6ba1daf21303942dfa08e39e220 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Mon, 13 Jun 2022 11:02:18 -0700 Subject: [PATCH] Move CircuitDag to contrib (#5481) * Move CircuitDag to contrib - Moves CircuitDag to contrib. - cirq.CircuitDag is now deprecated and now must be changed to cirq.contrib.CircuitDag. - Also moves the related cirq.Unique class Note: the tests use a high number of cirq.Unique instances which are also moved, so unable to verify count of deprecation messages in many tests. --- cirq-core/cirq/circuits/circuit_dag.py | 4 +- cirq-core/cirq/circuits/circuit_dag_test.py | 246 ++++++++++-------- cirq-core/cirq/contrib/__init__.py | 1 + .../contrib/acquaintance/inspection_utils.py | 5 +- .../acquaintance/topological_sort_test.py | 3 +- cirq-core/cirq/contrib/circuitdag/__init__.py | 15 ++ .../cirq/contrib/circuitdag/circuit_dag.py | 204 +++++++++++++++ .../contrib/circuitdag/circuit_dag_test.py | 242 +++++++++++++++++ .../contrib/paulistring/pauli_string_dag.py | 5 +- .../paulistring/pauli_string_optimize.py | 7 +- .../cirq/contrib/paulistring/recombine.py | 10 +- cirq-core/cirq/contrib/routing/greedy.py | 10 +- cirq-core/cirq/contrib/routing/utils.py | 7 +- 13 files changed, 634 insertions(+), 125 deletions(-) create mode 100644 cirq-core/cirq/contrib/circuitdag/__init__.py create mode 100644 cirq-core/cirq/contrib/circuitdag/circuit_dag.py create mode 100644 cirq-core/cirq/contrib/circuitdag/circuit_dag_test.py diff --git a/cirq-core/cirq/circuits/circuit_dag.py b/cirq-core/cirq/circuits/circuit_dag.py index bb0aec09176..3f5eae5063f 100644 --- a/cirq-core/cirq/circuits/circuit_dag.py +++ b/cirq-core/cirq/circuits/circuit_dag.py @@ -16,7 +16,7 @@ import functools import networkx -from cirq import ops +from cirq import _compat, ops from cirq.circuits import circuit if TYPE_CHECKING: @@ -25,6 +25,7 @@ T = TypeVar('T') +@_compat.deprecated_class(deadline='v0.16', fix='Use cirq contrib.Unique instead.') @functools.total_ordering class Unique(Generic[T]): """A wrapper for a value that doesn't compare equal to other instances. @@ -55,6 +56,7 @@ def _disjoint_qubits(op1: 'cirq.Operation', op2: 'cirq.Operation') -> bool: return not set(op1.qubits) & set(op2.qubits) +@_compat.deprecated_class(deadline='v0.16', fix='Use cirq contrib.CircuitDag instead.') class CircuitDag(networkx.DiGraph): """A representation of a Circuit as a directed acyclic graph. diff --git a/cirq-core/cirq/circuits/circuit_dag_test.py b/cirq-core/cirq/circuits/circuit_dag_test.py index 4b9c37a5a1e..1ac88ef5d88 100644 --- a/cirq-core/cirq/circuits/circuit_dag_test.py +++ b/cirq-core/cirq/circuits/circuit_dag_test.py @@ -28,15 +28,18 @@ class FakeDevice(cirq.Device): def test_wrapper_eq(): q0, q1 = cirq.LineQubit.range(2) eq = cirq.testing.EqualsTester() - eq.add_equality_group(cirq.CircuitDag.make_node(cirq.X(q0))) - eq.add_equality_group(cirq.CircuitDag.make_node(cirq.X(q0))) - eq.add_equality_group(cirq.CircuitDag.make_node(cirq.Y(q0))) - eq.add_equality_group(cirq.CircuitDag.make_node(cirq.X(q1))) + + with cirq.testing.assert_deprecated('Use cirq contrib.Unique', deadline='v0.16', count=4): + eq.add_equality_group(cirq.CircuitDag.make_node(cirq.X(q0))) + eq.add_equality_group(cirq.CircuitDag.make_node(cirq.X(q0))) + eq.add_equality_group(cirq.CircuitDag.make_node(cirq.Y(q0))) + eq.add_equality_group(cirq.CircuitDag.make_node(cirq.X(q1))) def test_wrapper_cmp(): - u0 = cirq.Unique(0) - u1 = cirq.Unique(1) + with cirq.testing.assert_deprecated('Use cirq contrib.Unique', deadline='v0.16', count=2): + u0 = cirq.Unique(0) + u1 = cirq.Unique(1) # The ordering of Unique instances is unpredictable u0, u1 = (u1, u0) if u1 < u0 else (u0, u1) assert u0 == u0 @@ -50,88 +53,107 @@ def test_wrapper_cmp(): def test_wrapper_cmp_failure(): - with pytest.raises(TypeError): - _ = object() < cirq.Unique(1) - with pytest.raises(TypeError): - _ = cirq.Unique(1) < object() + with cirq.testing.assert_deprecated('Use cirq contrib.Unique', deadline='v0.16', count=2): + with pytest.raises(TypeError): + _ = object() < cirq.Unique(1) + with pytest.raises(TypeError): + _ = cirq.Unique(1) < object() def test_wrapper_repr(): q0 = cirq.LineQubit(0) - node = cirq.CircuitDag.make_node(cirq.X(q0)) - assert repr(node) == 'cirq.Unique(' + str(id(node)) + ', cirq.X(cirq.LineQubit(0)))' + with cirq.testing.assert_deprecated('Use cirq contrib.Unique', deadline='v0.16'): + node = cirq.CircuitDag.make_node(cirq.X(q0)) + assert repr(node) == 'cirq.Unique(' + str(id(node)) + ', cirq.X(cirq.LineQubit(0)))' def test_init(): - dag = cirq.CircuitDag() - assert networkx.dag.is_directed_acyclic_graph(dag) - assert list(dag.nodes()) == [] - assert list(dag.edges()) == [] + with cirq.testing.assert_deprecated('Use cirq contrib.CircuitDag', deadline='v0.16', count=1): + dag = cirq.CircuitDag() + assert networkx.dag.is_directed_acyclic_graph(dag) + assert list(dag.nodes()) == [] + assert list(dag.edges()) == [] def test_append(): q0 = cirq.LineQubit(0) - dag = cirq.CircuitDag() - dag.append(cirq.X(q0)) - dag.append(cirq.Y(q0)) - assert networkx.dag.is_directed_acyclic_graph(dag) - assert len(dag.nodes()) == 2 - assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))] + with cirq.testing.assert_deprecated( + 'Use cirq contrib.CircuitDag', deadline='v0.16', count=None + ): + dag = cirq.CircuitDag() + dag.append(cirq.X(q0)) + dag.append(cirq.Y(q0)) + assert networkx.dag.is_directed_acyclic_graph(dag) + assert len(dag.nodes()) == 2 + assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))] def test_two_identical_ops(): q0 = cirq.LineQubit(0) - dag = cirq.CircuitDag() - dag.append(cirq.X(q0)) - dag.append(cirq.Y(q0)) - dag.append(cirq.X(q0)) - assert networkx.dag.is_directed_acyclic_graph(dag) - assert len(dag.nodes()) == 3 - assert set((n1.val, n2.val) for n1, n2 in dag.edges()) == { - (cirq.X(q0), cirq.Y(q0)), - (cirq.X(q0), cirq.X(q0)), - (cirq.Y(q0), cirq.X(q0)), - } + with cirq.testing.assert_deprecated( + 'Use cirq contrib.CircuitDag', deadline='v0.16', count=None + ): + dag = cirq.CircuitDag() + dag.append(cirq.X(q0)) + dag.append(cirq.Y(q0)) + dag.append(cirq.X(q0)) + assert networkx.dag.is_directed_acyclic_graph(dag) + assert len(dag.nodes()) == 3 + assert set((n1.val, n2.val) for n1, n2 in dag.edges()) == { + (cirq.X(q0), cirq.Y(q0)), + (cirq.X(q0), cirq.X(q0)), + (cirq.Y(q0), cirq.X(q0)), + } def test_from_ops(): q0 = cirq.LineQubit(0) - dag = cirq.CircuitDag.from_ops(cirq.X(q0), cirq.Y(q0)) - assert networkx.dag.is_directed_acyclic_graph(dag) - assert len(dag.nodes()) == 2 - assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))] + with cirq.testing.assert_deprecated( + 'Use cirq contrib.CircuitDag', deadline='v0.16', count=None + ): + dag = cirq.CircuitDag.from_ops(cirq.X(q0), cirq.Y(q0)) + assert networkx.dag.is_directed_acyclic_graph(dag) + assert len(dag.nodes()) == 2 + assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))] def test_from_circuit(): q0 = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.X(q0), cirq.Y(q0)) - dag = cirq.CircuitDag.from_circuit(circuit) - assert networkx.dag.is_directed_acyclic_graph(dag) - assert len(dag.nodes()) == 2 - assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))] - assert sorted(circuit.all_qubits()) == sorted(dag.all_qubits()) + with cirq.testing.assert_deprecated( + 'Use cirq contrib.CircuitDag', deadline='v0.16', count=None + ): + dag = cirq.CircuitDag.from_circuit(circuit) + assert networkx.dag.is_directed_acyclic_graph(dag) + assert len(dag.nodes()) == 2 + assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))] + assert sorted(circuit.all_qubits()) == sorted(dag.all_qubits()) def test_to_empty_circuit(): circuit = cirq.Circuit() - dag = cirq.CircuitDag.from_circuit(circuit) - assert networkx.dag.is_directed_acyclic_graph(dag) - assert circuit == dag.to_circuit() + with cirq.testing.assert_deprecated('Use cirq contrib.CircuitDag', deadline='v0.16'): + dag = cirq.CircuitDag.from_circuit(circuit) + assert networkx.dag.is_directed_acyclic_graph(dag) + assert circuit == dag.to_circuit() def test_to_circuit(): q0 = cirq.LineQubit(0) circuit = cirq.Circuit(cirq.X(q0), cirq.Y(q0)) - dag = cirq.CircuitDag.from_circuit(circuit) + with cirq.testing.assert_deprecated( + 'Use cirq contrib.CircuitDag', deadline='v0.16', count=None + ): + dag = cirq.CircuitDag.from_circuit(circuit) - assert networkx.dag.is_directed_acyclic_graph(dag) - # Only one possible output circuit for this simple case - assert circuit == dag.to_circuit() + assert networkx.dag.is_directed_acyclic_graph(dag) + # Only one possible output circuit for this simple case + assert circuit == dag.to_circuit() - cirq.testing.assert_allclose_up_to_global_phase( - circuit.unitary(), dag.to_circuit().unitary(), atol=1e-7 - ) + cirq.testing.assert_allclose_up_to_global_phase( + circuit.unitary(), dag.to_circuit().unitary(), atol=1e-7 + ) def test_equality(): @@ -156,43 +178,49 @@ def test_equality(): ) eq = cirq.testing.EqualsTester() - eq.make_equality_group( - lambda: cirq.CircuitDag.from_circuit(circuit1), - lambda: cirq.CircuitDag.from_circuit(circuit2), - ) - eq.add_equality_group(cirq.CircuitDag.from_circuit(circuit3)) - eq.add_equality_group(cirq.CircuitDag.from_circuit(circuit4)) + with cirq.testing.assert_deprecated( + 'Use cirq contrib.CircuitDag', deadline='v0.16', count=None + ): + eq.make_equality_group( + lambda: cirq.CircuitDag.from_circuit(circuit1), + lambda: cirq.CircuitDag.from_circuit(circuit2), + ) + eq.add_equality_group(cirq.CircuitDag.from_circuit(circuit3)) + eq.add_equality_group(cirq.CircuitDag.from_circuit(circuit4)) def test_larger_circuit(): - q0, q1, q2, q3 = [ - cirq.GridQubit(0, 5), - cirq.GridQubit(1, 5), - cirq.GridQubit(2, 5), - cirq.GridQubit(3, 5), - ] - # This circuit does not have CZ gates on adjacent qubits because the order - # dag.to_circuit() would append them is non-deterministic. - circuit = cirq.Circuit( - cirq.X(q0), - cirq.CZ(q1, q2), - cirq.CZ(q0, q1), - cirq.Y(q0), - cirq.Z(q0), - cirq.CZ(q1, q2), - cirq.X(q0), - cirq.Y(q0), - cirq.CZ(q0, q1), - cirq.T(q3), - strategy=cirq.InsertStrategy.EARLIEST, - ) - - dag = cirq.CircuitDag.from_circuit(circuit) - - assert networkx.dag.is_directed_acyclic_graph(dag) - # Operation order within a moment is non-deterministic - # but text diagrams still look the same. - desired = """ + with cirq.testing.assert_deprecated( + 'Use cirq contrib.CircuitDag', deadline='v0.16', count=None + ): + q0, q1, q2, q3 = [ + cirq.GridQubit(0, 5), + cirq.GridQubit(1, 5), + cirq.GridQubit(2, 5), + cirq.GridQubit(3, 5), + ] + # This circuit does not have CZ gates on adjacent qubits because the order + # dag.to_circuit() would append them is non-deterministic. + circuit = cirq.Circuit( + cirq.X(q0), + cirq.CZ(q1, q2), + cirq.CZ(q0, q1), + cirq.Y(q0), + cirq.Z(q0), + cirq.CZ(q1, q2), + cirq.X(q0), + cirq.Y(q0), + cirq.CZ(q0, q1), + cirq.T(q3), + strategy=cirq.InsertStrategy.EARLIEST, + ) + + dag = cirq.CircuitDag.from_circuit(circuit) + + assert networkx.dag.is_directed_acyclic_graph(dag) + # Operation order within a moment is non-deterministic + # but text diagrams still look the same. + desired = """ (0, 5): ───X───@───Y───Z───X───Y───@─── │ │ (1, 5): ───@───@───@───────────────@─── @@ -201,20 +229,26 @@ def test_larger_circuit(): (3, 5): ───T─────────────────────────── """ - cirq.testing.assert_has_diagram(circuit, desired) - cirq.testing.assert_has_diagram(dag.to_circuit(), desired) + cirq.testing.assert_has_diagram(circuit, desired) + cirq.testing.assert_has_diagram(dag.to_circuit(), desired) - cirq.testing.assert_allclose_up_to_global_phase( - circuit.unitary(), dag.to_circuit().unitary(), atol=1e-7 - ) + cirq.testing.assert_allclose_up_to_global_phase( + circuit.unitary(), dag.to_circuit().unitary(), atol=1e-7 + ) @pytest.mark.parametrize('circuit', [cirq.testing.random_circuit(10, 10, 0.5) for _ in range(3)]) def test_is_maximalist(circuit): - dag = cirq.CircuitDag.from_circuit(circuit) - transitive_closure = networkx.dag.transitive_closure(dag) - assert cirq.CircuitDag(incoming_graph_data=transitive_closure) == dag - assert not any(dag.has_edge(b, a) for a, b in itertools.combinations(dag.ordered_nodes(), 2)) + # This creates a number of Unique classes so the count is not consistent. + with cirq.testing.assert_deprecated( + 'Use cirq contrib.CircuitDag', deadline='v0.16', count=None + ): + dag = cirq.CircuitDag.from_circuit(circuit) + transitive_closure = networkx.dag.transitive_closure(dag) + assert cirq.CircuitDag(incoming_graph_data=transitive_closure) == dag + assert not any( + dag.has_edge(b, a) for a, b in itertools.combinations(dag.ordered_nodes(), 2) + ) def _get_circuits_and_is_blockers(): @@ -230,12 +264,16 @@ def _get_circuits_and_is_blockers(): @pytest.mark.parametrize('circuit, is_blocker', _get_circuits_and_is_blockers()) def test_findall_nodes_until_blocked(circuit, is_blocker): - dag = cirq.CircuitDag.from_circuit(circuit) - all_nodes = list(dag.ordered_nodes()) - found_nodes = list(dag.findall_nodes_until_blocked(is_blocker)) - assert not any(dag.has_edge(b, a) for a, b in itertools.combinations(found_nodes, 2)) - - blocking_nodes = set(node for node in all_nodes if is_blocker(node.val)) - blocked_nodes = blocking_nodes.union(*(dag.succ[node] for node in blocking_nodes)) - expected_nodes = set(all_nodes) - blocked_nodes - assert sorted(found_nodes) == sorted(expected_nodes) + # This creates a number of Unique classes so the count is not consistent. + with cirq.testing.assert_deprecated( + 'Use cirq contrib.CircuitDag', deadline='v0.16', count=None + ): + dag = cirq.CircuitDag.from_circuit(circuit) + all_nodes = list(dag.ordered_nodes()) + found_nodes = list(dag.findall_nodes_until_blocked(is_blocker)) + assert not any(dag.has_edge(b, a) for a, b in itertools.combinations(found_nodes, 2)) + + blocking_nodes = set(node for node in all_nodes if is_blocker(node.val)) + blocked_nodes = blocking_nodes.union(*(dag.succ[node] for node in blocking_nodes)) + expected_nodes = set(all_nodes) - blocked_nodes + assert sorted(found_nodes) == sorted(expected_nodes) diff --git a/cirq-core/cirq/contrib/__init__.py b/cirq-core/cirq/contrib/__init__.py index 68c3800fbf5..53888be889f 100644 --- a/cirq-core/cirq/contrib/__init__.py +++ b/cirq-core/cirq/contrib/__init__.py @@ -23,3 +23,4 @@ from cirq.contrib import quirk from cirq.contrib.qcircuit import circuit_to_latex_using_qcircuit from cirq.contrib import json +from cirq.contrib.circuitdag import CircuitDag, Unique diff --git a/cirq-core/cirq/contrib/acquaintance/inspection_utils.py b/cirq-core/cirq/contrib/acquaintance/inspection_utils.py index d5d8a13241b..789e65dd8aa 100644 --- a/cirq-core/cirq/contrib/acquaintance/inspection_utils.py +++ b/cirq-core/cirq/contrib/acquaintance/inspection_utils.py @@ -14,11 +14,12 @@ from typing import FrozenSet, Sequence, Set, TYPE_CHECKING -from cirq import circuits, devices +from cirq import devices from cirq.contrib.acquaintance.executor import AcquaintanceOperation, ExecutionStrategy from cirq.contrib.acquaintance.mutation_utils import expose_acquaintance_gates from cirq.contrib.acquaintance.permutation import LogicalIndex, LogicalMapping +from cirq.contrib import circuitdag if TYPE_CHECKING: import cirq @@ -59,7 +60,7 @@ def get_acquaintance_dag(strategy: 'cirq.Circuit', initial_mapping: LogicalMappi for op in moment.operations if isinstance(op, AcquaintanceOperation) ) - return circuits.CircuitDag.from_ops(acquaintance_ops) + return circuitdag.CircuitDag.from_ops(acquaintance_ops) def get_logical_acquaintance_opportunities( diff --git a/cirq-core/cirq/contrib/acquaintance/topological_sort_test.py b/cirq-core/cirq/contrib/acquaintance/topological_sort_test.py index c813b7f913f..e6bc8e033d0 100644 --- a/cirq-core/cirq/contrib/acquaintance/topological_sort_test.py +++ b/cirq-core/cirq/contrib/acquaintance/topological_sort_test.py @@ -23,7 +23,8 @@ [ (dag, tuple(cca.random_topological_sort(dag))) for dag in [ - cirq.CircuitDag.from_circuit(cirq.testing.random_circuit(10, 10, 0.5)) for _ in range(5) + cirq.contrib.CircuitDag.from_circuit(cirq.testing.random_circuit(10, 10, 0.5)) + for _ in range(5) ] for _ in range(5) ], diff --git a/cirq-core/cirq/contrib/circuitdag/__init__.py b/cirq-core/cirq/contrib/circuitdag/__init__.py new file mode 100644 index 00000000000..7e1f1df3b79 --- /dev/null +++ b/cirq-core/cirq/contrib/circuitdag/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cirq.contrib.circuitdag.circuit_dag import CircuitDag, Unique diff --git a/cirq-core/cirq/contrib/circuitdag/circuit_dag.py b/cirq-core/cirq/contrib/circuitdag/circuit_dag.py new file mode 100644 index 00000000000..6753c4fa2ed --- /dev/null +++ b/cirq-core/cirq/contrib/circuitdag/circuit_dag.py @@ -0,0 +1,204 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Generic, Iterator, TypeVar, cast, TYPE_CHECKING + +import functools +import networkx + +from cirq import ops +from cirq.circuits import circuit + +if TYPE_CHECKING: + import cirq + +T = TypeVar('T') + + +@functools.total_ordering +class Unique(Generic[T]): + """A wrapper for a value that doesn't compare equal to other instances. + + For example: 5 == 5 but Unique(5) != Unique(5). + + Unique is used by CircuitDag to wrap operations because nodes in a graph + are considered the same node if they compare equal to each other. For + example, `X(q0)` in one moment of a circuit, and `X(q0)` in another moment + of the circuit are wrapped by `cirq.Unique(X(q0))` so they are distinct + nodes in the graph. + """ + + def __init__(self, val: T) -> None: + self.val = val + + def __repr__(self) -> str: + return f'cirq.contrib.Unique({id(self)}, {self.val!r})' + + def __lt__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return id(self) < id(other) + + +def _disjoint_qubits(op1: 'cirq.Operation', op2: 'cirq.Operation') -> bool: + """Returns true only if the operations have qubits in common.""" + return not set(op1.qubits) & set(op2.qubits) + + +class CircuitDag(networkx.DiGraph): + """A representation of a Circuit as a directed acyclic graph. + + Nodes of the graph are instances of Unique containing each operation of a + circuit. + + Edges of the graph are tuples of nodes. Each edge specifies a required + application order between two operations. The first must be applied before + the second. + + The graph is maximalist (transitive completion). + """ + + disjoint_qubits = staticmethod(_disjoint_qubits) + + def __init__( + self, + can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits, + incoming_graph_data: Any = None, + ) -> None: + """Initializes a CircuitDag. + + Args: + can_reorder: A predicate that determines if two operations may be + reordered. Graph edges are created for pairs of operations + where this returns False. + + The default predicate allows reordering only when the operations + don't share common qubits. + incoming_graph_data: Data in initialize the graph. This can be any + value supported by networkx.DiGraph() e.g. an edge list or + another graph. + device: Hardware that the circuit should be able to run on. + """ + super().__init__(incoming_graph_data) + self.can_reorder = can_reorder + + @staticmethod + def make_node(op: 'cirq.Operation') -> Unique: + return Unique(op) + + @staticmethod + def from_circuit( + circuit: circuit.Circuit, + can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits, + ) -> 'CircuitDag': + return CircuitDag.from_ops(circuit.all_operations(), can_reorder=can_reorder) + + @staticmethod + def from_ops( + *operations: 'cirq.OP_TREE', + can_reorder: Callable[['cirq.Operation', 'cirq.Operation'], bool] = _disjoint_qubits, + ) -> 'CircuitDag': + dag = CircuitDag(can_reorder=can_reorder) + for op in ops.flatten_op_tree(operations): + dag.append(cast(ops.Operation, op)) + return dag + + def append(self, op: 'cirq.Operation') -> None: + new_node = self.make_node(op) + for node in list(self.nodes()): + if not self.can_reorder(node.val, op): + self.add_edge(node, new_node) + for pred in self.pred[node]: + self.add_edge(pred, new_node) + self.add_node(new_node) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + g1 = self.copy() + g2 = other.copy() + for node, attr in g1.nodes(data=True): + attr['val'] = node.val + for node, attr in g2.nodes(data=True): + attr['val'] = node.val + + def node_match(attr1: Dict[Any, Any], attr2: Dict[Any, Any]) -> bool: + return attr1['val'] == attr2['val'] + + return networkx.is_isomorphic(g1, g2, node_match=node_match) + + def __ne__(self, other): + return not self == other + + __hash__ = None # type: ignore + + def ordered_nodes(self) -> Iterator[Unique['cirq.Operation']]: + if not self.nodes(): + return + g = self.copy() + + def get_root_node(some_node: Unique['cirq.Operation']) -> Unique['cirq.Operation']: + pred = g.pred + while pred[some_node]: + some_node = next(iter(pred[some_node])) + return some_node + + def get_first_node() -> Unique['cirq.Operation']: + return get_root_node(next(iter(g.nodes()))) + + def get_next_node(succ: networkx.classes.coreviews.AtlasView) -> Unique['cirq.Operation']: + if succ: + return get_root_node(next(iter(succ))) + + return get_first_node() + + node = get_first_node() + while True: + yield node + succ = g.succ[node] + g.remove_node(node) + + if not g.nodes(): + return + + node = get_next_node(succ) + + def all_operations(self) -> Iterator['cirq.Operation']: + return (node.val for node in self.ordered_nodes()) + + def all_qubits(self): + return frozenset(q for node in self.nodes for q in node.val.qubits) + + def to_circuit(self) -> circuit.Circuit: + return circuit.Circuit(self.all_operations(), strategy=circuit.InsertStrategy.EARLIEST) + + def findall_nodes_until_blocked( + self, is_blocker: Callable[['cirq.Operation'], bool] + ) -> Iterator[Unique['cirq.Operation']]: + """Finds all nodes before blocking ones. + + Args: + is_blocker: The predicate that indicates whether or not an + operation is blocking. + """ + remaining_dag = self.copy() + + for node in self.ordered_nodes(): + if node not in remaining_dag: + continue + if is_blocker(node.val): + successors = list(remaining_dag.succ[node]) + remaining_dag.remove_nodes_from(successors) + remaining_dag.remove_node(node) + continue + yield node diff --git a/cirq-core/cirq/contrib/circuitdag/circuit_dag_test.py b/cirq-core/cirq/contrib/circuitdag/circuit_dag_test.py new file mode 100644 index 00000000000..2ad67fd49f3 --- /dev/null +++ b/cirq-core/cirq/contrib/circuitdag/circuit_dag_test.py @@ -0,0 +1,242 @@ +# Copyright 2018 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import random + +import pytest +import networkx + +import cirq + + +class FakeDevice(cirq.Device): + pass + + +def test_wrapper_eq(): + q0, q1 = cirq.LineQubit.range(2) + eq = cirq.testing.EqualsTester() + eq.add_equality_group(cirq.contrib.CircuitDag.make_node(cirq.X(q0))) + eq.add_equality_group(cirq.contrib.CircuitDag.make_node(cirq.X(q0))) + eq.add_equality_group(cirq.contrib.CircuitDag.make_node(cirq.Y(q0))) + eq.add_equality_group(cirq.contrib.CircuitDag.make_node(cirq.X(q1))) + + +def test_wrapper_cmp(): + u0 = cirq.contrib.Unique(0) + u1 = cirq.contrib.Unique(1) + # The ordering of Unique instances is unpredictable + u0, u1 = (u1, u0) if u1 < u0 else (u0, u1) + assert u0 == u0 + assert u0 != u1 + assert u0 < u1 + assert u1 > u0 + assert u0 <= u0 + assert u0 <= u1 + assert u0 >= u0 + assert u1 >= u0 + + +def test_wrapper_cmp_failure(): + with pytest.raises(TypeError): + _ = object() < cirq.contrib.Unique(1) + with pytest.raises(TypeError): + _ = cirq.contrib.Unique(1) < object() + + +def test_wrapper_repr(): + q0 = cirq.LineQubit(0) + + node = cirq.contrib.CircuitDag.make_node(cirq.X(q0)) + expected = f'cirq.contrib.Unique({id(node)}, cirq.X(cirq.LineQubit(0)))' + assert repr(node) == expected + + +def test_init(): + dag = cirq.contrib.CircuitDag() + assert networkx.dag.is_directed_acyclic_graph(dag) + assert list(dag.nodes()) == [] + assert list(dag.edges()) == [] + + +def test_append(): + q0 = cirq.LineQubit(0) + dag = cirq.contrib.CircuitDag() + dag.append(cirq.X(q0)) + dag.append(cirq.Y(q0)) + assert networkx.dag.is_directed_acyclic_graph(dag) + assert len(dag.nodes()) == 2 + assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))] + + +def test_two_identical_ops(): + q0 = cirq.LineQubit(0) + dag = cirq.contrib.CircuitDag() + dag.append(cirq.X(q0)) + dag.append(cirq.Y(q0)) + dag.append(cirq.X(q0)) + assert networkx.dag.is_directed_acyclic_graph(dag) + assert len(dag.nodes()) == 3 + assert set((n1.val, n2.val) for n1, n2 in dag.edges()) == { + (cirq.X(q0), cirq.Y(q0)), + (cirq.X(q0), cirq.X(q0)), + (cirq.Y(q0), cirq.X(q0)), + } + + +def test_from_ops(): + q0 = cirq.LineQubit(0) + dag = cirq.contrib.CircuitDag.from_ops(cirq.X(q0), cirq.Y(q0)) + assert networkx.dag.is_directed_acyclic_graph(dag) + assert len(dag.nodes()) == 2 + assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))] + + +def test_from_circuit(): + q0 = cirq.LineQubit(0) + circuit = cirq.Circuit(cirq.X(q0), cirq.Y(q0)) + dag = cirq.contrib.CircuitDag.from_circuit(circuit) + assert networkx.dag.is_directed_acyclic_graph(dag) + assert len(dag.nodes()) == 2 + assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))] + assert sorted(circuit.all_qubits()) == sorted(dag.all_qubits()) + + +def test_to_empty_circuit(): + circuit = cirq.Circuit() + dag = cirq.contrib.CircuitDag.from_circuit(circuit) + assert networkx.dag.is_directed_acyclic_graph(dag) + assert circuit == dag.to_circuit() + + +def test_to_circuit(): + q0 = cirq.LineQubit(0) + circuit = cirq.Circuit(cirq.X(q0), cirq.Y(q0)) + dag = cirq.contrib.CircuitDag.from_circuit(circuit) + + assert networkx.dag.is_directed_acyclic_graph(dag) + # Only one possible output circuit for this simple case + assert circuit == dag.to_circuit() + + cirq.testing.assert_allclose_up_to_global_phase( + circuit.unitary(), dag.to_circuit().unitary(), atol=1e-7 + ) + + +def test_equality(): + q0, q1 = cirq.LineQubit.range(2) + circuit1 = cirq.Circuit( + cirq.X(q0), cirq.Y(q0), cirq.Z(q1), cirq.CZ(q0, q1), cirq.X(q1), cirq.Y(q1), cirq.Z(q0) + ) + circuit2 = cirq.Circuit( + cirq.Z(q1), cirq.X(q0), cirq.Y(q0), cirq.CZ(q0, q1), cirq.Z(q0), cirq.X(q1), cirq.Y(q1) + ) + circuit3 = cirq.Circuit( + cirq.X(q0), + cirq.Y(q0), + cirq.Z(q1), + cirq.CZ(q0, q1), + cirq.X(q1), + cirq.Y(q1), + cirq.Z(q0) ** 0.5, + ) + circuit4 = cirq.Circuit( + cirq.X(q0), cirq.Y(q0), cirq.Z(q1), cirq.CZ(q0, q1), cirq.X(q1), cirq.Y(q1) + ) + + eq = cirq.testing.EqualsTester() + eq.make_equality_group( + lambda: cirq.contrib.CircuitDag.from_circuit(circuit1), + lambda: cirq.contrib.CircuitDag.from_circuit(circuit2), + ) + eq.add_equality_group(cirq.contrib.CircuitDag.from_circuit(circuit3)) + eq.add_equality_group(cirq.contrib.CircuitDag.from_circuit(circuit4)) + + +def test_larger_circuit(): + q0, q1, q2, q3 = [ + cirq.GridQubit(0, 5), + cirq.GridQubit(1, 5), + cirq.GridQubit(2, 5), + cirq.GridQubit(3, 5), + ] + # This circuit does not have CZ gates on adjacent qubits because the order + # dag.to_circuit() would append them is non-deterministic. + circuit = cirq.Circuit( + cirq.X(q0), + cirq.CZ(q1, q2), + cirq.CZ(q0, q1), + cirq.Y(q0), + cirq.Z(q0), + cirq.CZ(q1, q2), + cirq.X(q0), + cirq.Y(q0), + cirq.CZ(q0, q1), + cirq.T(q3), + strategy=cirq.InsertStrategy.EARLIEST, + ) + + dag = cirq.contrib.CircuitDag.from_circuit(circuit) + + assert networkx.dag.is_directed_acyclic_graph(dag) + # Operation order within a moment is non-deterministic + # but text diagrams still look the same. + desired = """ +(0, 5): ───X───@───Y───Z───X───Y───@─── + │ │ +(1, 5): ───@───@───@───────────────@─── + │ │ +(2, 5): ───@───────@─────────────────── + +(3, 5): ───T─────────────────────────── +""" + cirq.testing.assert_has_diagram(circuit, desired) + cirq.testing.assert_has_diagram(dag.to_circuit(), desired) + + cirq.testing.assert_allclose_up_to_global_phase( + circuit.unitary(), dag.to_circuit().unitary(), atol=1e-7 + ) + + +@pytest.mark.parametrize('circuit', [cirq.testing.random_circuit(10, 10, 0.5) for _ in range(3)]) +def test_is_maximalist(circuit): + dag = cirq.contrib.CircuitDag.from_circuit(circuit) + transitive_closure = networkx.dag.transitive_closure(dag) + assert cirq.contrib.CircuitDag(incoming_graph_data=transitive_closure) == dag + assert not any(dag.has_edge(b, a) for a, b in itertools.combinations(dag.ordered_nodes(), 2)) + + +def _get_circuits_and_is_blockers(): + qubits = cirq.LineQubit.range(10) + circuits = [cirq.testing.random_circuit(qubits, 10, 0.5) for _ in range(1)] + edges = [ + set(qubit_pair) for qubit_pair in itertools.combinations(qubits, 2) if random.random() > 0.5 + ] + not_on_edge = lambda op: len(op.qubits) > 1 and set(op.qubits) not in edges + is_blockers = [lambda op: False, not_on_edge] + return itertools.product(circuits, is_blockers) + + +@pytest.mark.parametrize('circuit, is_blocker', _get_circuits_and_is_blockers()) +def test_findall_nodes_until_blocked(circuit, is_blocker): + dag = cirq.contrib.CircuitDag.from_circuit(circuit) + all_nodes = list(dag.ordered_nodes()) + found_nodes = list(dag.findall_nodes_until_blocked(is_blocker)) + assert not any(dag.has_edge(b, a) for a, b in itertools.combinations(found_nodes, 2)) + + blocking_nodes = set(node for node in all_nodes if is_blocker(node.val)) + blocked_nodes = blocking_nodes.union(*(dag.succ[node] for node in blocking_nodes)) + expected_nodes = set(all_nodes) - blocked_nodes + assert sorted(found_nodes) == sorted(expected_nodes) diff --git a/cirq-core/cirq/contrib/paulistring/pauli_string_dag.py b/cirq-core/cirq/contrib/paulistring/pauli_string_dag.py index 1196431b1c9..0aaed13d856 100644 --- a/cirq-core/cirq/contrib/paulistring/pauli_string_dag.py +++ b/cirq-core/cirq/contrib/paulistring/pauli_string_dag.py @@ -15,6 +15,7 @@ from typing import cast from cirq import circuits, ops, protocols +from cirq.contrib import circuitdag def pauli_string_reorder_pred(op1: ops.Operation, op2: ops.Operation) -> bool: @@ -23,5 +24,5 @@ def pauli_string_reorder_pred(op1: ops.Operation, op2: ops.Operation) -> bool: return protocols.commutes(ps1, ps2) -def pauli_string_dag_from_circuit(circuit: circuits.Circuit) -> circuits.CircuitDag: - return circuits.CircuitDag.from_circuit(circuit, pauli_string_reorder_pred) +def pauli_string_dag_from_circuit(circuit: circuits.Circuit) -> circuitdag.CircuitDag: + return circuitdag.CircuitDag.from_circuit(circuit, pauli_string_reorder_pred) diff --git a/cirq-core/cirq/contrib/paulistring/pauli_string_optimize.py b/cirq-core/cirq/contrib/paulistring/pauli_string_optimize.py index 9223f900c07..4a102686355 100644 --- a/cirq-core/cirq/contrib/paulistring/pauli_string_optimize.py +++ b/cirq-core/cirq/contrib/paulistring/pauli_string_optimize.py @@ -15,10 +15,11 @@ import networkx from cirq import circuits, linalg +from cirq.contrib import circuitdag from cirq.contrib.paulistring.pauli_string_dag import pauli_string_dag_from_circuit -from cirq.ops import PauliStringGateOperation from cirq.contrib.paulistring.recombine import move_pauli_strings_into_circuit from cirq.contrib.paulistring.separate import convert_and_separate_circuit +from cirq.ops import PauliStringGateOperation def pauli_string_optimized_circuit( @@ -48,7 +49,7 @@ def assert_no_multi_qubit_pauli_strings(circuit: circuits.Circuit) -> None: assert len(op.pauli_string) == 1, 'Multi qubit Pauli string left over' -def merge_equal_strings(string_dag: circuits.CircuitDag) -> None: +def merge_equal_strings(string_dag: circuitdag.CircuitDag) -> None: for node in tuple(string_dag.nodes()): if node not in string_dag.nodes(): # Node was removed @@ -65,7 +66,7 @@ def merge_equal_strings(string_dag: circuits.CircuitDag) -> None: node.val = node.val.merged_with(other_node.val) -def remove_negligible_strings(string_dag: circuits.CircuitDag, atol=1e-8) -> None: +def remove_negligible_strings(string_dag: circuitdag.CircuitDag, atol=1e-8) -> None: for node in tuple(string_dag.nodes()): if linalg.all_near_zero_mod(node.val.exponent_relative, 2, atol=atol): string_dag.remove_node(node) diff --git a/cirq-core/cirq/contrib/paulistring/recombine.py b/cirq-core/cirq/contrib/paulistring/recombine.py index d40d03fe7c4..5780a8d8965 100644 --- a/cirq-core/cirq/contrib/paulistring/recombine.py +++ b/cirq-core/cirq/contrib/paulistring/recombine.py @@ -15,7 +15,7 @@ from typing import Any, Callable, Iterable, Sequence, Tuple, Union, cast, List from cirq import circuits, ops, protocols - +from cirq.contrib import circuitdag from cirq.contrib.paulistring.pauli_string_dag import ( pauli_string_reorder_pred, pauli_string_dag_from_circuit, @@ -26,7 +26,7 @@ def _sorted_best_string_placements( possible_nodes: Iterable[Any], output_ops: Sequence[ops.Operation], key: Callable[[Any], ops.PauliStringPhasor] = lambda node: node.val, -) -> List[Tuple[ops.PauliStringPhasor, int, circuits.Unique[ops.PauliStringPhasor]]]: +) -> List[Tuple[ops.PauliStringPhasor, int, circuitdag.Unique[ops.PauliStringPhasor]]]: sort_key = lambda placement: (-len(placement[0].pauli_string), placement[1]) @@ -65,10 +65,10 @@ def _sorted_best_string_placements( def move_pauli_strings_into_circuit( - circuit_left: Union[circuits.Circuit, circuits.CircuitDag], circuit_right: circuits.Circuit + circuit_left: Union[circuits.Circuit, circuitdag.CircuitDag], circuit_right: circuits.Circuit ) -> circuits.Circuit: - if isinstance(circuit_left, circuits.CircuitDag): - string_dag = circuits.CircuitDag(pauli_string_reorder_pred, circuit_left) + if isinstance(circuit_left, circuitdag.CircuitDag): + string_dag = circuitdag.CircuitDag(pauli_string_reorder_pred, circuit_left) else: string_dag = pauli_string_dag_from_circuit(cast(circuits.Circuit, circuit_left)) output_ops = list(circuit_right.all_operations()) diff --git a/cirq-core/cirq/contrib/routing/greedy.py b/cirq-core/cirq/contrib/routing/greedy.py index 8b563d148b8..ad2e20a9908 100644 --- a/cirq-core/cirq/contrib/routing/greedy.py +++ b/cirq-core/cirq/contrib/routing/greedy.py @@ -31,6 +31,7 @@ from cirq import circuits, ops, value import cirq.contrib.acquaintance as cca +from cirq.contrib import circuitdag from cirq.contrib.routing.initialization import get_initial_mapping from cirq.contrib.routing.swap_network import SwapNetwork from cirq.contrib.routing.utils import get_time_slices, ops_are_consistent_with_device_graph @@ -104,9 +105,10 @@ def __init__( max_search_radius: int = 1, max_num_empty_steps: int = 5, initial_mapping: Optional[Dict[ops.Qid, ops.Qid]] = None, - can_reorder: Callable[ - [ops.Operation, ops.Operation], bool - ] = circuits.circuit_dag._disjoint_qubits, + can_reorder: Callable[[ops.Operation, ops.Operation], bool] = lambda op1, op2: not set( + op1.qubits + ) + & set(op2.qubits), random_state: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ): @@ -119,7 +121,7 @@ def __init__( for b, d in neighbor_distances.items() } - self.remaining_dag = circuits.CircuitDag.from_circuit(circuit, can_reorder=can_reorder) + self.remaining_dag = circuitdag.CircuitDag.from_circuit(circuit, can_reorder=can_reorder) self.logical_qubits = list(self.remaining_dag.all_qubits()) self.physical_qubits = list(self.device_graph.nodes) self.edge_sets: Dict[int, List[Sequence[QidPair]]] = {} diff --git a/cirq-core/cirq/contrib/routing/utils.py b/cirq-core/cirq/contrib/routing/utils.py index 189a3df7755..5627261ee3f 100644 --- a/cirq-core/cirq/contrib/routing/utils.py +++ b/cirq-core/cirq/contrib/routing/utils.py @@ -20,6 +20,7 @@ from cirq import circuits, ops import cirq.contrib.acquaintance as cca +from cirq.contrib.circuitdag import CircuitDag from cirq.contrib.routing.swap_network import SwapNetwork if TYPE_CHECKING: @@ -28,7 +29,7 @@ BINARY_OP_PREDICATE = Callable[[ops.Operation, ops.Operation], bool] -def get_time_slices(dag: circuits.CircuitDag) -> List[nx.Graph]: +def get_time_slices(dag: CircuitDag) -> List[nx.Graph]: """Slices the DAG into logical graphs. Each time slice is a graph whose vertices are qubits and whose edges @@ -60,7 +61,7 @@ def is_valid_routing( swap_network: SwapNetwork, *, equals: BINARY_OP_PREDICATE = operator.eq, - can_reorder: BINARY_OP_PREDICATE = circuits.circuit_dag._disjoint_qubits, + can_reorder: BINARY_OP_PREDICATE = lambda op1, op2: not set(op1.qubits) & set(op2.qubits), ) -> bool: """Determines whether a swap network is consistent with a given circuit. @@ -75,7 +76,7 @@ def is_valid_routing( Raises: ValueError: If equals operator or can_reorder throws a ValueError. """ - circuit_dag = circuits.CircuitDag.from_circuit(circuit, can_reorder=can_reorder) + circuit_dag = CircuitDag.from_circuit(circuit, can_reorder=can_reorder) logical_operations = swap_network.get_logical_operations() try: return cca.is_topologically_sorted(circuit_dag, logical_operations, equals)