Skip to content

Commit

Permalink
Move CircuitDag to contrib (#5481)
Browse files Browse the repository at this point in the history
* Move CircuitDag to contrib

- Moves CircuitDag to contrib.
- cirq.CircuitDag is now deprecated and now must be changed
to cirq.contrib.CircuitDag.
- Also moves the related cirq.Unique class

Note: the tests use a high number of cirq.Unique instances which
are also moved, so unable to verify count of deprecation messages
in many tests.
  • Loading branch information
dstrain115 authored Jun 13, 2022
1 parent a823872 commit 39795e1
Show file tree
Hide file tree
Showing 13 changed files with 634 additions and 125 deletions.
4 changes: 3 additions & 1 deletion cirq-core/cirq/circuits/circuit_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import functools
import networkx

from cirq import ops
from cirq import _compat, ops
from cirq.circuits import circuit

if TYPE_CHECKING:
Expand All @@ -25,6 +25,7 @@
T = TypeVar('T')


@_compat.deprecated_class(deadline='v0.16', fix='Use cirq contrib.Unique instead.')
@functools.total_ordering
class Unique(Generic[T]):
"""A wrapper for a value that doesn't compare equal to other instances.
Expand Down Expand Up @@ -55,6 +56,7 @@ def _disjoint_qubits(op1: 'cirq.Operation', op2: 'cirq.Operation') -> bool:
return not set(op1.qubits) & set(op2.qubits)


@_compat.deprecated_class(deadline='v0.16', fix='Use cirq contrib.CircuitDag instead.')
class CircuitDag(networkx.DiGraph):
"""A representation of a Circuit as a directed acyclic graph.
Expand Down
246 changes: 142 additions & 104 deletions cirq-core/cirq/circuits/circuit_dag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@ class FakeDevice(cirq.Device):
def test_wrapper_eq():
q0, q1 = cirq.LineQubit.range(2)
eq = cirq.testing.EqualsTester()
eq.add_equality_group(cirq.CircuitDag.make_node(cirq.X(q0)))
eq.add_equality_group(cirq.CircuitDag.make_node(cirq.X(q0)))
eq.add_equality_group(cirq.CircuitDag.make_node(cirq.Y(q0)))
eq.add_equality_group(cirq.CircuitDag.make_node(cirq.X(q1)))

with cirq.testing.assert_deprecated('Use cirq contrib.Unique', deadline='v0.16', count=4):
eq.add_equality_group(cirq.CircuitDag.make_node(cirq.X(q0)))
eq.add_equality_group(cirq.CircuitDag.make_node(cirq.X(q0)))
eq.add_equality_group(cirq.CircuitDag.make_node(cirq.Y(q0)))
eq.add_equality_group(cirq.CircuitDag.make_node(cirq.X(q1)))


def test_wrapper_cmp():
u0 = cirq.Unique(0)
u1 = cirq.Unique(1)
with cirq.testing.assert_deprecated('Use cirq contrib.Unique', deadline='v0.16', count=2):
u0 = cirq.Unique(0)
u1 = cirq.Unique(1)
# The ordering of Unique instances is unpredictable
u0, u1 = (u1, u0) if u1 < u0 else (u0, u1)
assert u0 == u0
Expand All @@ -50,88 +53,107 @@ def test_wrapper_cmp():


def test_wrapper_cmp_failure():
with pytest.raises(TypeError):
_ = object() < cirq.Unique(1)
with pytest.raises(TypeError):
_ = cirq.Unique(1) < object()
with cirq.testing.assert_deprecated('Use cirq contrib.Unique', deadline='v0.16', count=2):
with pytest.raises(TypeError):
_ = object() < cirq.Unique(1)
with pytest.raises(TypeError):
_ = cirq.Unique(1) < object()


def test_wrapper_repr():
q0 = cirq.LineQubit(0)

node = cirq.CircuitDag.make_node(cirq.X(q0))
assert repr(node) == 'cirq.Unique(' + str(id(node)) + ', cirq.X(cirq.LineQubit(0)))'
with cirq.testing.assert_deprecated('Use cirq contrib.Unique', deadline='v0.16'):
node = cirq.CircuitDag.make_node(cirq.X(q0))
assert repr(node) == 'cirq.Unique(' + str(id(node)) + ', cirq.X(cirq.LineQubit(0)))'


def test_init():
dag = cirq.CircuitDag()
assert networkx.dag.is_directed_acyclic_graph(dag)
assert list(dag.nodes()) == []
assert list(dag.edges()) == []
with cirq.testing.assert_deprecated('Use cirq contrib.CircuitDag', deadline='v0.16', count=1):
dag = cirq.CircuitDag()
assert networkx.dag.is_directed_acyclic_graph(dag)
assert list(dag.nodes()) == []
assert list(dag.edges()) == []


def test_append():
q0 = cirq.LineQubit(0)
dag = cirq.CircuitDag()
dag.append(cirq.X(q0))
dag.append(cirq.Y(q0))
assert networkx.dag.is_directed_acyclic_graph(dag)
assert len(dag.nodes()) == 2
assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))]
with cirq.testing.assert_deprecated(
'Use cirq contrib.CircuitDag', deadline='v0.16', count=None
):
dag = cirq.CircuitDag()
dag.append(cirq.X(q0))
dag.append(cirq.Y(q0))
assert networkx.dag.is_directed_acyclic_graph(dag)
assert len(dag.nodes()) == 2
assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))]


