diff --git a/crates/circuit/src/dag_circuit.rs b/crates/circuit/src/dag_circuit.rs index fdaa81e3b4ca..1f4acde9934a 100644 --- a/crates/circuit/src/dag_circuit.rs +++ b/crates/circuit/src/dag_circuit.rs @@ -4350,9 +4350,7 @@ def _format(operand): let mut new_layer = self.copy_empty_like(py, vars_mode)?; - for (node, _) in op_nodes { - new_layer.push_back(py, node.clone())?; - } + new_layer.extend(py, op_nodes.iter().map(|(inst, _)| (*inst).clone()))?; let new_layer_op_nodes = new_layer.op_nodes(false).filter_map(|node_index| { match new_layer.dag.node_weight(node_index) { @@ -6347,6 +6345,143 @@ impl DAGCircuit { Err(DAGCircuitError::new_err("Specified node is not an op node")) } } + + /// Extends the DAG with valid instances of [PackedInstruction] + pub fn extend(&mut self, py: Python, iter: I) -> PyResult> + where + I: IntoIterator, + { + // Create HashSets to keep track of each bit/var's last node + let mut qubit_last_nodes: HashMap = HashMap::default(); + let mut clbit_last_nodes: HashMap = HashMap::default(); + // TODO: Refactor once Vars are in rust + // Dict [ Var: (int, VarWeight)] + let vars_last_nodes: Bound = PyDict::new_bound(py); + + // Consume into iterator to obtain size hint + let iter = iter.into_iter(); + // Store new nodes to return + let mut new_nodes = Vec::with_capacity(iter.size_hint().1.unwrap_or_default()); + for instr in iter { + let op_name = instr.op.name(); + let (all_cbits, vars): (Vec, Option>) = { + if self.may_have_additional_wires(py, &instr) { + let mut clbits: HashSet = + HashSet::from_iter(self.cargs_interner.get(instr.clbits).iter().copied()); + let (additional_clbits, additional_vars) = + self.additional_wires(py, instr.op.view(), instr.condition())?; + for clbit in additional_clbits { + clbits.insert(clbit); + } + (clbits.into_iter().collect(), Some(additional_vars)) + } else { + (self.cargs_interner.get(instr.clbits).to_vec(), None) + } + }; + + // Increment the operation count + self.increment_op(op_name); + + // Get the correct qubit indices + let qubits_id = instr.qubits; + + // Insert op-node to graph. + let new_node = self.dag.add_node(NodeType::Operation(instr)); + new_nodes.push(new_node); + + // Check all the qubits in this instruction. + for qubit in self.qargs_interner.get(qubits_id) { + // Retrieve each qubit's last node + let qubit_last_node = *qubit_last_nodes.entry(*qubit).or_insert_with(|| { + // If the qubit is not in the last nodes collection, the edge between the output node and its predecessor. + // Then, store the predecessor's NodeIndex in the last nodes collection. + let output_node = self.qubit_io_map[qubit.0 as usize][1]; + let (edge_id, predecessor_node) = self + .dag + .edges_directed(output_node, Incoming) + .next() + .map(|edge| (edge.id(), edge.source())) + .unwrap(); + self.dag.remove_edge(edge_id); + predecessor_node + }); + qubit_last_nodes + .entry(*qubit) + .and_modify(|val| *val = new_node); + self.dag + .add_edge(qubit_last_node, new_node, Wire::Qubit(*qubit)); + } + + // Check all the clbits in this instruction. + for clbit in all_cbits { + let clbit_last_node = *clbit_last_nodes.entry(clbit).or_insert_with(|| { + // If the qubit is not in the last nodes collection, the edge between the output node and its predecessor. + // Then, store the predecessor's NodeIndex in the last nodes collection. + let output_node = self.clbit_io_map[clbit.0 as usize][1]; + let (edge_id, predecessor_node) = self + .dag + .edges_directed(output_node, Incoming) + .next() + .map(|edge| (edge.id(), edge.source())) + .unwrap(); + self.dag.remove_edge(edge_id); + predecessor_node + }); + clbit_last_nodes + .entry(clbit) + .and_modify(|val| *val = new_node); + self.dag + .add_edge(clbit_last_node, new_node, Wire::Clbit(clbit)); + } + + // If available, check all the vars in this instruction + for var in vars.iter().flatten() { + let var_last_node = if let Some(result) = vars_last_nodes.get_item(var)? { + let node: usize = result.extract()?; + vars_last_nodes.del_item(var)?; + NodeIndex::new(node) + } else { + // If the var is not in the last nodes collection, the edge between the output node and its predecessor. + // Then, store the predecessor's NodeIndex in the last nodes collection. + let output_node = self.var_output_map.get(py, var).unwrap(); + let (edge_id, predecessor_node) = self + .dag + .edges_directed(output_node, Incoming) + .next() + .map(|edge| (edge.id(), edge.source())) + .unwrap(); + self.dag.remove_edge(edge_id); + predecessor_node + }; + + vars_last_nodes.set_item(var, new_node.index())?; + self.dag + .add_edge(var_last_node, new_node, Wire::Var(var.clone_ref(py))); + } + } + + // Add the output_nodes back to qargs + for (qubit, node) in qubit_last_nodes { + let output_node = self.qubit_io_map[qubit.0 as usize][1]; + self.dag.add_edge(node, output_node, Wire::Qubit(qubit)); + } + + // Add the output_nodes back to cargs + for (clbit, node) in clbit_last_nodes { + let output_node = self.clbit_io_map[clbit.0 as usize][1]; + self.dag.add_edge(node, output_node, Wire::Clbit(clbit)); + } + + // Add the output_nodes back to vars + for item in vars_last_nodes.items() { + let (var, node): (PyObject, usize) = item.extract()?; + let output_node = self.var_output_map.get(py, &var).unwrap(); + self.dag + .add_edge(NodeIndex::new(node), output_node, Wire::Var(var)); + } + + Ok(new_nodes) + } } /// Add to global phase. Global phase can only be Float or ParameterExpression so this