Skip to content

Commit

Permalink
Support decomposition of parameterized gates in cirq.PhasedXPowGate (#…
Browse files Browse the repository at this point in the history
…5083)

- Adds support for decomposing parameterized `cirq.PhasedXPowGate`s.
- Also modifies the decomposition to respect global phase, so that the decomposition is valid for controlled variants as well. 
- Part of #4858
  • Loading branch information
tanujkhattar authored Mar 16, 2022
1 parent 765ccfe commit 28f90b0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
4 changes: 1 addition & 3 deletions cirq-core/cirq/ops/phased_x_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
assert len(qubits) == 1
q = qubits[0]
z = cirq.Z(q) ** self._phase_exponent
x = cirq.X(q) ** self._exponent
if protocols.is_parameterized(z):
return NotImplemented
x = cirq.XPowGate(exponent=self._exponent, global_shift=self.global_shift).on(q)
return z ** -1, x, z

@property
Expand Down
39 changes: 23 additions & 16 deletions cirq-core/cirq/ops/phased_x_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,11 @@
],
)
def test_phased_x_consistent_protocols(phase_exponent):
# If there is no global_shift, the gate is global phase insensitive.
cirq.testing.assert_implements_consistent_protocols(
cirq.PhasedXPowGate(phase_exponent=phase_exponent, exponent=1.0),
ignoring_global_phase=False,
)
cirq.testing.assert_implements_consistent_protocols(
cirq.PhasedXPowGate(phase_exponent=phase_exponent, exponent=1.0, global_shift=0.1),
ignoring_global_phase=True,
)


Expand Down Expand Up @@ -171,26 +168,36 @@ def test_str_repr():
)


@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once])
def test_parameterize(resolve_fn):
@pytest.mark.parametrize(
'resolve_fn, global_shift', [(cirq.resolve_parameters, 0), (cirq.resolve_parameters_once, 0.1)]
)
def test_parameterize(resolve_fn, global_shift):
parameterized_gate = cirq.PhasedXPowGate(
exponent=sympy.Symbol('a'), phase_exponent=sympy.Symbol('b')
exponent=sympy.Symbol('a'), phase_exponent=sympy.Symbol('b'), global_shift=global_shift
)
assert cirq.pow(parameterized_gate, 5) == cirq.PhasedXPowGate(
exponent=sympy.Symbol('a') * 5, phase_exponent=sympy.Symbol('b')
)
assert (
cirq.decompose_once_with_qubits(parameterized_gate, [cirq.LineQubit(0)], NotImplemented)
is NotImplemented
exponent=sympy.Symbol('a') * 5, phase_exponent=sympy.Symbol('b'), global_shift=global_shift
)
assert cirq.unitary(parameterized_gate, default=None) is None
assert cirq.is_parameterized(parameterized_gate)
q = cirq.NamedQubit("q")
parameterized_decomposed_circuit = cirq.Circuit(cirq.decompose(parameterized_gate(q)))
for resolver in cirq.Linspace('a', 0, 2, 10) * cirq.Linspace('b', 0, 2, 10):
resolved_gate = resolve_fn(parameterized_gate, resolver)
assert resolved_gate == cirq.PhasedXPowGate(
exponent=resolver.value_of('a'),
phase_exponent=resolver.value_of('b'),
global_shift=global_shift,
)
np.testing.assert_allclose(
cirq.unitary(resolved_gate(q)),
cirq.unitary(resolve_fn(parameterized_decomposed_circuit, resolver)),
atol=1e-8,
)

resolver = cirq.ParamResolver({'a': 0.1, 'b': 0.2})
resolved_gate = resolve_fn(parameterized_gate, resolver)
assert resolved_gate == cirq.PhasedXPowGate(exponent=0.1, phase_exponent=0.2)

unparameterized_gate = cirq.PhasedXPowGate(exponent=0.1, phase_exponent=0.2)
unparameterized_gate = cirq.PhasedXPowGate(
exponent=0.1, phase_exponent=0.2, global_shift=global_shift
)
assert not cirq.is_parameterized(unparameterized_gate)
assert cirq.is_parameterized(unparameterized_gate ** sympy.Symbol('a'))
assert cirq.is_parameterized(unparameterized_gate ** (sympy.Symbol('a') + 1))
Expand Down

0 comments on commit 28f90b0

Please sign in to comment.