diff --git a/qiskit/circuit/equivalence.py b/qiskit/circuit/equivalence.py index 08525905f457..da5a60eb000d 100644 --- a/qiskit/circuit/equivalence.py +++ b/qiskit/circuit/equivalence.py @@ -45,11 +45,12 @@ def __init__(self, *, base=None): if base is None: self._graph = rx.PyDiGraph() self._key_to_node_index = {} - self._rule_count = 0 + # Some unique identifier for rules. + self._rule_id = 0 else: self._graph = base._graph.copy() self._key_to_node_index = copy.deepcopy(base._key_to_node_index) - self._rule_count = base._rule_count + self._rule_id = base._rule_id @property def graph(self) -> rx.PyDiGraph: @@ -104,12 +105,12 @@ def add_equivalence(self, gate, equivalent_circuit): ( self._set_default_node(source), target, - EdgeData(index=self._rule_count, num_gates=len(sources), rule=equiv, source=source), + EdgeData(index=self._rule_id, num_gates=len(sources), rule=equiv, source=source), ) for source in sources ] self._graph.add_edges_from(edges) - self._rule_count += 1 + self._rule_id += 1 def has_entry(self, gate): """Check if a library contains any decompositions for gate. @@ -142,10 +143,15 @@ def set_entry(self, gate, entry): _raise_if_shape_mismatch(gate, equiv) _raise_if_param_mismatch(gate.params, equiv.parameters) - key = Key(name=gate.name, num_qubits=gate.num_qubits) - equivs = [Equivalence(params=gate.params.copy(), circuit=equiv.copy()) for equiv in entry] - - self._graph[self._set_default_node(key)] = NodeData(key=key, equivs=equivs) + node_index = self._set_default_node(Key(name=gate.name, num_qubits=gate.num_qubits)) + # Remove previous equivalences of this node, leaving in place any later equivalences that + # were added that use `gate`. + self._graph[node_index].equivs.clear() + for (parent, child, _) in self._graph.in_edges(node_index): + # `child` should always be ourselves, but there might be parallel edges. + self._graph.remove_edge(parent, child) + for equivalence in entry: + self.add_equivalence(gate, equivalence) def get_entry(self, gate): """Gets the set of QuantumCircuits circuits from the library which diff --git a/releasenotes/notes/fix-equivalence-setentry-5a30b0790666fcf2.yaml b/releasenotes/notes/fix-equivalence-setentry-5a30b0790666fcf2.yaml new file mode 100644 index 000000000000..9df1eb36c265 --- /dev/null +++ b/releasenotes/notes/fix-equivalence-setentry-5a30b0790666fcf2.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Calling :meth:`.EquivalenceLibrary.set_entry` will now correctly update the internal graph + object of the library. Previously, the metadata would be updated, but the graph structure would + be unaltered, meaning that users like :class:`.BasisTranslator` would still use the old rules. + Fixed `#11958 `__. diff --git a/test/python/circuit/test_equivalence.py b/test/python/circuit/test_equivalence.py index bc2f3c957582..59ecd8287394 100644 --- a/test/python/circuit/test_equivalence.py +++ b/test/python/circuit/test_equivalence.py @@ -103,24 +103,42 @@ def test_add_double_entry(self): self.assertEqual(entry[1], second_equiv) def test_set_entry(self): - """Verify setting an entry overrides any previously added.""" + """Verify setting an entry overrides any previously added, without affecting entries that + depended on the set entry.""" eq_lib = EquivalenceLibrary() - gate = OneQubitZeroParamGate() - first_equiv = QuantumCircuit(1) - first_equiv.h(0) - - eq_lib.add_equivalence(gate, first_equiv) - - second_equiv = QuantumCircuit(1) - second_equiv.append(U2Gate(0, np.pi), [0]) - - eq_lib.set_entry(gate, [second_equiv]) - - entry = eq_lib.get_entry(gate) - - self.assertEqual(len(entry), 1) - self.assertEqual(entry[0], second_equiv) + gates = {key: Gate(key, 1, []) for key in "abcd"} + target = Gate("target", 1, []) + + old = QuantumCircuit(1) + old.append(gates["a"], [0]) + old.append(gates["b"], [0]) + eq_lib.add_equivalence(target, old) + + outbound = QuantumCircuit(1) + outbound.append(target, [0]) + eq_lib.add_equivalence(gates["c"], outbound) + + self.assertEqual(eq_lib.get_entry(target), [old]) + self.assertEqual(eq_lib.get_entry(gates["c"]), [outbound]) + # Assert the underlying graph structure is correct as well. + gate_indices = {eq_lib.graph[node].key.name: node for node in eq_lib.graph.node_indices()} + self.assertTrue(eq_lib.graph.has_edge(gate_indices["a"], gate_indices["target"])) + self.assertTrue(eq_lib.graph.has_edge(gate_indices["b"], gate_indices["target"])) + self.assertTrue(eq_lib.graph.has_edge(gate_indices["target"], gate_indices["c"])) + + new = QuantumCircuit(1) + new.append(gates["d"], [0]) + eq_lib.set_entry(target, [new]) + + self.assertEqual(eq_lib.get_entry(target), [new]) + self.assertEqual(eq_lib.get_entry(gates["c"]), [outbound]) + # Assert the underlying graph structure is correct as well. + gate_indices = {eq_lib.graph[node].key.name: node for node in eq_lib.graph.node_indices()} + self.assertFalse(eq_lib.graph.has_edge(gate_indices["a"], gate_indices["target"])) + self.assertFalse(eq_lib.graph.has_edge(gate_indices["b"], gate_indices["target"])) + self.assertTrue(eq_lib.graph.has_edge(gate_indices["d"], gate_indices["target"])) + self.assertTrue(eq_lib.graph.has_edge(gate_indices["target"], gate_indices["c"])) def test_raise_if_gate_entry_shape_mismatch(self): """Verify we raise if adding a circuit and gate with different shapes."""