From c7ccbd4a1819f049c9054fbff51b0bf72c6aa483 Mon Sep 17 00:00:00 2001 From: Julien Gacon Date: Tue, 5 Nov 2024 22:55:02 +0100 Subject: [PATCH] Fix CC for comparison with rotation gates on angles k*pi (#13399) --- crates/accelerate/src/commutation_checker.rs | 4 +- .../circuit/test_commutation_checker.py | 54 +++++++++++++------ 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/crates/accelerate/src/commutation_checker.rs b/crates/accelerate/src/commutation_checker.rs index 16fcc5eca8fb..fe242c73422f 100644 --- a/crates/accelerate/src/commutation_checker.rs +++ b/crates/accelerate/src/commutation_checker.rs @@ -43,6 +43,8 @@ static SUPPORTED_OP: Lazy> = Lazy::new(|| { ]) }); +const TWOPI: f64 = 2.0 * std::f64::consts::PI; + // map rotation gates to their generators, or to ``None`` if we cannot currently efficiently // represent the generator in Rust and store the commutation relation in the commutation dictionary static SUPPORTED_ROTATIONS: Lazy>> = Lazy::new(|| { @@ -632,7 +634,7 @@ fn map_rotation<'a>( // commute with everything, and we simply return the operation with the flag that // it commutes trivially if let Param::Float(angle) = params[0] { - if (angle % std::f64::consts::PI).abs() < tol { + if (angle % TWOPI).abs() < tol { return (op, params, true); }; }; diff --git a/test/python/circuit/test_commutation_checker.py b/test/python/circuit/test_commutation_checker.py index 8bb3f6939add..9759b5bffd1e 100644 --- a/test/python/circuit/test_commutation_checker.py +++ b/test/python/circuit/test_commutation_checker.py @@ -16,7 +16,7 @@ from test import QiskitTestCase # pylint: disable=wrong-import-order import numpy as np -from ddt import data, ddt +from ddt import idata, ddt from qiskit import ClassicalRegister from qiskit.circuit import ( @@ -52,9 +52,25 @@ SGate, XGate, ZGate, + HGate, ) from qiskit.dagcircuit import DAGOpNode +ROTATION_GATES = [ + RXGate, + RYGate, + RZGate, + PhaseGate, + CRXGate, + CRYGate, + CRZGate, + CPhaseGate, + RXXGate, + RYYGate, + RZZGate, + RZXGate, +] + class NewGateCX(Gate): """A dummy class containing an cx gate unknown to the commutation checker's library.""" @@ -373,20 +389,7 @@ def test_serialization(self): cc2.commute_nodes(dop1, dop2) self.assertEqual(cc2.num_cached_entries(), 1) - @data( - RXGate, - RYGate, - RZGate, - PhaseGate, - CRXGate, - CRYGate, - CRZGate, - CPhaseGate, - RXXGate, - RYYGate, - RZZGate, - RZXGate, - ) + @idata(ROTATION_GATES) def test_cutoff_angles(self, gate_cls): """Check rotations with a small enough angle are cut off.""" max_power = 30 @@ -406,6 +409,27 @@ def test_cutoff_angles(self, gate_cls): else: self.assertFalse(scc.commute(generic_gate, [0, 1], [], gate, qargs, [])) + @idata(ROTATION_GATES) + def test_rotation_mod_2pi(self, gate_cls): + """Test the rotations modulo 2pi commute with any gate.""" + generic_gate = HGate() # does not commute with any rotation gate + even = np.arange(-6, 7, 2) + + with self.subTest(msg="even multiples"): + for multiple in even: + gate = gate_cls(multiple * np.pi) + self.assertTrue( + scc.commute(generic_gate, [0], [], gate, list(range(gate.num_qubits)), []) + ) + + odd = np.arange(-5, 6, 2) + with self.subTest(msg="odd multiples"): + for multiple in odd: + gate = gate_cls(multiple * np.pi) + self.assertFalse( + scc.commute(generic_gate, [0], [], gate, list(range(gate.num_qubits)), []) + ) + if __name__ == "__main__": unittest.main()