Skip to content

Commit

Permalink
Implement GlobalPhaseGate (quantumlib#4697)
Browse files Browse the repository at this point in the history
Implements GlobalPhaseOperation in terms of a GateOperation on a new class GlobalPhaseGate.

Mostly involved moving existing functions from the operation to the gate, and then having the operation call those methods under the hood.
  • Loading branch information
daxfohl authored and MichaelBroughton committed Jan 22, 2022
1 parent 9e94b9e commit 185b53a
Show file tree
Hide file tree
Showing 33 changed files with 253 additions and 98 deletions.
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@
generalized_amplitude_damp,
GeneralizedAmplitudeDampingChannel,
givens,
GlobalPhaseGate,
GlobalPhaseOperation,
global_phase_operation,
H,
HPowGate,
I,
Expand Down
8 changes: 4 additions & 4 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ def default_namer(label_entity):
diagram.write(0, i, name)
first_annotation_row = max(label_map.values(), default=0) + 1

if any(isinstance(op.untagged, cirq.GlobalPhaseOperation) for op in self.all_operations()):
if any(isinstance(op.gate, cirq.GlobalPhaseGate) for op in self.all_operations()):
diagram.write(0, max(label_map.values(), default=0) + 1, 'global phase:')
first_annotation_row += 1

Expand Down Expand Up @@ -2359,7 +2359,7 @@ def _get_moment_annotations(
if op.qubits:
continue
op = op.untagged
if isinstance(op, ops.GlobalPhaseOperation):
if isinstance(op.gate, ops.GlobalPhaseGate):
continue
if isinstance(op, CircuitOperation):
for m in op.circuit:
Expand Down Expand Up @@ -2493,8 +2493,8 @@ def _draw_moment_in_diagram(


def _get_global_phase_and_tags_for_op(op: 'cirq.Operation') -> Tuple[Optional[complex], List[Any]]:
if isinstance(op.untagged, ops.GlobalPhaseOperation):
return complex(op.untagged.coefficient), list(op.tags)
if isinstance(op.gate, ops.GlobalPhaseGate):
return complex(op.gate.coefficient), list(op.tags)
elif isinstance(op.untagged, CircuitOperation):
op_phase, op_tags = _get_global_phase_and_tags_for_ops(op.untagged.circuit.all_operations())
return op_phase, list(op.tags) + op_tags
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,11 @@ def test_string_format():
assert str(op0) == f"[ ]"

fc0_global_phase_inner = cirq.FrozenCircuit(
cirq.GlobalPhaseOperation(1j), cirq.GlobalPhaseOperation(1j)
cirq.global_phase_operation(1j), cirq.global_phase_operation(1j)
)
op0_global_phase_inner = cirq.CircuitOperation(fc0_global_phase_inner)
fc0_global_phase_outer = cirq.FrozenCircuit(
op0_global_phase_inner, cirq.GlobalPhaseOperation(1j)
op0_global_phase_inner, cirq.global_phase_operation(1j)
)
op0_global_phase_outer = cirq.CircuitOperation(fc0_global_phase_outer)
assert (
Expand Down
8 changes: 5 additions & 3 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2568,7 +2568,7 @@ def test_diagram_wgate_none_precision(circuit_cls):
@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
def test_diagram_global_phase(circuit_cls):
qa = cirq.NamedQubit('a')
global_phase = cirq.GlobalPhaseOperation(coefficient=1j)
global_phase = cirq.global_phase_operation(coefficient=1j)
c = circuit_cls([global_phase])
cirq.testing.assert_has_diagram(
c, "\n\nglobal phase: 0.5pi", use_unicode_characters=False, precision=2
Expand Down Expand Up @@ -2601,7 +2601,9 @@ def test_diagram_global_phase(circuit_cls):

c = circuit_cls(
cirq.X(cirq.LineQubit(2)),
cirq.CircuitOperation(circuit_cls(cirq.GlobalPhaseOperation(-1).with_tags("tag")).freeze()),
cirq.CircuitOperation(
circuit_cls(cirq.global_phase_operation(-1).with_tags("tag")).freeze()
),
)
cirq.testing.assert_has_diagram(
c,
Expand Down Expand Up @@ -5131,7 +5133,7 @@ def _circuit_diagram_info_(self, args) -> str:
cirq.Moment(
cirq.H(cirq.LineQubit(0)),
CustomOperationAnnotation("a"),
cirq.GlobalPhaseOperation(1j),
cirq.global_phase_operation(1j),
),
),
"""
Expand Down
10 changes: 5 additions & 5 deletions cirq-core/cirq/interop/quirk/cells/scalar_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@


def generate_all_scalar_cell_makers() -> Iterator[CellMaker]:
yield _scalar("NeGate", ops.GlobalPhaseOperation(-1))
yield _scalar("i", ops.GlobalPhaseOperation(1j))
yield _scalar("-i", ops.GlobalPhaseOperation(-1j))
yield _scalar("√i", ops.GlobalPhaseOperation(1j ** 0.5))
yield _scalar("√-i", ops.GlobalPhaseOperation((-1j) ** 0.5))
yield _scalar("NeGate", ops.global_phase_operation(-1))
yield _scalar("i", ops.global_phase_operation(1j))
yield _scalar("-i", ops.global_phase_operation(-1j))
yield _scalar("√i", ops.global_phase_operation(1j ** 0.5))
yield _scalar("√-i", ops.global_phase_operation((-1j) ** 0.5))


def _scalar(identifier: str, operation: 'cirq.Operation') -> CellMaker:
Expand Down
12 changes: 7 additions & 5 deletions cirq-core/cirq/interop/quirk/cells/scalar_cells_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@ def test_scalar_operations():
assert_url_to_circuit_returns('{"cols":[["…"]]}', cirq.Circuit())

assert_url_to_circuit_returns(
'{"cols":[["NeGate"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(-1))
'{"cols":[["NeGate"]]}', cirq.Circuit(cirq.global_phase_operation(-1))
)

assert_url_to_circuit_returns('{"cols":[["i"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(1j)))
assert_url_to_circuit_returns('{"cols":[["i"]]}', cirq.Circuit(cirq.global_phase_operation(1j)))

assert_url_to_circuit_returns('{"cols":[["-i"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(-1j)))
assert_url_to_circuit_returns(
'{"cols":[["-i"]]}', cirq.Circuit(cirq.global_phase_operation(-1j))
)

assert_url_to_circuit_returns(
'{"cols":[["√i"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(1j ** 0.5))
'{"cols":[["√i"]]}', cirq.Circuit(cirq.global_phase_operation(1j ** 0.5))
)

assert_url_to_circuit_returns(
'{"cols":[["√-i"]]}', cirq.Circuit(cirq.GlobalPhaseOperation(1j ** -0.5))
'{"cols":[["√-i"]]}', cirq.Circuit(cirq.global_phase_operation(1j ** -0.5))
)
1 change: 1 addition & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def _parallel_gate_op(gate, qubits):
'GateOperation': cirq.GateOperation,
'Gateset': cirq.Gateset,
'GeneralizedAmplitudeDampingChannel': cirq.GeneralizedAmplitudeDampingChannel,
'GlobalPhaseGate': cirq.GlobalPhaseGate,
'GlobalPhaseOperation': cirq.GlobalPhaseOperation,
'GridInteractionLayer': GridInteractionLayer,
'GridParallelXEBMetadata': GridParallelXEBMetadata,
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/linalg/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def _decompose_(self, qubits):

a, b = qubits
return [
ops.GlobalPhaseOperation(self.global_phase),
ops.global_phase_operation(self.global_phase),
ops.MatrixGate(self.single_qubit_operations_before[0]).on(a),
ops.MatrixGate(self.single_qubit_operations_before[1]).on(b),
np.exp(1j * ops.X(a) * ops.X(b) * self.interaction_coefficients[0]),
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@
)

from cirq.ops.global_phase_op import (
GlobalPhaseGate,
GlobalPhaseOperation,
global_phase_operation,
)

from cirq.ops.kraus_channel import (
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/dense_pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _decompose_(self, qubits):
return NotImplemented
result = [PAULI_GATES[p].on(q) for p, q in zip(self.pauli_mask, qubits) if p]
if self.coefficient != 1:
result.append(global_phase_op.GlobalPhaseOperation(self.coefficient))
result.append(global_phase_op.global_phase_operation(self.coefficient))
return result

def _is_parameterized_(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/diagonal_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
# diagonal gate for sub-system like controlled gate, it is no longer equivalent. Hence,
# we add global phase.
decomposed_circ: List[Any] = [
global_phase_op.GlobalPhaseOperation(np.exp(1j * hat_angles[0]))
global_phase_op.global_phase_operation(np.exp(1j * hat_angles[0]))
]
for i, bit_flip in _gen_gray_code(n):
decomposed_circ.extend(self._decompose_for_basis(i, bit_flip, -hat_angles[i], qubits))
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __repr__(self):

def __str__(self) -> str:
qubits = ', '.join(str(e) for e in self.qubits)
return f'{self.gate}({qubits})'
return f'{self.gate}({qubits})' if qubits else str(self.gate)

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['gate', 'qubits'])
Expand Down
5 changes: 3 additions & 2 deletions cirq-core/cirq/ops/gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,9 @@ def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool
g = item if isinstance(item, raw_types.Gate) else item.gate
assert g is not None, f'`item`: {item} must be a gate or have a valid `item.gate`'

if isinstance(g, global_phase_op.GlobalPhaseGate):
return self._accept_global_phase_op

if g in self._instance_gate_families:
assert item in self._instance_gate_families[g], (
f"{item} instance matches {self._instance_gate_families[g]} but "
Expand Down Expand Up @@ -396,8 +399,6 @@ def _validate_operation(self, op: raw_types.Operation) -> bool:
lambda q: cast(circuit_operation.CircuitOperation, op).qubit_map.get(q, q)
)
return self.validate(op_circuit)
elif isinstance(op, global_phase_op.GlobalPhaseOperation):
return self._accept_global_phase_op
else:
return False

Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ops/gateset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_gate_family_eq():
(cirq.SingleQubitGate(), False),
(cirq.X ** 0.5, False),
(None, False),
(cirq.GlobalPhaseOperation(1j), False),
(cirq.global_phase_operation(1j), False),
],
),
(
Expand All @@ -144,7 +144,7 @@ def test_gate_family_eq():
(CustomX ** 3, True),
(CustomX ** sympy.Symbol('theta'), False),
(None, False),
(cirq.GlobalPhaseOperation(1j), False),
(cirq.global_phase_operation(1j), False),
],
),
(
Expand Down Expand Up @@ -255,7 +255,7 @@ def get_ops(use_circuit_op, use_global_phase):
)
yield [circuit_op, recursive_circuit_op]
if use_global_phase:
yield cirq.GlobalPhaseOperation(1j)
yield cirq.global_phase_operation(1j)

def assert_validate_and_contains_consistent(gateset, op_tree, result):
assert all(op in gateset for op in cirq.flatten_to_ops(op_tree)) is result
Expand Down
65 changes: 51 additions & 14 deletions cirq-core/cirq/ops/global_phase_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,69 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""A no-qubit global phase operation."""
from typing import Any, Dict, Tuple, TYPE_CHECKING
from typing import Any, Dict, Sequence, Tuple, TYPE_CHECKING

import numpy as np

from cirq import value, protocols
from cirq.ops import raw_types
from cirq._compat import deprecated_class
from cirq.ops import gate_operation, raw_types

if TYPE_CHECKING:
import cirq


@value.value_equality(approximate=True)
class GlobalPhaseOperation(raw_types.Operation):
@deprecated_class(deadline='v0.16', fix='Use cirq.global_phase_operation')
class GlobalPhaseOperation(gate_operation.GateOperation):
def __init__(self, coefficient: value.Scalar, atol: float = 1e-8) -> None:
if abs(1 - abs(coefficient)) > atol:
raise ValueError(f'Coefficient is not unitary: {coefficient!r}')
self.coefficient = coefficient

@property
def qubits(self) -> Tuple['cirq.Qid', ...]:
return ()
gate = GlobalPhaseGate(coefficient, atol)
super().__init__(gate, [])

def with_qubits(self, *new_qubits) -> 'GlobalPhaseOperation':
if new_qubits:
raise ValueError(f'{self!r} applies to 0 qubits but new_qubits={new_qubits!r}.')
return self

@property
def coefficient(self) -> value.Scalar:
return self.gate.coefficient # type: ignore

@coefficient.setter
def coefficient(self, coefficient: value.Scalar):
# coverage: ignore
self.gate._coefficient = coefficient # type: ignore

def __str__(self) -> str:
return str(self.coefficient)

def __repr__(self) -> str:
return f'cirq.GlobalPhaseOperation({self.coefficient!r})'

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['coefficient'])


@value.value_equality(approximate=True)
class GlobalPhaseGate(raw_types.Gate):
def __init__(self, coefficient: value.Scalar, atol: float = 1e-8) -> None:
if abs(1 - abs(coefficient)) > atol:
raise ValueError(f'Coefficient is not unitary: {coefficient!r}')
self._coefficient = coefficient

@property
def coefficient(self) -> value.Scalar:
return self._coefficient

def _value_equality_values_(self) -> Any:
return self.coefficient

def _has_unitary_(self) -> bool:
return True

def __pow__(self, power):
def __pow__(self, power) -> 'cirq.GlobalPhaseGate':
if isinstance(power, (int, float)):
return GlobalPhaseOperation(self.coefficient ** power)
return GlobalPhaseGate(self.coefficient ** power)
return NotImplemented

def _unitary_(self) -> np.ndarray:
Expand All @@ -60,7 +87,7 @@ def _apply_unitary_(self, args) -> np.ndarray:
def _has_stabilizer_effect_(self) -> bool:
return True

def _act_on_(self, args: 'cirq.ActOnArgs'):
def _act_on_(self, args: 'cirq.ActOnArgs', qubits):
from cirq.sim import clifford

if isinstance(args, clifford.ActOnCliffordTableauArgs):
Expand All @@ -78,7 +105,17 @@ def __str__(self) -> str:
return str(self.coefficient)

def __repr__(self) -> str:
return f'cirq.GlobalPhaseOperation({self.coefficient!r})'
return f'cirq.GlobalPhaseGate({self.coefficient!r})'

def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
return f'cirq.global_phase_operation({self.coefficient!r})'

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['coefficient'])

def _qid_shape_(self) -> Tuple[int, ...]:
return tuple()


def global_phase_operation(coefficient: value.Scalar, atol: float = 1e-8) -> 'cirq.GateOperation':
return GlobalPhaseGate(coefficient, atol)()
Loading

0 comments on commit 185b53a

Please sign in to comment.