diff --git a/cirq-core/cirq/ops/phased_x_gate.py b/cirq-core/cirq/ops/phased_x_gate.py index c2c647aee10..cf83219fbcd 100644 --- a/cirq-core/cirq/ops/phased_x_gate.py +++ b/cirq-core/cirq/ops/phased_x_gate.py @@ -80,9 +80,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE': assert len(qubits) == 1 q = qubits[0] z = cirq.Z(q) ** self._phase_exponent - x = cirq.X(q) ** self._exponent - if protocols.is_parameterized(z): - return NotImplemented + x = cirq.XPowGate(exponent=self._exponent, global_shift=self.global_shift).on(q) return z ** -1, x, z @property diff --git a/cirq-core/cirq/ops/phased_x_gate_test.py b/cirq-core/cirq/ops/phased_x_gate_test.py index 8cfcc144c4d..04e8420cafd 100644 --- a/cirq-core/cirq/ops/phased_x_gate_test.py +++ b/cirq-core/cirq/ops/phased_x_gate_test.py @@ -35,14 +35,11 @@ ], ) def test_phased_x_consistent_protocols(phase_exponent): - # If there is no global_shift, the gate is global phase insensitive. cirq.testing.assert_implements_consistent_protocols( cirq.PhasedXPowGate(phase_exponent=phase_exponent, exponent=1.0), - ignoring_global_phase=False, ) cirq.testing.assert_implements_consistent_protocols( cirq.PhasedXPowGate(phase_exponent=phase_exponent, exponent=1.0, global_shift=0.1), - ignoring_global_phase=True, ) @@ -171,26 +168,36 @@ def test_str_repr(): ) -@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once]) -def test_parameterize(resolve_fn): +@pytest.mark.parametrize( + 'resolve_fn, global_shift', [(cirq.resolve_parameters, 0), (cirq.resolve_parameters_once, 0.1)] +) +def test_parameterize(resolve_fn, global_shift): parameterized_gate = cirq.PhasedXPowGate( - exponent=sympy.Symbol('a'), phase_exponent=sympy.Symbol('b') + exponent=sympy.Symbol('a'), phase_exponent=sympy.Symbol('b'), global_shift=global_shift ) assert cirq.pow(parameterized_gate, 5) == cirq.PhasedXPowGate( - exponent=sympy.Symbol('a') * 5, phase_exponent=sympy.Symbol('b') - ) - assert ( - cirq.decompose_once_with_qubits(parameterized_gate, [cirq.LineQubit(0)], NotImplemented) - is NotImplemented + exponent=sympy.Symbol('a') * 5, phase_exponent=sympy.Symbol('b'), global_shift=global_shift ) assert cirq.unitary(parameterized_gate, default=None) is None assert cirq.is_parameterized(parameterized_gate) + q = cirq.NamedQubit("q") + parameterized_decomposed_circuit = cirq.Circuit(cirq.decompose(parameterized_gate(q))) + for resolver in cirq.Linspace('a', 0, 2, 10) * cirq.Linspace('b', 0, 2, 10): + resolved_gate = resolve_fn(parameterized_gate, resolver) + assert resolved_gate == cirq.PhasedXPowGate( + exponent=resolver.value_of('a'), + phase_exponent=resolver.value_of('b'), + global_shift=global_shift, + ) + np.testing.assert_allclose( + cirq.unitary(resolved_gate(q)), + cirq.unitary(resolve_fn(parameterized_decomposed_circuit, resolver)), + atol=1e-8, + ) - resolver = cirq.ParamResolver({'a': 0.1, 'b': 0.2}) - resolved_gate = resolve_fn(parameterized_gate, resolver) - assert resolved_gate == cirq.PhasedXPowGate(exponent=0.1, phase_exponent=0.2) - - unparameterized_gate = cirq.PhasedXPowGate(exponent=0.1, phase_exponent=0.2) + unparameterized_gate = cirq.PhasedXPowGate( + exponent=0.1, phase_exponent=0.2, global_shift=global_shift + ) assert not cirq.is_parameterized(unparameterized_gate) assert cirq.is_parameterized(unparameterized_gate ** sympy.Symbol('a')) assert cirq.is_parameterized(unparameterized_gate ** (sympy.Symbol('a') + 1))