def test_two_identical_ops():
q0 = cirq.LineQubit(0)
dag = cirq.CircuitDag()
dag.append(cirq.X(q0))
dag.append(cirq.Y(q0))
dag.append(cirq.X(q0))
assert networkx.dag.is_directed_acyclic_graph(dag)
assert len(dag.nodes()) == 3
assert set((n1.val, n2.val) for n1, n2 in dag.edges()) == {
(cirq.X(q0), cirq.Y(q0)),
(cirq.X(q0), cirq.X(q0)),
(cirq.Y(q0), cirq.X(q0)),
}
with cirq.testing.assert_deprecated(
'Use cirq contrib.CircuitDag', deadline='v0.16', count=None
):
dag = cirq.CircuitDag()
dag.append(cirq.X(q0))
dag.append(cirq.Y(q0))
dag.append(cirq.X(q0))
assert networkx.dag.is_directed_acyclic_graph(dag)
assert len(dag.nodes()) == 3
assert set((n1.val, n2.val) for n1, n2 in dag.edges()) == {
(cirq.X(q0), cirq.Y(q0)),
(cirq.X(q0), cirq.X(q0)),
(cirq.Y(q0), cirq.X(q0)),
}


def test_from_ops():
q0 = cirq.LineQubit(0)
dag = cirq.CircuitDag.from_ops(cirq.X(q0), cirq.Y(q0))
assert networkx.dag.is_directed_acyclic_graph(dag)
assert len(dag.nodes()) == 2
assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))]
with cirq.testing.assert_deprecated(
'Use cirq contrib.CircuitDag', deadline='v0.16', count=None
):
dag = cirq.CircuitDag.from_ops(cirq.X(q0), cirq.Y(q0))
assert networkx.dag.is_directed_acyclic_graph(dag)
assert len(dag.nodes()) == 2
assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))]


def test_from_circuit():
q0 = cirq.LineQubit(0)
circuit = cirq.Circuit(cirq.X(q0), cirq.Y(q0))
dag = cirq.CircuitDag.from_circuit(circuit)
assert networkx.dag.is_directed_acyclic_graph(dag)
assert len(dag.nodes()) == 2
assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))]
assert sorted(circuit.all_qubits()) == sorted(dag.all_qubits())
with cirq.testing.assert_deprecated(
'Use cirq contrib.CircuitDag', deadline='v0.16', count=None
):
dag = cirq.CircuitDag.from_circuit(circuit)
assert networkx.dag.is_directed_acyclic_graph(dag)
assert len(dag.nodes()) == 2
assert [(n1.val, n2.val) for n1, n2 in dag.edges()] == [(cirq.X(q0), cirq.Y(q0))]
assert sorted(circuit.all_qubits()) == sorted(dag.all_qubits())


def test_to_empty_circuit():
circuit = cirq.Circuit()
dag = cirq.CircuitDag.from_circuit(circuit)
assert networkx.dag.is_directed_acyclic_graph(dag)
assert circuit == dag.to_circuit()
with cirq.testing.assert_deprecated('Use cirq contrib.CircuitDag', deadline='v0.16'):
dag = cirq.CircuitDag.from_circuit(circuit)
assert networkx.dag.is_directed_acyclic_graph(dag)
assert circuit == dag.to_circuit()


def test_to_circuit():
q0 = cirq.LineQubit(0)
circuit = cirq.Circuit(cirq.X(q0), cirq.Y(q0))
dag = cirq.CircuitDag.from_circuit(circuit)
with cirq.testing.assert_deprecated(
'Use cirq contrib.CircuitDag', deadline='v0.16', count=None
):
dag = cirq.CircuitDag.from_circuit(circuit)

assert networkx.dag.is_directed_acyclic_graph(dag)
# Only one possible output circuit for this simple case
assert circuit == dag.to_circuit()
assert networkx.dag.is_directed_acyclic_graph(dag)
# Only one possible output circuit for this simple case
assert circuit == dag.to_circuit()

cirq.testing.assert_allclose_up_to_global_phase(
circuit.unitary(), dag.to_circuit().unitary(), atol=1e-7
)
cirq.testing.assert_allclose_up_to_global_phase(
circuit.unitary(), dag.to_circuit().unitary(), atol=1e-7
)


def test_equality():
Expand All @@ -156,43 +178,49 @@ def test_equality():
)

eq = cirq.testing.EqualsTester()
eq.make_equality_group(
lambda: cirq.CircuitDag.from_circuit(circuit1),
lambda: cirq.CircuitDag.from_circuit(circuit2),
)
eq.add_equality_group(cirq.CircuitDag.from_circuit(circuit3))
eq.add_equality_group(cirq.CircuitDag.from_circuit(circuit4))
with cirq.testing.assert_deprecated(
'Use cirq contrib.CircuitDag', deadline='v0.16', count=None
):
eq.make_equality_group(
lambda: cirq.CircuitDag.from_circuit(circuit1),
lambda: cirq.CircuitDag.from_circuit(circuit2),
)
eq.add_equality_group(cirq.CircuitDag.from_circuit(circuit3))
eq.add_equality_group(cirq.CircuitDag.from_circuit(circuit4))


def test_larger_circuit():
q0, q1, q2, q3 = [
cirq.GridQubit(0, 5),
cirq.GridQubit(1, 5),
cirq.GridQubit(2, 5),
cirq.GridQubit(3, 5),
]
# This circuit does not have CZ gates on adjacent qubits because the order
# dag.to_circuit() would append them is non-deterministic.
circuit = cirq.Circuit(
cirq.X(q0),
cirq.CZ(q1, q2),
cirq.CZ(q0, q1),
cirq.Y(q0),
cirq.Z(q0),
cirq.CZ(q1, q2),
cirq.X(q0),
cirq.Y(q0),
cirq.CZ(q0, q1),
cirq.T(q3),
strategy=cirq.InsertStrategy.EARLIEST,
)

dag = cirq.CircuitDag.from_circuit(circuit)

assert networkx.dag.is_directed_acyclic_graph(dag)
# Operation order within a moment is non-deterministic
# but text diagrams still look the same.
desired = """
with cirq.testing.assert_deprecated(
'Use cirq contrib.CircuitDag', deadline='v0.16', count=None
):
q0, q1, q2, q3 = [
cirq.GridQubit(0, 5),
cirq.GridQubit(1, 5),
cirq.GridQubit(2, 5),
cirq.GridQubit(3, 5),
]
# This circuit does not have CZ gates on adjacent qubits because the order
# dag.to_circuit() would append them is non-deterministic.
circuit = cirq.Circuit(
cirq.X(q0),
cirq.CZ(q1, q2),
cirq.CZ(q0, q1),
cirq.Y(q0),
cirq.Z(q0),
cirq.CZ(q1, q2),
cirq.X(q0),
cirq.Y(q0),
cirq.CZ(q0, q1),
cirq.T(q3),
strategy=cirq.InsertStrategy.EARLIEST,
)

dag = cirq.CircuitDag.from_circuit(circuit)

assert networkx.dag.is_directed_acyclic_graph(dag)
# Operation order within a moment is non-deterministic
# but text diagrams still look the same.
desired = """
(0, 5): ───X───@───Y───Z───X───Y───@───
│ │
(1, 5): ───@───@───@───────────────@───
Expand All @@ -201,20 +229,26 @@ def test_larger_circuit():
(3, 5): ───T───────────────────────────
"""
cirq.testing.assert_has_diagram(circuit, desired)
cirq.testing.assert_has_diagram(dag.to_circuit(), desired)
cirq.testing.assert_has_diagram(circuit, desired)
cirq.testing.assert_has_diagram(dag.to_circuit(), desired)

cirq.testing.assert_allclose_up_to_global_phase(
circuit.unitary(), dag.to_circuit().unitary(), atol=1e-7
)
cirq.testing.assert_allclose_up_to_global_phase(
circuit.unitary(), dag.to_circuit().unitary(), atol=1e-7
)


@pytest.mark.parametrize('circuit', [cirq.testing.random_circuit(10, 10, 0.5) for _ in range(3)])
def test_is_maximalist(circuit):
dag = cirq.CircuitDag.from_circuit(circuit)
transitive_closure = networkx.dag.transitive_closure(dag)
assert cirq.CircuitDag(incoming_graph_data=transitive_closure) == dag
assert not any(dag.has_edge(b, a) for a, b in itertools.combinations(dag.ordered_nodes(), 2))
# This creates a number of Unique classes so the count is not consistent.
with cirq.testing.assert_deprecated(
'Use cirq contrib.CircuitDag', deadline='v0.16', count=None
):
dag = cirq.CircuitDag.from_circuit(circuit)
transitive_closure = networkx.dag.transitive_closure(dag)
assert cirq.CircuitDag(incoming_graph_data=transitive_closure) == dag
assert not any(
dag.has_edge(b, a) for a, b in itertools.combinations(dag.ordered_nodes(), 2)
)


def _get_circuits_and_is_blockers():
Expand All @@ -230,12 +264,16 @@ def _get_circuits_and_is_blockers():

@pytest.mark.parametrize('circuit, is_blocker', _get_circuits_and_is_blockers())
def test_findall_nodes_until_blocked(circuit, is_blocker):
dag = cirq.CircuitDag.from_circuit(circuit)
all_nodes = list(dag.ordered_nodes())
found_nodes = list(dag.findall_nodes_until_blocked(is_blocker))
assert not any(dag.has_edge(b, a) for a, b in itertools.combinations(found_nodes, 2))

blocking_nodes = set(node for node in all_nodes if is_blocker(node.val))
blocked_nodes = blocking_nodes.union(*(dag.succ[node] for node in blocking_nodes))
expected_nodes = set(all_nodes) - blocked_nodes
assert sorted(found_nodes) == sorted(expected_nodes)
# This creates a number of Unique classes so the count is not consistent.
with cirq.testing.assert_deprecated(
'Use cirq contrib.CircuitDag', deadline='v0.16', count=None
):
dag = cirq.CircuitDag.from_circuit(circuit)
all_nodes = list(dag.ordered_nodes())
found_nodes = list(dag.findall_nodes_until_blocked(is_blocker))
assert not any(dag.has_edge(b, a) for a, b in itertools.combinations(found_nodes, 2))

blocking_nodes = set(node for node in all_nodes if is_blocker(node.val))
blocked_nodes = blocking_nodes.union(*(dag.succ[node] for node in blocking_nodes))
expected_nodes = set(all_nodes) - blocked_nodes
assert sorted(found_nodes) == sorted(expected_nodes)
1 change: 1 addition & 0 deletions cirq-core/cirq/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from cirq.contrib import quirk
from cirq.contrib.qcircuit import circuit_to_latex_using_qcircuit
from cirq.contrib import json
from cirq.contrib.circuitdag import CircuitDag, Unique
5 changes: 3 additions & 2 deletions cirq-core/cirq/contrib/acquaintance/inspection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@

from typing import FrozenSet, Sequence, Set, TYPE_CHECKING

from cirq import circuits, devices
from cirq import devices

from cirq.contrib.acquaintance.executor import AcquaintanceOperation, ExecutionStrategy
from cirq.contrib.acquaintance.mutation_utils import expose_acquaintance_gates
from cirq.contrib.acquaintance.permutation import LogicalIndex, LogicalMapping
from cirq.contrib import circuitdag

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -59,7 +60,7 @@ def get_acquaintance_dag(strategy: 'cirq.Circuit', initial_mapping: LogicalMappi
for op in moment.operations
if isinstance(op, AcquaintanceOperation)
)
return circuits.CircuitDag.from_ops(acquaintance_ops)
return circuitdag.CircuitDag.from_ops(acquaintance_ops)


def get_logical_acquaintance_opportunities(
Expand Down
3 changes: 2 additions & 1 deletion cirq-core/cirq/contrib/acquaintance/topological_sort_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
[
(dag, tuple(cca.random_topological_sort(dag)))
for dag in [
cirq.CircuitDag.from_circuit(cirq.testing.random_circuit(10, 10, 0.5)) for _ in range(5)
cirq.contrib.CircuitDag.from_circuit(cirq.testing.random_circuit(10, 10, 0.5))
for _ in range(5)
]
for _ in range(5)
],
Expand Down
Loading

0 comments on commit 39795e1

Please sign in to comment.